diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 4f5a2548b..892b94c87 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -147,6 +147,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable var t shaderir.Type switch { case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr: + // TODO: Check types of the operands. t = shaderir.Type{Main: shaderir.Bool} case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr: if rhst.Main == shaderir.Int { @@ -168,34 +169,30 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable t = lhst case lhst.Equal(&rhst): t = lhst - case lhst.Main == shaderir.Float || lhst.Main == shaderir.Int: + case lhst.Main == shaderir.Float: switch rhst.Main { - case shaderir.Int: - t = lhst case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: t = rhst default: cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } - case rhst.Main == shaderir.Float || rhst.Main == shaderir.Int: + case rhst.Main == shaderir.Float: switch lhst.Main { - case shaderir.Int: - t = rhst case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: t = lhst default: cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } - case lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 || - lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2: + case op == shaderir.Mul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 || + lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2): t = shaderir.Type{Main: shaderir.Vec2} - case lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 || - lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3: + case op == shaderir.Mul && (lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 || + lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3): t = shaderir.Type{Main: shaderir.Vec3} - case lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 || - lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4: + case op == shaderir.Mul && (lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 || + lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4): t = shaderir.Type{Main: shaderir.Vec4} default: cs.addError(e.Pos(), fmt.Sprintf("invalid expression: %s %s %s", lhst.String(), e.Op, rhst.String())) diff --git a/shader_test.go b/shader_test.go index 802ac80f0..5035ca9fc 100644 --- a/shader_test.go +++ b/shader_test.go @@ -15,6 +15,7 @@ package ebiten_test import ( + "fmt" "image" "image/color" "testing" @@ -1538,3 +1539,63 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { t.Errorf("error must be non-nil but was nil") } } + +// Issue #1971 +func TestShaderOperatorMultiply(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := 1 * vec2(2); _ = a", err: false}, + {stmt: "a := int(1) * vec2(2); _ = a", err: true}, + {stmt: "a := 1.0 * vec2(2); _ = a", err: false}, + {stmt: "a := 1 + vec2(2); _ = a", err: false}, + {stmt: "a := int(1) + vec2(2); _ = a", err: true}, + {stmt: "a := 1.0 + vec2(2); _ = a", err: false}, + {stmt: "a := 1 * vec3(2); _ = a", err: false}, + {stmt: "a := 1.0 * vec3(2); _ = a", err: false}, + {stmt: "a := 1 * vec4(2); _ = a", err: false}, + {stmt: "a := 1.0 * vec4(2); _ = a", err: false}, + {stmt: "a := 1 * mat2(2); _ = a", err: false}, + {stmt: "a := 1.0 * mat2(2); _ = a", err: false}, + {stmt: "a := 1 * mat3(2); _ = a", err: false}, + {stmt: "a := 1.0 * mat3(2); _ = a", err: false}, + {stmt: "a := 1 * mat4(2); _ = a", err: false}, + {stmt: "a := 1.0 * mat4(2); _ = a", err: false}, + {stmt: "a := vec2(1) * 2; _ = a", err: false}, + {stmt: "a := vec2(1) * 2.0; _ = a", err: false}, + {stmt: "a := vec2(1) * int(2); _ = a", err: true}, + {stmt: "a := vec2(1) * vec2(2); _ = a", err: false}, + {stmt: "a := vec2(1) + vec2(2); _ = a", err: false}, + {stmt: "a := vec2(1) * vec3(2); _ = a", err: true}, + {stmt: "a := vec2(1) * vec4(2); _ = a", err: true}, + {stmt: "a := vec2(1) * mat2(2); _ = a", err: false}, + {stmt: "a := vec2(1) * mat3(2); _ = a", err: true}, + {stmt: "a := vec2(1) * mat4(2); _ = a", err: true}, + {stmt: "a := mat2(1) * 2; _ = a", err: false}, + {stmt: "a := mat2(1) * 2.0; _ = a", err: false}, + {stmt: "a := mat2(1) * int(2); _ = a", err: true}, + {stmt: "a := mat2(1) + 2.0; _ = a", err: false}, + {stmt: "a := mat2(1) * vec2(2); _ = a", err: false}, + {stmt: "a := mat2(1) + vec2(2); _ = a", err: true}, + {stmt: "a := mat2(1) * vec3(2); _ = a", err: true}, + {stmt: "a := mat2(1) * vec4(2); _ = a", err: true}, + {stmt: "a := mat2(1) * mat2(2); _ = a", err: false}, + {stmt: "a := mat2(1) * mat3(2); _ = a", err: true}, + {stmt: "a := mat2(1) * mat4(2); _ = a", err: true}, + } + + for _, c := range cases { + _, err := ebiten.NewShader([]byte(fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +}