Compare commits

...

8 Commits

Author SHA1 Message Date
aoyako
b0a7fb36c7 simplify type deduction in TypeFromBinaryOp for shift op 2024-03-21 17:28:32 +09:00
aoyako
706011c275 add: new binary shift operator rules 2024-03-21 17:21:11 +09:00
aoyako
49682f1097 Revert "add return type for type resolving"
This reverts commit 7f9d997175.
2024-03-21 17:14:32 +09:00
aoyako
8d15a459cf Revert "remove return type for deduced int"
This reverts commit 66a4b20bda.
2024-03-21 17:14:24 +09:00
aoyako
ced2d6ec8b Revert "add basic checks"
This reverts commit f44640778d.
2024-03-21 17:14:16 +09:00
aoyako
0651a09052 Revert "add shift type checks"
This reverts commit f02e9fd4d0.
2024-03-21 17:14:04 +09:00
aoyako
4f3e649bf0 Revert "update tests for right shift"
This reverts commit d1b9216ee1.
2024-03-21 17:13:34 +09:00
aoyako
29fb6c7f6f Revert "remove comment"
This reverts commit 359e7b8597.
2024-03-21 17:13:23 +09:00
7 changed files with 70 additions and 403 deletions

View File

@ -1,165 +0,0 @@
// Copyright 2024 The Ebiten Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shader
import (
"fmt"
gconstant "go/constant"
"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
)
type delayedTypeValidator interface {
Validate(t shaderir.Type) (shaderir.Type, bool)
IsValidated() (shaderir.Type, bool)
Error() string
}
func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool {
return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F
}
func isIntType(t shaderir.Type) bool {
return t.Main == shaderir.Int || t.IsIntVector()
}
func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type {
if check, ok := cs.delayedTypeCheks[expr.Ast]; ok {
if resT, ok := check.IsValidated(); ok {
return resT
}
resT, ok := check.Validate(t)
if !ok {
return shaderir.Type{Main: shaderir.None}
}
return resT
}
switch expr.Type {
case shaderir.LocalVariable:
return block.vars[expr.Index].typ
case shaderir.Binary:
left := expr.Exprs[0]
right := expr.Exprs[1]
leftType := cs.ValidateDefaultTypesForExpr(block, left, t)
rightType := cs.ValidateDefaultTypesForExpr(block, right, t)
// Usure about top-level type, try to validate by neighbour type
// The same work is done twice. Can it be optimized?
if t.Main == shaderir.None {
cs.ValidateDefaultTypesForExpr(block, left, rightType)
cs.ValidateDefaultTypesForExpr(block, right, leftType)
}
case shaderir.Call:
fun := expr.Exprs[0]
if fun.Type == shaderir.BuiltinFuncExpr {
if isArgDefaultTypeInt(fun.BuiltinFunc) {
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int})
}
return shaderir.Type{Main: shaderir.Int}
}
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float})
}
return shaderir.Type{Main: shaderir.Float}
}
if fun.Type == shaderir.FunctionExpr {
args := cs.funcs[fun.Index].ir.InParams
for i, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, args[i])
}
retT := cs.funcs[fun.Index].ir.Return
return retT
}
}
return shaderir.Type{Main: shaderir.None}
}
func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) {
switch stmt.Type {
case shaderir.Assign:
left := stmt.Exprs[0]
right := stmt.Exprs[1]
if left.Type == shaderir.LocalVariable {
varType := block.vars[left.Index].typ
// Type is not explicitly specified
if stmt.IsTypeGuessed {
varType = shaderir.Type{Main: shaderir.None}
}
cs.ValidateDefaultTypesForExpr(block, right, varType)
}
case shaderir.ExprStmt:
for _, e := range stmt.Exprs {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None})
}
}
}
type delayedShiftValidator struct {
shiftType shaderir.Op
value gconstant.Value
validated bool
closestUnknown bool
failed bool
}
func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) {
if d.failed {
return shaderir.Type{}, false
}
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
// If only matched with None
if d.closestUnknown {
// Was it originally represented by an int constant?
if d.value.Kind() == gconstant.Int {
return shaderir.Type{Main: shaderir.Int}, true
}
}
return shaderir.Type{}, false
}
func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) {
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
if isIntType(t) {
d.validated = true
return shaderir.Type{Main: shaderir.Int}, true
}
if t.Main == shaderir.None {
d.closestUnknown = true
return t, true
}
return shaderir.Type{Main: shaderir.None}, false
}
func (d *delayedShiftValidator) Error() string {
st := "left shift"
if d.shiftType == shaderir.RightShift {
st = "right shift"
}
return fmt.Sprintf("left operand for %s should be int", st)
}

View File

@ -36,7 +36,7 @@ func canTruncateToFloat(v gconstant.Value) bool {
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
@ -101,14 +101,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
// Resolve untyped constants. // Resolve untyped constants.
var l gconstant.Value l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(op2, lhs[0].Const, rhs[0].Const, lhst, rhst)
var r gconstant.Value
origLvalue := lhs[0].Const
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
} else {
l, r, ok = shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
}
if !ok { if !ok {
// TODO: Show a better type name for untyped constants. // TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
@ -116,27 +109,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
lhs[0].Const, rhs[0].Const = l, r lhs[0].Const, rhs[0].Const = l, r
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
if !(lhst.Main == shaderir.None && rhst.Main == shaderir.None) {
// If both are const
if rhs[0].Const != nil && (rhst.Main == shaderir.None || lhs[0].Const != nil) {
rhst = shaderir.Type{Main: shaderir.Int}
}
// If left is untyped const
if lhst.Main == shaderir.None && lhs[0].Const != nil {
lhst = shaderir.Type{Main: shaderir.Int}
// Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone.
if rhs[0].Const == nil {
defer func() {
if ok {
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue})
}
}()
}
}
}
} else {
// If either is typed, resolve the other type. // If either is typed, resolve the other type.
// If both are untyped, keep them untyped. // If both are untyped, keep them untyped.
if lhst.Main != shaderir.None || rhst.Main != shaderir.None { if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
@ -161,7 +133,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
} }
} }
}
t, ok := shaderir.TypeFromBinaryOp(op2, lhst, rhst, lhs[0].Const, rhs[0].Const) t, ok := shaderir.TypeFromBinaryOp(op2, lhst, rhst, lhs[0].Const, rhs[0].Const)
if !ok { if !ok {
@ -201,7 +172,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
{ {
Type: shaderir.Binary, Type: shaderir.Binary,
Op: op2, Op: op2,
Ast: expr,
Exprs: []shaderir.Expr{lhs[0], rhs[0]}, Exprs: []shaderir.Expr{lhs[0], rhs[0]},
}, },
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{t}, stmts, true

View File

@ -61,8 +61,6 @@ type compileState struct {
varyingParsed bool varyingParsed bool
delayedTypeCheks map[ast.Expr]delayedTypeValidator
errs []string errs []string
} }
@ -84,13 +82,6 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) {
return 0, false return 0, false
} }
func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) {
if cs.delayedTypeCheks == nil {
cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1)
}
cs.delayedTypeCheks[at] = check
}
type typ struct { type typ struct {
name string name string
ir shaderir.Type ir shaderir.Type

View File

@ -49,10 +49,6 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
if !ok { if !ok {
return nil, false return nil, false
} }
for i := range ss {
ss[i].IsTypeGuessed = true
}
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ASSIGN: case token.ASSIGN:
if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 { if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 {
@ -477,25 +473,6 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt))
return nil, false return nil, false
} }
// Need to run delayed checks
if len(cs.delayedTypeCheks) != 0 {
for _, st := range stmts {
cs.ValidateDefaultTypes(block, st)
}
// Collect all errors first
foundErr := false
for s, v := range cs.delayedTypeCheks {
if _, ok := v.IsValidated(); !ok {
foundErr = true
cs.addError(s.Pos(), v.Error())
}
}
if foundErr {
return nil, false
}
}
return stmts, true return stmts, true
} }

View File

