From 63eee0600eba32de3a6e5f24593875edc66d313a Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Fri, 19 Aug 2022 01:57:13 +0900 Subject: [PATCH] internal/shader: add type checks for the builtin function transpose Updates #2184 --- internal/shader/expr.go | 14 ++++++++++--- internal/shader/syntax_test.go | 37 ++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 258895ca0..d2c36777d 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -572,9 +572,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar args[0].ConstType = shaderir.ConstTypeFloat argts[0] = shaderir.Type{Main: shaderir.Float} } - if argts[0].Main != shaderir.Float && argts[0].Main != shaderir.Vec2 && argts[0].Main != shaderir.Vec3 && argts[0].Main != shaderir.Vec4 { - cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[0].String(), callee.BuiltinFunc)) - return nil, nil, nil, false + switch callee.BuiltinFunc { + case shaderir.Transpose: + if argts[0].Main != shaderir.Mat2 && argts[0].Main != shaderir.Mat3 && argts[0].Main != shaderir.Mat4 { + cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as mat2, mat3, or mat4 value in argument to %s", argts[0].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } + default: + if argts[0].Main != shaderir.Float && argts[0].Main != shaderir.Vec2 && argts[0].Main != shaderir.Vec3 && argts[0].Main != shaderir.Vec4 { + cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[0].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } } if callee.BuiltinFunc == shaderir.Length { t = shaderir.Type{Main: shaderir.Float} diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 8f7c52728..2d892d38e 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -2152,3 +2152,40 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #2184 +func TestSyntaxBuiltinFuncTransposeType(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := transpose(); _ = a", err: true}, + {stmt: "a := transpose(false); _ = a", err: true}, + {stmt: "a := transpose(1); _ = a", err: true}, + {stmt: "a := transpose(1.0); _ = a", err: true}, + {stmt: "a := transpose(int(1)); _ = a", err: true}, + {stmt: "a := transpose(vec2(1)); _ = a", err: true}, + {stmt: "a := transpose(vec3(1)); _ = a", err: true}, + {stmt: "a := transpose(vec4(1)); _ = a", err: true}, + {stmt: "a := transpose(mat2(1)); _ = a", err: false}, + {stmt: "a := transpose(mat3(1)); _ = a", err: false}, + {stmt: "a := transpose(mat4(1)); _ = a", err: false}, + {stmt: "a := transpose(1, 1); _ = a", err: true}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +}