diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 258895ca0..d2c36777d 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -572,9 +572,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[0].ConstType = shaderir.ConstTypeFloat argts[0] = shaderir.Type{Main: shaderir.Float} } - if argts[0].Main != shaderir.Float && argts[0].Main != shaderir.Vec2 && argts[0].Main != shaderir.Vec3 && argts[0].Main != shaderir.Vec4 { - cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[0].String(), callee.BuiltinFunc)) - return nil, nil, nil, false + switch callee.BuiltinFunc { + case shaderir.Transpose: + if argts[0].Main != shaderir.Mat2 && argts[0].Main != shaderir.Mat3 && argts[0].Main != shaderir.Mat4 { + cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as mat2, mat3, or mat4 value in argument to %s", argts[0].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } + default: + if argts[0].Main != shaderir.Float && argts[0].Main != shaderir.Vec2 && argts[0].Main != shaderir.Vec3 && argts[0].Main != shaderir.Vec4 { + cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[0].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } } if callee.BuiltinFunc == shaderir.Length { t = shaderir.Type{Main: shaderir.Float} diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 8f7c52728..2d892d38e 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -2152,3 +2152,40 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #2184 +func TestSyntaxBuiltinFuncTransposeType(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := transpose(); _ = a", err: true}, + {stmt: "a := transpose(false); _ = a", err: true}, + {stmt: "a := transpose(1); _ = a", err: true}, + {stmt: "a := transpose(1.0); _ = a", err: true}, + {stmt: "a := transpose(int(1)); _ = a", err: true}, + {stmt: "a := transpose(vec2(1)); _ = a", err: true}, + {stmt: "a := transpose(vec3(1)); _ = a", err: true}, + {stmt: "a := transpose(vec4(1)); _ = a", err: true}, + {stmt: "a := transpose(mat2(1)); _ = a", err: false}, + {stmt: "a := transpose(mat3(1)); _ = a", err: false}, + {stmt: "a := transpose(mat4(1)); _ = a", err: false}, + {stmt: "a := transpose(1, 1); _ = a", err: true}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +}