This commit is contained in:
Mykhailo Lohachov 2024-03-20 23:20:24 +09:00 committed by GitHub
commit 0c5bd5d810
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 565 additions and 34 deletions

165
internal/shader/delayed.go Normal file
View File

@ -0,0 +1,165 @@
// Copyright 2024 The Ebiten Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shader
import (
"fmt"
gconstant "go/constant"
"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
)
type delayedTypeValidator interface {
Validate(t shaderir.Type) (shaderir.Type, bool)
IsValidated() (shaderir.Type, bool)
Error() string
}
func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool {
return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F
}
func isIntType(t shaderir.Type) bool {
return t.Main == shaderir.Int || t.IsIntVector()
}
func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type {
if check, ok := cs.delayedTypeCheks[expr.Ast]; ok {
if resT, ok := check.IsValidated(); ok {
return resT
}
resT, ok := check.Validate(t)
if !ok {
return shaderir.Type{Main: shaderir.None}
}
return resT
}
switch expr.Type {
case shaderir.LocalVariable:
return block.vars[expr.Index].typ
case shaderir.Binary:
left := expr.Exprs[0]
right := expr.Exprs[1]
leftType := cs.ValidateDefaultTypesForExpr(block, left, t)
rightType := cs.ValidateDefaultTypesForExpr(block, right, t)
// Usure about top-level type, try to validate by neighbour type
// The same work is done twice. Can it be optimized?
if t.Main == shaderir.None {
cs.ValidateDefaultTypesForExpr(block, left, rightType)
cs.ValidateDefaultTypesForExpr(block, right, leftType)
}
case shaderir.Call:
fun := expr.Exprs[0]
if fun.Type == shaderir.BuiltinFuncExpr {
if isArgDefaultTypeInt(fun.BuiltinFunc) {
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int})
}
return shaderir.Type{Main: shaderir.Int}
}
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float})
}
return shaderir.Type{Main: shaderir.Float}
}
if fun.Type == shaderir.FunctionExpr {
args := cs.funcs[fun.Index].ir.InParams
for i, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, args[i])
}
retT := cs.funcs[fun.Index].ir.Return
return retT
}
}
return shaderir.Type{Main: shaderir.None}
}
func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) {
switch stmt.Type {
case shaderir.Assign:
left := stmt.Exprs[0]
right := stmt.Exprs[1]
if left.Type == shaderir.LocalVariable {
varType := block.vars[left.Index].typ
// Type is not explicitly specified
if stmt.IsTypeGuessed {
varType = shaderir.Type{Main: shaderir.None}
}
cs.ValidateDefaultTypesForExpr(block, right, varType)
}
case shaderir.ExprStmt:
for _, e := range stmt.Exprs {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None})
}
}
}
type delayedShiftValidator struct {
shiftType shaderir.Op
value gconstant.Value
validated bool
closestUnknown bool
failed bool
}
func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) {
if d.failed {
return shaderir.Type{}, false
}
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
// If only matched with None
if d.closestUnknown {
// Was it originally represented by an int constant?
if d.value.Kind() == gconstant.Int {
return shaderir.Type{Main: shaderir.Int}, true
}
}
return shaderir.Type{}, false
}
func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) {
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
if isIntType(t) {
d.validated = true
return shaderir.Type{Main: shaderir.Int}, true
}
if t.Main == shaderir.None {
d.closestUnknown = true
return t, true
}
return shaderir.Type{Main: shaderir.None}, false
}
func (d *delayedShiftValidator) Error() string {
st := "left shift"
if d.shiftType == shaderir.RightShift {
st = "right shift"
}
return fmt.Sprintf("left operand for %s should be int", st)
}

View File

