diff --git a/internal/shader/expr.go b/internal/shader/expr.go index fc118dc7c..a72b63520 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -128,32 +128,44 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar return nil, nil, nil, false } - // If both are consts, adjust the types. - if lhs[0].Const != nil && rhs[0].Const != nil && lhs[0].Const.Kind() != rhs[0].Const.Kind() { - l, r, ok := shaderir.AdjustConstTypesForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) - if !ok { - // TODO: Show a better type name for untyped constants. - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) - return nil, nil, nil, false - } - lhs[0].Const, rhs[0].Const = l, r + // Resolve untyped constants. + l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) + if !ok { + // TODO: Show a better type name for untyped constants. + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) + return nil, nil, nil, false + } + lhs[0].Const, rhs[0].Const = l, r - // TODO: Remove this (#2550) - switch lhs[0].Const.Kind() { - case gconstant.Float: - lhs[0].ConstType = shaderir.ConstTypeFloat - case gconstant.Int: - lhs[0].ConstType = shaderir.ConstTypeInt - case gconstant.Bool: - lhs[0].ConstType = shaderir.ConstTypeBool + // If either is typed, resolve the other type. + // If both are untyped, keep them untyped. + if lhst.Main != shaderir.None || rhst.Main != shaderir.None { + // TODO: Remove ConstType (#2550) + if lhs[0].Const != nil { + switch lhs[0].Const.Kind() { + case gconstant.Float: + lhst = shaderir.Type{Main: shaderir.Float} + lhs[0].ConstType = shaderir.ConstTypeFloat + case gconstant.Int: + lhst = shaderir.Type{Main: shaderir.Int} + lhs[0].ConstType = shaderir.ConstTypeInt + case gconstant.Bool: + lhst = shaderir.Type{Main: shaderir.Bool} + lhs[0].ConstType = shaderir.ConstTypeBool + } } - switch rhs[0].Const.Kind() { - case gconstant.Float: - rhs[0].ConstType = shaderir.ConstTypeFloat - case gconstant.Int: - rhs[0].ConstType = shaderir.ConstTypeInt - case gconstant.Bool: - rhs[0].ConstType = shaderir.ConstTypeBool + if rhs[0].Const != nil { + switch rhs[0].Const.Kind() { + case gconstant.Float: + rhst = shaderir.Type{Main: shaderir.Float} + rhs[0].ConstType = shaderir.ConstTypeFloat + case gconstant.Int: + rhst = shaderir.Type{Main: shaderir.Int} + rhs[0].ConstType = shaderir.ConstTypeInt + case gconstant.Bool: + rhst = shaderir.Type{Main: shaderir.Bool} + rhs[0].ConstType = shaderir.ConstTypeBool + } } } @@ -227,17 +239,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } fallthrough case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: - if lhs[0].ConstType == shaderir.ConstTypeInt { + if lhst.Main != shaderir.Float { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } case shaderir.IVec2, shaderir.IVec3, shaderir.IVec4: - if !canTruncateToInteger(lhs[0].Const) { + if lhst.Main != shaderir.Int { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } case shaderir.Int: - if !canTruncateToInteger(lhs[0].Const) { + if lhst.Main != shaderir.Int { cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", lhs[0].Const.String())) return nil, nil, nil, false } @@ -253,12 +265,12 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } fallthrough case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: - if rhs[0].ConstType == shaderir.ConstTypeInt { + if rhst.Main != shaderir.Float { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } case shaderir.IVec2, shaderir.IVec3, shaderir.IVec4: - if !canTruncateToInteger(rhs[0].Const) { + if rhst.Main != shaderir.Int { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } @@ -266,7 +278,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar return nil, nil, nil, false } case shaderir.Int: - if !canTruncateToInteger(rhs[0].Const) { + if rhst.Main != shaderir.Int { cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", rhs[0].Const.String())) return nil, nil, nil, false } diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index 2133e7ff5..1772749f5 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -18,7 +18,7 @@ import ( "go/constant" ) -func AdjustConstTypesForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (constant.Value, constant.Value, bool) { +func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { if lhst.Main == None && rhst.Main == None { if lhs.Kind() == rhs.Kind() { return lhs, rhs, true @@ -38,19 +38,11 @@ func AdjustConstTypesForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (cons return nil, nil, false } - if lhst.Equal(&rhst) { - if lhs.Kind() == rhs.Kind() { - return lhs, rhs, true - } - // TODO: When to reach this? - return nil, nil, false - } - if lhst.Main == None { - if rhst.Main == Float && constant.ToFloat(lhs).Kind() != constant.Unknown { + if (rhst.Main == Float || rhst.isFloatVector() || rhst.IsMatrix()) && constant.ToFloat(lhs).Kind() != constant.Unknown { return constant.ToFloat(lhs), rhs, true } - if rhst.Main == Int && constant.ToInt(lhs).Kind() != constant.Unknown { + if (rhst.Main == Int || rhst.isIntVector()) && constant.ToInt(lhs).Kind() != constant.Unknown { return constant.ToInt(lhs), rhs, true } if rhst.Main == Bool && lhs.Kind() == constant.Bool { @@ -60,10 +52,10 @@ func AdjustConstTypesForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (cons } if rhst.Main == None { - if lhst.Main == Float && constant.ToFloat(rhs).Kind() != constant.Unknown { + if (lhst.Main == Float || lhst.isFloatVector() || lhst.IsMatrix()) && constant.ToFloat(rhs).Kind() != constant.Unknown { return lhs, constant.ToFloat(rhs), true } - if lhst.Main == Int && constant.ToInt(rhs).Kind() != constant.Unknown { + if (lhst.Main == Int || lhst.isIntVector()) && constant.ToInt(rhs).Kind() != constant.Unknown { return lhs, constant.ToInt(rhs), true } if lhst.Main == Bool && rhs.Kind() == constant.Bool { @@ -72,7 +64,8 @@ func AdjustConstTypesForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (cons return nil, nil, false } - return nil, nil, false + // lhst and rhst might not match, but this has nothing to do with resolving untyped consts. + return lhs, rhs, true } func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { @@ -105,45 +98,10 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { return true } - // If the types match, that's fine. - if lhst.Equal(&rhst) { - return true + // Both types must not be untyped. + if lhst.Main == None || rhst.Main == None { + panic("shaderir: cannot resolve untyped values") } - // If lhs is untyped and rhs is not, compare the constant and the type and try to truncate the constant if necessary. - if lhst.Main == None { - // For %, if only one of the operands is a constant, try to truncate it. - if op == ModOp { - return constant.ToInt(lhs.Const).Kind() != constant.Unknown && rhst.Main == Int - } - if rhst.Main == Float { - return constant.ToFloat(lhs.Const).Kind() != constant.Unknown - } - if rhst.Main == Int { - return constant.ToInt(lhs.Const).Kind() != constant.Unknown - } - if rhst.Main == Bool { - return lhs.Const.Kind() == constant.Bool - } - return false - } - - // Ditto. - if rhst.Main == None { - if op == ModOp { - return constant.ToInt(rhs.Const).Kind() != constant.Unknown && lhst.Main == Int - } - if lhst.Main == Float { - return constant.ToFloat(rhs.Const).Kind() != constant.Unknown - } - if lhst.Main == Int { - return constant.ToInt(rhs.Const).Kind() != constant.Unknown - } - if lhst.Main == Bool { - return rhs.Const.Kind() == constant.Bool - } - return false - } - - return false + return lhst.Equal(&rhst) } diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index c13b0bfb3..76733a70b 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -126,6 +126,22 @@ func (t *Type) IsVector() bool { return false } +func (t *Type) isFloatVector() bool { + switch t.Main { + case Vec2, Vec3, Vec4: + return true + } + return false +} + +func (t *Type) isIntVector() bool { + switch t.Main { + case IVec2, IVec3, IVec4: + return true + } + return false +} + func (t *Type) VectorElementCount() int { switch t.Main { case Vec2: