add shift type checks

This commit is contained in:
aoyako 2024-03-02 15:32:31 +09:00
parent f44640778d
commit f02e9fd4d0
7 changed files with 211 additions and 191 deletions

View File

@ -15,110 +15,151 @@
package shader
import (
"go/ast"
"fmt"
gconstant "go/constant"
"go/token"
"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
)
type resolveTypeStatus int
const (
resolveUnsure resolveTypeStatus = iota
resolveOk
resolveFail
)
type delayedValidator interface {
Validate(expr ast.Expr) resolveTypeStatus
Pos() token.Pos
type delayedTypeValidator interface {
Validate(t shaderir.Type) (shaderir.Type, bool)
IsValidated() (shaderir.Type, bool)
Error() string
}
func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) {
valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks))
for k := range cs.delayedTypeCheks {
valExprs = append(valExprs, k)
}
for _, expr := range valExprs {
if cexpr == expr {
continue
}
// Check if delayed validation can be done by adding current context
cres := cs.delayedTypeCheks[expr].Validate(cexpr)
switch cres {
case resolveFail:
cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error())
return false
case resolveOk:
delete(cs.delayedTypeCheks, expr)
}
}
return true
}
type delayedShiftValidator struct {
value gconstant.Value
pos token.Pos
last ast.Expr
}
func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool {
return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F
}
func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) {
switch cexpr.(type) {
case *ast.Ident:
ident := cexpr.(*ast.Ident)
// For BuiltinFunc, only int* are allowed
if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok {
if isArgDefaultTypeInt(fname) {
return resolveOk
}
return resolveFail
}
// Untyped constant must represent int
if ident.Name == "_" {
if d.value != nil && d.value.Kind() == gconstant.Int {
return resolveOk
}
return resolveFail
}
if ident.Obj != nil {
if t, ok := ident.Obj.Type.(*ast.Ident); ok {
return d.Validate(t)
}
if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok {
return d.Validate(decl.Type)
}
if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok {
if d.value != nil && d.value.Kind() == gconstant.Int {
return resolveOk
}
return resolveFail
}
}
case *ast.BinaryExpr:
bs := cexpr.(*ast.BinaryExpr)
left, right := bs.X, bs.Y
if bs.Y == d.last {
left, right = right, left
}
func isIntType(t shaderir.Type) bool {
return t.Main == shaderir.Int || t.IsIntVector()
}
rightCheck := d.Validate(right)
d.last = cexpr
return rightCheck
func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type {
if check, ok := cs.delayedTypeCheks[expr.Ast]; ok {
if resT, ok := check.IsValidated(); ok {
return resT
}
resT, ok := check.Validate(t)
if !ok {
return shaderir.Type{Main: shaderir.None}
}
return resT
}
return resolveUnsure
switch expr.Type {
case shaderir.LocalVariable:
return block.vars[expr.Index].typ
case shaderir.Binary:
left := expr.Exprs[0]
right := expr.Exprs[1]
leftType := cs.ValidateDefaultTypesForExpr(block, left, t)
rightType := cs.ValidateDefaultTypesForExpr(block, right, t)
// Usure about top-level type, try to validate by neighbour type
// The same work is done twice. Can it be optimized?
if t.Main == shaderir.None {
cs.ValidateDefaultTypesForExpr(block, left, rightType)
cs.ValidateDefaultTypesForExpr(block, right, leftType)
}
case shaderir.Call:
fun := expr.Exprs[0]
if fun.Type == shaderir.BuiltinFuncExpr {
if isArgDefaultTypeInt(fun.BuiltinFunc) {
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int})
}
return shaderir.Type{Main: shaderir.Int}
}
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float})
}
return shaderir.Type{Main: shaderir.Float}
}
if fun.Type == shaderir.FunctionExpr {
args := cs.funcs[fun.Index].ir.InParams
for i, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, args[i])
}
retT := cs.funcs[fun.Index].ir.Return
return retT
}
}
return shaderir.Type{Main: shaderir.None}
}
func (d delayedShiftValidator) Pos() token.Pos {
return d.pos
func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) {
switch stmt.Type {
case shaderir.Assign:
left := stmt.Exprs[0]
right := stmt.Exprs[1]
if left.Type == shaderir.LocalVariable {
varType := block.vars[left.Index].typ
// Type is not explicitly specified
if stmt.IsTypeGuessed {
varType = shaderir.Type{Main: shaderir.None}
}
cs.ValidateDefaultTypesForExpr(block, right, varType)
}
case shaderir.ExprStmt:
for _, e := range stmt.Exprs {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None})
}
}
}
func (d delayedShiftValidator) Error() string {
return "left shift operand should be int"
type delayedShiftValidator struct {
shiftType shaderir.Op
value gconstant.Value
validated bool
closestUnknown bool
failed bool
}
func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) {
if d.failed {
return shaderir.Type{}, false
}
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
// If only matched with None
if d.closestUnknown {
// Was it originally represented by an int constant?
if d.value.Kind() == gconstant.Int {
return shaderir.Type{Main: shaderir.Int}, true
}
}
return shaderir.Type{}, false
}
func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) {
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
if isIntType(t) {
d.validated = true
return shaderir.Type{Main: shaderir.Int}, true
}
if t.Main == shaderir.None {
d.closestUnknown = true
return t, true
}
return shaderir.Type{Main: shaderir.None}, false
}
func (d *delayedShiftValidator) Error() string {
st := "left shift"
if d.shiftType == shaderir.RightShift {
st = "right shift"
}
return fmt.Sprintf("left operand for %s should be int", st)
}