@ -36,7 +36,7 @@ func canTruncateToFloat(v gconstant.Value) bool {
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
@ -105,7 +105,14 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
// Resolve untyped constants. // Resolve untyped constants.
l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) var l gconstant.Value
var r gconstant.Value
origLvalue := lhs[0].Const
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
} else {
l, r, ok = shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
}
if !ok { if !ok {
// TODO: Show a better type name for untyped constants. // TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
@ -113,27 +120,49 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
lhs[0].Const, rhs[0].Const = l, r lhs[0].Const, rhs[0].Const = l, r
// If either is typed, resolve the other type. if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
// If both are untyped, keep them untyped. if !(lhst.Main == shaderir.None && rhst.Main == shaderir.None) {
if lhst.Main != shaderir.None || rhst.Main != shaderir.None { // If both are const
if lhs[0].Const != nil { if rhs[0].Const != nil && (rhst.Main == shaderir.None || lhs[0].Const != nil) {
switch lhs[0].Const.Kind() { rhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Float: }
lhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int: // If left is untyped const
if lhst.Main == shaderir.None && lhs[0].Const != nil {
lhst = shaderir.Type{Main: shaderir.Int} lhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool: // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone.
lhst = shaderir.Type{Main: shaderir.Bool} if rhs[0].Const == nil {
defer func() {
if ok {
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue})
}
}()
}
} }
} }
if rhs[0].Const != nil { } else {
switch rhs[0].Const.Kind() { // If either is typed, resolve the other type.
case gconstant.Float: // If both are untyped, keep them untyped.
rhst = shaderir.Type{Main: shaderir.Float} if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
case gconstant.Int: if lhs[0].Const != nil {
rhst = shaderir.Type{Main: shaderir.Int} switch lhs[0].Const.Kind() {
case gconstant.Bool: case gconstant.Float:
rhst = shaderir.Type{Main: shaderir.Bool} lhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
lhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
lhst = shaderir.Type{Main: shaderir.Bool}
}
}
if rhs[0].Const != nil {
switch rhs[0].Const.Kind() {
case gconstant.Float:
rhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
rhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
rhst = shaderir.Type{Main: shaderir.Bool}
}
} }
} }
} }
@ -153,6 +182,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
v = gconstant.MakeBool(b) v = gconstant.MakeBool(b)
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ: case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ:
v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const))
case token.SHL, token.SHR:
shift, ok := gconstant.Int64Val(rhs[0].Const)
if !ok {
cs.addError(e.Pos(), fmt.Sprintf("unexpected %s type for: %s", rhs[0].Const.String(), e.Op))
} else {
v = gconstant.Shift(lhs[0].Const, op, uint(shift))
}
default: default:
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
} }
@ -169,6 +205,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
{ {
Type: shaderir.Binary, Type: shaderir.Binary,
Op: op2, Op: op2,
Ast: expr,
Exprs: []shaderir.Expr{lhs[0], rhs[0]}, Exprs: []shaderir.Expr{lhs[0], rhs[0]},
}, },
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{t}, stmts, true

View File

@ -61,6 +61,8 @@ type compileState struct {
varyingParsed bool varyingParsed bool
delayedTypeCheks map[ast.Expr]delayedTypeValidator
errs []string errs []string
} }
@ -82,6 +84,13 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) {
return 0, false return 0, false
} }
func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) {
if cs.delayedTypeCheks == nil {
cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1)
}
cs.delayedTypeCheks[at] = check
}
type typ struct { type typ struct {
name string name string
ir shaderir.Type ir shaderir.Type

View File

@ -49,6 +49,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
if !ok { if !ok {
return nil, false return nil, false
} }
for i := range ss {
ss[i].IsTypeGuessed = true
}
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ASSIGN: case token.ASSIGN:
if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 { if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 {
@ -60,7 +64,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN: case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN:
rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true) rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true)
if !ok { if !ok {
return nil, false return nil, false
@ -100,6 +104,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
op = shaderir.Or op = shaderir.Or
case token.XOR_ASSIGN: case token.XOR_ASSIGN:
op = shaderir.Xor op = shaderir.Xor
case token.SHL_ASSIGN:
op = shaderir.LeftShift
case token.SHR_ASSIGN:
op = shaderir.RightShift
default: default:
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok))
return nil, false return nil, false
@ -110,7 +118,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String()))
return nil, false return nil, false
} }
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() { if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} }
@ -137,7 +145,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
} }
} }
case shaderir.Float: case shaderir.Float:
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} else if rhs[0].Const != nil && } else if rhs[0].Const != nil &&
(rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) && (rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) &&
@ -148,7 +156,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} else if (op == shaderir.MatrixMul || op == shaderir.Div) && } else if (op == shaderir.MatrixMul || op == shaderir.Div) &&
(rts[0].Main == shaderir.Float || (rts[0].Main == shaderir.Float ||
@ -469,6 +477,25 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt))
return nil, false return nil, false
} }
// Need to run delayed checks
if len(cs.delayedTypeCheks) != 0 {
for _, st := range stmts {
cs.ValidateDefaultTypes(block, st)
}
// Collect all errors first
foundErr := false
for s, v := range cs.delayedTypeCheks {
if _, ok := v.IsValidated(); !ok {
foundErr = true
cs.addError(s.Pos(), v.Error())
}
}
if foundErr {
return nil, false
}
}
return stmts, true return stmts, true
} }

