From 2e5b4954f34c20ee4a7b0fc8542a881743a8cba0 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Fri, 21 Jan 2022 02:39:28 +0900 Subject: [PATCH] internal/shader: bug fix: forbid mat + float mat + float doesn't work on Metal. --- internal/shader/expr.go | 30 ++++++++++++++++++++++++++++-- internal/shader/stmt.go | 7 ++++++- shader_test.go | 22 +++++++++++++++++++++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 892b94c87..459fb04ce 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -150,6 +150,13 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable // TODO: Check types of the operands. t = shaderir.Type{Main: shaderir.Bool} case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr: + switch rhst.Main { + case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: + if op != shaderir.Mul { + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) + return nil, nil, nil, false + } + } if rhst.Main == shaderir.Int { if !canTruncateToInteger(lhs[0].Const) { cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", lhs[0].Const.String())) @@ -159,6 +166,13 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable } t = rhst case lhs[0].Type != shaderir.NumberExpr && rhs[0].Type == shaderir.NumberExpr: + switch lhst.Main { + case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: + if op != shaderir.Mul && op != shaderir.Div { + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) + return nil, nil, nil, false + } + } if lhst.Main == shaderir.Int { if !canTruncateToInteger(rhs[0].Const) { cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", rhs[0].Const.String())) @@ -171,15 +185,27 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable t = lhst case lhst.Main == shaderir.Float: switch rhst.Main { - case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: + case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: t = rhst + case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: + if op != shaderir.Mul { + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) + return nil, nil, nil, false + } + t = lhst default: cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } case rhst.Main == shaderir.Float: switch lhst.Main { - case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: + case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: + t = lhst + case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: + if op != shaderir.Mul && op != shaderir.Div { + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) + return nil, nil, nil, false + } t = lhst default: cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 061364aa1..f87a2821f 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -109,7 +109,12 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if rts[0].Main == shaderir.Float { + if (op == shaderir.Mul || op == shaderir.Div) && rts[0].Main == shaderir.Float { + // OK + } else if (lts[0].Main == shaderir.Vec2 || + lts[0].Main == shaderir.Vec3 || + lts[0].Main == shaderir.Vec4) && + rts[0].Main == shaderir.Float { // OK } else if op == shaderir.Mul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || (lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) || diff --git a/shader_test.go b/shader_test.go index 4c66351ab..c73651f79 100644 --- a/shader_test.go +++ b/shader_test.go @@ -1542,6 +1542,8 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { // Issue #1971 func TestShaderOperatorMultiply(t *testing.T) { + // Note: mat + float is allowed in GLSL but not in Metal. + cases := []struct { stmt string err bool @@ -1551,6 +1553,7 @@ func TestShaderOperatorMultiply(t *testing.T) { {stmt: "a := 1.0 * vec2(2); _ = a", err: false}, {stmt: "a := 1 + vec2(2); _ = a", err: false}, {stmt: "a := int(1) + vec2(2); _ = a", err: true}, + {stmt: "a := 1.0 / vec2(2); _ = a", err: false}, {stmt: "a := 1.0 + vec2(2); _ = a", err: false}, {stmt: "a := 1 * vec3(2); _ = a", err: false}, {stmt: "a := 1.0 * vec3(2); _ = a", err: false}, @@ -1558,12 +1561,18 @@ func TestShaderOperatorMultiply(t *testing.T) { {stmt: "a := 1.0 * vec4(2); _ = a", err: false}, {stmt: "a := 1 * mat2(2); _ = a", err: false}, {stmt: "a := 1.0 * mat2(2); _ = a", err: false}, + {stmt: "a := float(1.0) / mat2(2); _ = a", err: true}, + {stmt: "a := 1.0 / mat2(2); _ = a", err: true}, + {stmt: "a := float(1.0) + mat2(2); _ = a", err: true}, + {stmt: "a := 1.0 + mat2(2); _ = a", err: true}, {stmt: "a := 1 * mat3(2); _ = a", err: false}, {stmt: "a := 1.0 * mat3(2); _ = a", err: false}, {stmt: "a := 1 * mat4(2); _ = a", err: false}, {stmt: "a := 1.0 * mat4(2); _ = a", err: false}, {stmt: "a := vec2(1) * 2; _ = a", err: false}, {stmt: "a := vec2(1) * 2.0; _ = a", err: false}, + {stmt: "a := vec2(1) / 2.0; _ = a", err: false}, + {stmt: "a := vec2(1) + 2.0; _ = a", err: false}, {stmt: "a := vec2(1) * int(2); _ = a", err: true}, {stmt: "a := vec2(1) * vec2(2); _ = a", err: false}, {stmt: "a := vec2(1) + vec2(2); _ = a", err: false}, @@ -1575,8 +1584,11 @@ func TestShaderOperatorMultiply(t *testing.T) { {stmt: "a := vec2(1) * mat4(2); _ = a", err: true}, {stmt: "a := mat2(1) * 2; _ = a", err: false}, {stmt: "a := mat2(1) * 2.0; _ = a", err: false}, + {stmt: "a := mat2(1) / 2.0; _ = a", err: false}, + {stmt: "a := mat2(1) / float(2); _ = a", err: false}, {stmt: "a := mat2(1) * int(2); _ = a", err: true}, - {stmt: "a := mat2(1) + 2.0; _ = a", err: false}, + {stmt: "a := mat2(1) + 2.0; _ = a", err: true}, + {stmt: "a := mat2(1) + float(2); _ = a", err: true}, {stmt: "a := mat2(1) * vec2(2); _ = a", err: false}, {stmt: "a := mat2(1) + vec2(2); _ = a", err: true}, {stmt: "a := mat2(1) * vec3(2); _ = a", err: true}, @@ -1618,7 +1630,11 @@ func TestShaderOperatorMultiplyAssign(t *testing.T) { {stmt: "a := 1.0; a *= mat4(2)", err: true}, {stmt: "a := vec2(1); a *= 2", err: false}, {stmt: "a := vec2(1); a *= 2.0", err: false}, + {stmt: "a := vec2(1); a /= 2.0", err: false}, + {stmt: "a := vec2(1); a += 2.0", err: false}, {stmt: "a := vec2(1); a *= int(2)", err: true}, + {stmt: "a := vec2(1); a *= float(2)", err: false}, + {stmt: "a := vec2(1); a /= float(2)", err: false}, {stmt: "a := vec2(1); a *= vec2(2)", err: false}, {stmt: "a := vec2(1); a += vec2(2)", err: false}, {stmt: "a := vec2(1); a *= vec3(2)", err: true}, @@ -1629,7 +1645,11 @@ func TestShaderOperatorMultiplyAssign(t *testing.T) { {stmt: "a := vec2(1); a *= mat4(2)", err: true}, {stmt: "a := mat2(1); a *= 2", err: false}, {stmt: "a := mat2(1); a *= 2.0", err: false}, + {stmt: "a := mat2(1); a /= 2.0", err: false}, + {stmt: "a := mat2(1); a += 2.0", err: true}, {stmt: "a := mat2(1); a *= int(2)", err: true}, + {stmt: "a := mat2(1); a *= float(2)", err: false}, + {stmt: "a := mat2(1); a /= float(2)", err: false}, {stmt: "a := mat2(1); a *= vec2(2)", err: true}, {stmt: "a := mat2(1); a += vec2(2)", err: true}, {stmt: "a := mat2(1); a *= vec3(2)", err: true},