diff --git a/internal/shader/expr.go b/internal/shader/expr.go index a5b4ee2fa..753ea348e 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -195,14 +195,14 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar t = lhst case lhst.Equal(&rhst): t = lhst - case lhst.Main == shaderir.Float: - t = rhst - case rhst.Main == shaderir.Float: + case op2 == shaderir.MatrixMul && (lhst.Main == shaderir.Float || lhst.IsVector()) && rhst.IsMatrix(): t = lhst - case op2 == shaderir.MatrixMul && lhst.IsVector() && rhst.IsMatrix(): - t = lhst - case op2 == shaderir.MatrixMul && lhst.IsMatrix() && rhst.IsVector(): + case op2 == shaderir.MatrixMul && lhst.IsMatrix() && (rhst.Main == shaderir.Float || rhst.IsVector()): t = rhst + case (lhst.Main == shaderir.Float || lhst.Main == shaderir.Int) && rhst.IsVector(): + t = rhst + case lhst.IsVector() && (rhst.Main == shaderir.Float || rhst.Main == shaderir.Int): + t = lhst default: panic(fmt.Sprintf("shaderir: invalid expression: %s %s %s", lhst.String(), e.Op, rhst.String())) } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 029338bf3..11e76ce4a 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -3291,3 +3291,53 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #2706 +func TestSyntaxScalarAndVector(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := vec2(1) + 1; var b vec2 = a; _ = b", err: false}, + {stmt: "a := 1 + vec2(1); var b vec2 = a; _ = b", err: false}, + {stmt: "a := vec2(1); b := 1; var c vec2 = a + b; _ = c", err: true}, + {stmt: "a := vec2(1); b := 1; var c vec2 = b + a; _ = c", err: true}, + {stmt: "a := vec2(1) + 1.0; var b vec2 = a; _ = b", err: false}, + {stmt: "a := 1.0 + vec2(1); var b vec2 = a; _ = b", err: false}, + {stmt: "a := vec2(1); b := 1.0; var c vec2 = a + b; _ = c", err: false}, + {stmt: "a := vec2(1); b := 1.0; var c vec2 = b + a; _ = c", err: false}, + {stmt: "a := vec2(1) + 1.1; var b vec2 = a; _ = b", err: false}, + {stmt: "a := 1.1 + vec2(1); var b vec2 = a; _ = b", err: false}, + {stmt: "a := vec2(1); b := 1.1; var c vec2 = a + b; _ = c", err: false}, + {stmt: "a := vec2(1); b := 1.1; var c vec2 = b + a; _ = c", err: false}, + + {stmt: "a := ivec2(1) + 1; var b ivec2 = a; _ = b", err: false}, + {stmt: "a := 1 + ivec2(1); var b ivec2 = a; _ = b", err: false}, + {stmt: "a := ivec2(1); b := 1; var c ivec2 = a + b; _ = c", err: false}, + {stmt: "a := ivec2(1); b := 1; var c ivec2 = b + a; _ = c", err: false}, + {stmt: "a := ivec2(1) + 1.0; var b ivec2 = a; _ = b", err: false}, + {stmt: "a := 1.0 + ivec2(1); var b ivec2 = a; _ = b", err: false}, + {stmt: "a := ivec2(1); b := 1.0; var c ivec2 = a + b; _ = c", err: true}, + {stmt: "a := ivec2(1); b := 1.0; var c ivec2 = b + a; _ = c", err: true}, + {stmt: "a := ivec2(1) + 1.1; var b ivec2 = a; _ = b", err: true}, + {stmt: "a := 1.1 + ivec2(1); var b ivec2 = a; _ = b", err: true}, + {stmt: "a := ivec2(1); b := 1.1; var c ivec2 = a + b; _ = c", err: true}, + {stmt: "a := ivec2(1); b := 1.1; var c ivec2 = b + a; _ = c", err: true}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +}