internal/shader: refactoring: move type deduction to shaderir package

Updates #2754
This commit is contained in:
Hajime Hoshi 2023-09-13 00:15:17 +09:00
parent 5e30e1ee1d
commit 19413c2805
2 changed files with 84 additions and 91 deletions

View File

@ -134,7 +134,8 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
} }
if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) { t, ok := shaderir.TypeFromBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst)
if !ok {
// TODO: Show a better type name for untyped constants. // TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
@ -142,32 +143,14 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
if lhs[0].Const != nil && rhs[0].Const != nil { if lhs[0].Const != nil && rhs[0].Const != nil {
var v gconstant.Value var v gconstant.Value
var t shaderir.Type
switch op { switch op {
case token.LAND, token.LOR: case token.LAND, token.LOR:
b := gconstant.BoolVal(gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)) b := gconstant.BoolVal(gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const))
v = gconstant.MakeBool(b) v = gconstant.MakeBool(b)
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
t = shaderir.Type{Main: shaderir.Bool}
}
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ: case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ:
v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const))
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
t = shaderir.Type{Main: shaderir.Bool}
}
default: default:
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
switch {
case lhst.Main == shaderir.Float || rhst.Main == shaderir.Float:
t = shaderir.Type{Main: shaderir.Float}
case lhst.Main == shaderir.Int || rhst.Main == shaderir.Int:
t = shaderir.Type{Main: shaderir.Int}
case lhst.Main == shaderir.Bool || rhst.Main == shaderir.Bool:
t = shaderir.Type{Main: shaderir.Bool}
default:
// If both operands are untyped, keep untyped.
t = shaderir.Type{}
}
} }
return []shaderir.Expr{ return []shaderir.Expr{
@ -178,38 +161,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{t}, stmts, true
} }
var t shaderir.Type
switch {
case op2 == shaderir.LessThanOp || op2 == shaderir.LessThanEqualOp || op2 == shaderir.GreaterThanOp || op2 == shaderir.GreaterThanEqualOp || op2 == shaderir.EqualOp || op2 == shaderir.NotEqualOp || op2 == shaderir.VectorEqualOp || op2 == shaderir.VectorNotEqualOp || op2 == shaderir.AndAnd || op2 == shaderir.OrOr:
t = shaderir.Type{Main: shaderir.Bool}
case lhs[0].Const != nil && rhs[0].Const == nil:
t = rhst
case lhs[0].Const == nil && rhs[0].Const != nil:
t = lhst
case lhst.Equal(&rhst):
t = lhst
case op2 == shaderir.MatrixMul && lhst.Main == shaderir.Float && rhst.IsMatrix():
t = rhst
case op2 == shaderir.MatrixMul && lhst.IsFloatVector() && rhst.IsMatrix():
t = lhst
case op2 == shaderir.MatrixMul && lhst.IsMatrix() && rhst.Main == shaderir.Float:
t = lhst
case op2 == shaderir.MatrixMul && lhst.IsMatrix() && rhst.IsFloatVector():
t = rhst
case op2 == shaderir.Div && lhst.IsMatrix() && rhst.Main == shaderir.Float:
t = lhst
case lhst.Main == shaderir.Float && rhst.IsFloatVector():
t = rhst
case lhst.Main == shaderir.Int && rhst.IsIntVector():
t = rhst
case lhst.IsFloatVector() && rhst.Main == shaderir.Float:
t = lhst
case lhst.IsIntVector() && rhst.Main == shaderir.Int:
t = lhst
default:
panic(fmt.Sprintf("shaderir: invalid expression: %s %s %s", lhst.String(), e.Op, rhst.String()))
}
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.Binary, Type: shaderir.Binary,

View File

@ -68,7 +68,7 @@ func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (
return lhs, rhs, true return lhs, rhs, true
} }
func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { func TypeFromBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) (Type, bool) {
// If both are untyped consts, compare the constants and try to truncate them if necessary. // If both are untyped consts, compare the constants and try to truncate them if necessary.
if lhst.Main == None && rhst.Main == None { if lhst.Main == None && rhst.Main == None {
// Assume that the constant types are already adjusted. // Assume that the constant types are already adjusted.
@ -77,19 +77,43 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
} }
if op == AndAnd || op == OrOr { if op == AndAnd || op == OrOr {
return lhs.Const.Kind() == constant.Bool && rhs.Const.Kind() == constant.Bool if lhs.Const.Kind() == constant.Bool && rhs.Const.Kind() == constant.Bool {
return Type{Main: Bool}, true
}
return Type{}, false
} }
// For %, both operands must be integers if both are constants. Truncatable to an integer is not enough. // For %, both operands must be integers if both are constants. Truncatable to an integer is not enough.
if op == ModOp { if op == ModOp {
return lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int if lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int {
return Type{Main: Int}, true
}
return Type{}, false
} }
if op == And || op == Or || op == Xor { if op == And || op == Or || op == Xor {
return lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int if lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int {
return Type{Main: Int}, true
}
return Type{}, false
} }
return true if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return Type{Main: Bool}, true
}
if lhst.Main == Float || rhst.Main == Float {
return Type{Main: Float}, true
}
if lhst.Main == Int || rhst.Main == Int {
return Type{Main: Int}, true
}
if lhst.Main == Bool || rhst.Main == Bool {
return Type{Main: Bool}, true
}
// If both operands are untyped, keep untyped.
return Type{}, true
} }
// Both types must not be untyped. // Both types must not be untyped.
@ -98,115 +122,133 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
} }
if op == AndAnd || op == OrOr { if op == AndAnd || op == OrOr {
return lhst.Main == Bool && rhst.Main == Bool if lhst.Main == Bool && rhst.Main == Bool {
return Type{Main: Bool}, true
}
return Type{}, false
} }
if op == VectorEqualOp || op == VectorNotEqualOp { if op == VectorEqualOp || op == VectorNotEqualOp {
return (lhst.IsFloatVector() || lhst.IsIntVector()) && (rhst.IsFloatVector() || lhst.IsIntVector()) && lhst.Equal(&rhst) if (lhst.IsFloatVector() || lhst.IsIntVector()) && (rhst.IsFloatVector() || lhst.IsIntVector()) && lhst.Equal(&rhst) {
return Type{Main: Bool}, true
}
return Type{}, false
} }
if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return (lhst.Main == Int && rhst.Main == Int) || (lhst.Main == Float && rhst.Main == Float) if (lhst.Main == Int && rhst.Main == Int) || (lhst.Main == Float && rhst.Main == Float) {
return Type{Main: Bool}, true
}
return Type{}, false
} }
// Comparing matrices are forbidden (#2187). // Comparing matrices are forbidden (#2187).
if op == EqualOp || op == NotEqualOp { if op == EqualOp || op == NotEqualOp {
if lhst.IsMatrix() || rhst.IsMatrix() { if lhst.IsMatrix() || rhst.IsMatrix() {
return false return Type{}, false
} }
return lhst.Equal(&rhst) if lhst.Equal(&rhst) {
return Type{Main: Bool}, true
}
return Type{}, false
} }
if op == Div && rhst.IsMatrix() { if op == Div && rhst.IsMatrix() {
return false return Type{}, false
} }
if op == ModOp { if op == ModOp {
if lhst.Main == IVec2 && rhst.Main == IVec2 { if lhst.Main == IVec2 && rhst.Main == IVec2 {
return true return Type{Main: IVec2}, true
} }
if lhst.Main == IVec3 && rhst.Main == IVec3 { if lhst.Main == IVec3 && rhst.Main == IVec3 {
return true return Type{Main: IVec3}, true
} }
if lhst.Main == IVec4 && rhst.Main == IVec4 { if lhst.Main == IVec4 && rhst.Main == IVec4 {
return true return Type{Main: IVec4}, true
} }
return (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
return Type{}, false
} }
if op == And || op == Or || op == Xor { if op == And || op == Or || op == Xor {
if lhst.Main == Int && rhst.Main == Int { if lhst.Main == Int && rhst.Main == Int {
return true return Type{Main: Int}, true
} }
if lhst.Main == IVec2 && rhst.Main == IVec2 { if lhst.Main == IVec2 && rhst.Main == IVec2 {
return true return Type{Main: IVec2}, true
} }
if lhst.Main == IVec3 && rhst.Main == IVec3 { if lhst.Main == IVec3 && rhst.Main == IVec3 {
return true return Type{Main: IVec3}, true
} }
if lhst.Main == IVec4 && rhst.Main == IVec4 { if lhst.Main == IVec4 && rhst.Main == IVec4 {
return true return Type{Main: IVec4}, true
} }
if lhst.IsIntVector() && rhst.Main == Int { if lhst.IsIntVector() && rhst.Main == Int {
return true return lhst, true
} }
if lhst.Main == Int && rhst.IsIntVector() { if lhst.Main == Int && rhst.IsIntVector() {
return true return rhst, true
} }
return false return Type{}, false
} }
if lhst.Equal(&rhst) { if lhst.Equal(&rhst) {
return true if lhst.Main == None {
return rhst, true
}
return lhst, true
} }
if op == MatrixMul { if op == MatrixMul {
if lhst.IsMatrix() && rhst.Main == Float { if lhst.IsMatrix() && rhst.Main == Float {
return true return lhst, true
} }
if lhst.Main == Mat2 && rhst.Main == Vec2 { if lhst.Main == Mat2 && rhst.Main == Vec2 {
return true return rhst, true
} }
if lhst.Main == Mat3 && rhst.Main == Vec3 { if lhst.Main == Mat3 && rhst.Main == Vec3 {
return true return rhst, true
} }
if lhst.Main == Mat4 && rhst.Main == Vec4 { if lhst.Main == Mat4 && rhst.Main == Vec4 {
return true return rhst, true
} }
if lhst.Main == Float && rhst.IsMatrix() { if lhst.Main == Float && rhst.IsMatrix() {
return true return rhst, true
} }
if lhst.Main == Vec2 && rhst.Main == Mat2 { if lhst.Main == Vec2 && rhst.Main == Mat2 {
return true return lhst, true
} }
if lhst.Main == Vec3 && rhst.Main == Mat3 { if lhst.Main == Vec3 && rhst.Main == Mat3 {
return true return lhst, true
} }
if lhst.Main == Vec4 && rhst.Main == Mat4 { if lhst.Main == Vec4 && rhst.Main == Mat4 {
return true return lhst, true
} }
return false return Type{}, false
} }
if op == Div { if op == Div {
if lhst.IsMatrix() && rhst.Main == Float { if lhst.IsMatrix() && rhst.Main == Float {
return true return lhst, true
} }
// fallback // fallback
} }
if lhst.IsFloatVector() && rhst.Main == Float { if lhst.IsFloatVector() && rhst.Main == Float {
return true return lhst, true
} }
if rhst.IsFloatVector() && lhst.Main == Float { if lhst.Main == Float && rhst.IsFloatVector() {
return true return rhst, true
} }
if lhst.IsIntVector() && rhst.Main == Int { if lhst.IsIntVector() && rhst.Main == Int {
return true return lhst, true
} }
if rhst.IsIntVector() && lhst.Main == Int { if lhst.Main == Int && rhst.IsIntVector() {
return true return rhst, true
} }
return false return Type{}, false
} }