diff --git a/internal/shader/expr.go b/internal/shader/expr.go index f5a9425a7..1c063820f 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -105,7 +105,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } // Resolve untyped constants. - l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) + l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(op2, lhs[0].Const, rhs[0].Const, lhst, rhst) if !ok { // TODO: Show a better type name for untyped constants. cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) @@ -153,6 +153,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar v = gconstant.MakeBool(b) case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ: v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) + case token.SHL, token.SHR: + shift, ok := gconstant.Int64Val(rhs[0].Const) + if !ok { + cs.addError(e.Pos(), fmt.Sprintf("unexpected %s type for: %s", rhs[0].Const.String(), e.Op)) + return nil, nil, nil, false + } + v = gconstant.Shift(lhs[0].Const, op, uint(shift)) default: v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) } diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 42f0ddd13..6c07057da 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -60,7 +60,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } stmts = append(stmts, ss...) - case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN: + case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN: rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true) if !ok { return nil, false @@ -100,6 +100,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP op = shaderir.Or case token.XOR_ASSIGN: op = shaderir.Xor + case token.SHL_ASSIGN: + op = shaderir.LeftShift + case token.SHR_ASSIGN: + op = shaderir.RightShift default: cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok)) return nil, false @@ -110,7 +114,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String())) return nil, false } - if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { + if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift { if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) } @@ -137,7 +141,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } } case shaderir.Float: - if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { + if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) } else if rhs[0].Const != nil && (rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) && @@ -148,7 +152,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { + if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift { cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) } else if (op == shaderir.MatrixMul || op == shaderir.Div) && (rts[0].Main == shaderir.Float || diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index a595c0c4a..e94a41134 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1314,6 +1314,169 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { } } +// Issue: #2755 +func TestSyntaxOperatorShift(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := 1 << 2; _ = a", err: false}, + {stmt: "a := 1 << 2.0; _ = a", err: false}, + {stmt: "a := 1.0 << 2; _ = a", err: false}, + {stmt: "a := 1.0 << 2.0; _ = a", err: false}, + {stmt: "a := 1.0 << int(1); _ = a", err: false}, + {stmt: "a := int(1) << 2.0; _ = a", err: false}, + {stmt: "a := ivec2(1) << 2.0; _ = a", err: false}, + {stmt: "var a = 1; b := a << 2.0; _ = b", err: false}, + {stmt: "var a = 1; b := 2.0 << a; _ = b", err: false}, // PR: #2916 + {stmt: "a := float(1.0) << 2; _ = a", err: true}, + {stmt: "a := 1 << float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1) << 2; _ = a", err: false}, + {stmt: "a := 1 << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << ivec3(2); _ = a", err: true}, + {stmt: "a := 1 << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << 2; _ = a", err: true}, + {stmt: "a := float(1.0) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, + + {stmt: "a := 1 >> 2; _ = a", err: false}, + {stmt: "a := 1 >> 2.0; _ = a", err: false}, + {stmt: "a := 1.0 >> 2; _ = a", err: false}, + {stmt: "a := 1.0 >> 2.0; _ = a", err: false}, + {stmt: "a := 1.0 >> int(1); _ = a", err: false}, + {stmt: "a := int(1) >> 2.0; _ = a", err: false}, + {stmt: "a := ivec2(1) >> 2.0; _ = a", err: false}, + {stmt: "var a = 1; b := a >> 2.0; _ = b", err: false}, + {stmt: "var a = 1; b := 2.0 >> a; _ = b", err: false}, // PR: #2916 + {stmt: "a := float(1.0) >> 2; _ = a", err: true}, + {stmt: "a := 1 >> float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1) >> 2; _ = a", err: false}, + {stmt: "a := 1 >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true}, + {stmt: "a := 1 >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> 2; _ = a", err: true}, + {stmt: "a := float(1.0) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true}, + } + + for _, c := range cases { + _, err := compileToIR([]byte(fmt.Sprintf(`package main + +func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + %s + return dstPos +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +} + +func TestSyntaxOperatorShiftAssign(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := 1; a <<= 2; _ = a", err: false}, + {stmt: "a := 1; a <<= 2.0; _ = a", err: false}, + {stmt: "a := float(1.0); a <<= 2; _ = a", err: true}, + {stmt: "a := 1; a <<= float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= 2; _ = a", err: false}, + {stmt: "a := 1; a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= ivec3(2); _ = a", err: true}, + {stmt: "a := 1; a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= 2; _ = a", err: true}, + {stmt: "a := float(1.0); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1); a <<= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a <<= vec3(2); _ = a", err: true}, + {stmt: "const c = 2; a := 1; a <<= c; _ = a", err: false}, + {stmt: "const c = 2.0; a := 1; a <<= c; _ = a", err: false}, + {stmt: "const c = 2; a := float(1.0); a <<= c; _ = a", err: true}, + {stmt: "const c float = 2; a := 1; a <<= c; _ = a", err: true}, + {stmt: "const c float = 2.0; a := 1; a <<= c; _ = a", err: true}, + {stmt: "const c int = 2; a := ivec2(1); a <<= c; _ = a", err: false}, + {stmt: "const c int = 2; a := vec2(1); a <<= c; _ = a", err: true}, + + {stmt: "a := 1; a >>= 2; _ = a", err: false}, + {stmt: "a := 1; a >>= 2.0; _ = a", err: false}, + {stmt: "a := float(1.0); a >>= 2; _ = a", err: true}, + {stmt: "a := 1; a >>= float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= 2; _ = a", err: false}, + {stmt: "a := 1; a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= ivec3(2); _ = a", err: true}, + {stmt: "a := 1; a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= 2; _ = a", err: true}, + {stmt: "a := float(1.0); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1); a >>= ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1); a >>= vec3(2); _ = a", err: true}, + {stmt: "const c = 2; a := 1; a >>= c; _ = a", err: false}, + {stmt: "const c = 2.0; a := 1; a >>= c; _ = a", err: false}, + {stmt: "const c = 2; a := float(1.0); a >>= c; _ = a", err: true}, + {stmt: "const c float = 2; a := 1; a >>= c; _ = a", err: true}, + {stmt: "const c float = 2.0; a := 1; a >>= c; _ = a", err: true}, + {stmt: "const c int = 2; a := ivec2(1); a >>= c; _ = a", err: false}, + {stmt: "const c int = 2; a := vec2(1); a >>= c; _ = a", err: true}, + } + + for _, c := range cases { + _, err := compileToIR([]byte(fmt.Sprintf(`package main + +func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + %s + return dstPos +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +} + // Issue #1971 func TestSyntaxOperatorMultiplyAssign(t *testing.T) { cases := []struct { diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index 857bd1e2a..88fa18d49 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -18,8 +18,21 @@ import ( "go/constant" ) -func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { +func ResolveUntypedConstsForBinaryOp(op Op, lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { if lhst.Main == None && rhst.Main == None { + if op == LeftShift || op == RightShift { + newLhs = constant.ToInt(lhs) + newRhs = constant.ToInt(rhs) + + if newLhs.Kind() == constant.Unknown { + return nil, nil, false + } + if newRhs.Kind() == constant.Unknown { + return nil, nil, false + } + return newLhs, newRhs, true + } + if lhs.Kind() == rhs.Kind() { return lhs, rhs, true } @@ -98,6 +111,13 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } + if op == LeftShift || op == RightShift { + if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int { + return Type{Main: Int}, true + } + return Type{}, false + } + if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { return Type{Main: Bool}, true } @@ -195,6 +215,16 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } + if op == LeftShift || op == RightShift { + if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int { + return lhst, true + } + if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() { + return lhst, true + } + return Type{}, false + } + if lhst.Equal(&rhst) { if lhst.Main == None { return rhst, true