add shift type checks

This commit is contained in:
aoyako 2024-03-02 15:32:31 +09:00
parent f44640778d
commit f02e9fd4d0
7 changed files with 211 additions and 191 deletions

View File

@ -15,110 +15,151 @@
package shader package shader
import ( import (
"go/ast" "fmt"
gconstant "go/constant" gconstant "go/constant"
"go/token"
"github.com/hajimehoshi/ebiten/v2/internal/shaderir" "github.com/hajimehoshi/ebiten/v2/internal/shaderir"
) )
type resolveTypeStatus int type delayedTypeValidator interface {
Validate(t shaderir.Type) (shaderir.Type, bool)
const ( IsValidated() (shaderir.Type, bool)
resolveUnsure resolveTypeStatus = iota
resolveOk
resolveFail
)
type delayedValidator interface {
Validate(expr ast.Expr) resolveTypeStatus
Pos() token.Pos
Error() string Error() string
} }
func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) {
valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks))
for k := range cs.delayedTypeCheks {
valExprs = append(valExprs, k)
}
for _, expr := range valExprs {
if cexpr == expr {
continue
}
// Check if delayed validation can be done by adding current context
cres := cs.delayedTypeCheks[expr].Validate(cexpr)
switch cres {
case resolveFail:
cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error())
return false
case resolveOk:
delete(cs.delayedTypeCheks, expr)
}
}
return true
}
type delayedShiftValidator struct {
value gconstant.Value
pos token.Pos
last ast.Expr
}
func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool { func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool {
return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F
} }
func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) { func isIntType(t shaderir.Type) bool {
switch cexpr.(type) { return t.Main == shaderir.Int || t.IsIntVector()
case *ast.Ident: }
ident := cexpr.(*ast.Ident)
// For BuiltinFunc, only int* are allowed
if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok {
if isArgDefaultTypeInt(fname) {
return resolveOk
}
return resolveFail
}
// Untyped constant must represent int
if ident.Name == "_" {
if d.value != nil && d.value.Kind() == gconstant.Int {
return resolveOk
}
return resolveFail
}
if ident.Obj != nil {
if t, ok := ident.Obj.Type.(*ast.Ident); ok {
return d.Validate(t)
}
if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok {
return d.Validate(decl.Type)
}
if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok {
if d.value != nil && d.value.Kind() == gconstant.Int {
return resolveOk
}
return resolveFail
}
}
case *ast.BinaryExpr:
bs := cexpr.(*ast.BinaryExpr)
left, right := bs.X, bs.Y
if bs.Y == d.last {
left, right = right, left
}
rightCheck := d.Validate(right) func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type {
d.last = cexpr if check, ok := cs.delayedTypeCheks[expr.Ast]; ok {
return rightCheck if resT, ok := check.IsValidated(); ok {
return resT
}
resT, ok := check.Validate(t)
if !ok {
return shaderir.Type{Main: shaderir.None}
}
return resT
} }
return resolveUnsure
switch expr.Type {
case shaderir.LocalVariable:
return block.vars[expr.Index].typ
case shaderir.Binary:
left := expr.Exprs[0]
right := expr.Exprs[1]
leftType := cs.ValidateDefaultTypesForExpr(block, left, t)
rightType := cs.ValidateDefaultTypesForExpr(block, right, t)
// Usure about top-level type, try to validate by neighbour type
// The same work is done twice. Can it be optimized?
if t.Main == shaderir.None {
cs.ValidateDefaultTypesForExpr(block, left, rightType)
cs.ValidateDefaultTypesForExpr(block, right, leftType)
}
case shaderir.Call:
fun := expr.Exprs[0]
if fun.Type == shaderir.BuiltinFuncExpr {
if isArgDefaultTypeInt(fun.BuiltinFunc) {
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int})
}
return shaderir.Type{Main: shaderir.Int}
}
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float})
}
return shaderir.Type{Main: shaderir.Float}
}
if fun.Type == shaderir.FunctionExpr {
args := cs.funcs[fun.Index].ir.InParams
for i, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, args[i])
}
retT := cs.funcs[fun.Index].ir.Return
return retT
}
}
return shaderir.Type{Main: shaderir.None}
} }
func (d delayedShiftValidator) Pos() token.Pos { func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) {
return d.pos switch stmt.Type {
case shaderir.Assign:
left := stmt.Exprs[0]
right := stmt.Exprs[1]
if left.Type == shaderir.LocalVariable {
varType := block.vars[left.Index].typ
// Type is not explicitly specified
if stmt.IsTypeGuessed {
varType = shaderir.Type{Main: shaderir.None}
}
cs.ValidateDefaultTypesForExpr(block, right, varType)
}
case shaderir.ExprStmt:
for _, e := range stmt.Exprs {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None})
}
}
} }
func (d delayedShiftValidator) Error() string { type delayedShiftValidator struct {
return "left shift operand should be int" shiftType shaderir.Op
value gconstant.Value
validated bool
closestUnknown bool
failed bool
}
func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) {
if d.failed {
return shaderir.Type{}, false
}
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
// If only matched with None
if d.closestUnknown {
// Was it originally represented by an int constant?
if d.value.Kind() == gconstant.Int {
return shaderir.Type{Main: shaderir.Int}, true
}
}
return shaderir.Type{}, false
}
func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) {
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
if isIntType(t) {
d.validated = true
return shaderir.Type{Main: shaderir.Int}, true
}
if t.Main == shaderir.None {
d.closestUnknown = true
return t, true
}
return shaderir.Type{Main: shaderir.None}, false
}
func (d *delayedShiftValidator) Error() string {
st := "left shift"
if d.shiftType == shaderir.RightShift {
st = "right shift"
}
return fmt.Sprintf("left operand for %s should be int", st)
} }