View File

@ -37,11 +37,6 @@ func canTruncateToFloat(v gconstant.Value) bool {
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) {
defer func() {
// Due to use of early return in the parsing, delayed checks are conducted in defer
ok = ok && cs.tryValidateDelayed(expr)
}()
switch e := expr.(type) {
case *ast.BasicLit:
switch e.Kind {
@ -133,7 +128,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
lhst = shaderir.Type{Main: shaderir.Int}
// Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone.
if rhs[0].Const == nil {
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr})
defer func() {
if ok {
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue})
}
}()
}
}
}
@ -202,6 +201,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
{
Type: shaderir.Binary,
Op: op2,
Ast: expr,
Exprs: []shaderir.Expr{lhs[0], rhs[0]},
},
}, []shaderir.Type{t}, stmts, true

View File

@ -61,7 +61,7 @@ type compileState struct {
varyingParsed bool
delayedTypeCheks map[ast.Expr]delayedValidator
delayedTypeCheks map[ast.Expr]delayedTypeValidator
errs []string
}
@ -84,9 +84,9 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) {
return 0, false
}
func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) {
func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) {
if cs.delayedTypeCheks == nil {
cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1)
cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1)
}
cs.delayedTypeCheks[at] = check
}

View File

@ -49,6 +49,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
if !ok {
return nil, false
}
for i := range ss {
ss[i].IsTypeGuessed = true
}
stmts = append(stmts, ss...)
case token.ASSIGN:
if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 {
@ -473,6 +477,25 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt))
return nil, false
}
// Need to run delayed checks
if len(cs.delayedTypeCheks) != 0 {
for _, st := range stmts {
cs.ValidateDefaultTypes(block, st)
}
// Collect all errors first
foundErr := false
for s, v := range cs.delayedTypeCheks {
if _, ok := v.IsValidated(); !ok {
foundErr = true
cs.addError(s.Pos(), v.Error())
}
}
if foundErr {
return nil, false
}
}
return stmts, true
}

View File

