diff --git a/internal/shader/expr.go b/internal/shader/expr.go index b57c28db5..cc41cea32 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -149,6 +149,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar v = gconstant.MakeBool(b) case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ: v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) + case token.SHL, token.SHR: + shift, ok := gconstant.Int64Val(rhs[0].Const) + if !ok { + cs.addError(e.Pos(), fmt.Sprintf("unexpected %s type for: %s", rhs[0].Const.String(), e.Op)) + } else { + v = gconstant.Shift(lhs[0].Const, op, uint(shift)) + } default: v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 68ee4b5e8..8ed19be6f 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1314,6 +1314,50 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { } } +// Issue: #2755 +func TestSyntaxOperatorShift(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "a := 1 << 2; _ = a", err: false}, + {stmt: "a := float(1.0) << 2; _ = a", err: true}, + {stmt: "a := 1 << float(2.0); _ = a", err: true}, + {stmt: "a := ivec2(1) << 2; _ = a", err: false}, + {stmt: "a := 1 << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << float(2.0); _ = a", err: true}, + {stmt: "a := float(1.0) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << ivec2(2); _ = a", err: false}, + {stmt: "a := ivec3(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << ivec3(2); _ = a", err: true}, + {stmt: "a := 1 << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << 2; _ = a", err: true}, + {stmt: "a := float(1.0) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << float(2.0); _ = a", err: true}, + {stmt: "a := vec2(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << vec3(2); _ = a", err: true}, + {stmt: "a := vec3(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec2(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << vec2(2); _ = a", err: true}, + {stmt: "a := vec3(1) << ivec2(2); _ = a", err: true}, + {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, + } + + for _, c := range cases { + _, err := compileToIR([]byte(fmt.Sprintf(`package main + +func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 { + %s + return dstPos +}`, c.stmt))) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", c.stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", c.stmt, err) + } + } +} + // Issue #1971 func TestSyntaxOperatorMultiplyAssign(t *testing.T) { cases := []struct { diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index b6d8b08bb..cb93f3b79 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -98,6 +98,13 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } + if op == LeftShift || op == RightShift { + if lhsConst.Kind() == constant.Int && rhsConst.Kind() == constant.Int { + return Type{Main: Int}, true + } + return Type{}, false + } + if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { return Type{Main: Bool}, true } @@ -173,7 +180,7 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return Type{}, false } - if op == And || op == Or || op == Xor || op == LeftShift || op == RightShift { + if op == And || op == Or || op == Xor { if lhst.Main == Int && rhst.Main == Int { return Type{Main: Int}, true } @@ -190,9 +197,26 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value) return lhst, true } if lhst.Main == Int && rhst.IsIntVector() { - if op == And || op == Or || op == Xor { - return rhst, true - } + return rhst, true + } + return Type{}, false + } + + if op == LeftShift || op == RightShift { + if lhst.Main == Int && rhst.Main == Int { + return Type{Main: Int}, true + } + if lhst.Main == IVec2 && rhst.Main == IVec2 { + return Type{Main: IVec2}, true + } + if lhst.Main == IVec3 && rhst.Main == IVec3 { + return Type{Main: IVec3}, true + } + if lhst.Main == IVec4 && rhst.Main == IVec4 { + return Type{Main: IVec4}, true + } + if lhst.IsIntVector() && rhst.Main == Int { + return lhst, true } return Type{}, false }