View File

@ -37,11 +37,6 @@ func canTruncateToFloat(v gconstant.Value) bool {
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) {
defer func() {
// Due to use of early return in the parsing, delayed checks are conducted in defer
ok = ok && cs.tryValidateDelayed(expr)
}()
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
@ -133,7 +128,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
lhst = shaderir.Type{Main: shaderir.Int} lhst = shaderir.Type{Main: shaderir.Int}
// Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone.
if rhs[0].Const == nil { if rhs[0].Const == nil {
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr}) defer func() {
if ok {
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue})
}
}()
} }
} }
} }
@ -202,6 +201,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
{ {
Type: shaderir.Binary, Type: shaderir.Binary,
Op: op2, Op: op2,
Ast: expr,
Exprs: []shaderir.Expr{lhs[0], rhs[0]}, Exprs: []shaderir.Expr{lhs[0], rhs[0]},
}, },
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{t}, stmts, true

View File

@ -61,7 +61,7 @@ type compileState struct {
varyingParsed bool varyingParsed bool
delayedTypeCheks map[ast.Expr]delayedValidator delayedTypeCheks map[ast.Expr]delayedTypeValidator
errs []string errs []string
} }
@ -84,9 +84,9 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) {
return 0, false return 0, false
} }
func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) { func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) {
if cs.delayedTypeCheks == nil { if cs.delayedTypeCheks == nil {
cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1) cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1)
} }
cs.delayedTypeCheks[at] = check cs.delayedTypeCheks[at] = check
} }

View File

@ -49,6 +49,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
if !ok { if !ok {
return nil, false return nil, false
} }
for i := range ss {
ss[i].IsTypeGuessed = true
}
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ASSIGN: case token.ASSIGN:
if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 { if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 {
@ -473,6 +477,25 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt))
return nil, false return nil, false
} }
// Need to run delayed checks
if len(cs.delayedTypeCheks) != 0 {
for _, st := range stmts {
cs.ValidateDefaultTypes(block, st)
}
// Collect all errors first
foundErr := false
for s, v := range cs.delayedTypeCheks {
if _, ok := v.IsValidated(); !ok {
foundErr = true
cs.addError(s.Pos(), v.Error())
}
}
if foundErr {
return nil, false
}
}
return stmts, true return stmts, true
} }

View File

