From 82a9aac6895deef5704e296e176d55ac07fd8bda Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Thu, 18 Aug 2022 23:44:58 +0900 Subject: [PATCH] internal/shader: add type checks for the builtin function cross Updates #2184 --- internal/shader/expr.go | 21 ++++++++++----- internal/shader/syntax_test.go | 49 ++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index c4422e65f..7611def3e 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -436,9 +436,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar case shaderir.Smoothstep: // TODO: Check arg types. t = argts[2] - case shaderir.Cross: - // TODO: Check arg types. - t = shaderir.Type{Main: shaderir.Vec3} case shaderir.Texture2DF: // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Vec4} @@ -455,7 +452,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar Type: shaderir.Discard, }) return nil, nil, stmts, true - case shaderir.Atan2, shaderir.Mod, shaderir.Min, shaderir.Max, shaderir.Distance, shaderir.Dot, shaderir.Reflect: + + case shaderir.Atan2, shaderir.Mod, shaderir.Min, shaderir.Max, shaderir.Distance, shaderir.Dot, shaderir.Cross, shaderir.Reflect: + // 2 arguments if len(args) != 2 { cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 2 but %d", callee.BuiltinFunc, len(args))) return nil, nil, nil, false @@ -473,12 +472,20 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } } - if callee.BuiltinFunc == shaderir.Mod || callee.BuiltinFunc == shaderir.Min || callee.BuiltinFunc == shaderir.Max { + switch callee.BuiltinFunc { + case shaderir.Mod, shaderir.Min, shaderir.Max: if !argts[0].Equal(&argts[1]) && argts[1].Main != shaderir.Float { cs.addError(e.Pos(), fmt.Sprintf("the second argument for %s must equal to the first argument %s or float but %s", callee.BuiltinFunc, argts[0].String(), argts[1].String())) return nil, nil, nil, false } - } else { + case shaderir.Cross: + for i := range argts { + if argts[i].Main != shaderir.Vec3 { + cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as vec3 value in argument to %s", argts[i].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } + } + default: if !argts[0].Equal(&argts[1]) { cs.addError(e.Pos(), fmt.Sprintf("%s and %s don't match in argument to %s", argts[0].String(), argts[1].String(), callee.BuiltinFunc)) return nil, nil, nil, false @@ -489,7 +496,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } else { t = argts[0] } + default: + // 1 argument if len(args) != 1 { cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 1 but %d", callee.BuiltinFunc, len(args))) return nil, nil, nil, false diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index b9f98781c..204d422bc 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1792,3 +1792,52 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #2184 +func TestSyntaxBuiltinFuncCrossType(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := cross(); _ = a", err: true}, + {stmt: "a := cross(1); _ = a", err: true}, + {stmt: "a := cross(false, false); _ = a", err: true}, + {stmt: "a := cross(1, 1); _ = a", err: true}, + {stmt: "a := cross(1.0, 1); _ = a", err: true}, + {stmt: "a := cross(1, 1.0); _ = a", err: true}, + {stmt: "a := cross(int(1), int(1)); _ = a", err: true}, + {stmt: "a := cross(1, vec2(1)); _ = a", err: true}, + {stmt: "a := cross(1, vec3(1)); _ = a", err: true}, + {stmt: "a := cross(1, vec4(1)); _ = a", err: true}, + {stmt: "a := cross(vec2(1), 1); _ = a", err: true}, + {stmt: "a := cross(vec2(1), vec2(1)); _ = a", err: true}, + {stmt: "a := cross(vec2(1), vec3(1)); _ = a", err: true}, + {stmt: "a := cross(vec2(1), vec4(1)); _ = a", err: true}, + {stmt: "a := cross(vec3(1), 1); _ = a", err: true}, + {stmt: "a := cross(vec3(1), vec2(1)); _ = a", err: true}, + {stmt: "a := cross(vec3(1), vec3(1)); _ = a", err: false}, // Only two vec3s are allowed + {stmt: "a := cross(vec3(1), vec4(1)); _ = a", err: true}, + {stmt: "a := cross(vec4(1), 1); _ = a", err: true}, + {stmt: "a := cross(vec4(1), vec2(1)); _ = a", err: true}, + {stmt: "a := cross(vec4(1), vec3(1)); _ = a", err: true}, + {stmt: "a := cross(vec4(1), vec4(1)); _ = a", err: true}, + {stmt: "a := cross(mat2(1), mat2(1)); _ = a", err: true}, + {stmt: "a := cross(1, 1, 1); _ = a", err: true}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +}