From f02e9fd4d060d1d050fe7254efe49906705a3792 Mon Sep 17 00:00:00 2001 From: aoyako Date: Sat, 2 Mar 2024 15:32:31 +0900 Subject: [PATCH] add shift type checks --- internal/shader/delayed.go | 217 ++++++++++++++++++++------------- internal/shader/expr.go | 12 +- internal/shader/shader.go | 6 +- internal/shader/stmt.go | 23 ++++ internal/shader/syntax_test.go | 112 ++++++----------- internal/shaderir/check.go | 9 +- internal/shaderir/program.go | 23 ++-- 7 files changed, 211 insertions(+), 191 deletions(-) diff --git a/internal/shader/delayed.go b/internal/shader/delayed.go index 935d5797e..abf675697 100644 --- a/internal/shader/delayed.go +++ b/internal/shader/delayed.go @@ -15,110 +15,151 @@ package shader import ( - "go/ast" + "fmt" gconstant "go/constant" - "go/token" "github.com/hajimehoshi/ebiten/v2/internal/shaderir" ) -type resolveTypeStatus int - -const ( - resolveUnsure resolveTypeStatus = iota - resolveOk - resolveFail -) - -type delayedValidator interface { - Validate(expr ast.Expr) resolveTypeStatus - Pos() token.Pos +type delayedTypeValidator interface { + Validate(t shaderir.Type) (shaderir.Type, bool) + IsValidated() (shaderir.Type, bool) Error() string } -func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) { - valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks)) - for k := range cs.delayedTypeCheks { - valExprs = append(valExprs, k) - } - for _, expr := range valExprs { - if cexpr == expr { - continue - } - // Check if delayed validation can be done by adding current context - cres := cs.delayedTypeCheks[expr].Validate(cexpr) - switch cres { - case resolveFail: - cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error()) - return false - case resolveOk: - delete(cs.delayedTypeCheks, expr) - } - } - - return true -} - -type delayedShiftValidator struct { - value gconstant.Value - pos token.Pos - last ast.Expr -} - func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool { return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F } -func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) { - switch cexpr.(type) { - case *ast.Ident: - ident := cexpr.(*ast.Ident) - // For BuiltinFunc, only int* are allowed - if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok { - if isArgDefaultTypeInt(fname) { - return resolveOk - } - return resolveFail - } - // Untyped constant must represent int - if ident.Name == "_" { - if d.value != nil && d.value.Kind() == gconstant.Int { - return resolveOk - } - return resolveFail - } - if ident.Obj != nil { - if t, ok := ident.Obj.Type.(*ast.Ident); ok { - return d.Validate(t) - } - if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok { - return d.Validate(decl.Type) - } - if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok { - if d.value != nil && d.value.Kind() == gconstant.Int { - return resolveOk - } - return resolveFail - } - } - case *ast.BinaryExpr: - bs := cexpr.(*ast.BinaryExpr) - left, right := bs.X, bs.Y - if bs.Y == d.last { - left, right = right, left - } +func isIntType(t shaderir.Type) bool { + return t.Main == shaderir.Int || t.IsIntVector() +} - rightCheck := d.Validate(right) - d.last = cexpr - return rightCheck +func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type { + if check, ok := cs.delayedTypeCheks[expr.Ast]; ok { + if resT, ok := check.IsValidated(); ok { + return resT + } + resT, ok := check.Validate(t) + if !ok { + return shaderir.Type{Main: shaderir.None} + } + return resT } - return resolveUnsure + + switch expr.Type { + case shaderir.LocalVariable: + return block.vars[expr.Index].typ + + case shaderir.Binary: + left := expr.Exprs[0] + right := expr.Exprs[1] + + leftType := cs.ValidateDefaultTypesForExpr(block, left, t) + rightType := cs.ValidateDefaultTypesForExpr(block, right, t) + + // Usure about top-level type, try to validate by neighbour type + // The same work is done twice. Can it be optimized? + if t.Main == shaderir.None { + cs.ValidateDefaultTypesForExpr(block, left, rightType) + cs.ValidateDefaultTypesForExpr(block, right, leftType) + } + case shaderir.Call: + fun := expr.Exprs[0] + if fun.Type == shaderir.BuiltinFuncExpr { + if isArgDefaultTypeInt(fun.BuiltinFunc) { + for _, e := range expr.Exprs[1:] { + cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int}) + } + return shaderir.Type{Main: shaderir.Int} + } + + for _, e := range expr.Exprs[1:] { + cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float}) + } + return shaderir.Type{Main: shaderir.Float} + } + + if fun.Type == shaderir.FunctionExpr { + args := cs.funcs[fun.Index].ir.InParams + + for i, e := range expr.Exprs[1:] { + cs.ValidateDefaultTypesForExpr(block, e, args[i]) + } + + retT := cs.funcs[fun.Index].ir.Return + + return retT + } + } + + return shaderir.Type{Main: shaderir.None} } -func (d delayedShiftValidator) Pos() token.Pos { - return d.pos +func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) { + switch stmt.Type { + case shaderir.Assign: + left := stmt.Exprs[0] + right := stmt.Exprs[1] + if left.Type == shaderir.LocalVariable { + varType := block.vars[left.Index].typ + // Type is not explicitly specified + if stmt.IsTypeGuessed { + varType = shaderir.Type{Main: shaderir.None} + } + cs.ValidateDefaultTypesForExpr(block, right, varType) + } + case shaderir.ExprStmt: + for _, e := range stmt.Exprs { + cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None}) + } + } } -func (d delayedShiftValidator) Error() string { - return "left shift operand should be int" +type delayedShiftValidator struct { + shiftType shaderir.Op + value gconstant.Value + validated bool + closestUnknown bool + failed bool +} + +func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) { + if d.failed { + return shaderir.Type{}, false + } + if d.validated { + return shaderir.Type{Main: shaderir.Int}, true + } + // If only matched with None + if d.closestUnknown { + // Was it originally represented by an int constant? + if d.value.Kind() == gconstant.Int { + return shaderir.Type{Main: shaderir.Int}, true + } + } + return shaderir.Type{}, false +} + +func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) { + if d.validated { + return shaderir.Type{Main: shaderir.Int}, true + } + if isIntType(t) { + d.validated = true + return shaderir.Type{Main: shaderir.Int}, true + } + if t.Main == shaderir.None { + d.closestUnknown = true + return t, true + } + return shaderir.Type{Main: shaderir.None}, false +} + +func (d *delayedShiftValidator) Error() string { + st := "left shift" + if d.shiftType == shaderir.RightShift { + st = "right shift" + } + return fmt.Sprintf("left operand for %s should be int", st) } diff --git a/internal/shader/expr.go b/internal/shader/expr.go index ca11d4e66..53d328b4e 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -37,11 +37,6 @@ func canTruncateToFloat(v gconstant.Value) bool { var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { - defer func() { - // Due to use of early return in the parsing, delayed checks are conducted in defer - ok = ok && cs.tryValidateDelayed(expr) - }() - switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -133,7 +128,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar lhst = shaderir.Type{Main: shaderir.Int} // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. if rhs[0].Const == nil { - cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr}) + defer func() { + if ok { + cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue}) + } + }() } } } @@ -202,6 +201,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar { Type: shaderir.Binary, Op: op2, + Ast: expr, Exprs: []shaderir.Expr{lhs[0], rhs[0]}, }, }, []shaderir.Type{t}, stmts, true diff --git a/internal/shader/shader.go b/internal/shader/shader.go index c89d43cde..017a7a81d 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -61,7 +61,7 @@ type compileState struct { varyingParsed bool - delayedTypeCheks map[ast.Expr]delayedValidator + delayedTypeCheks map[ast.Expr]delayedTypeValidator errs []string } @@ -84,9 +84,9 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) { return 0, false } -func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) { +func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) { if cs.delayedTypeCheks == nil { - cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1) + cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1) } cs.delayedTypeCheks[at] = check } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 6c07057da..a794a749c 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -49,6 +49,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP if !ok { return nil, false } + for i := range ss { + ss[i].IsTypeGuessed = true + } + stmts = append(stmts, ss...) case token.ASSIGN: if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 { @@ -473,6 +477,25 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) return nil, false } + + // Need to run delayed checks + if len(cs.delayedTypeCheks) != 0 { + for _, st := range stmts { + cs.ValidateDefaultTypes(block, st) + } + + // Collect all errors first + foundErr := false + for s, v := range cs.delayedTypeCheks { + if _, ok := v.IsValidated(); !ok { + foundErr = true + cs.addError(s.Pos(), v.Error()) + } + } + if foundErr { + return nil, false + } + } return stmts, true } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 2d78dd5a9..a355dfa79 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,9 +1320,35 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false}, - {stmt: "s := 1; _ = 1 << s", err: false}, - {stmt: "s := 1; _ = 1.0 << s", err: true}, + {stmt: "s := 1; a := 1.0<