From f89277fd85ee39a0baf599f7a9c62791d1611d7d Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Wed, 6 Jul 2022 01:25:40 +0900 Subject: [PATCH] internal/shader: add type checks to vec2/vec3/vec4 Updates #2184 --- internal/shader/expr.go | 24 +++++++ internal/shader/syntax_test.go | 80 ++++++++++++++++++++++ internal/shader/type.go | 121 +++++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index d79ea28b6..971760ba5 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -379,34 +379,58 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable var t shaderir.Type switch callee.BuiltinFunc { case shaderir.BoolF: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Bool} case shaderir.IntF: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Int} case shaderir.FloatF: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Float} case shaderir.Vec2F: + if err := checkArgsForVec2BuiltinFunc(args, argts); err != nil { + cs.addError(e.Pos(), err.Error()) + return nil, nil, nil, false + } t = shaderir.Type{Main: shaderir.Vec2} case shaderir.Vec3F: + if err := checkArgsForVec3BuiltinFunc(args, argts); err != nil { + cs.addError(e.Pos(), err.Error()) + return nil, nil, nil, false + } t = shaderir.Type{Main: shaderir.Vec3} case shaderir.Vec4F: + if err := checkArgsForVec4BuiltinFunc(args, argts); err != nil { + cs.addError(e.Pos(), err.Error()) + return nil, nil, nil, false + } t = shaderir.Type{Main: shaderir.Vec4} case shaderir.Mat2F: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Mat2} case shaderir.Mat3F: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Mat3} case shaderir.Mat4F: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Mat4} case shaderir.Step: + // TODO: Check arg types. t = argts[1] case shaderir.Smoothstep: + // TODO: Check arg types. t = argts[2] case shaderir.Length, shaderir.Distance, shaderir.Dot: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Float} case shaderir.Cross: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Vec3} case shaderir.Texture2DF: + // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Vec4} default: + // TODO: Check arg types? // If the argument is a non-typed constant value, treat is as a float value (#1874). if args[0].Type == shaderir.NumberExpr && args[0].ConstType == shaderir.ConstTypeNone { args[0].ConstType = shaderir.ConstTypeFloat diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index be4265ed0..bc396977b 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1351,3 +1351,83 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { t.Errorf("error must be non-nil but was nil") } } + +// Issue #2184 +func TestSyntaxBuiltinFuncType(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := vec2(1); _ = a", err: false}, + {stmt: "a := vec2(1.0); _ = a", err: false}, + {stmt: "i := 1; a := vec2(i); _ = a", err: false}, + {stmt: "i := 1.0; a := vec2(i); _ = a", err: false}, + {stmt: "a := vec2(vec2(1)); _ = a", err: false}, + {stmt: "a := vec2(vec3(1)); _ = a", err: true}, + + {stmt: "a := vec2(1, 1); _ = a", err: false}, + {stmt: "a := vec2(1.0, 1.0); _ = a", err: false}, + {stmt: "i := 1; a := vec2(i, i); _ = a", err: false}, + {stmt: "i := 1.0; a := vec2(i, i); _ = a", err: false}, + {stmt: "a := vec2(vec2(1), 1); _ = a", err: true}, + {stmt: "a := vec2(1, vec2(1)); _ = a", err: true}, + {stmt: "a := vec2(vec2(1), vec2(1)); _ = a", err: true}, + {stmt: "a := vec2(1, 1, 1); _ = a", err: true}, + + {stmt: "a := vec3(1); _ = a", err: false}, + {stmt: "a := vec3(1.0); _ = a", err: false}, + {stmt: "i := 1; a := vec3(i); _ = a", err: false}, + {stmt: "i := 1.0; a := vec3(i); _ = a", err: false}, + {stmt: "a := vec3(vec3(1)); _ = a", err: false}, + {stmt: "a := vec3(vec2(1)); _ = a", err: true}, + {stmt: "a := vec3(vec4(1)); _ = a", err: true}, + + {stmt: "a := vec3(1, 1, 1); _ = a", err: false}, + {stmt: "a := vec3(1.0, 1.0, 1.0); _ = a", err: false}, + {stmt: "i := 1; a := vec3(i, i, i); _ = a", err: false}, + {stmt: "i := 1.0; a := vec3(i, i, i); _ = a", err: false}, + {stmt: "a := vec3(vec2(1), 1); _ = a", err: false}, + {stmt: "a := vec3(1, vec2(1)); _ = a", err: false}, + {stmt: "a := vec3(vec3(1), 1); _ = a", err: true}, + {stmt: "a := vec3(1, vec3(1)); _ = a", err: true}, + {stmt: "a := vec3(vec3(1), vec3(1), vec3(1)); _ = a", err: true}, + {stmt: "a := vec3(1, 1, 1, 1); _ = a", err: true}, + + {stmt: "a := vec4(1); _ = a", err: false}, + {stmt: "a := vec4(1.0); _ = a", err: false}, + {stmt: "i := 1; a := vec4(i); _ = a", err: false}, + {stmt: "i := 1.0; a := vec4(i); _ = a", err: false}, + {stmt: "a := vec4(vec4(1)); _ = a", err: false}, + {stmt: "a := vec4(vec2(1)); _ = a", err: true}, + {stmt: "a := vec4(vec3(1)); _ = a", err: true}, + + {stmt: "a := vec4(1, 1, 1, 1); _ = a", err: false}, + {stmt: "a := vec4(1.0, 1.0, 1.0, 1.0); _ = a", err: false}, + {stmt: "i := 1; a := vec4(i, i, i, i); _ = a", err: false}, + {stmt: "i := 1.0; a := vec4(i, i, i, i); _ = a", err: false}, + {stmt: "a := vec4(vec2(1), 1, 1); _ = a", err: false}, + {stmt: "a := vec4(1, vec2(1), 1); _ = a", err: false}, + {stmt: "a := vec4(1, 1, vec2(1)); _ = a", err: false}, + {stmt: "a := vec4(vec2(1), vec2(1)); _ = a", err: false}, + {stmt: "a := vec4(vec3(1), 1); _ = a", err: false}, + {stmt: "a := vec4(1, vec3(1)); _ = a", err: false}, + {stmt: "a := vec4(vec4(1), 1); _ = a", err: true}, + {stmt: "a := vec4(1, vec4(1)); _ = a", err: true}, + {stmt: "a := vec4(vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: true}, + {stmt: "a := vec4(1, 1, 1, 1, 1); _ = a", err: true}, + } + + for _, c := range cases { + _, err := compileToIR([]byte(fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +} diff --git a/internal/shader/type.go b/internal/shader/type.go index 4e99d83d3..9b1ca0ee7 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -98,3 +98,124 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b return shaderir.Type{}, false } } + +func canBeFloatImplicitly(expr shaderir.Expr, t shaderir.Type) bool { + // TODO: For integers, should only constants be allowed? + if t.Main == shaderir.Int { + return true + } + if t.Main == shaderir.Float { + return true + } + if expr.Const != nil { + if expr.Const.Kind() == gconstant.Int { + return true + } + if expr.Const.Kind() == gconstant.Float { + return true + } + } + return false +} + +func checkArgsForVec2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { + if len(args) != len(argts) { + return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) + } + + switch len(args) { + case 1: + if canBeFloatImplicitly(args[0], argts[0]) { + return nil + } + if argts[0].Main == shaderir.Vec2 { + return nil + } + return fmt.Errorf("invalid arguments for vec2: (%s)", argts[0].String()) + case 2: + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) { + return nil + } + return fmt.Errorf("invalid arguments for vec2: (%s, %s)", argts[0].String(), argts[1].String()) + default: + return fmt.Errorf("too many arguments for vec2") + } +} + +func checkArgsForVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { + if len(args) != len(argts) { + return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) + } + + switch len(args) { + case 1: + if canBeFloatImplicitly(args[0], argts[0]) { + return nil + } + if argts[0].Main == shaderir.Vec3 { + return nil + } + return fmt.Errorf("invalid arguments for vec3: (%s)", argts[0].String()) + case 2: + if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec2 { + return nil + } + if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) { + return nil + } + return fmt.Errorf("invalid arguments for vec3: (%s, %s)", argts[0].String(), argts[1].String()) + case 3: + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) { + return nil + } + return fmt.Errorf("invalid arguments for vec3: (%s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String()) + default: + return fmt.Errorf("too many arguments for vec3") + } +} + +func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { + if len(args) != len(argts) { + return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts)) + } + + switch len(args) { + case 1: + if canBeFloatImplicitly(args[0], argts[0]) { + return nil + } + if argts[0].Main == shaderir.Vec4 { + return nil + } + return fmt.Errorf("invalid arguments for vec4: (%s)", argts[0].String()) + case 2: + if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec3 { + return nil + } + if argts[0].Main == shaderir.Vec2 && argts[1].Main == shaderir.Vec2 { + return nil + } + if argts[0].Main == shaderir.Vec3 && canBeFloatImplicitly(args[1], argts[1]) { + return nil + } + return fmt.Errorf("invalid arguments for vec4: (%s, %s)", argts[0].String(), argts[1].String()) + case 3: + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && argts[2].Main == shaderir.Vec2 { + return nil + } + if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec2 && canBeFloatImplicitly(args[2], argts[2]) { + return nil + } + if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) { + return nil + } + return fmt.Errorf("invalid arguments for vec4: (%s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String()) + case 4: + if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) && canBeFloatImplicitly(args[3], argts[3]) { + return nil + } + return fmt.Errorf("invalid arguments for vec4: (%s, %s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String(), argts[3].String()) + default: + return fmt.Errorf("too many arguments for vec4") + } +}