From 7af6c9995450d5ab2e7daa356a067d5f81fa545f Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Thu, 18 Aug 2022 17:30:12 +0900 Subject: [PATCH] internal/shader: add type checks for some builtin functions Updates #2184 --- internal/shader/expr.go | 85 ++++++++++------ internal/shader/syntax_test.go | 180 ++++++++++++++++++++++++++++++++- 2 files changed, 232 insertions(+), 33 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index c9cadcaaa..c4422e65f 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -361,7 +361,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } case shaderir.FloatF: if len(args) == 1 && args[0].Type == shaderir.NumberExpr { - if args[0].Const.Kind() == gconstant.Int || args[0].Const.Kind() == gconstant.Float { + if gconstant.ToFloat(args[0].Const).Kind() != gconstant.Unknown { return []shaderir.Expr{ { Type: shaderir.NumberExpr, @@ -430,39 +430,12 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar return nil, nil, nil, false } t = shaderir.Type{Main: shaderir.Mat4} - case shaderir.Atan: - if len(args) != 1 { - cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 1 but %d", callee.BuiltinFunc, len(args))) - return nil, nil, nil, false - } - // TODO: Check arg types. - // If the argument is a non-typed constant value, treat is as a float value (#1874). - if args[0].Type == shaderir.NumberExpr && args[0].ConstType == shaderir.ConstTypeNone { - args[0].ConstType = shaderir.ConstTypeFloat - argts[0] = shaderir.Type{Main: shaderir.Float} - } - t = argts[0] - case shaderir.Atan2: - if len(args) != 2 { - cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 2 but %d", callee.BuiltinFunc, len(args))) - return nil, nil, nil, false - } - // TODO: Check arg types. - // If the argument is a non-typed constant value, treat is as a float value (#1874). - if args[0].Type == shaderir.NumberExpr && args[0].ConstType == shaderir.ConstTypeNone { - args[0].ConstType = shaderir.ConstTypeFloat - argts[0] = shaderir.Type{Main: shaderir.Float} - } - t = argts[0] case shaderir.Step: // TODO: Check arg types. t = argts[1] case shaderir.Smoothstep: // TODO: Check arg types. t = argts[2] - case shaderir.Length, shaderir.Distance, shaderir.Dot: - // TODO: Check arg types. - t = shaderir.Type{Main: shaderir.Float} case shaderir.Cross: // TODO: Check arg types. t = shaderir.Type{Main: shaderir.Vec3} @@ -472,22 +445,70 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar case shaderir.DiscardF: if len(args) != 0 { cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args))) + return nil, nil, nil, false } if fname != cs.fragmentEntry { cs.addError(e.Pos(), fmt.Sprintf("discard is available only in %s", cs.fragmentEntry)) + return nil, nil, nil, false } stmts = append(stmts, shaderir.Stmt{ Type: shaderir.Discard, }) return nil, nil, stmts, true + case shaderir.Atan2, shaderir.Mod, shaderir.Min, shaderir.Max, shaderir.Distance, shaderir.Dot, shaderir.Reflect: + if len(args) != 2 { + cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 2 but %d", callee.BuiltinFunc, len(args))) + return nil, nil, nil, false + } + for i := range args { + // If the argument is a non-typed constant value, treat this as a float value (#1874). + if args[i].Type == shaderir.NumberExpr && args[i].ConstType == shaderir.ConstTypeNone && gconstant.ToFloat(args[i].Const).Kind() != gconstant.Unknown { + args[i].Const = gconstant.ToFloat(args[i].Const) + args[i].ConstType = shaderir.ConstTypeFloat + argts[i] = shaderir.Type{Main: shaderir.Float} + } + if argts[i].Main != shaderir.Float && argts[i].Main != shaderir.Vec2 && argts[i].Main != shaderir.Vec3 && argts[i].Main != shaderir.Vec4 { + cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[i].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } + } + + if callee.BuiltinFunc == shaderir.Mod || callee.BuiltinFunc == shaderir.Min || callee.BuiltinFunc == shaderir.Max { + if !argts[0].Equal(&argts[1]) && argts[1].Main != shaderir.Float { + cs.addError(e.Pos(), fmt.Sprintf("the second argument for %s must equal to the first argument %s or float but %s", callee.BuiltinFunc, argts[0].String(), argts[1].String())) + return nil, nil, nil, false + } + } else { + if !argts[0].Equal(&argts[1]) { + cs.addError(e.Pos(), fmt.Sprintf("%s and %s don't match in argument to %s", argts[0].String(), argts[1].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } + } + if callee.BuiltinFunc == shaderir.Distance || callee.BuiltinFunc == shaderir.Dot { + t = shaderir.Type{Main: shaderir.Float} + } else { + t = argts[0] + } default: - // TODO: Check arg types. - // If the argument is a non-typed constant value, treat is as a float value (#1874). - if args[0].Type == shaderir.NumberExpr && args[0].ConstType == shaderir.ConstTypeNone { + if len(args) != 1 { + cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 1 but %d", callee.BuiltinFunc, len(args))) + return nil, nil, nil, false + } + // If the argument is a non-typed constant value, treat this as a float value (#1874). + if args[0].Type == shaderir.NumberExpr && args[0].ConstType == shaderir.ConstTypeNone && gconstant.ToFloat(args[0].Const).Kind() != gconstant.Unknown { + args[0].Const = gconstant.ToFloat(args[0].Const) args[0].ConstType = shaderir.ConstTypeFloat argts[0] = shaderir.Type{Main: shaderir.Float} } - t = argts[0] + if argts[0].Main != shaderir.Float && argts[0].Main != shaderir.Vec2 && argts[0].Main != shaderir.Vec3 && argts[0].Main != shaderir.Vec4 { + cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[0].String(), callee.BuiltinFunc)) + return nil, nil, nil, false + } + if callee.BuiltinFunc == shaderir.Length { + t = shaderir.Type{Main: shaderir.Float} + } else { + t = argts[0] + } } return []shaderir.Expr{ { diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 56e6ad13a..b9f98781c 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -18,6 +18,7 @@ import ( "fmt" "go/parser" "go/token" + "strings" "testing" "github.com/hajimehoshi/ebiten/v2/internal/shader" @@ -1392,7 +1393,7 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } // Issue #2184 -func TestSyntaxBuiltinFuncType(t *testing.T) { +func TestSyntaxConstructorFuncType(t *testing.T) { cases := []struct { stmt string err bool @@ -1614,3 +1615,180 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { t.Errorf("error must be non-nil but was nil") } } + +// Issue #2184 +func TestSyntaxBuiltinFuncSingleArgType(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := {{.Func}}(); _ = a", err: true}, + {stmt: "a := {{.Func}}(false); _ = a", err: true}, + {stmt: "a := {{.Func}}(1); _ = a", err: false}, + {stmt: "a := {{.Func}}(1.0); _ = a", err: false}, + {stmt: "a := {{.Func}}(int(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec2(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(vec3(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(vec4(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(mat2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(mat3(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(mat4(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, 1); _ = a", err: true}, + } + + funcs := []string{ + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "pow", + "exp", + "log", + "exp2", + "log2", + "sqrt", + "inversesqrt", + "abs", + "sign", + "floor", + "ceil", + "fract", + "length", + "normalize", + "dfdx", + "dfdy", + "fwidth", + } + for _, c := range cases { + for _, f := range funcs { + stmt := strings.ReplaceAll(c.stmt, "{{.Func}}", f) + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } + } +} + +// Issue #2184 +func TestSyntaxBuiltinFuncDoubleArgsType(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := {{.Func}}(); _ = a", err: true}, + {stmt: "a := {{.Func}}(1); _ = a", err: true}, + {stmt: "a := {{.Func}}(false, false); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, 1); _ = a", err: false}, + {stmt: "a := {{.Func}}(1.0, 1); _ = a", err: false}, + {stmt: "a := {{.Func}}(1, 1.0); _ = a", err: false}, + {stmt: "a := {{.Func}}(int(1), int(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, vec2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, vec3(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, vec4(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec2(1), 1); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec2(1), vec2(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(vec2(1), vec3(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec2(1), vec4(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec3(1), 1); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec3(1), vec2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec3(1), vec3(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(vec3(1), vec4(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec4(1), 1); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec4(1), vec2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec4(1), vec3(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec4(1), vec4(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(mat2(1), mat2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, 1, 1); _ = a", err: true}, + } + + funcs := []string{ + "atan2", + "distance", + "dot", + "reflect", + } + for _, c := range cases { + for _, f := range funcs { + stmt := strings.ReplaceAll(c.stmt, "{{.Func}}", f) + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } + } +} + +// Issue #2184 +func TestSyntaxBuiltinFuncDoubleArgsType2(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := {{.Func}}(); _ = a", err: true}, + {stmt: "a := {{.Func}}(1); _ = a", err: true}, + {stmt: "a := {{.Func}}(false, false); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, 1); _ = a", err: false}, + {stmt: "a := {{.Func}}(1.0, 1); _ = a", err: false}, + {stmt: "a := {{.Func}}(1, 1.0); _ = a", err: false}, + {stmt: "a := {{.Func}}(int(1), int(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, vec2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, vec3(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, vec4(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec2(1), 1); _ = a", err: false}, // The second argument can be a scalar. + {stmt: "a := {{.Func}}(vec2(1), vec2(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(vec2(1), vec3(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec2(1), vec4(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec3(1), 1); _ = a", err: false}, // The second argument can be a scalar. + {stmt: "a := {{.Func}}(vec3(1), vec2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec3(1), vec3(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(vec3(1), vec4(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec4(1), 1); _ = a", err: false}, // The second argument can be a scalar. + {stmt: "a := {{.Func}}(vec4(1), vec2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec4(1), vec3(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(vec4(1), vec4(1)); _ = a", err: false}, + {stmt: "a := {{.Func}}(mat2(1), mat2(1)); _ = a", err: true}, + {stmt: "a := {{.Func}}(1, 1, 1); _ = a", err: true}, + } + + funcs := []string{ + "mod", + "min", + "max", + } + for _, c := range cases { + for _, f := range funcs { + stmt := strings.ReplaceAll(c.stmt, "{{.Func}}", f) + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } + } +}