add: new binary shift operator rules

This commit is contained in:
aoyako 2024-03-21 17:21:11 +09:00
parent 49682f1097
commit 706011c275
3 changed files with 25 additions and 5 deletions

View File

@ -101,7 +101,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
// Resolve untyped constants. // Resolve untyped constants.
l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(op2, lhs[0].Const, rhs[0].Const, lhst, rhst)
if !ok { if !ok {
// TODO: Show a better type name for untyped constants. // TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))

View File

@ -1320,10 +1320,12 @@ func TestSyntaxOperatorShift(t *testing.T) {
stmt string stmt string
err bool err bool
}{ }{
{stmt: "a := 1 << 2.0; _ = a", err: true},
{stmt: "a := 1.0 << 2; _ = a", err: true},
{stmt: "a := 1.0 << 2.0; _ = a", err: true},
{stmt: "a := 1 << 2; _ = a", err: false}, {stmt: "a := 1 << 2; _ = a", err: false},
{stmt: "a := 1 << 2.0; _ = a", err: false},
{stmt: "a := 1.0 << 2; _ = a", err: false},
{stmt: "a := 1.0 << 2.0; _ = a", err: false},
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "var a = 1; b := 2.0 << a; _ = b", err: false}, // PR: #2916
{stmt: "a := float(1.0) << 2; _ = a", err: true}, {stmt: "a := float(1.0) << 2; _ = a", err: true},
{stmt: "a := 1 << float(2.0); _ = a", err: true}, {stmt: "a := 1 << float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1) << 2; _ = a", err: false}, {stmt: "a := ivec2(1) << 2; _ = a", err: false},
@ -1346,6 +1348,11 @@ func TestSyntaxOperatorShift(t *testing.T) {
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true}, {stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
{stmt: "a := 1 >> 2; _ = a", err: false}, {stmt: "a := 1 >> 2; _ = a", err: false},
{stmt: "a := 1 >> 2.0; _ = a", err: false},
{stmt: "a := 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1.0 >> 2.0; _ = a", err: false},
{stmt: "var a = 1; b := a >> 2.0; _ = b", err: false},
{stmt: "var a = 1; b := 2.0 >> a; _ = b", err: false}, // PR: #2916
{stmt: "a := float(1.0) >> 2; _ = a", err: true}, {stmt: "a := float(1.0) >> 2; _ = a", err: true},
{stmt: "a := 1 >> float(2.0); _ = a", err: true}, {stmt: "a := 1 >> float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1) >> 2; _ = a", err: false}, {stmt: "a := ivec2(1) >> 2; _ = a", err: false},

View File

@ -18,8 +18,21 @@ import (
"go/constant" "go/constant"
) )
func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { func ResolveUntypedConstsForBinaryOp(op Op, lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
if lhst.Main == None && rhst.Main == None { if lhst.Main == None && rhst.Main == None {
if op == LeftShift || op == RightShift {
newLhs = constant.ToInt(lhs)
newRhs = constant.ToInt(rhs)
if newLhs.Kind() == constant.Unknown {
return nil, nil, false
}
if newRhs.Kind() == constant.Unknown {
return nil, nil, false
}
return newLhs, newRhs, true
}
if lhs.Kind() == rhs.Kind() { if lhs.Kind() == rhs.Kind() {
return lhs, rhs, true return lhs, rhs, true
} }