@ -1320,62 +1320,14 @@ func TestSyntaxOperatorShift(t *testing.T) {
stmt string stmt string
err bool err bool
}{ }{
{stmt: "b := 2.0; a := 1.0 << 2.0 == b; _ = a", err: false}, {stmt: "a := 1 << 2; _ = a", err: false},
{stmt: "s := 1; b := 2.0; a := 1.0<<s == b; _ = a", err: true},
{stmt: "s := 1; b := 2; a := 1.0<<s == b; _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + ivec2(3.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + vec2(3); _ = a", err: true},
{stmt: "s := 1; a := 2.0<<s + ivec2(3); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + foo_int_int(3.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + 3.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 2<<s + 3.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 2.0<<s + 3<<s; _ = a", err: true},
{stmt: "s := 1; a := 2<<s + 3<<s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 2<<s)", err: false},
{stmt: "s := 1; foo_multivar(0, 2.0<<s, 0)", err: true},
{stmt: "s := 1; foo_multivar(2.0<<s, 0, 0)", err: false},
{stmt: "s := 1; a := foo_multivar(2.0<<s, 0, 0); _ = a", err: false},
{stmt: "s := 1; a := foo_multivar(0, 2.0<<s, 0); _ = a", err: true},
{stmt: "s := 1; a := foo_multivar(0, 0, 2.0<<s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0<<2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0<<2.0, 0); _ = a", err: false},
{stmt: "s := 1; a := int(1) + 1.0<<s + int(float(1<<s)); _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 << 2.0 << 3.0 << 4.0 << s; _ = a", err: false},
{stmt: "s := 1; var a float = 1 << 1 << 1 << 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 1; a := foo_float_int(1.0<<s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 1; a := foo_int_float(1<<s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 1; a := foo_int_int(1<<s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 << s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 << s; _ = a", err: false},
{stmt: "s := 1; b := 1 << s; _ = b", err: false},
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
{stmt: "s := 1; var a int = int(1 << s); _ = a", err: false},
{stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false},
{stmt: "s := 1; a := 1 << s; _ = a", err: false},
{stmt: "s := 1; a := 1.0 << s; _ = a", err: true},
{stmt: "s := 1; a := int(1.0 << s); _ = a", err: false},
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
{stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 << s; _ = a", err: false},
{stmt: "var a float = 1.0 << 2.0; _ = a", err: false},
{stmt: "var a int = 1.0 << 2; _ = a", err: false},
{stmt: "var a float = 1.0 << 2; _ = a", err: false},
{stmt: "a := 1 << 2.0; _ = a", err: false}, {stmt: "a := 1 << 2.0; _ = a", err: false},
{stmt: "a := 1.0 << 2; _ = a", err: false}, {stmt: "a := 1.0 << 2; _ = a", err: false},
{stmt: "a := 1.0 << 2.0; _ = a", err: false}, {stmt: "a := 1.0 << 2.0; _ = a", err: false},
{stmt: "a := 1 << 2; _ = a", err: false}, {stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "var a = 1; b := 2.0 << a; _ = b", err: false}, // PR: #2916
{stmt: "a := float(1.0) << 2; _ = a", err: true}, {stmt: "a := float(1.0) << 2; _ = a", err: true},
{stmt: "a := 1 << float(2.0); _ = a", err: false}, {stmt: "a := 1 << float(2.0); _ = a", err: true},
{stmt: "a := 1.0 << float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) << 2; _ = a", err: false}, {stmt: "a := ivec2(1) << 2; _ = a", err: false},
{stmt: "a := 1 << ivec2(2); _ = a", err: true}, {stmt: "a := 1 << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1) << float(2.0); _ = a", err: true},
@ -1395,62 +1347,14 @@ func TestSyntaxOperatorShift(t *testing.T) {
{stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
{stmt: "b := 2.0; a := 1.0 >> 2.0 == b; _ = a", err: false}, {stmt: "a := 1 >> 2; _ = a", err: false},
{stmt: "s := 1; b := 2.0; a := 1.0>>s == b; _ = a", err: true},
{stmt: "s := 1; b := 2; a := 1.0>>s == b; _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + ivec2(3.0>>s); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + vec2(3); _ = a", err: true},
{stmt: "s := 1; a := 2.0>>s + ivec2(3); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + foo_int_int(3.0>>s); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + 3.0>>s; _ = a", err: true},
{stmt: "s := 1; a := 2>>s + 3.0>>s; _ = a", err: true},
{stmt: "s := 1; a := 2.0>>s + 3>>s; _ = a", err: true},
{stmt: "s := 1; a := 2>>s + 3>>s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 2>>s)", err: false},
{stmt: "s := 1; foo_multivar(0, 2.0>>s, 0)", err: true},
{stmt: "s := 1; foo_multivar(2.0>>s, 0, 0)", err: false},
{stmt: "s := 1; a := foo_multivar(2.0>>s, 0, 0); _ = a", err: false},
{stmt: "s := 1; a := foo_multivar(0, 2.0>>s, 0); _ = a", err: true},
{stmt: "s := 1; a := foo_multivar(0, 0, 2.0>>s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0>>2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0>>2.0, 0); _ = a", err: false},
{stmt: "s := 1; a := int(1) + 1.0>>s + int(float(1>>s)); _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 >> 2.0 >> 3.0 >> 4.0 >> s; _ = a", err: false},
{stmt: "s := 1; var a float = 1 >> 1 >> 1 >> 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 1; a := foo_float_int(1.0>>s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 1; a := foo_int_float(1>>s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 1; a := foo_int_int(1>>s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 >> s; _ = a", err: false},
{stmt: "s := 1; b := 1 >> s; _ = b", err: false},
{stmt: "var a = 1; b := a >> 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true},
{stmt: "s := 1; var a int = int(1 >> s); _ = a", err: false},
{stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false},
{stmt: "s := 1; a := 1 >> s; _ = a", err: false},
{stmt: "s := 1; a := 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; a := int(1.0 >> s); _ = a", err: false},
{stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true},
{stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 >> s; _ = a", err: false},
{stmt: "var a float = 1.0 >> 2.0; _ = a", err: false},
{stmt: "var a int = 1.0 >> 2; _ = a", err: false},
{stmt: "var a float = 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1 >> 2.0; _ = a", err: false}, {stmt: "a := 1 >> 2.0; _ = a", err: false},
{stmt: "a := 1.0 >> 2; _ = a", err: false}, {stmt: "a := 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1.0 >> 2.0; _ = a", err: false}, {stmt: "a := 1.0 >> 2.0; _ = a", err: false},
{stmt: "a := 1 >> 2; _ = a", err: false}, {stmt: "var a = 1; b := a >> 2.0; _ = b", err: false},
{stmt: "var a = 1; b := 2.0 >> a; _ = b", err: false}, // PR: #2916
{stmt: "a := float(1.0) >> 2; _ = a", err: true}, {stmt: "a := float(1.0) >> 2; _ = a", err: true},
{stmt: "a := 1 >> float(2.0); _ = a", err: false}, {stmt: "a := 1 >> float(2.0); _ = a", err: true},
{stmt: "a := 1.0 >> float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) >> 2; _ = a", err: false}, {stmt: "a := ivec2(1) >> 2; _ = a", err: false},
{stmt: "a := 1 >> ivec2(2); _ = a", err: true}, {stmt: "a := 1 >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true},
@ -1473,11 +1377,7 @@ func TestSyntaxOperatorShift(t *testing.T) {
for _, c := range cases { for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main _, err := compileToIR([]byte(fmt.Sprintf(`package main
func foo_multivar(x int, y float, z int) int {return x}
func foo_int_int(x int) int {return x}
func foo_float_int(x float) int {return int(x)}
func foo_int_float(x int) float {return float(x)}
func foo_float_float(x float) float {return x}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s %s
return dstPos return dstPos

View File

@ -18,31 +18,21 @@ import (
"go/constant" "go/constant"
) )
func ResolveUntypedConstsForBitShiftOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { func ResolveUntypedConstsForBinaryOp(op Op, lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
cLhs := lhs
cRhs := rhs
// Right is const -> int
if rhs != nil {
cRhs = constant.ToInt(rhs)
if cRhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
// Left if untyped const -> int
if lhs != nil && lhst.Main == None {
cLhs = constant.ToInt(lhs)
if cLhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
return cLhs, cRhs, true
}
func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
if lhst.Main == None && rhst.Main == None { if lhst.Main == None && rhst.Main == None {
if op == LeftShift || op == RightShift {
newLhs = constant.ToInt(lhs)
newRhs = constant.ToInt(rhs)
if newLhs.Kind() == constant.Unknown {
return nil, nil, false
}
if newRhs.Kind() == constant.Unknown {
return nil, nil, false
}
return newLhs, newRhs, true
}
if lhs.Kind() == rhs.Kind() { if lhs.Kind() == rhs.Kind() {
return lhs, rhs, true return lhs, rhs, true
} }
@ -121,6 +111,13 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
return Type{}, false return Type{}, false
} }
if op == LeftShift || op == RightShift {
if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int {
return Type{Main: Int}, true
}
return Type{}, false
}
if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return Type{Main: Bool}, true return Type{Main: Bool}, true
} }
@ -144,16 +141,6 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
panic("shaderir: cannot resolve untyped values") panic("shaderir: cannot resolve untyped values")
} }
if op == LeftShift || op == RightShift {
if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return lhst, true
}
return Type{}, false
}
if op == AndAnd || op == OrOr { if op == AndAnd || op == OrOr {
if lhst.Main == Bool && rhst.Main == Bool { if lhst.Main == Bool && rhst.Main == Bool {
return Type{Main: Bool}, true return Type{Main: Bool}, true
@ -228,6 +215,16 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
return Type{}, false return Type{}, false
} }
if op == LeftShift || op == RightShift {
if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return lhst, true
}
return Type{}, false
}
if lhst.Equal(&rhst) { if lhst.Equal(&rhst) {
if lhst.Main == None { if lhst.Main == None {
return rhst, true return rhst, true

View File

@ -16,7 +16,6 @@
package shaderir package shaderir
import ( import (
"go/ast"
"go/constant" "go/constant"
"go/token" "go/token"
"sort" "sort"
@ -85,7 +84,6 @@ type Stmt struct {
ForOp Op ForOp Op
ForDelta constant.Value ForDelta constant.Value
InitIndex int InitIndex int
IsTypeGuessed bool
} }
type StmtType int type StmtType int
@ -111,7 +109,6 @@ type Expr struct {
Swizzling string Swizzling string
Index int Index int
Op Op Op Op
Ast ast.Expr
} }
type ExprType int type ExprType int