@ -1320,9 +1320,35 @@ func TestSyntaxOperatorShift(t *testing.T) {
stmt string
err bool
}{
{stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false},
{stmt: "s := 1; _ = 1 << s", err: false},
{stmt: "s := 1; _ = 1.0 << s", err: true},
{stmt: "s := 1; a := 1.0<<s + vec2(1); _ = a", err: true},
{stmt: "s := 1; a := 1.0<<s + ivec2(1); _ = a", err: false},
{stmt: "s := 1; a := 1.0<<s + foo_int_int(1.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 1.0<<s + 1.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 1<<s + 1.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 1.0<<s + 1<<s; _ = a", err: true},
{stmt: "s := 1; a := 1<<s + 1<<s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 1<<s)", err: false},
{stmt: "s := 1; foo_multivar(0, 1.0<<s, 0)", err: true},
{stmt: "s := 1; foo_multivar(1.0<<s, 0, 0)", err: false},
{stmt: "s := 12; a := foo_multivar(1.0<<s, 0, 0); _ = a", err: false},
{stmt: "s := 12; a := foo_multivar(0, 1.0<<s, 0); _ = a", err: true},
{stmt: "s := 12; a := foo_multivar(0, 0, 1.0<<s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0<<2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0<<2.0, 0); _ = a", err: false},
{stmt: "s := 12; var a int = 1.0 << 2.0 << 3.0 << 4.0 << s; _ = a", err: false},
{stmt: "s := 12; var a float = 1 << 1 << 1 << 1 << s; _ = a", err: true},
{stmt: "s := 12; var a float = 1 << s + 1.2; _ = a", err: true},
{stmt: "s := 12; a := 1.0 << s + 1.2; _ = a", err: true},
{stmt: "s := 12; a := 1.0 << s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 12; a := 1.0 << s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 12; a := foo_float_int(1.0<<s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 12; a := foo_int_float(1<<s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 12; a := foo_int_int(1<<s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 << s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 << s; _ = a", err: false},
{stmt: "s := 1; b := 1 << s; _ = b", err: false},
{stmt: "s := 1; a = 1.0 << s; _ = a", err: true},
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
@ -1370,7 +1396,11 @@ func TestSyntaxOperatorShift(t *testing.T) {
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func foo_multivar(x int, y float, z int) int {return x}
func foo_int_int(x int) int {return x}
func foo_float_int(x float) int {return int(x)}
func foo_int_float(x int) float {return float(x)}
func foo_float_float(x float) float {return x}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
@ -1381,80 +1411,6 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
casesFunc := []struct {
prog string
err bool
}{
{
prog: `package main
func Foo(x int) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1 << s)
return dstPos
}`,
err: false,
},
{
prog: `package main
func Foo(x int) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1.0 << s)
return dstPos
}`,
err: false,
},
{
prog: `package main
func Foo(x float) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1 << s)
return dstPos
}`,
err: true,
},
{
prog: `package main
func Foo(x float) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
s := 1
Foo(1 << s)
return dstPos
}`,
err: true,
},
{
prog: `package main
func Foo(x float) {
_ = x
}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
Foo(1.0 << 2.0)
return dstPos
}`,
err: false,
},
}
for _, c := range casesFunc {
_, err := compileToIR([]byte(c.prog))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.prog)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.prog, err)
}
}
}
func TestSyntaxOperatorShiftAssign(t *testing.T) {

View File

@ -145,14 +145,11 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
}
if op == LeftShift || op == RightShift {
if lhst.Main == Int && rhst.Main == Int {
return Type{Main: lhst.Main}, true
}
if lhst.IsIntVector() && rhst.Main == Int {
return Type{Main: lhst.Main}, true
if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return Type{Main: lhst.Main}, true
return lhst, true
}
return Type{}, false
}

View File

@ -16,6 +16,7 @@
package shaderir
import (
"go/ast"
"go/constant"
"go/token"
"sort"
@ -74,16 +75,17 @@ type Block struct {
}
type Stmt struct {
Type StmtType
Exprs []Expr
Blocks []*Block
ForVarType Type
ForVarIndex int
ForInit constant.Value
ForEnd constant.Value
ForOp Op
ForDelta constant.Value
InitIndex int
Type StmtType
Exprs []Expr
Blocks []*Block
ForVarType Type
ForVarIndex int
ForInit constant.Value
ForEnd constant.Value
ForOp Op
ForDelta constant.Value
InitIndex int
IsTypeGuessed bool
}
type StmtType int
@ -109,6 +111,7 @@ type Expr struct {
Swizzling string
Index int
Op Op
Ast ast.Expr
}
type ExprType int