View File

@ -1314,6 +1314,263 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
} }
} }
// Issue: #2755
func TestSyntaxOperatorShift(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "b := 2.0; a := 1.0 << 2.0 == b; _ = a", err: false},
{stmt: "s := 1; b := 2.0; a := 1.0<<s == b; _ = a", err: true},
{stmt: "s := 1; b := 2; a := 1.0<<s == b; _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + ivec2(3.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + vec2(3); _ = a", err: true},
{stmt: "s := 1; a := 2.0<<s + ivec2(3); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + foo_int_int(3.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + 3.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 2<<s + 3.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 2.0<<s + 3<<s; _ = a", err: true},
{stmt: "s := 1; a := 2<<s + 3<<s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 2<<s)", err: false},
{stmt: "s := 1; foo_multivar(0, 2.0<<s, 0)", err: true},
{stmt: "s := 1; foo_multivar(2.0<<s, 0, 0)", err: false},
{stmt: "s := 1; a := foo_multivar(2.0<<s, 0, 0); _ = a", err: false},
{stmt: "s := 1; a := foo_multivar(0, 2.0<<s, 0); _ = a", err: true},
{stmt: "s := 1; a := foo_multivar(0, 0, 2.0<<s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0<<2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0<<2.0, 0); _ = a", err: false},
{stmt: "s := 1; a := int(1) + 1.0<<s + int(float(1<<s)); _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 << 2.0 << 3.0 << 4.0 << s; _ = a", err: false},
{stmt: "s := 1; var a float = 1 << 1 << 1 << 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 1; a := foo_float_int(1.0<<s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 1; a := foo_int_float(1<<s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 1; a := foo_int_int(1<<s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 << s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 << s; _ = a", err: false},
{stmt: "s := 1; b := 1 << s; _ = b", err: false},
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
{stmt: "s := 1; var a int = int(1 << s); _ = a", err: false},
{stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false},
{stmt: "s := 1; a := 1 << s; _ = a", err: false},
{stmt: "s := 1; a := 1.0 << s; _ = a", err: true},
{stmt: "s := 1; a := int(1.0 << s); _ = a", err: false},
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
{stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 << s; _ = a", err: false},
{stmt: "var a float = 1.0 << 2.0; _ = a", err: false},
{stmt: "var a int = 1.0 << 2; _ = a", err: false},
{stmt: "var a float = 1.0 << 2; _ = a", err: false},
{stmt: "a := 1 << 2.0; _ = a", err: false},
{stmt: "a := 1.0 << 2; _ = a", err: false},
{stmt: "a := 1.0 << 2.0; _ = a", err: false},
{stmt: "a := 1 << 2; _ = a", err: false},
{stmt: "a := float(1.0) << 2; _ = a", err: true},
{stmt: "a := 1 << float(2.0); _ = a", err: false},
{stmt: "a := 1.0 << float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) << 2; _ = a", err: false},
{stmt: "a := 1 << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << float(2.0); _ = a", err: true},
{stmt: "a := float(1.0) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << ivec3(2); _ = a", err: true},
{stmt: "a := 1 << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << 2; _ = a", err: true},
{stmt: "a := float(1.0) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << float(2.0); _ = a", err: true},
{stmt: "a := vec2(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << vec3(2); _ = a", err: true},
{stmt: "a := vec3(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
{stmt: "b := 2.0; a := 1.0 >> 2.0 == b; _ = a", err: false},
{stmt: "s := 1; b := 2.0; a := 1.0>>s == b; _ = a", err: true},
{stmt: "s := 1; b := 2; a := 1.0>>s == b; _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + ivec2(3.0>>s); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + vec2(3); _ = a", err: true},
{stmt: "s := 1; a := 2.0>>s + ivec2(3); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + foo_int_int(3.0>>s); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + 3.0>>s; _ = a", err: true},
{stmt: "s := 1; a := 2>>s + 3.0>>s; _ = a", err: true},
{stmt: "s := 1; a := 2.0>>s + 3>>s; _ = a", err: true},
{stmt: "s := 1; a := 2>>s + 3>>s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 2>>s)", err: false},
{stmt: "s := 1; foo_multivar(0, 2.0>>s, 0)", err: true},
{stmt: "s := 1; foo_multivar(2.0>>s, 0, 0)", err: false},
{stmt: "s := 1; a := foo_multivar(2.0>>s, 0, 0); _ = a", err: false},
{stmt: "s := 1; a := foo_multivar(0, 2.0>>s, 0); _ = a", err: true},
{stmt: "s := 1; a := foo_multivar(0, 0, 2.0>>s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0>>2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0>>2.0, 0); _ = a", err: false},
{stmt: "s := 1; a := int(1) + 1.0>>s + int(float(1>>s)); _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 >> 2.0 >> 3.0 >> 4.0 >> s; _ = a", err: false},
{stmt: "s := 1; var a float = 1 >> 1 >> 1 >> 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 1; a := foo_float_int(1.0>>s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 1; a := foo_int_float(1>>s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 1; a := foo_int_int(1>>s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 >> s; _ = a", err: false},
{stmt: "s := 1; b := 1 >> s; _ = b", err: false},
{stmt: "var a = 1; b := a >> 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true},
{stmt: "s := 1; var a int = int(1 >> s); _ = a", err: false},
{stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false},
{stmt: "s := 1; a := 1 >> s; _ = a", err: false},
{stmt: "s := 1; a := 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; a := int(1.0 >> s); _ = a", err: false},
{stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true},
{stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 >> s; _ = a", err: false},
{stmt: "var a float = 1.0 >> 2.0; _ = a", err: false},
{stmt: "var a int = 1.0 >> 2; _ = a", err: false},
{stmt: "var a float = 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1 >> 2.0; _ = a", err: false},
{stmt: "a := 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1.0 >> 2.0; _ = a", err: false},
{stmt: "a := 1 >> 2; _ = a", err: false},
{stmt: "a := float(1.0) >> 2; _ = a", err: true},
{stmt: "a := 1 >> float(2.0); _ = a", err: false},
{stmt: "a := 1.0 >> float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) >> 2; _ = a", err: false},
{stmt: "a := 1 >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true},
{stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true},
{stmt: "a := 1 >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> 2; _ = a", err: true},
{stmt: "a := float(1.0) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> float(2.0); _ = a", err: true},
{stmt: "a := vec2(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> vec3(2); _ = a", err: true},
{stmt: "a := vec3(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true},
}
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func foo_multivar(x int, y float, z int) int {return x}
func foo_int_int(x int) int {return x}
func foo_float_int(x float) int {return int(x)}
func foo_int_float(x int) float {return float(x)}
func foo_float_float(x float) float {return x}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}
func TestSyntaxOperatorShiftAssign(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := 1; a <<= 2; _ = a", err: false},
{stmt: "a := 1; a <<= 2.0; _ = a", err: false},
{stmt: "a := float(1.0); a <<= 2; _ = a", err: true},
{stmt: "a := 1; a <<= float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= 2; _ = a", err: false},
{stmt: "a := 1; a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= float(2.0); _ = a", err: true},
{stmt: "a := float(1.0); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= ivec3(2); _ = a", err: true},
{stmt: "a := 1; a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= 2; _ = a", err: true},
{stmt: "a := float(1.0); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= float(2.0); _ = a", err: true},
{stmt: "a := vec2(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= vec3(2); _ = a", err: true},
{stmt: "a := vec3(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec3(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= vec3(2); _ = a", err: true},
{stmt: "const c = 2; a := 1; a <<= c; _ = a", err: false},
{stmt: "const c = 2.0; a := 1; a <<= c; _ = a", err: false},
{stmt: "const c = 2; a := float(1.0); a <<= c; _ = a", err: true},
{stmt: "const c float = 2; a := 1; a <<= c; _ = a", err: true},
{stmt: "const c float = 2.0; a := 1; a <<= c; _ = a", err: true},
{stmt: "const c int = 2; a := ivec2(1); a <<= c; _ = a", err: false},
{stmt: "const c int = 2; a := vec2(1); a <<= c; _ = a", err: true},
{stmt: "a := 1; a >>= 2; _ = a", err: false},
{stmt: "a := 1; a >>= 2.0; _ = a", err: false},
{stmt: "a := float(1.0); a >>= 2; _ = a", err: true},
{stmt: "a := 1; a >>= float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= 2; _ = a", err: false},
{stmt: "a := 1; a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= float(2.0); _ = a", err: true},
{stmt: "a := float(1.0); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= ivec3(2); _ = a", err: true},
{stmt: "a := 1; a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= 2; _ = a", err: true},
{stmt: "a := float(1.0); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= float(2.0); _ = a", err: true},
{stmt: "a := vec2(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= vec3(2); _ = a", err: true},
{stmt: "a := vec3(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec3(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= vec3(2); _ = a", err: true},
{stmt: "const c = 2; a := 1; a >>= c; _ = a", err: false},
{stmt: "const c = 2.0; a := 1; a >>= c; _ = a", err: false},
{stmt: "const c = 2; a := float(1.0); a >>= c; _ = a", err: true},
{stmt: "const c float = 2; a := 1; a >>= c; _ = a", err: true},
{stmt: "const c float = 2.0; a := 1; a >>= c; _ = a", err: true},
{stmt: "const c int = 2; a := ivec2(1); a >>= c; _ = a", err: false},
{stmt: "const c int = 2; a := vec2(1); a >>= c; _ = a", err: true},
}
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}
// Issue #1971 // Issue #1971
func TestSyntaxOperatorMultiplyAssign(t *testing.T) { func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
cases := []struct { cases := []struct {

View File

@ -18,6 +18,29 @@ import (
"go/constant" "go/constant"
) )
func ResolveUntypedConstsForBitShiftOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
cLhs := lhs
cRhs := rhs
// Right is const -> int
if rhs != nil {
cRhs = constant.ToInt(rhs)
if cRhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
// Left if untyped const -> int
if lhs != nil && lhst.Main == None {
cLhs = constant.ToInt(lhs)
if cLhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
return cLhs, cRhs, true
}
func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
if lhst.Main == None && rhst.Main == None { if lhst.Main == None && rhst.Main == None {
if lhs.Kind() == rhs.Kind() { if lhs.Kind() == rhs.Kind() {
@ -121,6 +144,16 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
panic("shaderir: cannot resolve untyped values") panic("shaderir: cannot resolve untyped values")
} }
if op == LeftShift || op == RightShift {
if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return lhst, true
}
return Type{}, false
}
if op == AndAnd || op == OrOr { if op == AndAnd || op == OrOr {
if lhst.Main == Bool && rhst.Main == Bool { if lhst.Main == Bool && rhst.Main == Bool {
return Type{Main: Bool}, true return Type{Main: Bool}, true

View File

@ -16,6 +16,7 @@
package shaderir package shaderir
import ( import (
"go/ast"
"go/constant" "go/constant"
"go/token" "go/token"
"sort" "sort"
@ -74,16 +75,17 @@ type Block struct {
} }
type Stmt struct { type Stmt struct {
Type StmtType Type StmtType
Exprs []Expr Exprs []Expr
Blocks []*Block Blocks []*Block
ForVarType Type ForVarType Type
ForVarIndex int ForVarIndex int
ForInit constant.Value ForInit constant.Value
ForEnd constant.Value ForEnd constant.Value
ForOp Op ForOp Op
ForDelta constant.Value ForDelta constant.Value
InitIndex int InitIndex int
IsTypeGuessed bool
} }
type StmtType int type StmtType int
@ -109,6 +111,7 @@ type Expr struct {
Swizzling string Swizzling string
Index int Index int
Op Op Op Op
Ast ast.Expr
} }
type ExprType int type ExprType int