diff --git a/internal/shader/expr.go b/internal/shader/expr.go index c55a52810..a5b4ee2fa 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -34,34 +34,6 @@ func canTruncateToFloat(v gconstant.Value) bool { return gconstant.ToFloat(v).Kind() != gconstant.Unknown } -func isValidForModOp(lhs, rhs *shaderir.Expr, lhst, rhst shaderir.Type) bool { - isInt := func(s *shaderir.Expr, t shaderir.Type) bool { - if t.Main == shaderir.Int { - return true - } - if s.Const == nil { - return false - } - if s.Const.Kind() == gconstant.Int { - return true - } - if canTruncateToInteger(s.Const) { - return true - } - return false - } - - if isInt(lhs, lhst) { - return isInt(rhs, rhst) - } - - if lhst.Main == shaderir.IVec2 || lhst.Main == shaderir.IVec3 || lhst.Main == shaderir.IVec4 { - return lhst.Equal(&rhst) || isInt(rhs, rhst) - } - - return false -} - var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { @@ -227,32 +199,12 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar t = rhst case rhst.Main == shaderir.Float: t = lhst - case op2 == shaderir.MatrixMul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 || - lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2): - t = shaderir.Type{Main: shaderir.Vec2} - case op2 == shaderir.MatrixMul && (lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 || - lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3): - t = shaderir.Type{Main: shaderir.Vec3} - case op2 == shaderir.MatrixMul && (lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 || - lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4): - t = shaderir.Type{Main: shaderir.Vec4} + case op2 == shaderir.MatrixMul && lhst.IsVector() && rhst.IsMatrix(): + t = lhst + case op2 == shaderir.MatrixMul && lhst.IsMatrix() && rhst.IsVector(): + t = rhst default: - cs.addError(e.Pos(), fmt.Sprintf("invalid expression: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - - // For `%`, both types must be deducible to integers. - if op2 == shaderir.ModOp { - if !isValidForModOp(&lhs[0], &rhs[0], lhst, rhst) { - var wrongType shaderir.Type - if lhst.Main != shaderir.Int { - wrongType = lhst - } else { - wrongType = rhst - } - cs.addError(e.Pos(), fmt.Sprintf("invalid operation: operator %% not defined on %s", wrongType.String())) - return nil, nil, nil, false - } + panic(fmt.Sprintf("shaderir: invalid expression: %s %s %s", lhst.String(), e.Op, rhst.String())) } return []shaderir.Expr{ diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index 66c4d4f60..ebb899264 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -69,25 +69,6 @@ func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) ( } func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { - if op == AndAnd || op == OrOr { - return lhst.Main == Bool && rhst.Main == Bool - } - - if op == VectorEqualOp || op == VectorNotEqualOp { - return lhst.IsVector() && rhst.IsVector() && lhst.Equal(&rhst) - } - - // Comparing matrices are forbidden (#2187). - if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp || op == EqualOp || op == NotEqualOp { - if lhst.IsMatrix() || rhst.IsMatrix() { - return false - } - } - - if op == Div && rhst.IsMatrix() { - return false - } - // If both are untyped consts, compare the constants and try to truncate them if necessary. if lhst.Main == None && rhst.Main == None { // Assume that the constant types are already adjusted. @@ -95,6 +76,10 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { panic("shaderir: const types for a binary op must be adjusted") } + if op == AndAnd || op == OrOr { + return lhs.Const.Kind() == constant.Bool && rhs.Const.Kind() == constant.Bool + } + // For %, both operands must be integers if both are constants. Truncatable to an integer is not enough. if op == ModOp { return lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int @@ -107,25 +92,73 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { panic("shaderir: cannot resolve untyped values") } + if op == AndAnd || op == OrOr { + return lhst.Main == Bool && rhst.Main == Bool + } + + if op == VectorEqualOp || op == VectorNotEqualOp { + return lhst.IsVector() && rhst.IsVector() && lhst.Equal(&rhst) + } + + // Comparing matrices are forbidden (#2187). + if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp || op == EqualOp || op == NotEqualOp { + if lhst.IsMatrix() || rhst.IsMatrix() { + return false + } + return lhst.Equal(&rhst) + } + + if op == Div && rhst.IsMatrix() { + return false + } + + if op == ModOp { + if lhst.Main == IVec2 && rhst.Main == IVec2 { + return true + } + if lhst.Main == IVec3 && rhst.Main == IVec3 { + return true + } + if lhst.Main == IVec4 && rhst.Main == IVec4 { + return true + } + return (lhst.Main == Int || lhst.isIntVector()) && rhst.Main == Int + } + if lhst.Equal(&rhst) { return true } if op == MatrixMul { - if lhst.IsMatrix() && (rhst.isFloatVector() || rhst.Main == Float) { - // TODO: Check dimensions + if lhst.IsMatrix() && rhst.Main == Float { return true } - if rhst.IsMatrix() && (lhst.isFloatVector() || lhst.Main == Float) { - // TODO: Check dimensions + if lhst.Main == Mat2 && rhst.Main == Vec2 { + return true + } + if lhst.Main == Mat3 && rhst.Main == Vec3 { + return true + } + if lhst.Main == Mat4 && rhst.Main == Vec4 { + return true + } + if lhst.Main == Float && rhst.IsMatrix() { + return true + } + if lhst.Main == Vec2 && rhst.Main == Mat2 { + return true + } + if lhst.Main == Vec3 && rhst.Main == Mat3 { + return true + } + if lhst.Main == Vec4 && rhst.Main == Mat4 { return true } return false } if op == Div { - if lhst.IsMatrix() && (rhst.isFloatVector() || rhst.Main == Float) { - // TODO: Check dimensions + if lhst.IsMatrix() && rhst.Main == Float { return true } // fallback