diff --git a/internal/shader/delayed.go b/internal/shader/delayed.go new file mode 100644 index 000000000..935d5797e --- /dev/null +++ b/internal/shader/delayed.go @@ -0,0 +1,124 @@ +// Copyright 2024 The Ebiten Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shader + +import ( + "go/ast" + gconstant "go/constant" + "go/token" + + "github.com/hajimehoshi/ebiten/v2/internal/shaderir" +) + +type resolveTypeStatus int + +const ( + resolveUnsure resolveTypeStatus = iota + resolveOk + resolveFail +) + +type delayedValidator interface { + Validate(expr ast.Expr) resolveTypeStatus + Pos() token.Pos + Error() string +} + +func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) { + valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks)) + for k := range cs.delayedTypeCheks { + valExprs = append(valExprs, k) + } + for _, expr := range valExprs { + if cexpr == expr { + continue + } + // Check if delayed validation can be done by adding current context + cres := cs.delayedTypeCheks[expr].Validate(cexpr) + switch cres { + case resolveFail: + cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error()) + return false + case resolveOk: + delete(cs.delayedTypeCheks, expr) + } + } + + return true +} + +type delayedShiftValidator struct { + value gconstant.Value + pos token.Pos + last ast.Expr +} + +func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool { + return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F +} + +func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) { + switch cexpr.(type) { + case *ast.Ident: + ident := cexpr.(*ast.Ident) + // For BuiltinFunc, only int* are allowed + if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok { + if isArgDefaultTypeInt(fname) { + return resolveOk + } + return resolveFail + } + // Untyped constant must represent int + if ident.Name == "_" { + if d.value != nil && d.value.Kind() == gconstant.Int { + return resolveOk + } + return resolveFail + } + if ident.Obj != nil { + if t, ok := ident.Obj.Type.(*ast.Ident); ok { + return d.Validate(t) + } + if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok { + return d.Validate(decl.Type) + } + if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok { + if d.value != nil && d.value.Kind() == gconstant.Int { + return resolveOk + } + return resolveFail + } + } + case *ast.BinaryExpr: + bs := cexpr.(*ast.BinaryExpr) + left, right := bs.X, bs.Y + if bs.Y == d.last { + left, right = right, left + } + + rightCheck := d.Validate(right) + d.last = cexpr + return rightCheck + } + return resolveUnsure +} + +func (d delayedShiftValidator) Pos() token.Pos { + return d.pos +} + +func (d delayedShiftValidator) Error() string { + return "left shift operand should be int" +} diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 1d06be676..ca11d4e66 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -36,7 +36,12 @@ func canTruncateToFloat(v gconstant.Value) bool { var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) -func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { +func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { + defer func() { + // Due to use of early return in the parsing, delayed checks are conducted in defer + ok = ok && cs.tryValidateDelayed(expr) + }() + switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -103,6 +108,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar // Resolve untyped constants. var l gconstant.Value var r gconstant.Value + origLvalue := lhs[0].Const if op2 == shaderir.LeftShift || op2 == shaderir.RightShift { l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst) } else { @@ -126,6 +132,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar if lhst.Main == shaderir.None && lhs[0].Const != nil { lhst = shaderir.Type{Main: shaderir.Int} // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. + if rhs[0].Const == nil { + cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr}) + } } } } else { diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 90ed2d611..c89d43cde 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -61,6 +61,8 @@ type compileState struct { varyingParsed bool + delayedTypeCheks map[ast.Expr]delayedValidator + errs []string } @@ -82,6 +84,13 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) { return 0, false } +func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) { + if cs.delayedTypeCheks == nil { + cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1) + } + cs.delayedTypeCheks[at] = check +} + type typ struct { name string ir shaderir.Type @@ -350,6 +359,12 @@ func (cs *compileState) parse(f *ast.File) { for _, f := range cs.funcs { cs.ir.Funcs = append(cs.ir.Funcs, f.ir) } + + // if len(cs.delayedTypeCheks) != 0 { + // for _, check := range cs.delayedTypeCheks { + // cs.addError(check.Pos(), check.Error()) + // } + // } } func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) { diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 1c2920cc4..2d78dd5a9 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,23 +1320,27 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - // {stmt: "s := 1; var a float = float(1 << s); _ = a", err: true}, - // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - // {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, - // {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, - // {stmt: "s := 1; a := 1 << s; _ = a", err: false}, - // {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, - // {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, - // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - // {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, - // {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, - // {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, - // {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, + {stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false}, + {stmt: "s := 1; _ = 1 << s", err: false}, + {stmt: "s := 1; _ = 1.0 << s", err: true}, + {stmt: "var a = 1; b := a << 2.0; _ = b", err: false}, + {stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, + {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, + {stmt: "s := 1; a := 1 << s; _ = a", err: false}, + {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, + {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, + {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, {stmt: "var a float = 1.0 << 2.0; _ = a", err: false}, {stmt: "var a int = 1.0 << 2; _ = a", err: false}, {stmt: "var a float = 1.0 << 2; _ = a", err: false}, - {stmt: "var a = 1.0 << 2; _ = a", err: false}, {stmt: "a := 1 << 2.0; _ = a", err: false}, {stmt: "a := 1.0 << 2; _ = a", err: false}, {stmt: "a := 1.0 << 2.0; _ = a", err: false}, @@ -1362,36 +1366,6 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, - - {stmt: "var a float = 1.0 >> 2.0; _ = a", err: false}, - {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, - {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, - {stmt: "var a = 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1 >> 2.0; _ = a", err: false}, - {stmt: "a := 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, - {stmt: "a := 1 >> 2; _ = a", err: false}, - {stmt: "a := float(1.0) >> 2; _ = a", err: true}, - {stmt: "a := 1 >> float(2.0); _ = a", err: false}, - {stmt: "a := 1.0 >> float(2.0); _ = a", err: false}, - {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, - {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, - {stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false}, - {stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true}, - {stmt: "a := 1 >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> 2; _ = a", err: true}, - {stmt: "a := float(1.0) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> float(2.0); _ = a", err: true}, - {stmt: "a := vec2(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> vec3(2); _ = a", err: true}, - {stmt: "a := vec3(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true}, } for _, c := range cases { @@ -1407,6 +1381,80 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { t.Errorf("%s must not return nil but returned %v", c.stmt, err) } } + + casesFunc := []struct { + prog string + err bool + }{ + { + prog: `package main + func Foo(x int) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1 << s) + return dstPos + }`, + err: false, + }, + { + prog: `package main + func Foo(x int) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1.0 << s) + return dstPos + }`, + err: false, + }, + { + prog: `package main + func Foo(x float) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1 << s) + return dstPos + }`, + err: true, + }, + { + prog: `package main + func Foo(x float) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1 << s) + return dstPos + }`, + err: true, + }, + { + prog: `package main + func Foo(x float) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + Foo(1.0 << 2.0) + return dstPos + }`, + err: false, + }, + } + + for _, c := range casesFunc { + _, err := compileToIR([]byte(c.prog)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.prog) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.prog, err) + } + } } func TestSyntaxOperatorShiftAssign(t *testing.T) {