From fe887e2565e70a28c63972780874f94c844f56fd Mon Sep 17 00:00:00 2001 From: aoyako Date: Sun, 25 Feb 2024 21:36:55 +0900 Subject: [PATCH 01/20] add typechecks for bitshifts ops --- internal/shaderir/check.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index 857bd1e2a..b6d8b08bb 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -173,7 +173,7 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } - if op == And || op == Or || op == Xor { + if op == And || op == Or || op == Xor || op == LeftShift || op == RightShift { if lhst.Main == Int && rhst.Main == Int { return Type{Main: Int}, true } @@ -190,7 +190,9 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return lhst, true } if lhst.Main == Int && rhst.IsIntVector() { - return rhst, true + if op == And || op == Or || op == Xor { + return rhst, true + } } return Type{}, false } From 7f01f9820064d5ac0644fc293a9f31101de76fea Mon Sep 17 00:00:00 2001 From: aoyako Date: Mon, 26 Feb 2024 17:02:02 +0900 Subject: [PATCH 02/20] add tests for binop shift --- internal/shader/expr.go | 7 ++++++ internal/shader/syntax_test.go | 44 ++++++++++++++++++++++++++++++++++ internal/shaderir/check.go | 32 +++++++++++++++++++++---- 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index b57c28db5..cc41cea32 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -149,6 +149,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar v = gconstant.MakeBool(b) case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ: v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) + case token.SHL, token.SHR: + shift, ok := gconstant.Int64Val(rhs[0].Const) + if !ok { + cs.addError(e.Pos(), fmt.Sprintf("unexpected %s type for: %s", rhs[0].Const.String(), e.Op)) + } else { + v = gconstant.Shift(lhs[0].Const, op, uint(shift)) + } default: v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 68ee4b5e8..8ed19be6f 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1314,6 +1314,50 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { } } +// Issue: #2755 +func TestSyntaxOperatorShift(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := 1 << 2; _ = a", err: false}, + {stmt: "a := float(1.0) << 2; _ = a", err: true}, + {stmt: "a := 1 << float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1) << 2; _ = a", err: false}, + {stmt: "a := 1 << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << ivec3(2); _ = a", err: true}, + {stmt: "a := 1 << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << 2; _ = a", err: true}, + {stmt: "a := float(1.0) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, + } + + for _, c := range cases { + _, err := compileToIR([]byte(fmt.Sprintf(`package main + +func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + %s + return dstPos +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +} + // Issue #1971 func TestSyntaxOperatorMultiplyAssign(t *testing.T) { cases := []struct { diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index b6d8b08bb..cb93f3b79 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -98,6 +98,13 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } + if op == LeftShift || op == RightShift { + if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int { + return Type{Main: Int}, true + } + return Type{}, false + } + if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { return Type{Main: Bool}, true } @@ -173,7 +180,7 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } - if op == And || op == Or || op == Xor || op == LeftShift || op == RightShift { + if op == And || op == Or || op == Xor { if lhst.Main == Int && rhst.Main == Int { return Type{Main: Int}, true } @@ -190,9 +197,26 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return lhst, true } if lhst.Main == Int && rhst.IsIntVector() { - if op == And || op == Or || op == Xor { - return rhst, true - } + return rhst, true + } + return Type{}, false + } + + if op == LeftShift || op == RightShift { + if lhst.Main == Int && rhst.Main == Int { + return Type{Main: Int}, true + } + if lhst.Main == IVec2 && rhst.Main == IVec2 { + return Type{Main: IVec2}, true + } + if lhst.Main == IVec3 && rhst.Main == IVec3 { + return Type{Main: IVec3}, true + } + if lhst.Main == IVec4 && rhst.Main == IVec4 { + return Type{Main: IVec4}, true + } + if lhst.IsIntVector() && rhst.Main == Int { + return lhst, true } return Type{}, false } From 5f61cf00e5443a2cbf74deeb50457ee4ac7f3482 Mon Sep 17 00:00:00 2001 From: aoyako Date: Mon, 26 Feb 2024 17:03:19 +0900 Subject: [PATCH 03/20] extend tests with right-shift op --- internal/shader/syntax_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 8ed19be6f..dc2a2b273 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1341,6 +1341,27 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, + {stmt: "a := 1 >> 2; _ = a", err: false}, + {stmt: "a := float(1.0) >> 2; _ = a", err: true}, + {stmt: "a := 1 >> float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, + {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true}, + {stmt: "a := 1 >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> 2; _ = a", err: true}, + {stmt: "a := float(1.0) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true}, } for _, c := range cases { From d69bb04a561006db3b30aba26e862adcc564dab0 Mon Sep 17 00:00:00 2001 From: aoyako Date: Mon, 26 Feb 2024 18:00:52 +0900 Subject: [PATCH 04/20] add support for shift + assign --- internal/shader/stmt.go | 19 +++++++-- internal/shader/syntax_test.go | 78 ++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 42f0ddd13..ef1745853 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -60,7 +60,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } stmts = append(stmts, ss...) - case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN: + case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN: rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true) if !ok { return nil, false @@ -100,6 +100,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP op = shaderir.Or case token.XOR_ASSIGN: op = shaderir.Xor + case token.SHL_ASSIGN: + op = shaderir.LeftShift + case token.SHR_ASSIGN: + op = shaderir.RightShift default: cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok)) return nil, false @@ -119,6 +123,15 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } return nil, false } + if op == shaderir.LeftShift || op == shaderir.RightShift { + if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() { + cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) + } + if rts[0].Main != shaderir.Int && !rts[0].IsIntVector() { + cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, rts[0].String())) + } + return nil, false + } if lts[0].Main == shaderir.Int && rhs[0].Const != nil { if !cs.forceToInt(stmt, &rhs[0]) { return nil, false @@ -137,7 +150,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } } case shaderir.Float: - if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { + if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) } else if rhs[0].Const != nil && (rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) && @@ -148,7 +161,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { + if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) } else if (op == shaderir.MatrixMul || op == shaderir.Div) && (rts[0].Main == shaderir.Float || diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index dc2a2b273..c32544882 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1341,6 +1341,7 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, + {stmt: "a := 1 >> 2; _ = a", err: false}, {stmt: "a := float(1.0) >> 2; _ = a", err: true}, {stmt: "a := 1 >> float(2.0); _ = a", err: true}, @@ -1379,6 +1380,83 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { } } +func TestSyntaxOperatorShiftAssign(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := 1; a <<= 2; _ = a", err: false}, + {stmt: "a := float(1.0); a <<= 2; _ = a", err: true}, + {stmt: "a := 1; a <<= float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= 2; _ = a", err: false}, + {stmt: "a := 1; a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= ivec3(2); _ = a", err: true}, + {stmt: "a := 1; a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= 2; _ = a", err: true}, + {stmt: "a := float(1.0); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= vec3(2); _ = a", err: true}, + {stmt: "const c = 2; a := 1; a <<= c; _ = a", err: false}, + {stmt: "const c = 2; a := float(1.0); a <<= c; _ = a", err: true}, + {stmt: "const c float = 2; a := 1; a <<= c; _ = a", err: true}, + {stmt: "const c float = 2.0; a := 1; a <<= c; _ = a", err: true}, + {stmt: "const c int = 2; a := ivec2(1); a <<= c; _ = a", err: false}, + {stmt: "const c int = 2; a := vec2(1); a <<= c; _ = a", err: true}, + + {stmt: "a := 1; a >>= 2; _ = a", err: false}, + {stmt: "a := float(1.0); a >>= 2; _ = a", err: true}, + {stmt: "a := 1; a >>= float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= 2; _ = a", err: false}, + {stmt: "a := 1; a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= ivec3(2); _ = a", err: true}, + {stmt: "a := 1; a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= 2; _ = a", err: true}, + {stmt: "a := float(1.0); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= vec3(2); _ = a", err: true}, + {stmt: "const c = 2; a := 1; a >>= c; _ = a", err: false}, + {stmt: "const c = 2; a := float(1.0); a >>= c; _ = a", err: true}, + {stmt: "const c float = 2; a := 1; a >>= c; _ = a", err: true}, + {stmt: "const c float = 2.0; a := 1; a >>= c; _ = a", err: true}, + {stmt: "const c int = 2; a := ivec2(1); a >>= c; _ = a", err: false}, + {stmt: "const c int = 2; a := vec2(1); a >>= c; _ = a", err: true}, + } + + for _, c := range cases { + _, err := compileToIR([]byte(fmt.Sprintf(`package main + +func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + %s + return dstPos +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +} + // Issue #1971 func TestSyntaxOperatorMultiplyAssign(t *testing.T) { cases := []struct { From 2b7d20e7da1680ec717e3b3ac8a2196b1d11c8ad Mon Sep 17 00:00:00 2001 From: aoyako Date: Mon, 26 Feb 2024 18:17:41 +0900 Subject: [PATCH 05/20] fix: remove unnecessary branch --- internal/shader/stmt.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index ef1745853..6c07057da 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -114,16 +114,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String())) return nil, false } - if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { - if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() { - cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) - } - if rts[0].Main != shaderir.Int && !rts[0].IsIntVector() { - cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, rts[0].String())) - } - return nil, false - } - if op == shaderir.LeftShift || op == shaderir.RightShift { + if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift { if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) } From c90d02f8d408a028c3527b0acafd32e3b3f9bced Mon Sep 17 00:00:00 2001 From: aoyako Date: Mon, 26 Feb 2024 21:06:28 +0900 Subject: [PATCH 06/20] add: float->int cast tests --- internal/shader/syntax_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index c32544882..e50ff7d4a 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,6 +1320,9 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ + {stmt: "a := 1 << 2.0; _ = a", err: true}, + {stmt: "a := 1.0 << 2; _ = a", err: true}, + {stmt: "a := 1.0 << 2.0; _ = a", err: true}, {stmt: "a := 1 << 2; _ = a", err: false}, {stmt: "a := float(1.0) << 2; _ = a", err: true}, {stmt: "a := 1 << float(2.0); _ = a", err: true}, @@ -1386,6 +1389,7 @@ func TestSyntaxOperatorShiftAssign(t *testing.T) { err bool }{ {stmt: "a := 1; a <<= 2; _ = a", err: false}, + {stmt: "a := 1; a <<= 2.0; _ = a", err: false}, {stmt: "a := float(1.0); a <<= 2; _ = a", err: true}, {stmt: "a := 1; a <<= float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1); a <<= 2; _ = a", err: false}, @@ -1407,6 +1411,7 @@ func TestSyntaxOperatorShiftAssign(t *testing.T) { {stmt: "a := vec3(1); a <<= ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1); a <<= vec3(2); _ = a", err: true}, {stmt: "const c = 2; a := 1; a <<= c; _ = a", err: false}, + {stmt: "const c = 2.0; a := 1; a <<= c; _ = a", err: false}, {stmt: "const c = 2; a := float(1.0); a <<= c; _ = a", err: true}, {stmt: "const c float = 2; a := 1; a <<= c; _ = a", err: true}, {stmt: "const c float = 2.0; a := 1; a <<= c; _ = a", err: true}, @@ -1414,6 +1419,7 @@ func TestSyntaxOperatorShiftAssign(t *testing.T) { {stmt: "const c int = 2; a := vec2(1); a <<= c; _ = a", err: true}, {stmt: "a := 1; a >>= 2; _ = a", err: false}, + {stmt: "a := 1; a >>= 2.0; _ = a", err: false}, {stmt: "a := float(1.0); a >>= 2; _ = a", err: true}, {stmt: "a := 1; a >>= float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1); a >>= 2; _ = a", err: false}, @@ -1435,6 +1441,7 @@ func TestSyntaxOperatorShiftAssign(t *testing.T) { {stmt: "a := vec3(1); a >>= ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1); a >>= vec3(2); _ = a", err: true}, {stmt: "const c = 2; a := 1; a >>= c; _ = a", err: false}, + {stmt: "const c = 2.0; a := 1; a >>= c; _ = a", err: false}, {stmt: "const c = 2; a := float(1.0); a >>= c; _ = a", err: true}, {stmt: "const c float = 2; a := 1; a >>= c; _ = a", err: true}, {stmt: "const c float = 2.0; a := 1; a >>= c; _ = a", err: true}, From 7f9d9971758e979488b7e6a6e9ff1208036f54e4 Mon Sep 17 00:00:00 2001 From: aoyako Date: Tue, 27 Feb 2024 19:29:39 +0900 Subject: [PATCH 07/20] add return type for type resolving --- internal/shader/expr.go | 64 +++++++++++++++++++++++----------- internal/shader/stmt.go | 6 ++++ internal/shader/syntax_test.go | 41 +++++++++++++++++++--- internal/shader/type.go | 2 +- internal/shaderir/check.go | 62 ++++++++++++++++++-------------- internal/shaderir/type.go | 1 + 6 files changed, 124 insertions(+), 52 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index cc41cea32..b3dac6a40 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -101,7 +101,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } // Resolve untyped constants. - l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) + var l gconstant.Value + var r gconstant.Value + if op2 == shaderir.LeftShift || op2 == shaderir.RightShift { + l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst) + } else { + l, r, ok = shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) + } if !ok { // TODO: Show a better type name for untyped constants. cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) @@ -109,27 +115,45 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } lhs[0].Const, rhs[0].Const = l, r - // If either is typed, resolve the other type. - // If both are untyped, keep them untyped. - if lhst.Main != shaderir.None || rhst.Main != shaderir.None { - if lhs[0].Const != nil { - switch lhs[0].Const.Kind() { - case gconstant.Float: - lhst = shaderir.Type{Main: shaderir.Float} - case gconstant.Int: - lhst = shaderir.Type{Main: shaderir.Int} - case gconstant.Bool: - lhst = shaderir.Type{Main: shaderir.Bool} + if op2 == shaderir.LeftShift || op2 == shaderir.RightShift { + if !(lhst.Main == shaderir.None && rhst.Main == shaderir.None) { + // If both are const + if rhs[0].Const != nil && (rhst.Main == shaderir.None || lhs[0].Const != nil) { + rhst = shaderir.Type{Main: shaderir.Int} + } + + // If left is untyped const + if lhst.Main == shaderir.None && lhs[0].Const != nil { + if rhs[0].Const != nil { + lhst = shaderir.Type{Main: shaderir.Int} + } else { + lhst = shaderir.Type{Main: shaderir.DeducedInt} + } } } - if rhs[0].Const != nil { - switch rhs[0].Const.Kind() { - case gconstant.Float: - rhst = shaderir.Type{Main: shaderir.Float} - case gconstant.Int: - rhst = shaderir.Type{Main: shaderir.Int} - case gconstant.Bool: - rhst = shaderir.Type{Main: shaderir.Bool} + } else { + // If either is typed, resolve the other type. + // If both are untyped, keep them untyped. + if lhst.Main != shaderir.None || rhst.Main != shaderir.None { + if lhs[0].Const != nil { + switch lhs[0].Const.Kind() { + case gconstant.Float: + lhst = shaderir.Type{Main: shaderir.Float} + case gconstant.Int: + lhst = shaderir.Type{Main: shaderir.Int} + case gconstant.Bool: + lhst = shaderir.Type{Main: shaderir.Bool} + } + } + if rhs[0].Const != nil { + switch rhs[0].Const.Kind() { + case gconstant.Float: + rhst = shaderir.Type{Main: shaderir.Float} + case gconstant.Int: + rhst = shaderir.Type{Main: shaderir.Int} + case gconstant.Bool: + rhst = shaderir.Type{Main: shaderir.Bool} + } } } } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 6c07057da..b3d8e1732 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -514,6 +514,9 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r return nil, false } t := ts[0] + if t.Main == shaderir.DeducedInt { + cs.addError(pos, "invalid operation: shifted operand 1 (type float) must be integer") + } if t.Main == shaderir.None { t = toDefaultType(r[0].Const) } @@ -705,6 +708,9 @@ func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool { if lt.Equal(rt) { return true } + if lt.Main == shaderir.Int && rt.Main == shaderir.DeducedInt { + return true + } if rc == nil { return false diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index e50ff7d4a..df2c1c84d 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,12 +1320,30 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "a := 1 << 2.0; _ = a", err: true}, - {stmt: "a := 1.0 << 2; _ = a", err: true}, - {stmt: "a := 1.0 << 2.0; _ = a", err: true}, + // {stmt: "s := 1; var a float = float(1 << s); _ = a", err: true}, + // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + // {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, + // {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, + + {stmt: "s := 1; a := 1 << s; _ = a", err: false}, + {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, + {stmt: "var a float = 1.0 << 2.0; _ = a", err: false}, + {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, + {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, + {stmt: "var a int = 1.0 << 2; _ = a", err: false}, + {stmt: "var a float = 1.0 << 2; _ = a", err: false}, + {stmt: "var a = 1.0 << 2; _ = a", err: false}, + {stmt: "a := 1 << 2.0; _ = a", err: false}, + {stmt: "a := 1.0 << 2; _ = a", err: false}, + {stmt: "a := 1.0 << 2.0; _ = a", err: false}, {stmt: "a := 1 << 2; _ = a", err: false}, {stmt: "a := float(1.0) << 2; _ = a", err: true}, - {stmt: "a := 1 << float(2.0); _ = a", err: true}, + {stmt: "a := 1 << float(2.0); _ = a", err: false}, + {stmt: "a := 1.0 << float(2.0); _ = a", err: false}, {stmt: "a := ivec2(1) << 2; _ = a", err: false}, {stmt: "a := 1 << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << float(2.0); _ = a", err: true}, @@ -1345,9 +1363,22 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, + {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, + {stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false}, + {stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false}, + {stmt: "s := 1; var a int = 1 >> s; _ = a", err: false}, + {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, + {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, + {stmt: "var a = 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1 >> 2.0; _ = a", err: false}, + {stmt: "a := 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, {stmt: "a := 1 >> 2; _ = a", err: false}, {stmt: "a := float(1.0) >> 2; _ = a", err: true}, - {stmt: "a := 1 >> float(2.0); _ = a", err: true}, + {stmt: "a := 1 >> float(2.0); _ = a", err: false}, + {stmt: "a := 1.0 >> float(2.0); _ = a", err: false}, {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, diff --git a/internal/shader/type.go b/internal/shader/type.go index 546407556..5c8aade39 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -165,7 +165,7 @@ func checkArgsForIntBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) err if len(args) != 1 { return fmt.Errorf("number of int's arguments must be 1 but %d", len(args)) } - if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float { + if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float || argts[0].Main == shaderir.DeducedInt { return nil } if args[0].Const != nil && gconstant.ToInt(args[0].Const).Kind() != gconstant.Unknown { diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index cb93f3b79..e518726b1 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -18,6 +18,29 @@ import ( "go/constant" ) +func ResolveUntypedConstsForBitShiftOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { + cLhs := lhs + cRhs := rhs + + // Right is const -> int + if rhs != nil { + cRhs = constant.ToInt(rhs) + if cRhs.Kind() == constant.Unknown { + return nil, nil, false + } + } + + // Left if untyped const -> int + if lhs != nil && lhst.Main == None { + cLhs = constant.ToInt(lhs) + if cLhs.Kind() == constant.Unknown { + return nil, nil, false + } + } + + return cLhs, cRhs, true +} + func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { if lhst.Main == None && rhst.Main == None { if lhs.Kind() == rhs.Kind() { @@ -98,13 +121,6 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } - if op == LeftShift || op == RightShift { - if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int { - return Type{Main: Int}, true - } - return Type{}, false - } - if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { return Type{Main: Bool}, true } @@ -128,6 +144,19 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) panic("shaderir: cannot resolve untyped values") } + if op == LeftShift || op == RightShift { + if (lhst.Main == Int || lhst.Main == DeducedInt) && rhst.Main == Int { + return Type{Main: lhst.Main}, true + } + if lhst.IsIntVector() && rhst.Main == Int { + return Type{Main: lhst.Main}, true + } + if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() { + return Type{Main: lhst.Main}, true + } + return Type{}, false + } + if op == AndAnd || op == OrOr { if lhst.Main == Bool && rhst.Main == Bool { return Type{Main: Bool}, true @@ -202,25 +231,6 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } - if op == LeftShift || op == RightShift { - if lhst.Main == Int && rhst.Main == Int { - return Type{Main: Int}, true - } - if lhst.Main == IVec2 && rhst.Main == IVec2 { - return Type{Main: IVec2}, true - } - if lhst.Main == IVec3 && rhst.Main == IVec3 { - return Type{Main: IVec3}, true - } - if lhst.Main == IVec4 && rhst.Main == IVec4 { - return Type{Main: IVec4}, true - } - if lhst.IsIntVector() && rhst.Main == Int { - return lhst, true - } - return Type{}, false - } - if lhst.Equal(&rhst) { if lhst.Main == None { return rhst, true diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index ede2c91e2..885d579c4 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -180,6 +180,7 @@ const ( Texture Array Struct + DeducedInt ) func descendantLocalVars(block, target *Block) ([]Type, bool) { From 66a4b20bdabf90a28f7498b7937d08329538264a Mon Sep 17 00:00:00 2001 From: aoyako Date: Tue, 27 Feb 2024 19:39:14 +0900 Subject: [PATCH 08/20] remove return type for deduced int --- internal/shader/expr.go | 7 ++----- internal/shader/stmt.go | 6 ------ internal/shader/syntax_test.go | 23 +++++++++-------------- internal/shader/type.go | 2 +- internal/shaderir/check.go | 2 +- internal/shaderir/type.go | 1 - 6 files changed, 13 insertions(+), 28 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index b3dac6a40..1d06be676 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -124,11 +124,8 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar // If left is untyped const if lhst.Main == shaderir.None && lhs[0].Const != nil { - if rhs[0].Const != nil { - lhst = shaderir.Type{Main: shaderir.Int} - } else { - lhst = shaderir.Type{Main: shaderir.DeducedInt} - } + lhst = shaderir.Type{Main: shaderir.Int} + // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. } } } else { diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index b3d8e1732..6c07057da 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -514,9 +514,6 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r return nil, false } t := ts[0] - if t.Main == shaderir.DeducedInt { - cs.addError(pos, "invalid operation: shifted operand 1 (type float) must be integer") - } if t.Main == shaderir.None { t = toDefaultType(r[0].Const) } @@ -708,9 +705,6 @@ func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool { if lt.Equal(rt) { return true } - if lt.Main == shaderir.Int && rt.Main == shaderir.DeducedInt { - return true - } if rc == nil { return false diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index df2c1c84d..1c2920cc4 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1324,16 +1324,16 @@ func TestSyntaxOperatorShift(t *testing.T) { // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, // {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, // {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, + // {stmt: "s := 1; a := 1 << s; _ = a", err: false}, + // {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, + // {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, + // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + // {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, + // {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + // {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, + // {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, - {stmt: "s := 1; a := 1 << s; _ = a", err: false}, - {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, - {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, {stmt: "var a float = 1.0 << 2.0; _ = a", err: false}, - {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, - {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, - {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, {stmt: "var a int = 1.0 << 2; _ = a", err: false}, {stmt: "var a float = 1.0 << 2; _ = a", err: false}, {stmt: "var a = 1.0 << 2; _ = a", err: false}, @@ -1363,12 +1363,7 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, - {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, - {stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false}, - {stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, - {stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false}, - {stmt: "s := 1; var a int = 1 >> s; _ = a", err: false}, + {stmt: "var a float = 1.0 >> 2.0; _ = a", err: false}, {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, {stmt: "var a = 1.0 >> 2; _ = a", err: false}, diff --git a/internal/shader/type.go b/internal/shader/type.go index 5c8aade39..546407556 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -165,7 +165,7 @@ func checkArgsForIntBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) err if len(args) != 1 { return fmt.Errorf("number of int's arguments must be 1 but %d", len(args)) } - if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float || argts[0].Main == shaderir.DeducedInt { + if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float { return nil } if args[0].Const != nil && gconstant.ToInt(args[0].Const).Kind() != gconstant.Unknown { diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index e518726b1..57bacc749 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -145,7 +145,7 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) } if op == LeftShift || op == RightShift { - if (lhst.Main == Int || lhst.Main == DeducedInt) && rhst.Main == Int { + if lhst.Main == Int && rhst.Main == Int { return Type{Main: lhst.Main}, true } if lhst.IsIntVector() && rhst.Main == Int { diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index 885d579c4..ede2c91e2 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -180,7 +180,6 @@ const ( Texture Array Struct - DeducedInt ) func descendantLocalVars(block, target *Block) ([]Type, bool) { From f44640778d14571a26233bf98a860309a7cc948d Mon Sep 17 00:00:00 2001 From: aoyako Date: Wed, 28 Feb 2024 20:27:26 +0900 Subject: [PATCH 09/20] add basic checks --- internal/shader/delayed.go | 124 ++++++++++++++++++++++++++++++ internal/shader/expr.go | 11 ++- internal/shader/shader.go | 15 ++++ internal/shader/syntax_test.go | 134 ++++++++++++++++++++++----------- 4 files changed, 240 insertions(+), 44 deletions(-) create mode 100644 internal/shader/delayed.go diff --git a/internal/shader/delayed.go b/internal/shader/delayed.go new file mode 100644 index 000000000..935d5797e --- /dev/null +++ b/internal/shader/delayed.go @@ -0,0 +1,124 @@ +// Copyright 2024 The Ebiten Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shader + +import ( + "go/ast" + gconstant "go/constant" + "go/token" + + "github.com/hajimehoshi/ebiten/v2/internal/shaderir" +) + +type resolveTypeStatus int + +const ( + resolveUnsure resolveTypeStatus = iota + resolveOk + resolveFail +) + +type delayedValidator interface { + Validate(expr ast.Expr) resolveTypeStatus + Pos() token.Pos + Error() string +} + +func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) { + valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks)) + for k := range cs.delayedTypeCheks { + valExprs = append(valExprs, k) + } + for _, expr := range valExprs { + if cexpr == expr { + continue + } + // Check if delayed validation can be done by adding current context + cres := cs.delayedTypeCheks[expr].Validate(cexpr) + switch cres { + case resolveFail: + cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error()) + return false + case resolveOk: + delete(cs.delayedTypeCheks, expr) + } + } + + return true +} + +type delayedShiftValidator struct { + value gconstant.Value + pos token.Pos + last ast.Expr +} + +func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool { + return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F +} + +func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) { + switch cexpr.(type) { + case *ast.Ident: + ident := cexpr.(*ast.Ident) + // For BuiltinFunc, only int* are allowed + if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok { + if isArgDefaultTypeInt(fname) { + return resolveOk + } + return resolveFail + } + // Untyped constant must represent int + if ident.Name == "_" { + if d.value != nil && d.value.Kind() == gconstant.Int { + return resolveOk + } + return resolveFail + } + if ident.Obj != nil { + if t, ok := ident.Obj.Type.(*ast.Ident); ok { + return d.Validate(t) + } + if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok { + return d.Validate(decl.Type) + } + if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok { + if d.value != nil && d.value.Kind() == gconstant.Int { + return resolveOk + } + return resolveFail + } + } + case *ast.BinaryExpr: + bs := cexpr.(*ast.BinaryExpr) + left, right := bs.X, bs.Y + if bs.Y == d.last { + left, right = right, left + } + + rightCheck := d.Validate(right) + d.last = cexpr + return rightCheck + } + return resolveUnsure +} + +func (d delayedShiftValidator) Pos() token.Pos { + return d.pos +} + +func (d delayedShiftValidator) Error() string { + return "left shift operand should be int" +} diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 1d06be676..ca11d4e66 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -36,7 +36,12 @@ func canTruncateToFloat(v gconstant.Value) bool { var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) -func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { +func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { + defer func() { + // Due to use of early return in the parsing, delayed checks are conducted in defer + ok = ok && cs.tryValidateDelayed(expr) + }() + switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -103,6 +108,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar // Resolve untyped constants. var l gconstant.Value var r gconstant.Value + origLvalue := lhs[0].Const if op2 == shaderir.LeftShift || op2 == shaderir.RightShift { l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst) } else { @@ -126,6 +132,9 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar if lhst.Main == shaderir.None && lhs[0].Const != nil { lhst = shaderir.Type{Main: shaderir.Int} // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. + if rhs[0].Const == nil { + cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr}) + } } } } else { diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 90ed2d611..c89d43cde 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -61,6 +61,8 @@ type compileState struct { varyingParsed bool + delayedTypeCheks map[ast.Expr]delayedValidator + errs []string } @@ -82,6 +84,13 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) { return 0, false } +func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) { + if cs.delayedTypeCheks == nil { + cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1) + } + cs.delayedTypeCheks[at] = check +} + type typ struct { name string ir shaderir.Type @@ -350,6 +359,12 @@ func (cs *compileState) parse(f *ast.File) { for _, f := range cs.funcs { cs.ir.Funcs = append(cs.ir.Funcs, f.ir) } + + // if len(cs.delayedTypeCheks) != 0 { + // for _, check := range cs.delayedTypeCheks { + // cs.addError(check.Pos(), check.Error()) + // } + // } } func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) { diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 1c2920cc4..2d78dd5a9 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,23 +1320,27 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - // {stmt: "s := 1; var a float = float(1 << s); _ = a", err: true}, - // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - // {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, - // {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, - // {stmt: "s := 1; a := 1 << s; _ = a", err: false}, - // {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, - // {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, - // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - // {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, - // {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, - // {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, - // {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, + {stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false}, + {stmt: "s := 1; _ = 1 << s", err: false}, + {stmt: "s := 1; _ = 1.0 << s", err: true}, + {stmt: "var a = 1; b := a << 2.0; _ = b", err: false}, + {stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, + {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, + {stmt: "s := 1; a := 1 << s; _ = a", err: false}, + {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, + {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, + {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, {stmt: "var a float = 1.0 << 2.0; _ = a", err: false}, {stmt: "var a int = 1.0 << 2; _ = a", err: false}, {stmt: "var a float = 1.0 << 2; _ = a", err: false}, - {stmt: "var a = 1.0 << 2; _ = a", err: false}, {stmt: "a := 1 << 2.0; _ = a", err: false}, {stmt: "a := 1.0 << 2; _ = a", err: false}, {stmt: "a := 1.0 << 2.0; _ = a", err: false}, @@ -1362,36 +1366,6 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, - - {stmt: "var a float = 1.0 >> 2.0; _ = a", err: false}, - {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, - {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, - {stmt: "var a = 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1 >> 2.0; _ = a", err: false}, - {stmt: "a := 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, - {stmt: "a := 1 >> 2; _ = a", err: false}, - {stmt: "a := float(1.0) >> 2; _ = a", err: true}, - {stmt: "a := 1 >> float(2.0); _ = a", err: false}, - {stmt: "a := 1.0 >> float(2.0); _ = a", err: false}, - {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, - {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, - {stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false}, - {stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true}, - {stmt: "a := 1 >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> 2; _ = a", err: true}, - {stmt: "a := float(1.0) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> float(2.0); _ = a", err: true}, - {stmt: "a := vec2(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> vec3(2); _ = a", err: true}, - {stmt: "a := vec3(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true}, } for _, c := range cases { @@ -1407,6 +1381,80 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { t.Errorf("%s must not return nil but returned %v", c.stmt, err) } } + + casesFunc := []struct { + prog string + err bool + }{ + { + prog: `package main + func Foo(x int) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1 << s) + return dstPos + }`, + err: false, + }, + { + prog: `package main + func Foo(x int) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1.0 << s) + return dstPos + }`, + err: false, + }, + { + prog: `package main + func Foo(x float) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1 << s) + return dstPos + }`, + err: true, + }, + { + prog: `package main + func Foo(x float) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + s := 1 + Foo(1 << s) + return dstPos + }`, + err: true, + }, + { + prog: `package main + func Foo(x float) { + _ = x + } + func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + Foo(1.0 << 2.0) + return dstPos + }`, + err: false, + }, + } + + for _, c := range casesFunc { + _, err := compileToIR([]byte(c.prog)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.prog) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.prog, err) + } + } } func TestSyntaxOperatorShiftAssign(t *testing.T) { From f02e9fd4d060d1d050fe7254efe49906705a3792 Mon Sep 17 00:00:00 2001 From: aoyako Date: Sat, 2 Mar 2024 15:32:31 +0900 Subject: [PATCH 10/20] add shift type checks --- internal/shader/delayed.go | 217 ++++++++++++++++++++------------- internal/shader/expr.go | 12 +- internal/shader/shader.go | 6 +- internal/shader/stmt.go | 23 ++++ internal/shader/syntax_test.go | 112 ++++++----------- internal/shaderir/check.go | 9 +- internal/shaderir/program.go | 23 ++-- 7 files changed, 211 insertions(+), 191 deletions(-) diff --git a/internal/shader/delayed.go b/internal/shader/delayed.go index 935d5797e..abf675697 100644 --- a/internal/shader/delayed.go +++ b/internal/shader/delayed.go @@ -15,110 +15,151 @@ package shader import ( - "go/ast" + "fmt" gconstant "go/constant" - "go/token" "github.com/hajimehoshi/ebiten/v2/internal/shaderir" ) -type resolveTypeStatus int - -const ( - resolveUnsure resolveTypeStatus = iota - resolveOk - resolveFail -) - -type delayedValidator interface { - Validate(expr ast.Expr) resolveTypeStatus - Pos() token.Pos +type delayedTypeValidator interface { + Validate(t shaderir.Type) (shaderir.Type, bool) + IsValidated() (shaderir.Type, bool) Error() string } -func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) { - valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks)) - for k := range cs.delayedTypeCheks { - valExprs = append(valExprs, k) - } - for _, expr := range valExprs { - if cexpr == expr { - continue - } - // Check if delayed validation can be done by adding current context - cres := cs.delayedTypeCheks[expr].Validate(cexpr) - switch cres { - case resolveFail: - cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error()) - return false - case resolveOk: - delete(cs.delayedTypeCheks, expr) - } - } - - return true -} - -type delayedShiftValidator struct { - value gconstant.Value - pos token.Pos - last ast.Expr -} - func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool { return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F } -func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) { - switch cexpr.(type) { - case *ast.Ident: - ident := cexpr.(*ast.Ident) - // For BuiltinFunc, only int* are allowed - if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok { - if isArgDefaultTypeInt(fname) { - return resolveOk - } - return resolveFail - } - // Untyped constant must represent int - if ident.Name == "_" { - if d.value != nil && d.value.Kind() == gconstant.Int { - return resolveOk - } - return resolveFail - } - if ident.Obj != nil { - if t, ok := ident.Obj.Type.(*ast.Ident); ok { - return d.Validate(t) - } - if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok { - return d.Validate(decl.Type) - } - if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok { - if d.value != nil && d.value.Kind() == gconstant.Int { - return resolveOk - } - return resolveFail - } - } - case *ast.BinaryExpr: - bs := cexpr.(*ast.BinaryExpr) - left, right := bs.X, bs.Y - if bs.Y == d.last { - left, right = right, left - } +func isIntType(t shaderir.Type) bool { + return t.Main == shaderir.Int || t.IsIntVector() +} - rightCheck := d.Validate(right) - d.last = cexpr - return rightCheck +func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type { + if check, ok := cs.delayedTypeCheks[expr.Ast]; ok { + if resT, ok := check.IsValidated(); ok { + return resT + } + resT, ok := check.Validate(t) + if !ok { + return shaderir.Type{Main: shaderir.None} + } + return resT } - return resolveUnsure + + switch expr.Type { + case shaderir.LocalVariable: + return block.vars[expr.Index].typ + + case shaderir.Binary: + left := expr.Exprs[0] + right := expr.Exprs[1] + + leftType := cs.ValidateDefaultTypesForExpr(block, left, t) + rightType := cs.ValidateDefaultTypesForExpr(block, right, t) + + // Usure about top-level type, try to validate by neighbour type + // The same work is done twice. Can it be optimized? + if t.Main == shaderir.None { + cs.ValidateDefaultTypesForExpr(block, left, rightType) + cs.ValidateDefaultTypesForExpr(block, right, leftType) + } + case shaderir.Call: + fun := expr.Exprs[0] + if fun.Type == shaderir.BuiltinFuncExpr { + if isArgDefaultTypeInt(fun.BuiltinFunc) { + for _, e := range expr.Exprs[1:] { + cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int}) + } + return shaderir.Type{Main: shaderir.Int} + } + + for _, e := range expr.Exprs[1:] { + cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float}) + } + return shaderir.Type{Main: shaderir.Float} + } + + if fun.Type == shaderir.FunctionExpr { + args := cs.funcs[fun.Index].ir.InParams + + for i, e := range expr.Exprs[1:] { + cs.ValidateDefaultTypesForExpr(block, e, args[i]) + } + + retT := cs.funcs[fun.Index].ir.Return + + return retT + } + } + + return shaderir.Type{Main: shaderir.None} } -func (d delayedShiftValidator) Pos() token.Pos { - return d.pos +func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) { + switch stmt.Type { + case shaderir.Assign: + left := stmt.Exprs[0] + right := stmt.Exprs[1] + if left.Type == shaderir.LocalVariable { + varType := block.vars[left.Index].typ + // Type is not explicitly specified + if stmt.IsTypeGuessed { + varType = shaderir.Type{Main: shaderir.None} + } + cs.ValidateDefaultTypesForExpr(block, right, varType) + } + case shaderir.ExprStmt: + for _, e := range stmt.Exprs { + cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None}) + } + } } -func (d delayedShiftValidator) Error() string { - return "left shift operand should be int" +type delayedShiftValidator struct { + shiftType shaderir.Op + value gconstant.Value + validated bool + closestUnknown bool + failed bool +} + +func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) { + if d.failed { + return shaderir.Type{}, false + } + if d.validated { + return shaderir.Type{Main: shaderir.Int}, true + } + // If only matched with None + if d.closestUnknown { + // Was it originally represented by an int constant? + if d.value.Kind() == gconstant.Int { + return shaderir.Type{Main: shaderir.Int}, true + } + } + return shaderir.Type{}, false +} + +func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) { + if d.validated { + return shaderir.Type{Main: shaderir.Int}, true + } + if isIntType(t) { + d.validated = true + return shaderir.Type{Main: shaderir.Int}, true + } + if t.Main == shaderir.None { + d.closestUnknown = true + return t, true + } + return shaderir.Type{Main: shaderir.None}, false +} + +func (d *delayedShiftValidator) Error() string { + st := "left shift" + if d.shiftType == shaderir.RightShift { + st = "right shift" + } + return fmt.Sprintf("left operand for %s should be int", st) } diff --git a/internal/shader/expr.go b/internal/shader/expr.go index ca11d4e66..53d328b4e 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -37,11 +37,6 @@ func canTruncateToFloat(v gconstant.Value) bool { var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { - defer func() { - // Due to use of early return in the parsing, delayed checks are conducted in defer - ok = ok && cs.tryValidateDelayed(expr) - }() - switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -133,7 +128,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar lhst = shaderir.Type{Main: shaderir.Int} // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. if rhs[0].Const == nil { - cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr}) + defer func() { + if ok { + cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue}) + } + }() } } } @@ -202,6 +201,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar { Type: shaderir.Binary, Op: op2, + Ast: expr, Exprs: []shaderir.Expr{lhs[0], rhs[0]}, }, }, []shaderir.Type{t}, stmts, true diff --git a/internal/shader/shader.go b/internal/shader/shader.go index c89d43cde..017a7a81d 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -61,7 +61,7 @@ type compileState struct { varyingParsed bool - delayedTypeCheks map[ast.Expr]delayedValidator + delayedTypeCheks map[ast.Expr]delayedTypeValidator errs []string } @@ -84,9 +84,9 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) { return 0, false } -func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) { +func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) { if cs.delayedTypeCheks == nil { - cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1) + cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1) } cs.delayedTypeCheks[at] = check } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 6c07057da..a794a749c 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -49,6 +49,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP if !ok { return nil, false } + for i := range ss { + ss[i].IsTypeGuessed = true + } + stmts = append(stmts, ss...) case token.ASSIGN: if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 { @@ -473,6 +477,25 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) return nil, false } + + // Need to run delayed checks + if len(cs.delayedTypeCheks) != 0 { + for _, st := range stmts { + cs.ValidateDefaultTypes(block, st) + } + + // Collect all errors first + foundErr := false + for s, v := range cs.delayedTypeCheks { + if _, ok := v.IsValidated(); !ok { + foundErr = true + cs.addError(s.Pos(), v.Error()) + } + } + if foundErr { + return nil, false + } + } return stmts, true } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 2d78dd5a9..a355dfa79 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,9 +1320,35 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false}, - {stmt: "s := 1; _ = 1 << s", err: false}, - {stmt: "s := 1; _ = 1.0 << s", err: true}, + {stmt: "s := 1; a := 1.0< Date: Sat, 2 Mar 2024 15:40:38 +0900 Subject: [PATCH 11/20] update tests for right shift --- internal/shader/syntax_test.go | 127 ++++++++++++++++++++++++++------- 1 file changed, 102 insertions(+), 25 deletions(-) diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index a355dfa79..165966158 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,35 +1320,38 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "s := 1; a := 1.0<> 2.0 == b; _ = a", err: false}, + {stmt: "s := 1; b := 2.0; a := 1.0>>s == b; _ = a", err: true}, + {stmt: "s := 1; b := 2; a := 1.0>>s == b; _ = a", err: false}, + {stmt: "s := 1; a := 2.0>>s + ivec2(3.0>>s); _ = a", err: false}, + {stmt: "s := 1; a := 2.0>>s + vec2(3); _ = a", err: true}, + {stmt: "s := 1; a := 2.0>>s + ivec2(3); _ = a", err: false}, + {stmt: "s := 1; a := 2.0>>s + foo_int_int(3.0>>s); _ = a", err: false}, + {stmt: "s := 1; a := 2.0>>s + 3.0>>s; _ = a", err: true}, + {stmt: "s := 1; a := 2>>s + 3.0>>s; _ = a", err: true}, + {stmt: "s := 1; a := 2.0>>s + 3>>s; _ = a", err: true}, + {stmt: "s := 1; a := 2>>s + 3>>s; _ = a", err: false}, + {stmt: "s := 1; foo_multivar(0, 0, 2>>s)", err: false}, + {stmt: "s := 1; foo_multivar(0, 2.0>>s, 0)", err: true}, + {stmt: "s := 1; foo_multivar(2.0>>s, 0, 0)", err: false}, + {stmt: "s := 1; a := foo_multivar(2.0>>s, 0, 0); _ = a", err: false}, + {stmt: "s := 1; a := foo_multivar(0, 2.0>>s, 0); _ = a", err: true}, + {stmt: "s := 1; a := foo_multivar(0, 0, 2.0>>s); _ = a", err: false}, + {stmt: "a := foo_multivar(0, 0, 1.0>>2.0); _ = a", err: false}, + {stmt: "a := foo_multivar(0, 1.0>>2.0, 0); _ = a", err: false}, + {stmt: "s := 1; a := int(1) + 1.0>>s + int(float(1>>s)); _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 >> 2.0 >> 3.0 >> 4.0 >> s; _ = a", err: false}, + {stmt: "s := 1; var a float = 1 >> 1 >> 1 >> 1 >> s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 >> s + 1.2; _ = a", err: true}, + {stmt: "s := 1; a := 1.0 >> s + 1.2; _ = a", err: true}, + {stmt: "s := 1; a := 1.0 >> s + foo_float_float(2); _ = a", err: true}, + {stmt: "s := 1; a := 1.0 >> s + foo_float_int(2); _ = a", err: false}, + {stmt: "s := 1; a := foo_float_int(1.0>>s) + foo_float_int(2); _ = a", err: true}, + {stmt: "s := 1; a := foo_int_float(1>>s) + foo_int_float(2); _ = a", err: false}, + {stmt: "s := 1; a := foo_int_int(1>>s) + foo_int_int(2); _ = a", err: false}, + {stmt: "s := 1; t := 2.0; a := t + 1.0 >> s; _ = a", err: true}, + {stmt: "s := 1; t := 2; a := t + 1.0 >> s; _ = a", err: false}, + {stmt: "s := 1; b := 1 >> s; _ = b", err: false}, + {stmt: "var a = 1; b := a >> 2.0; _ = b", err: false}, + {stmt: "s := 1; var a float; a = 1 >> s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, + {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, + {stmt: "s := 1; var a int = int(1 >> s); _ = a", err: false}, + {stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false}, + {stmt: "s := 1; a := 1 >> s; _ = a", err: false}, + {stmt: "s := 1; a := 1.0 >> s; _ = a", err: true}, + {stmt: "s := 1; a := int(1.0 >> s); _ = a", err: false}, + {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, + {stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false}, + {stmt: "s := 1; var a int = 1 >> s; _ = a", err: false}, + {stmt: "var a float = 1.0 >> 2.0; _ = a", err: false}, + {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, + {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1 >> 2.0; _ = a", err: false}, + {stmt: "a := 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, + {stmt: "a := 1 >> 2; _ = a", err: false}, + {stmt: "a := float(1.0) >> 2; _ = a", err: true}, + {stmt: "a := 1 >> float(2.0); _ = a", err: false}, + {stmt: "a := 1.0 >> float(2.0); _ = a", err: false}, + {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, + {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true}, + {stmt: "a := 1 >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> 2; _ = a", err: true}, + {stmt: "a := float(1.0) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true}, } for _, c := range cases { From 359e7b8597b985cabfa0fb9cf85fbc1205734a02 Mon Sep 17 00:00:00 2001 From: aoyako Date: Sat, 2 Mar 2024 15:41:30 +0900 Subject: [PATCH 12/20] remove comment --- internal/shader/shader.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 017a7a81d..381630a2e 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -359,12 +359,6 @@ func (cs *compileState) parse(f *ast.File) { for _, f := range cs.funcs { cs.ir.Funcs = append(cs.ir.Funcs, f.ir) } - - // if len(cs.delayedTypeCheks) != 0 { - // for _, check := range cs.delayedTypeCheks { - // cs.addError(check.Pos(), check.Error()) - // } - // } } func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) { From 29fb6c7f6f8e0ea5a5cb518e9f276c22e74685ee Mon Sep 17 00:00:00 2001 From: aoyako Date: Thu, 21 Mar 2024 17:13:23 +0900 Subject: [PATCH 13/20] Revert "remove comment" This reverts commit 359e7b8597b985cabfa0fb9cf85fbc1205734a02. --- internal/shader/shader.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 381630a2e..017a7a81d 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -359,6 +359,12 @@ func (cs *compileState) parse(f *ast.File) { for _, f := range cs.funcs { cs.ir.Funcs = append(cs.ir.Funcs, f.ir) } + + // if len(cs.delayedTypeCheks) != 0 { + // for _, check := range cs.delayedTypeCheks { + // cs.addError(check.Pos(), check.Error()) + // } + // } } func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) { From 4f3e649bf03aff1d68303c82d2080d3595886167 Mon Sep 17 00:00:00 2001 From: aoyako Date: Thu, 21 Mar 2024 17:13:34 +0900 Subject: [PATCH 14/20] Revert "update tests for right shift" This reverts commit d1b9216ee1b51bcda7dd829a11b931d6cdb3b552. --- internal/shader/syntax_test.go | 127 +++++++-------------------------- 1 file changed, 25 insertions(+), 102 deletions(-) diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 165966158..a355dfa79 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,38 +1320,35 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "b := 2.0; a := 1.0 << 2.0 == b; _ = a", err: false}, - {stmt: "s := 1; b := 2.0; a := 1.0<> 2.0 == b; _ = a", err: false}, - {stmt: "s := 1; b := 2.0; a := 1.0>>s == b; _ = a", err: true}, - {stmt: "s := 1; b := 2; a := 1.0>>s == b; _ = a", err: false}, - {stmt: "s := 1; a := 2.0>>s + ivec2(3.0>>s); _ = a", err: false}, - {stmt: "s := 1; a := 2.0>>s + vec2(3); _ = a", err: true}, - {stmt: "s := 1; a := 2.0>>s + ivec2(3); _ = a", err: false}, - {stmt: "s := 1; a := 2.0>>s + foo_int_int(3.0>>s); _ = a", err: false}, - {stmt: "s := 1; a := 2.0>>s + 3.0>>s; _ = a", err: true}, - {stmt: "s := 1; a := 2>>s + 3.0>>s; _ = a", err: true}, - {stmt: "s := 1; a := 2.0>>s + 3>>s; _ = a", err: true}, - {stmt: "s := 1; a := 2>>s + 3>>s; _ = a", err: false}, - {stmt: "s := 1; foo_multivar(0, 0, 2>>s)", err: false}, - {stmt: "s := 1; foo_multivar(0, 2.0>>s, 0)", err: true}, - {stmt: "s := 1; foo_multivar(2.0>>s, 0, 0)", err: false}, - {stmt: "s := 1; a := foo_multivar(2.0>>s, 0, 0); _ = a", err: false}, - {stmt: "s := 1; a := foo_multivar(0, 2.0>>s, 0); _ = a", err: true}, - {stmt: "s := 1; a := foo_multivar(0, 0, 2.0>>s); _ = a", err: false}, - {stmt: "a := foo_multivar(0, 0, 1.0>>2.0); _ = a", err: false}, - {stmt: "a := foo_multivar(0, 1.0>>2.0, 0); _ = a", err: false}, - {stmt: "s := 1; a := int(1) + 1.0>>s + int(float(1>>s)); _ = a", err: true}, - {stmt: "s := 1; var a int = 1.0 >> 2.0 >> 3.0 >> 4.0 >> s; _ = a", err: false}, - {stmt: "s := 1; var a float = 1 >> 1 >> 1 >> 1 >> s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 >> s + 1.2; _ = a", err: true}, - {stmt: "s := 1; a := 1.0 >> s + 1.2; _ = a", err: true}, - {stmt: "s := 1; a := 1.0 >> s + foo_float_float(2); _ = a", err: true}, - {stmt: "s := 1; a := 1.0 >> s + foo_float_int(2); _ = a", err: false}, - {stmt: "s := 1; a := foo_float_int(1.0>>s) + foo_float_int(2); _ = a", err: true}, - {stmt: "s := 1; a := foo_int_float(1>>s) + foo_int_float(2); _ = a", err: false}, - {stmt: "s := 1; a := foo_int_int(1>>s) + foo_int_int(2); _ = a", err: false}, - {stmt: "s := 1; t := 2.0; a := t + 1.0 >> s; _ = a", err: true}, - {stmt: "s := 1; t := 2; a := t + 1.0 >> s; _ = a", err: false}, - {stmt: "s := 1; b := 1 >> s; _ = b", err: false}, - {stmt: "var a = 1; b := a >> 2.0; _ = b", err: false}, - {stmt: "s := 1; var a float; a = 1 >> s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, - {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, - {stmt: "s := 1; var a int = int(1 >> s); _ = a", err: false}, - {stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false}, - {stmt: "s := 1; a := 1 >> s; _ = a", err: false}, - {stmt: "s := 1; a := 1.0 >> s; _ = a", err: true}, - {stmt: "s := 1; a := int(1.0 >> s); _ = a", err: false}, - {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, - {stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, - {stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false}, - {stmt: "s := 1; var a int = 1 >> s; _ = a", err: false}, - {stmt: "var a float = 1.0 >> 2.0; _ = a", err: false}, - {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, - {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1 >> 2.0; _ = a", err: false}, - {stmt: "a := 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, - {stmt: "a := 1 >> 2; _ = a", err: false}, - {stmt: "a := float(1.0) >> 2; _ = a", err: true}, - {stmt: "a := 1 >> float(2.0); _ = a", err: false}, - {stmt: "a := 1.0 >> float(2.0); _ = a", err: false}, - {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, - {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, - {stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false}, - {stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true}, - {stmt: "a := 1 >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> 2; _ = a", err: true}, - {stmt: "a := float(1.0) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> float(2.0); _ = a", err: true}, - {stmt: "a := vec2(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> vec3(2); _ = a", err: true}, - {stmt: "a := vec3(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true}, - {stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true}, - {stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true}, } for _, c := range cases { From 0651a090528974d16230c5fd8e6877bfcc4f7695 Mon Sep 17 00:00:00 2001 From: aoyako Date: Thu, 21 Mar 2024 17:14:04 +0900 Subject: [PATCH 15/20] Revert "add shift type checks" This reverts commit f02e9fd4d060d1d050fe7254efe49906705a3792. --- internal/shader/delayed.go | 207 +++++++++++++-------------------- internal/shader/expr.go | 12 +- internal/shader/shader.go | 6 +- internal/shader/stmt.go | 23 ---- internal/shader/syntax_test.go | 112 ++++++++++++------ internal/shaderir/check.go | 9 +- internal/shaderir/program.go | 23 ++-- 7 files changed, 186 insertions(+), 206 deletions(-) diff --git a/internal/shader/delayed.go b/internal/shader/delayed.go index abf675697..935d5797e 100644 --- a/internal/shader/delayed.go +++ b/internal/shader/delayed.go @@ -15,151 +15,110 @@ package shader import ( - "fmt" + "go/ast" gconstant "go/constant" + "go/token" "github.com/hajimehoshi/ebiten/v2/internal/shaderir" ) -type delayedTypeValidator interface { - Validate(t shaderir.Type) (shaderir.Type, bool) - IsValidated() (shaderir.Type, bool) +type resolveTypeStatus int + +const ( + resolveUnsure resolveTypeStatus = iota + resolveOk + resolveFail +) + +type delayedValidator interface { + Validate(expr ast.Expr) resolveTypeStatus + Pos() token.Pos Error() string } +func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) { + valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks)) + for k := range cs.delayedTypeCheks { + valExprs = append(valExprs, k) + } + for _, expr := range valExprs { + if cexpr == expr { + continue + } + // Check if delayed validation can be done by adding current context + cres := cs.delayedTypeCheks[expr].Validate(cexpr) + switch cres { + case resolveFail: + cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error()) + return false + case resolveOk: + delete(cs.delayedTypeCheks, expr) + } + } + + return true +} + +type delayedShiftValidator struct { + value gconstant.Value + pos token.Pos + last ast.Expr +} + func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool { return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F } -func isIntType(t shaderir.Type) bool { - return t.Main == shaderir.Int || t.IsIntVector() -} - -func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type { - if check, ok := cs.delayedTypeCheks[expr.Ast]; ok { - if resT, ok := check.IsValidated(); ok { - return resT +func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) { + switch cexpr.(type) { + case *ast.Ident: + ident := cexpr.(*ast.Ident) + // For BuiltinFunc, only int* are allowed + if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok { + if isArgDefaultTypeInt(fname) { + return resolveOk + } + return resolveFail } - resT, ok := check.Validate(t) - if !ok { - return shaderir.Type{Main: shaderir.None} + // Untyped constant must represent int + if ident.Name == "_" { + if d.value != nil && d.value.Kind() == gconstant.Int { + return resolveOk + } + return resolveFail } - return resT - } - - switch expr.Type { - case shaderir.LocalVariable: - return block.vars[expr.Index].typ - - case shaderir.Binary: - left := expr.Exprs[0] - right := expr.Exprs[1] - - leftType := cs.ValidateDefaultTypesForExpr(block, left, t) - rightType := cs.ValidateDefaultTypesForExpr(block, right, t) - - // Usure about top-level type, try to validate by neighbour type - // The same work is done twice. Can it be optimized? - if t.Main == shaderir.None { - cs.ValidateDefaultTypesForExpr(block, left, rightType) - cs.ValidateDefaultTypesForExpr(block, right, leftType) - } - case shaderir.Call: - fun := expr.Exprs[0] - if fun.Type == shaderir.BuiltinFuncExpr { - if isArgDefaultTypeInt(fun.BuiltinFunc) { - for _, e := range expr.Exprs[1:] { - cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int}) + if ident.Obj != nil { + if t, ok := ident.Obj.Type.(*ast.Ident); ok { + return d.Validate(t) + } + if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok { + return d.Validate(decl.Type) + } + if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok { + if d.value != nil && d.value.Kind() == gconstant.Int { + return resolveOk } - return shaderir.Type{Main: shaderir.Int} + return resolveFail } - - for _, e := range expr.Exprs[1:] { - cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float}) - } - return shaderir.Type{Main: shaderir.Float} + } + case *ast.BinaryExpr: + bs := cexpr.(*ast.BinaryExpr) + left, right := bs.X, bs.Y + if bs.Y == d.last { + left, right = right, left } - if fun.Type == shaderir.FunctionExpr { - args := cs.funcs[fun.Index].ir.InParams - - for i, e := range expr.Exprs[1:] { - cs.ValidateDefaultTypesForExpr(block, e, args[i]) - } - - retT := cs.funcs[fun.Index].ir.Return - - return retT - } + rightCheck := d.Validate(right) + d.last = cexpr + return rightCheck } - - return shaderir.Type{Main: shaderir.None} + return resolveUnsure } -func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) { - switch stmt.Type { - case shaderir.Assign: - left := stmt.Exprs[0] - right := stmt.Exprs[1] - if left.Type == shaderir.LocalVariable { - varType := block.vars[left.Index].typ - // Type is not explicitly specified - if stmt.IsTypeGuessed { - varType = shaderir.Type{Main: shaderir.None} - } - cs.ValidateDefaultTypesForExpr(block, right, varType) - } - case shaderir.ExprStmt: - for _, e := range stmt.Exprs { - cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None}) - } - } +func (d delayedShiftValidator) Pos() token.Pos { + return d.pos } -type delayedShiftValidator struct { - shiftType shaderir.Op - value gconstant.Value - validated bool - closestUnknown bool - failed bool -} - -func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) { - if d.failed { - return shaderir.Type{}, false - } - if d.validated { - return shaderir.Type{Main: shaderir.Int}, true - } - // If only matched with None - if d.closestUnknown { - // Was it originally represented by an int constant? - if d.value.Kind() == gconstant.Int { - return shaderir.Type{Main: shaderir.Int}, true - } - } - return shaderir.Type{}, false -} - -func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) { - if d.validated { - return shaderir.Type{Main: shaderir.Int}, true - } - if isIntType(t) { - d.validated = true - return shaderir.Type{Main: shaderir.Int}, true - } - if t.Main == shaderir.None { - d.closestUnknown = true - return t, true - } - return shaderir.Type{Main: shaderir.None}, false -} - -func (d *delayedShiftValidator) Error() string { - st := "left shift" - if d.shiftType == shaderir.RightShift { - st = "right shift" - } - return fmt.Sprintf("left operand for %s should be int", st) +func (d delayedShiftValidator) Error() string { + return "left shift operand should be int" } diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 53d328b4e..ca11d4e66 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -37,6 +37,11 @@ func canTruncateToFloat(v gconstant.Value) bool { var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { + defer func() { + // Due to use of early return in the parsing, delayed checks are conducted in defer + ok = ok && cs.tryValidateDelayed(expr) + }() + switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -128,11 +133,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar lhst = shaderir.Type{Main: shaderir.Int} // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. if rhs[0].Const == nil { - defer func() { - if ok { - cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue}) - } - }() + cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr}) } } } @@ -201,7 +202,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar { Type: shaderir.Binary, Op: op2, - Ast: expr, Exprs: []shaderir.Expr{lhs[0], rhs[0]}, }, }, []shaderir.Type{t}, stmts, true diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 017a7a81d..c89d43cde 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -61,7 +61,7 @@ type compileState struct { varyingParsed bool - delayedTypeCheks map[ast.Expr]delayedTypeValidator + delayedTypeCheks map[ast.Expr]delayedValidator errs []string } @@ -84,9 +84,9 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) { return 0, false } -func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) { +func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) { if cs.delayedTypeCheks == nil { - cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1) + cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1) } cs.delayedTypeCheks[at] = check } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index a794a749c..6c07057da 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -49,10 +49,6 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP if !ok { return nil, false } - for i := range ss { - ss[i].IsTypeGuessed = true - } - stmts = append(stmts, ss...) case token.ASSIGN: if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 { @@ -477,25 +473,6 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) return nil, false } - - // Need to run delayed checks - if len(cs.delayedTypeCheks) != 0 { - for _, st := range stmts { - cs.ValidateDefaultTypes(block, st) - } - - // Collect all errors first - foundErr := false - for s, v := range cs.delayedTypeCheks { - if _, ok := v.IsValidated(); !ok { - foundErr = true - cs.addError(s.Pos(), v.Error()) - } - } - if foundErr { - return nil, false - } - } return stmts, true } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index a355dfa79..2d78dd5a9 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,35 +1320,9 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "s := 1; a := 1.0< Date: Thu, 21 Mar 2024 17:14:16 +0900 Subject: [PATCH 16/20] Revert "add basic checks" This reverts commit f44640778d14571a26233bf98a860309a7cc948d. --- internal/shader/delayed.go | 124 ------------------------------ internal/shader/expr.go | 11 +-- internal/shader/shader.go | 15 ---- internal/shader/syntax_test.go | 134 +++++++++++---------------------- 4 files changed, 44 insertions(+), 240 deletions(-) delete mode 100644 internal/shader/delayed.go diff --git a/internal/shader/delayed.go b/internal/shader/delayed.go deleted file mode 100644 index 935d5797e..000000000 --- a/internal/shader/delayed.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2024 The Ebiten Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shader - -import ( - "go/ast" - gconstant "go/constant" - "go/token" - - "github.com/hajimehoshi/ebiten/v2/internal/shaderir" -) - -type resolveTypeStatus int - -const ( - resolveUnsure resolveTypeStatus = iota - resolveOk - resolveFail -) - -type delayedValidator interface { - Validate(expr ast.Expr) resolveTypeStatus - Pos() token.Pos - Error() string -} - -func (cs *compileState) tryValidateDelayed(cexpr ast.Expr) (ok bool) { - valExprs := make([]ast.Expr, 0, len(cs.delayedTypeCheks)) - for k := range cs.delayedTypeCheks { - valExprs = append(valExprs, k) - } - for _, expr := range valExprs { - if cexpr == expr { - continue - } - // Check if delayed validation can be done by adding current context - cres := cs.delayedTypeCheks[expr].Validate(cexpr) - switch cres { - case resolveFail: - cs.addError(cs.delayedTypeCheks[expr].Pos(), cs.delayedTypeCheks[expr].Error()) - return false - case resolveOk: - delete(cs.delayedTypeCheks, expr) - } - } - - return true -} - -type delayedShiftValidator struct { - value gconstant.Value - pos token.Pos - last ast.Expr -} - -func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool { - return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F -} - -func (d *delayedShiftValidator) Validate(cexpr ast.Expr) (rs resolveTypeStatus) { - switch cexpr.(type) { - case *ast.Ident: - ident := cexpr.(*ast.Ident) - // For BuiltinFunc, only int* are allowed - if fname, ok := shaderir.ParseBuiltinFunc(ident.Name); ok { - if isArgDefaultTypeInt(fname) { - return resolveOk - } - return resolveFail - } - // Untyped constant must represent int - if ident.Name == "_" { - if d.value != nil && d.value.Kind() == gconstant.Int { - return resolveOk - } - return resolveFail - } - if ident.Obj != nil { - if t, ok := ident.Obj.Type.(*ast.Ident); ok { - return d.Validate(t) - } - if decl, ok := ident.Obj.Decl.(*ast.ValueSpec); ok { - return d.Validate(decl.Type) - } - if _, ok := ident.Obj.Decl.(*ast.AssignStmt); ok { - if d.value != nil && d.value.Kind() == gconstant.Int { - return resolveOk - } - return resolveFail - } - } - case *ast.BinaryExpr: - bs := cexpr.(*ast.BinaryExpr) - left, right := bs.X, bs.Y - if bs.Y == d.last { - left, right = right, left - } - - rightCheck := d.Validate(right) - d.last = cexpr - return rightCheck - } - return resolveUnsure -} - -func (d delayedShiftValidator) Pos() token.Pos { - return d.pos -} - -func (d delayedShiftValidator) Error() string { - return "left shift operand should be int" -} diff --git a/internal/shader/expr.go b/internal/shader/expr.go index ca11d4e66..1d06be676 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -36,12 +36,7 @@ func canTruncateToFloat(v gconstant.Value) bool { var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) -func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) { - defer func() { - // Due to use of early return in the parsing, delayed checks are conducted in defer - ok = ok && cs.tryValidateDelayed(expr) - }() - +func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -108,7 +103,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar // Resolve untyped constants. var l gconstant.Value var r gconstant.Value - origLvalue := lhs[0].Const if op2 == shaderir.LeftShift || op2 == shaderir.RightShift { l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst) } else { @@ -132,9 +126,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar if lhst.Main == shaderir.None && lhs[0].Const != nil { lhst = shaderir.Type{Main: shaderir.Int} // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. - if rhs[0].Const == nil { - cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue, pos: e.Pos(), last: expr}) - } } } } else { diff --git a/internal/shader/shader.go b/internal/shader/shader.go index c89d43cde..90ed2d611 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -61,8 +61,6 @@ type compileState struct { varyingParsed bool - delayedTypeCheks map[ast.Expr]delayedValidator - errs []string } @@ -84,13 +82,6 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) { return 0, false } -func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedValidator) { - if cs.delayedTypeCheks == nil { - cs.delayedTypeCheks = make(map[ast.Expr]delayedValidator, 1) - } - cs.delayedTypeCheks[at] = check -} - type typ struct { name string ir shaderir.Type @@ -359,12 +350,6 @@ func (cs *compileState) parse(f *ast.File) { for _, f := range cs.funcs { cs.ir.Funcs = append(cs.ir.Funcs, f.ir) } - - // if len(cs.delayedTypeCheks) != 0 { - // for _, check := range cs.delayedTypeCheks { - // cs.addError(check.Pos(), check.Error()) - // } - // } } func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) { diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 2d78dd5a9..1c2920cc4 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,27 +1320,23 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "s := 1; t := 2; _ = t + 1.0 << s", err: false}, - {stmt: "s := 1; _ = 1 << s", err: false}, - {stmt: "s := 1; _ = 1.0 << s", err: true}, - {stmt: "var a = 1; b := a << 2.0; _ = b", err: false}, - {stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, - {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, - {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, - {stmt: "s := 1; a := 1 << s; _ = a", err: false}, - {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, - {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, - {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, - {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, - {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, + // {stmt: "s := 1; var a float = float(1 << s); _ = a", err: true}, + // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + // {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, + // {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, + // {stmt: "s := 1; a := 1 << s; _ = a", err: false}, + // {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, + // {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, + // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + // {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, + // {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + // {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, + // {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, {stmt: "var a float = 1.0 << 2.0; _ = a", err: false}, {stmt: "var a int = 1.0 << 2; _ = a", err: false}, {stmt: "var a float = 1.0 << 2; _ = a", err: false}, + {stmt: "var a = 1.0 << 2; _ = a", err: false}, {stmt: "a := 1 << 2.0; _ = a", err: false}, {stmt: "a := 1.0 << 2; _ = a", err: false}, {stmt: "a := 1.0 << 2.0; _ = a", err: false}, @@ -1366,6 +1362,36 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, + + {stmt: "var a float = 1.0 >> 2.0; _ = a", err: false}, + {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, + {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, + {stmt: "var a = 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1 >> 2.0; _ = a", err: false}, + {stmt: "a := 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, + {stmt: "a := 1 >> 2; _ = a", err: false}, + {stmt: "a := float(1.0) >> 2; _ = a", err: true}, + {stmt: "a := 1 >> float(2.0); _ = a", err: false}, + {stmt: "a := 1.0 >> float(2.0); _ = a", err: false}, + {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, + {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true}, + {stmt: "a := 1 >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> 2; _ = a", err: true}, + {stmt: "a := float(1.0) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true}, } for _, c := range cases { @@ -1381,80 +1407,6 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { t.Errorf("%s must not return nil but returned %v", c.stmt, err) } } - - casesFunc := []struct { - prog string - err bool - }{ - { - prog: `package main - func Foo(x int) { - _ = x - } - func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { - s := 1 - Foo(1 << s) - return dstPos - }`, - err: false, - }, - { - prog: `package main - func Foo(x int) { - _ = x - } - func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { - s := 1 - Foo(1.0 << s) - return dstPos - }`, - err: false, - }, - { - prog: `package main - func Foo(x float) { - _ = x - } - func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { - s := 1 - Foo(1 << s) - return dstPos - }`, - err: true, - }, - { - prog: `package main - func Foo(x float) { - _ = x - } - func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { - s := 1 - Foo(1 << s) - return dstPos - }`, - err: true, - }, - { - prog: `package main - func Foo(x float) { - _ = x - } - func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { - Foo(1.0 << 2.0) - return dstPos - }`, - err: false, - }, - } - - for _, c := range casesFunc { - _, err := compileToIR([]byte(c.prog)) - if err == nil && c.err { - t.Errorf("%s must return an error but does not", c.prog) - } else if err != nil && !c.err { - t.Errorf("%s must not return nil but returned %v", c.prog, err) - } - } } func TestSyntaxOperatorShiftAssign(t *testing.T) { From 8d15a459cffc5133a41b1bfa21091ee243b38872 Mon Sep 17 00:00:00 2001 From: aoyako Date: Thu, 21 Mar 2024 17:14:24 +0900 Subject: [PATCH 17/20] Revert "remove return type for deduced int" This reverts commit 66a4b20bdabf90a28f7498b7937d08329538264a. --- internal/shader/expr.go | 7 +++++-- internal/shader/stmt.go | 6 ++++++ internal/shader/syntax_test.go | 23 ++++++++++++++--------- internal/shader/type.go | 2 +- internal/shaderir/check.go | 2 +- internal/shaderir/type.go | 1 + 6 files changed, 28 insertions(+), 13 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 1d06be676..b3dac6a40 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -124,8 +124,11 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar // If left is untyped const if lhst.Main == shaderir.None && lhs[0].Const != nil { - lhst = shaderir.Type{Main: shaderir.Int} - // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone. + if rhs[0].Const != nil { + lhst = shaderir.Type{Main: shaderir.Int} + } else { + lhst = shaderir.Type{Main: shaderir.DeducedInt} + } } } } else { diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 6c07057da..b3d8e1732 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -514,6 +514,9 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r return nil, false } t := ts[0] + if t.Main == shaderir.DeducedInt { + cs.addError(pos, "invalid operation: shifted operand 1 (type float) must be integer") + } if t.Main == shaderir.None { t = toDefaultType(r[0].Const) } @@ -705,6 +708,9 @@ func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool { if lt.Equal(rt) { return true } + if lt.Main == shaderir.Int && rt.Main == shaderir.DeducedInt { + return true + } if rc == nil { return false diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 1c2920cc4..df2c1c84d 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1324,16 +1324,16 @@ func TestSyntaxOperatorShift(t *testing.T) { // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, // {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, // {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, - // {stmt: "s := 1; a := 1 << s; _ = a", err: false}, - // {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, - // {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, - // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - // {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, - // {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, - // {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, - // {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, + {stmt: "s := 1; a := 1 << s; _ = a", err: false}, + {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, {stmt: "var a float = 1.0 << 2.0; _ = a", err: false}, + {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, + {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, + {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, {stmt: "var a int = 1.0 << 2; _ = a", err: false}, {stmt: "var a float = 1.0 << 2; _ = a", err: false}, {stmt: "var a = 1.0 << 2; _ = a", err: false}, @@ -1363,7 +1363,12 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, - {stmt: "var a float = 1.0 >> 2.0; _ = a", err: false}, + {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, + {stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false}, + {stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true}, + {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, + {stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false}, + {stmt: "s := 1; var a int = 1 >> s; _ = a", err: false}, {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, {stmt: "var a = 1.0 >> 2; _ = a", err: false}, diff --git a/internal/shader/type.go b/internal/shader/type.go index 546407556..5c8aade39 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -165,7 +165,7 @@ func checkArgsForIntBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) err if len(args) != 1 { return fmt.Errorf("number of int's arguments must be 1 but %d", len(args)) } - if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float { + if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float || argts[0].Main == shaderir.DeducedInt { return nil } if args[0].Const != nil && gconstant.ToInt(args[0].Const).Kind() != gconstant.Unknown { diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index 57bacc749..e518726b1 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -145,7 +145,7 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) } if op == LeftShift || op == RightShift { - if lhst.Main == Int && rhst.Main == Int { + if (lhst.Main == Int || lhst.Main == DeducedInt) && rhst.Main == Int { return Type{Main: lhst.Main}, true } if lhst.IsIntVector() && rhst.Main == Int { diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index ede2c91e2..885d579c4 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -180,6 +180,7 @@ const ( Texture Array Struct + DeducedInt ) func descendantLocalVars(block, target *Block) ([]Type, bool) { From 49682f1097002bb7a30cf443bc7f7c894cebf7e3 Mon Sep 17 00:00:00 2001 From: aoyako Date: Thu, 21 Mar 2024 17:14:32 +0900 Subject: [PATCH 18/20] Revert "add return type for type resolving" This reverts commit 7f9d9971758e979488b7e6a6e9ff1208036f54e4. --- internal/shader/expr.go | 64 +++++++++++----------------------- internal/shader/stmt.go | 6 ---- internal/shader/syntax_test.go | 41 +++------------------- internal/shader/type.go | 2 +- internal/shaderir/check.go | 62 ++++++++++++++------------------ internal/shaderir/type.go | 1 - 6 files changed, 52 insertions(+), 124 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index b3dac6a40..cc41cea32 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -101,13 +101,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } // Resolve untyped constants. - var l gconstant.Value - var r gconstant.Value - if op2 == shaderir.LeftShift || op2 == shaderir.RightShift { - l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst) - } else { - l, r, ok = shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) - } + l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) if !ok { // TODO: Show a better type name for untyped constants. cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) @@ -115,45 +109,27 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } lhs[0].Const, rhs[0].Const = l, r - if op2 == shaderir.LeftShift || op2 == shaderir.RightShift { - if !(lhst.Main == shaderir.None && rhst.Main == shaderir.None) { - // If both are const - if rhs[0].Const != nil && (rhst.Main == shaderir.None || lhs[0].Const != nil) { - rhst = shaderir.Type{Main: shaderir.Int} - } - - // If left is untyped const - if lhst.Main == shaderir.None && lhs[0].Const != nil { - if rhs[0].Const != nil { - lhst = shaderir.Type{Main: shaderir.Int} - } else { - lhst = shaderir.Type{Main: shaderir.DeducedInt} - } + // If either is typed, resolve the other type. + // If both are untyped, keep them untyped. + if lhst.Main != shaderir.None || rhst.Main != shaderir.None { + if lhs[0].Const != nil { + switch lhs[0].Const.Kind() { + case gconstant.Float: + lhst = shaderir.Type{Main: shaderir.Float} + case gconstant.Int: + lhst = shaderir.Type{Main: shaderir.Int} + case gconstant.Bool: + lhst = shaderir.Type{Main: shaderir.Bool} } } - } else { - // If either is typed, resolve the other type. - // If both are untyped, keep them untyped. - if lhst.Main != shaderir.None || rhst.Main != shaderir.None { - if lhs[0].Const != nil { - switch lhs[0].Const.Kind() { - case gconstant.Float: - lhst = shaderir.Type{Main: shaderir.Float} - case gconstant.Int: - lhst = shaderir.Type{Main: shaderir.Int} - case gconstant.Bool: - lhst = shaderir.Type{Main: shaderir.Bool} - } - } - if rhs[0].Const != nil { - switch rhs[0].Const.Kind() { - case gconstant.Float: - rhst = shaderir.Type{Main: shaderir.Float} - case gconstant.Int: - rhst = shaderir.Type{Main: shaderir.Int} - case gconstant.Bool: - rhst = shaderir.Type{Main: shaderir.Bool} - } + if rhs[0].Const != nil { + switch rhs[0].Const.Kind() { + case gconstant.Float: + rhst = shaderir.Type{Main: shaderir.Float} + case gconstant.Int: + rhst = shaderir.Type{Main: shaderir.Int} + case gconstant.Bool: + rhst = shaderir.Type{Main: shaderir.Bool} } } } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index b3d8e1732..6c07057da 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -514,9 +514,6 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r return nil, false } t := ts[0] - if t.Main == shaderir.DeducedInt { - cs.addError(pos, "invalid operation: shifted operand 1 (type float) must be integer") - } if t.Main == shaderir.None { t = toDefaultType(r[0].Const) } @@ -708,9 +705,6 @@ func canAssign(lt *shaderir.Type, rt *shaderir.Type, rc gconstant.Value) bool { if lt.Equal(rt) { return true } - if lt.Main == shaderir.Int && rt.Main == shaderir.DeducedInt { - return true - } if rc == nil { return false diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index df2c1c84d..e50ff7d4a 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,30 +1320,12 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - // {stmt: "s := 1; var a float = float(1 << s); _ = a", err: true}, - // {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - // {stmt: "s := 1; var a int = int(1 << s); _ = a", err: false}, - // {stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false}, - - {stmt: "s := 1; a := 1 << s; _ = a", err: false}, - {stmt: "s := 1; a := 1.0 << s; _ = a", err: true}, - {stmt: "s := 1; a := int(1.0 << s); _ = a", err: false}, - {stmt: "var a float = 1.0 << 2.0; _ = a", err: false}, - {stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true}, - {stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 << s; _ = a", err: true}, - {stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false}, - {stmt: "s := 1; var a int = 1 << s; _ = a", err: false}, - {stmt: "var a int = 1.0 << 2; _ = a", err: false}, - {stmt: "var a float = 1.0 << 2; _ = a", err: false}, - {stmt: "var a = 1.0 << 2; _ = a", err: false}, - {stmt: "a := 1 << 2.0; _ = a", err: false}, - {stmt: "a := 1.0 << 2; _ = a", err: false}, - {stmt: "a := 1.0 << 2.0; _ = a", err: false}, + {stmt: "a := 1 << 2.0; _ = a", err: true}, + {stmt: "a := 1.0 << 2; _ = a", err: true}, + {stmt: "a := 1.0 << 2.0; _ = a", err: true}, {stmt: "a := 1 << 2; _ = a", err: false}, {stmt: "a := float(1.0) << 2; _ = a", err: true}, - {stmt: "a := 1 << float(2.0); _ = a", err: false}, - {stmt: "a := 1.0 << float(2.0); _ = a", err: false}, + {stmt: "a := 1 << float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1) << 2; _ = a", err: false}, {stmt: "a := 1 << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << float(2.0); _ = a", err: true}, @@ -1363,22 +1345,9 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, - {stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true}, - {stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false}, - {stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true}, - {stmt: "s := 1; var a float = 1 >> s; _ = a", err: true}, - {stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false}, - {stmt: "s := 1; var a int = 1 >> s; _ = a", err: false}, - {stmt: "var a int = 1.0 >> 2; _ = a", err: false}, - {stmt: "var a float = 1.0 >> 2; _ = a", err: false}, - {stmt: "var a = 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1 >> 2.0; _ = a", err: false}, - {stmt: "a := 1.0 >> 2; _ = a", err: false}, - {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, {stmt: "a := 1 >> 2; _ = a", err: false}, {stmt: "a := float(1.0) >> 2; _ = a", err: true}, - {stmt: "a := 1 >> float(2.0); _ = a", err: false}, - {stmt: "a := 1.0 >> float(2.0); _ = a", err: false}, + {stmt: "a := 1 >> float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, diff --git a/internal/shader/type.go b/internal/shader/type.go index 5c8aade39..546407556 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -165,7 +165,7 @@ func checkArgsForIntBuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) err if len(args) != 1 { return fmt.Errorf("number of int's arguments must be 1 but %d", len(args)) } - if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float || argts[0].Main == shaderir.DeducedInt { + if argts[0].Main == shaderir.Int || argts[0].Main == shaderir.Float { return nil } if args[0].Const != nil && gconstant.ToInt(args[0].Const).Kind() != gconstant.Unknown { diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index e518726b1..cb93f3b79 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -18,29 +18,6 @@ import ( "go/constant" ) -func ResolveUntypedConstsForBitShiftOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { - cLhs := lhs - cRhs := rhs - - // Right is const -> int - if rhs != nil { - cRhs = constant.ToInt(rhs) - if cRhs.Kind() == constant.Unknown { - return nil, nil, false - } - } - - // Left if untyped const -> int - if lhs != nil && lhst.Main == None { - cLhs = constant.ToInt(lhs) - if cLhs.Kind() == constant.Unknown { - return nil, nil, false - } - } - - return cLhs, cRhs, true -} - func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { if lhst.Main == None && rhst.Main == None { if lhs.Kind() == rhs.Kind() { @@ -121,6 +98,13 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } + if op == LeftShift || op == RightShift { + if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int { + return Type{Main: Int}, true + } + return Type{}, false + } + if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { return Type{Main: Bool}, true } @@ -144,19 +128,6 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) panic("shaderir: cannot resolve untyped values") } - if op == LeftShift || op == RightShift { - if (lhst.Main == Int || lhst.Main == DeducedInt) && rhst.Main == Int { - return Type{Main: lhst.Main}, true - } - if lhst.IsIntVector() && rhst.Main == Int { - return Type{Main: lhst.Main}, true - } - if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() { - return Type{Main: lhst.Main}, true - } - return Type{}, false - } - if op == AndAnd || op == OrOr { if lhst.Main == Bool && rhst.Main == Bool { return Type{Main: Bool}, true @@ -231,6 +202,25 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } + if op == LeftShift || op == RightShift { + if lhst.Main == Int && rhst.Main == Int { + return Type{Main: Int}, true + } + if lhst.Main == IVec2 && rhst.Main == IVec2 { + return Type{Main: IVec2}, true + } + if lhst.Main == IVec3 && rhst.Main == IVec3 { + return Type{Main: IVec3}, true + } + if lhst.Main == IVec4 && rhst.Main == IVec4 { + return Type{Main: IVec4}, true + } + if lhst.IsIntVector() && rhst.Main == Int { + return lhst, true + } + return Type{}, false + } + if lhst.Equal(&rhst) { if lhst.Main == None { return rhst, true diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index 885d579c4..ede2c91e2 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -180,7 +180,6 @@ const ( Texture Array Struct - DeducedInt ) func descendantLocalVars(block, target *Block) ([]Type, bool) { From 706011c2753feb2b32ed217e837a1413c9ab53c4 Mon Sep 17 00:00:00 2001 From: aoyako Date: Thu, 21 Mar 2024 17:21:11 +0900 Subject: [PATCH 19/20] add: new binary shift operator rules --- internal/shader/expr.go | 2 +- internal/shader/syntax_test.go | 13 ++++++++++--- internal/shaderir/check.go | 15 ++++++++++++++- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index cc41cea32..dd89c8bfc 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -101,7 +101,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } // Resolve untyped constants. - l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) + l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(op2, lhs[0].Const, rhs[0].Const, lhst, rhst) if !ok { // TODO: Show a better type name for untyped constants. cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index e50ff7d4a..1e905130d 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1320,10 +1320,12 @@ func TestSyntaxOperatorShift(t *testing.T) { stmt string err bool }{ - {stmt: "a := 1 << 2.0; _ = a", err: true}, - {stmt: "a := 1.0 << 2; _ = a", err: true}, - {stmt: "a := 1.0 << 2.0; _ = a", err: true}, {stmt: "a := 1 << 2; _ = a", err: false}, + {stmt: "a := 1 << 2.0; _ = a", err: false}, + {stmt: "a := 1.0 << 2; _ = a", err: false}, + {stmt: "a := 1.0 << 2.0; _ = a", err: false}, + {stmt: "var a = 1; b := a << 2.0; _ = b", err: false}, + {stmt: "var a = 1; b := 2.0 << a; _ = b", err: false}, // PR: #2916 {stmt: "a := float(1.0) << 2; _ = a", err: true}, {stmt: "a := 1 << float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1) << 2; _ = a", err: false}, @@ -1346,6 +1348,11 @@ func TestSyntaxOperatorShift(t *testing.T) { {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, {stmt: "a := 1 >> 2; _ = a", err: false}, + {stmt: "a := 1 >> 2.0; _ = a", err: false}, + {stmt: "a := 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, + {stmt: "var a = 1; b := a >> 2.0; _ = b", err: false}, + {stmt: "var a = 1; b := 2.0 >> a; _ = b", err: false}, // PR: #2916 {stmt: "a := float(1.0) >> 2; _ = a", err: true}, {stmt: "a := 1 >> float(2.0); _ = a", err: true}, {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index cb93f3b79..d279b3c88 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -18,8 +18,21 @@ import ( "go/constant" ) -func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { +func ResolveUntypedConstsForBinaryOp(op Op, lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { if lhst.Main == None && rhst.Main == None { + if op == LeftShift || op == RightShift { + newLhs = constant.ToInt(lhs) + newRhs = constant.ToInt(rhs) + + if newLhs.Kind() == constant.Unknown { + return nil, nil, false + } + if newRhs.Kind() == constant.Unknown { + return nil, nil, false + } + return newLhs, newRhs, true + } + if lhs.Kind() == rhs.Kind() { return lhs, rhs, true } From b0a7fb36c78c9ec1c32477aa3b06dcb6e4b77d26 Mon Sep 17 00:00:00 2001 From: aoyako Date: Thu, 21 Mar 2024 17:28:32 +0900 Subject: [PATCH 20/20] simplify type deduction in TypeFromBinaryOp for shift op --- internal/shaderir/check.go | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index d279b3c88..88fa18d49 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -216,19 +216,10 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) } if op == LeftShift || op == RightShift { - if lhst.Main == Int && rhst.Main == Int { - return Type{Main: Int}, true + if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int { + return lhst, true } - if lhst.Main == IVec2 && rhst.Main == IVec2 { - return Type{Main: IVec2}, true - } - if lhst.Main == IVec3 && rhst.Main == IVec3 { - return Type{Main: IVec3}, true - } - if lhst.Main == IVec4 && rhst.Main == IVec4 { - return Type{Main: IVec4}, true - } - if lhst.IsIntVector() && rhst.Main == Int { + if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() { return lhst, true } return Type{}, false