internal/shader: add type checks for bitwise operators

Updates #2754
This commit is contained in:
Hajime Hoshi 2023-09-12 02:27:36 +09:00
parent b88d02851f
commit c13980158f
3 changed files with 112 additions and 0 deletions

View File

@ -155,6 +155,16 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
t = shaderir.Type{Main: shaderir.Bool}
}
case token.AND, token.OR, token.XOR, token.AND_NOT:
if lhs[0].Const.Kind() != gconstant.Int {
cs.addError(e.Pos(), fmt.Sprintf("operator %s not defined on %s (%s)", op, lhs[0].Const.String(), lhst.String()))
return nil, nil, nil, false
}
if rhs[0].Const.Kind() != gconstant.Int {
cs.addError(e.Pos(), fmt.Sprintf("operator %s not defined on %s (%s)", op, rhs[0].Const.String(), rhst.String()))
return nil, nil, nil, false
}
fallthrough
default:
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
switch {

View File

@ -3620,3 +3620,83 @@ func Bar() {
t.Error(err)
}
}
func TestSyntaxBitwiseOperator(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "_ = false & true", err: true},
{stmt: "_ = int(0) & int(1)", err: false},
{stmt: "_ = float(0) & float(1)", err: true},
{stmt: "_ = vec2(0) & vec2(1)", err: true},
{stmt: "_ = vec3(0) & vec3(1)", err: true},
{stmt: "_ = vec4(0) & vec4(1)", err: true},
{stmt: "_ = ivec2(0) & ivec2(1)", err: false},
{stmt: "_ = ivec3(0) & ivec3(1)", err: false},
{stmt: "_ = ivec4(0) & ivec4(1)", err: false},
{stmt: "_ = ivec2(0) & int(1)", err: false},
{stmt: "_ = ivec3(0) & int(1)", err: false},
{stmt: "_ = ivec4(0) & int(1)", err: false},
{stmt: "_ = int(0) & ivec2(1)", err: false},
{stmt: "_ = int(0) & ivec3(1)", err: false},
{stmt: "_ = int(0) & ivec4(1)", err: false},
{stmt: "_ = mat2(0) & mat2(1)", err: true},
{stmt: "_ = mat3(0) & mat3(1)", err: true},
{stmt: "_ = mat4(0) & mat4(1)", err: true},
{stmt: "_ = false | true", err: true},
{stmt: "_ = int(0) | int(1)", err: false},
{stmt: "_ = float(0) | float(1)", err: true},
{stmt: "_ = vec2(0) | vec2(1)", err: true},
{stmt: "_ = vec3(0) | vec3(1)", err: true},
{stmt: "_ = vec4(0) | vec4(1)", err: true},
{stmt: "_ = ivec2(0) | ivec2(1)", err: false},
{stmt: "_ = ivec3(0) | ivec3(1)", err: false},
{stmt: "_ = ivec4(0) | ivec4(1)", err: false},
{stmt: "_ = ivec2(0) | int(1)", err: false},
{stmt: "_ = ivec3(0) | int(1)", err: false},
{stmt: "_ = ivec4(0) | int(1)", err: false},
{stmt: "_ = int(0) | ivec2(1)", err: false},
{stmt: "_ = int(0) | ivec3(1)", err: false},
{stmt: "_ = int(0) | ivec4(1)", err: false},
{stmt: "_ = mat2(0) | mat2(1)", err: true},
{stmt: "_ = mat3(0) | mat3(1)", err: true},
{stmt: "_ = mat4(0) | mat4(1)", err: true},
{stmt: "_ = false ^ true", err: true},
{stmt: "_ = int(0) ^ int(1)", err: false},
{stmt: "_ = float(0) ^ float(1)", err: true},
{stmt: "_ = vec2(0) ^ vec2(1)", err: true},
{stmt: "_ = vec3(0) ^ vec3(1)", err: true},
{stmt: "_ = vec4(0) ^ vec4(1)", err: true},
{stmt: "_ = ivec2(0) ^ ivec2(1)", err: false},
{stmt: "_ = ivec3(0) ^ ivec3(1)", err: false},
{stmt: "_ = ivec4(0) ^ ivec4(1)", err: false},
{stmt: "_ = ivec2(0) ^ int(1)", err: false},
{stmt: "_ = ivec3(0) ^ int(1)", err: false},
{stmt: "_ = ivec4(0) ^ int(1)", err: false},
{stmt: "_ = int(0) ^ ivec2(1)", err: false},
{stmt: "_ = int(0) ^ ivec3(1)", err: false},
{stmt: "_ = int(0) ^ ivec4(1)", err: false},
{stmt: "_ = mat2(0) ^ mat2(1)", err: true},
{stmt: "_ = mat3(0) ^ mat3(1)", err: true},
{stmt: "_ = mat4(0) ^ mat4(1)", err: true},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}

View File

@ -129,6 +129,28 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
return (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int
}
if op == And || op == Or || op == Xor {
if lhst.Main == Int && rhst.Main == Int {
return true
}
if lhst.Main == IVec2 && rhst.Main == IVec2 {
return true
}
if lhst.Main == IVec3 && rhst.Main == IVec3 {
return true
}
if lhst.Main == IVec4 && rhst.Main == IVec4 {
return true
}
if lhst.IsIntVector() && rhst.Main == Int {
return true
}
if lhst.Main == Int && rhst.IsIntVector() {
return true
}
return false
}
if lhst.Equal(&rhst) {
return true
}