Compare commits

...

21 Commits

Author SHA1 Message Date
Mykhailo Lohachov
0ddc018190
Merge b0a7fb36c7 into 4b1c0526a7 2024-03-21 08:29:08 +00:00
aoyako
b0a7fb36c7 simplify type deduction in TypeFromBinaryOp for shift op 2024-03-21 17:28:32 +09:00
aoyako
706011c275 add: new binary shift operator rules 2024-03-21 17:21:11 +09:00
aoyako
49682f1097 Revert "add return type for type resolving"
This reverts commit 7f9d997175.
2024-03-21 17:14:32 +09:00
aoyako
8d15a459cf Revert "remove return type for deduced int"
This reverts commit 66a4b20bda.
2024-03-21 17:14:24 +09:00
aoyako
ced2d6ec8b Revert "add basic checks"
This reverts commit f44640778d.
2024-03-21 17:14:16 +09:00
aoyako
0651a09052 Revert "add shift type checks"
This reverts commit f02e9fd4d0.
2024-03-21 17:14:04 +09:00
aoyako
4f3e649bf0 Revert "update tests for right shift"
This reverts commit d1b9216ee1.
2024-03-21 17:13:34 +09:00
aoyako
29fb6c7f6f Revert "remove comment"
This reverts commit 359e7b8597.
2024-03-21 17:13:23 +09:00
aoyako
359e7b8597 remove comment 2024-03-02 15:41:30 +09:00
aoyako
d1b9216ee1 update tests for right shift 2024-03-02 15:40:38 +09:00
aoyako
f02e9fd4d0 add shift type checks 2024-03-02 15:32:31 +09:00
aoyako
f44640778d add basic checks 2024-02-28 20:27:26 +09:00
aoyako
66a4b20bda remove return type for deduced int 2024-02-27 19:39:14 +09:00
aoyako
7f9d997175 add return type for type resolving 2024-02-27 19:29:39 +09:00
aoyako
c90d02f8d4 add: float->int cast tests 2024-02-26 21:06:28 +09:00
aoyako
2b7d20e7da fix: remove unnecessary branch 2024-02-26 18:17:41 +09:00
aoyako
d69bb04a56 add support for shift + assign 2024-02-26 18:00:52 +09:00
aoyako
5f61cf00e5 extend tests with right-shift op 2024-02-26 17:03:19 +09:00
aoyako
7f01f98200 add tests for binop shift 2024-02-26 17:02:02 +09:00
aoyako
fe887e2565 add typechecks for bitshifts ops 2024-02-26 16:14:55 +09:00
4 changed files with 204 additions and 6 deletions

View File

@ -105,7 +105,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
// Resolve untyped constants.
l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(op2, lhs[0].Const, rhs[0].Const, lhst, rhst)
if !ok {
// TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
@ -153,6 +153,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
v = gconstant.MakeBool(b)
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ:
v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const))
case token.SHL, token.SHR:
shift, ok := gconstant.Int64Val(rhs[0].Const)
if !ok {
cs.addError(e.Pos(), fmt.Sprintf("unexpected %s type for: %s", rhs[0].Const.String(), e.Op))
} else {
v = gconstant.Shift(lhs[0].Const, op, uint(shift))
}
default:
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
}

View File

@ -60,7 +60,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false
}
stmts = append(stmts, ss...)
case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN:
case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN:
rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true)
if !ok {
return nil, false
@ -100,6 +100,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
op = shaderir.Or
case token.XOR_ASSIGN:
op = shaderir.Xor
case token.SHL_ASSIGN:
op = shaderir.LeftShift
case token.SHR_ASSIGN:
op = shaderir.RightShift
default:
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok))
return nil, false
@ -110,7 +114,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String()))
return nil, false
}
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor {
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
}
@ -137,7 +141,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
}
}
case shaderir.Float:
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor {
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} else if rhs[0].Const != nil &&
(rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) &&
@ -148,7 +152,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false
}
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor {
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} else if (op == shaderir.MatrixMul || op == shaderir.Div) &&
(rts[0].Main == shaderir.Float ||

View File

@ -1314,6 +1314,163 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
}
}
// Issue: #2755
func TestSyntaxOperatorShift(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := 1 << 2; _ = a", err: false},
{stmt: "a := 1 << 2.0; _ = a", err: false},
{stmt: "a := 1.0 << 2; _ = a", err: false},
{stmt: "a := 1.0 << 2.0; _ = a", err: false},
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "var a = 1; b := 2.0 << a; _ = b", err: false}, // PR: #2916
{stmt: "a := float(1.0) << 2; _ = a", err: true},
{stmt: "a := 1 << float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1) << 2; _ = a", err: false},
{stmt: "a := 1 << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << float(2.0); _ = a", err: true},
{stmt: "a := float(1.0) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << ivec3(2); _ = a", err: true},
{stmt: "a := 1 << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << 2; _ = a", err: true},
{stmt: "a := float(1.0) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << float(2.0); _ = a", err: true},
{stmt: "a := vec2(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << vec3(2); _ = a", err: true},
{stmt: "a := vec3(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
{stmt: "a := 1 >> 2; _ = a", err: false},
{stmt: "a := 1 >> 2.0; _ = a", err: false},
{stmt: "a := 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1.0 >> 2.0; _ = a", err: false},
{stmt: "var a = 1; b := a >> 2.0; _ = b", err: false},
{stmt: "var a = 1; b := 2.0 >> a; _ = b", err: false}, // PR: #2916
{stmt: "a := float(1.0) >> 2; _ = a", err: true},
{stmt: "a := 1 >> float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1) >> 2; _ = a", err: false},
{stmt: "a := 1 >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true},
{stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true},
{stmt: "a := 1 >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> 2; _ = a", err: true},
{stmt: "a := float(1.0) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> float(2.0); _ = a", err: true},
{stmt: "a := vec2(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> vec3(2); _ = a", err: true},
{stmt: "a := vec3(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true},
}
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}
func TestSyntaxOperatorShiftAssign(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := 1; a <<= 2; _ = a", err: false},
{stmt: "a := 1; a <<= 2.0; _ = a", err: false},
{stmt: "a := float(1.0); a <<= 2; _ = a", err: true},
{stmt: "a := 1; a <<= float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= 2; _ = a", err: false},
{stmt: "a := 1; a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= float(2.0); _ = a", err: true},
{stmt: "a := float(1.0); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= ivec3(2); _ = a", err: true},
{stmt: "a := 1; a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= 2; _ = a", err: true},
{stmt: "a := float(1.0); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= float(2.0); _ = a", err: true},
{stmt: "a := vec2(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= vec3(2); _ = a", err: true},
{stmt: "a := vec3(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec3(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= vec3(2); _ = a", err: true},
{stmt: "const c = 2; a := 1; a <<= c; _ = a", err: false},
{stmt: "const c = 2.0; a := 1; a <<= c; _ = a", err: false},
{stmt: "const c = 2; a := float(1.0); a <<= c; _ = a", err: true},
{stmt: "const c float = 2; a := 1; a <<= c; _ = a", err: true},
{stmt: "const c float = 2.0; a := 1; a <<= c; _ = a", err: true},
{stmt: "const c int = 2; a := ivec2(1); a <<= c; _ = a", err: false},
{stmt: "const c int = 2; a := vec2(1); a <<= c; _ = a", err: true},
{stmt: "a := 1; a >>= 2; _ = a", err: false},
{stmt: "a := 1; a >>= 2.0; _ = a", err: false},
{stmt: "a := float(1.0); a >>= 2; _ = a", err: true},
{stmt: "a := 1; a >>= float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= 2; _ = a", err: false},
{stmt: "a := 1; a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= float(2.0); _ = a", err: true},
{stmt: "a := float(1.0); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= ivec3(2); _ = a", err: true},
{stmt: "a := 1; a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= 2; _ = a", err: true},
{stmt: "a := float(1.0); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= float(2.0); _ = a", err: true},
{stmt: "a := vec2(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= vec3(2); _ = a", err: true},
{stmt: "a := vec3(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec3(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= vec3(2); _ = a", err: true},
{stmt: "const c = 2; a := 1; a >>= c; _ = a", err: false},
{stmt: "const c = 2.0; a := 1; a >>= c; _ = a", err: false},
{stmt: "const c = 2; a := float(1.0); a >>= c; _ = a", err: true},
{stmt: "const c float = 2; a := 1; a >>= c; _ = a", err: true},
{stmt: "const c float = 2.0; a := 1; a >>= c; _ = a", err: true},
{stmt: "const c int = 2; a := ivec2(1); a >>= c; _ = a", err: false},
{stmt: "const c int = 2; a := vec2(1); a >>= c; _ = a", err: true},
}
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}
// Issue #1971
func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
cases := []struct {

View File

@ -18,8 +18,21 @@ import (
"go/constant"
)
func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
func ResolveUntypedConstsForBinaryOp(op Op, lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
if lhst.Main == None && rhst.Main == None {
if op == LeftShift || op == RightShift {
newLhs = constant.ToInt(lhs)
newRhs = constant.ToInt(rhs)
if newLhs.Kind() == constant.Unknown {
return nil, nil, false
}
if newRhs.Kind() == constant.Unknown {
return nil, nil, false
}
return newLhs, newRhs, true
}
if lhs.Kind() == rhs.Kind() {
return lhs, rhs, true
}
@ -98,6 +111,13 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
return Type{}, false
}
if op == LeftShift || op == RightShift {
if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int {
return Type{Main: Int}, true
}
return Type{}, false
}
if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return Type{Main: Bool}, true
}
@ -195,6 +215,16 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
return Type{}, false
}
if op == LeftShift || op == RightShift {
if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return lhst, true
}
return Type{}, false
}
if lhst.Equal(&rhst) {
if lhst.Main == None {
return rhst, true