diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index af4a17bb4..658f8a75c 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -92,19 +92,16 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP op = shaderir.ModOp } - // Treat an integer literal as an integer constant value. - wasTypedConstInt := rhs[0].ConstType == shaderir.ConstTypeInt - if rhs[0].Type == shaderir.NumberExpr && rts[0].Main == shaderir.Int { - if !cs.forceToInt(stmt, &rhs[0]) { - return nil, false - } - } - if lts[0].Main == rts[0].Main { - if op == shaderir.Div && (rts[0].Main == shaderir.Mat2 || rts[0].Main == shaderir.Mat3 || rts[0].Main == shaderir.Mat4) { + if op == shaderir.Div && rts[0].IsMatrix() { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String())) return nil, false } + if lts[0].Main == shaderir.Int && rhs[0].Const != nil { + if !cs.forceToInt(stmt, &rhs[0]) { + return nil, false + } + } } else { switch lts[0].Main { case shaderir.Int: @@ -112,7 +109,9 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } case shaderir.Float: - if rhs[0].Const != nil && rhs[0].Const.Kind() == gconstant.Int && !wasTypedConstInt { + if rhs[0].Const != nil && + rhs[0].ConstType != shaderir.ConstTypeInt && + gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown { rhs[0].Const = gconstant.ToFloat(rhs[0].Const) rhs[0].ConstType = shaderir.ConstTypeFloat } else { @@ -122,16 +121,16 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: if (op == shaderir.MatrixMul || op == shaderir.Div) && rts[0].Main == shaderir.Float { // OK - } else if (lts[0].Main == shaderir.Vec2 || - lts[0].Main == shaderir.Vec3 || - lts[0].Main == shaderir.Vec4) && - rts[0].Main == shaderir.Float { + } else if lts[0].IsVector() && rts[0].Main == shaderir.Float { // OK } else if op == shaderir.MatrixMul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || (lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) || (lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4)) { // OK - } else if rhs[0].Const != nil && rhs[0].Const.Kind() == gconstant.Int && !wasTypedConstInt { + } else if (op == shaderir.MatrixMul || op == shaderir.ComponentWiseMul || lts[0].IsVector()) && + rhs[0].Const != nil && + rhs[0].ConstType != shaderir.ConstTypeInt && + gconstant.ToFloat(rhs[0].Const).Kind() != gconstant.Unknown { rhs[0].Const = gconstant.ToFloat(rhs[0].Const) rhs[0].ConstType = shaderir.ConstTypeFloat } else { diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index c0c3c44a3..759367b0d 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1124,6 +1124,8 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) { }{ {stmt: "a := 1.0; a *= 2", err: false}, {stmt: "a := 1.0; a *= 2.0", err: false}, + {stmt: "const c = 2; a := 1.0; a *= c", err: false}, + {stmt: "const c = 2.0; a := 1.0; a *= c", err: false}, {stmt: "a := 1.0; a *= int(2)", err: true}, {stmt: "a := 1.0; a *= vec2(2)", err: true}, {stmt: "a := 1.0; a *= vec3(2)", err: true}, @@ -1133,6 +1135,8 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) { {stmt: "a := 1.0; a *= mat4(2)", err: true}, {stmt: "a := vec2(1); a *= 2", err: false}, {stmt: "a := vec2(1); a *= 2.0", err: false}, + {stmt: "const c = 2; a := vec2(1); a *= c", err: false}, + {stmt: "const c = 2.0; a := vec2(1); a *= c", err: false}, {stmt: "a := vec2(1); a /= 2.0", err: false}, {stmt: "a := vec2(1); a += 2.0", err: false}, {stmt: "a := vec2(1); a *= int(2)", err: true}, @@ -1149,6 +1153,8 @@ func TestSyntaxOperatorMultiplyAssign(t *testing.T) { {stmt: "a := vec2(1); a *= mat4(2)", err: true}, {stmt: "a := mat2(1); a *= 2", err: false}, {stmt: "a := mat2(1); a *= 2.0", err: false}, + {stmt: "const c = 2; a := mat2(1); a *= c", err: false}, + {stmt: "const c = 2.0; a := mat2(1); a *= c", err: false}, {stmt: "a := mat2(1); a /= 2.0", err: false}, {stmt: "a := mat2(1); a += 2.0", err: true}, {stmt: "a := mat2(1); a *= int(2)", err: true},