diff --git a/internal/shader/expr.go b/internal/shader/expr.go index f94af87f7..fc118dc7c 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -128,6 +128,37 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar return nil, nil, nil, false } + // If both are consts, adjust the types. + if lhs[0].Const != nil && rhs[0].Const != nil && lhs[0].Const.Kind() != rhs[0].Const.Kind() { + l, r, ok := shaderir.AdjustConstTypesForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) + if !ok { + // TODO: Show a better type name for untyped constants. + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) + return nil, nil, nil, false + } + lhs[0].Const, rhs[0].Const = l, r + + // TODO: Remove this (#2550) + switch lhs[0].Const.Kind() { + case gconstant.Float: + lhs[0].ConstType = shaderir.ConstTypeFloat + case gconstant.Int: + lhs[0].ConstType = shaderir.ConstTypeInt + case gconstant.Bool: + lhs[0].ConstType = shaderir.ConstTypeBool + } + switch rhs[0].Const.Kind() { + case gconstant.Float: + rhs[0].ConstType = shaderir.ConstTypeFloat + case gconstant.Int: + rhs[0].ConstType = shaderir.ConstTypeInt + case gconstant.Bool: + rhs[0].ConstType = shaderir.ConstTypeBool + } + } + + // TODO: Integrate AreValidTypesForBinaryOp calls to here. + if lhs[0].Const != nil && rhs[0].Const != nil { if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) { // TODO: Show a better type name for untyped constants. diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index 37580a3a2..2133e7ff5 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -18,6 +18,63 @@ import ( "go/constant" ) +func AdjustConstTypesForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (constant.Value, constant.Value, bool) { + if lhst.Main == None && rhst.Main == None { + if lhs.Kind() == rhs.Kind() { + return lhs, rhs, true + } + if lhs.Kind() == constant.Float && constant.ToFloat(rhs).Kind() != constant.Unknown { + return lhs, constant.ToFloat(rhs), true + } + if rhs.Kind() == constant.Float && constant.ToFloat(lhs).Kind() != constant.Unknown { + return constant.ToFloat(lhs), rhs, true + } + if lhs.Kind() == constant.Int && constant.ToInt(rhs).Kind() != constant.Unknown { + return lhs, constant.ToInt(rhs), true + } + if rhs.Kind() == constant.Int && constant.ToInt(lhs).Kind() != constant.Unknown { + return constant.ToInt(lhs), rhs, true + } + return nil, nil, false + } + + if lhst.Equal(&rhst) { + if lhs.Kind() == rhs.Kind() { + return lhs, rhs, true + } + // TODO: When to reach this? + return nil, nil, false + } + + if lhst.Main == None { + if rhst.Main == Float && constant.ToFloat(lhs).Kind() != constant.Unknown { + return constant.ToFloat(lhs), rhs, true + } + if rhst.Main == Int && constant.ToInt(lhs).Kind() != constant.Unknown { + return constant.ToInt(lhs), rhs, true + } + if rhst.Main == Bool && lhs.Kind() == constant.Bool { + return lhs, rhs, true + } + return nil, nil, false + } + + if rhst.Main == None { + if lhst.Main == Float && constant.ToFloat(rhs).Kind() != constant.Unknown { + return lhs, constant.ToFloat(rhs), true + } + if lhst.Main == Int && constant.ToInt(rhs).Kind() != constant.Unknown { + return lhs, constant.ToInt(rhs), true + } + if lhst.Main == Bool && rhs.Kind() == constant.Bool { + return lhs, rhs, true + } + return nil, nil, false + } + + return nil, nil, false +} + func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { if op == AndAnd || op == OrOr { return lhst.Main == Bool && rhst.Main == Bool @@ -36,26 +93,16 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { // If both are untyped consts, compare the constants and try to truncate them if necessary. if lhst.Main == None && rhst.Main == None { + // Assume that the constant types are already adjusted. + if lhs.Const.Kind() != rhs.Const.Kind() { + panic("shaderir: const types for a binary op must be adjusted") + } + // For %, both operands must be integers if both are constants. Truncatable to an integer is not enough. if op == ModOp { return lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int } - if lhs.Const.Kind() == rhs.Const.Kind() { - return true - } - if lhs.Const.Kind() == constant.Float && constant.ToFloat(rhs.Const).Kind() != constant.Unknown { - return true - } - if rhs.Const.Kind() == constant.Float && constant.ToFloat(lhs.Const).Kind() != constant.Unknown { - return true - } - if lhs.Const.Kind() == constant.Int && constant.ToInt(rhs.Const).Kind() != constant.Unknown { - return true - } - if rhs.Const.Kind() == constant.Int && constant.ToInt(lhs.Const).Kind() != constant.Unknown { - return true - } - return false + return true } // If the types match, that's fine.