@ -1320,9 +1320,35 @@ func TestSyntaxOperatorShift(t *testing.T) {
stmt string stmt string
err bool err bool
}{ }{
{stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false}, {stmt: "s := 1; a := 1.0<<s + vec2(1); _ = a", err: true},
{stmt: "s := 1; _ = 1 << s", err: false}, {stmt: "s := 1; a := 1.0<<s + ivec2(1); _ = a", err: false},
{stmt: "s := 1; _ = 1.0 << s", err: true}, {stmt: "s := 1; a := 1.0<<s + foo_int_int(1.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 1.0<<s + 1.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 1<<s + 1.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 1.0<<s + 1<<s; _ = a", err: true},
{stmt: "s := 1; a := 1<<s + 1<<s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 1<<s)", err: false},
{stmt: "s := 1; foo_multivar(0, 1.0<<s, 0)", err: true},
{stmt: "s := 1; foo_multivar(1.0<<s, 0, 0)", err: false},
{stmt: "s := 12; a := foo_multivar(1.0<<s, 0, 0); _ = a", err: false},
{stmt: "s := 12; a := foo_multivar(0, 1.0<<s, 0); _ = a", err: true},
{stmt: "s := 12; a := foo_multivar(0, 0, 1.0<<s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0<<2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0<<2.0, 0); _ = a", err: false},
{stmt: "s := 12; var a int = 1.0 << 2.0 << 3.0 << 4.0 << s; _ = a", err: false},
{stmt: "s := 12; var a float = 1 << 1 << 1 << 1 << s; _ = a", err: true},
{stmt: "s := 12; var a float = 1 << s + 1.2; _ = a", err: true},
{stmt: "s := 12; a := 1.0 << s + 1.2; _ = a", err: true},
{stmt: "s := 12; a := 1.0 << s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 12; a := 1.0 << s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 12; a := foo_float_int(1.0<<s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 12; a := foo_int_float(1<<s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 12; a := foo_int_int(1<<s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 << s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 << s; _ = a", err: false},
{stmt: "s := 1; b := 1 << s; _ = b", err: false},
{stmt: "s := 1; a = 1.0 << s; _ = a", err: true},
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false}, {stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true}, {stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, {stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
@ -1370,7 +1396,11 @@ func TestSyntaxOperatorShift(t *testing.T) {
for _, c := range cases { for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main _, err := compileToIR([]byte(fmt.Sprintf(`package main
func foo_multivar(x int, y float, z int) int {return x}
func foo_int_int(x int) int {return x}
func foo_float_int(x float) int {return int(x)}
func foo_int_float(x int) float {return float(x)}
func foo_float_float(x float) float {return x}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s %s
return dstPos return dstPos
@ -1381,80 +1411,6 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
t.Errorf("%s must not return nil but returned %v", c.stmt, err) t.Errorf("%s must not return nil but returned %v", c.stmt, err)
} }
} }
casesFunc := []struct {
prog string
err bool
}{
{
prog: `package main
func Foo(x int) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1 << s)
return dstPos
}`,
err: false,
},
{
prog: `package main
func Foo(x int) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1.0 << s)
return dstPos
}`,
err: false,
},
{
prog: `package main
func Foo(x float) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1 << s)
return dstPos
}`,
err: true,
},
{
prog: `package main
func Foo(x float) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1 << s)
return dstPos
}`,
err: true,
},
{
prog: `package main
func Foo(x float) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
Foo(1.0 << 2.0)
return dstPos
}`,
err: false,
},
}
for _, c := range casesFunc {
_, err := compileToIR([]byte(c.prog))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.prog)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.prog, err)
}
}
} }
func TestSyntaxOperatorShiftAssign(t *testing.T) { func TestSyntaxOperatorShiftAssign(t *testing.T) {

View File

@ -145,14 +145,11 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
} }
if op == LeftShift || op == RightShift { if op == LeftShift || op == RightShift {
if lhst.Main == Int && rhst.Main == Int { if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return Type{Main: lhst.Main}, true return lhst, true
}
if lhst.IsIntVector() && rhst.Main == Int {
return Type{Main: lhst.Main}, true
} }
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() { if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return Type{Main: lhst.Main}, true return lhst, true
} }
return Type{}, false return Type{}, false
} }

View File

@ -16,6 +16,7 @@
package shaderir package shaderir
import ( import (
"go/ast"
"go/constant" "go/constant"
"go/token" "go/token"
"sort" "sort"
@ -74,16 +75,17 @@ type Block struct {
} }
type Stmt struct { type Stmt struct {
Type StmtType Type StmtType
Exprs []Expr Exprs []Expr
Blocks []*Block Blocks []*Block
ForVarType Type ForVarType Type
ForVarIndex int ForVarIndex int
ForInit constant.Value ForInit constant.Value
ForEnd constant.Value ForEnd constant.Value
ForOp Op ForOp Op
ForDelta constant.Value ForDelta constant.Value
InitIndex int InitIndex int
IsTypeGuessed bool
} }
type StmtType int type StmtType int
@ -109,6 +111,7 @@ type Expr struct {
Swizzling string Swizzling string
Index int Index int
Op Op Op Op
Ast ast.Expr
} }
type ExprType int type ExprType int