internal/shader: bug fix: forbide comparing non-scalar values

Closes #2718
This commit is contained in:
Hajime Hoshi 2023-08-01 12:32:13 +09:00
parent 63df6168d9
commit d8630f940d
2 changed files with 107 additions and 1 deletions

View File

@ -3419,3 +3419,105 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
}
// Issue #2718
func TestSyntaxCompare(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "_ = false == true", err: false},
{stmt: "_ = int(0) == int(1)", err: false},
{stmt: "_ = float(0) == float(1)", err: false},
{stmt: "_ = vec2(0) == vec2(1)", err: false},
{stmt: "_ = vec3(0) == vec3(1)", err: false},
{stmt: "_ = vec4(0) == vec4(1)", err: false},
{stmt: "_ = ivec2(0) == ivec2(1)", err: false},
{stmt: "_ = ivec3(0) == ivec3(1)", err: false},
{stmt: "_ = ivec4(0) == ivec4(1)", err: false},
{stmt: "_ = mat2(0) == mat2(1)", err: true},
{stmt: "_ = mat3(0) == mat3(1)", err: true},
{stmt: "_ = mat4(0) == mat4(1)", err: true},
{stmt: "_ = false != true", err: false},
{stmt: "_ = int(0) != int(1)", err: false},
{stmt: "_ = float(0) != float(1)", err: false},
{stmt: "_ = vec2(0) != vec2(1)", err: false},
{stmt: "_ = vec3(0) != vec3(1)", err: false},
{stmt: "_ = vec4(0) != vec4(1)", err: false},
{stmt: "_ = ivec2(0) != ivec2(1)", err: false},
{stmt: "_ = ivec3(0) != ivec3(1)", err: false},
{stmt: "_ = ivec4(0) != ivec4(1)", err: false},
{stmt: "_ = mat2(0) != mat2(1)", err: true},
{stmt: "_ = mat3(0) != mat3(1)", err: true},
{stmt: "_ = mat4(0) != mat4(1)", err: true},
{stmt: "_ = false < true", err: true},
{stmt: "_ = int(0) < int(1)", err: false},
{stmt: "_ = float(0) < float(1)", err: false},
{stmt: "_ = vec2(0) < vec2(1)", err: true},
{stmt: "_ = vec3(0) < vec3(1)", err: true},
{stmt: "_ = vec4(0) < vec4(1)", err: true},
{stmt: "_ = ivec2(0) < ivec2(1)", err: true},
{stmt: "_ = ivec3(0) < ivec3(1)", err: true},
{stmt: "_ = ivec4(0) < ivec4(1)", err: true},
{stmt: "_ = mat2(0) < mat2(1)", err: true},
{stmt: "_ = mat3(0) < mat3(1)", err: true},
{stmt: "_ = mat4(0) < mat4(1)", err: true},
{stmt: "_ = false <= true", err: true},
{stmt: "_ = int(0) <= int(1)", err: false},
{stmt: "_ = float(0) <= float(1)", err: false},
{stmt: "_ = vec2(0) <= vec2(1)", err: true},
{stmt: "_ = vec3(0) <= vec3(1)", err: true},
{stmt: "_ = vec4(0) <= vec4(1)", err: true},
{stmt: "_ = ivec2(0) <= ivec2(1)", err: true},
{stmt: "_ = ivec3(0) <= ivec3(1)", err: true},
{stmt: "_ = ivec4(0) <= ivec4(1)", err: true},
{stmt: "_ = mat2(0) <= mat2(1)", err: true},
{stmt: "_ = mat3(0) <= mat3(1)", err: true},
{stmt: "_ = mat4(0) <= mat4(1)", err: true},
{stmt: "_ = false > true", err: true},
{stmt: "_ = int(0) > int(1)", err: false},
{stmt: "_ = float(0) > float(1)", err: false},
{stmt: "_ = vec2(0) > vec2(1)", err: true},
{stmt: "_ = vec3(0) > vec3(1)", err: true},
{stmt: "_ = vec4(0) > vec4(1)", err: true},
{stmt: "_ = ivec2(0) > ivec2(1)", err: true},
{stmt: "_ = ivec3(0) > ivec3(1)", err: true},
{stmt: "_ = ivec4(0) > ivec4(1)", err: true},
{stmt: "_ = mat2(0) > mat2(1)", err: true},
{stmt: "_ = mat3(0) > mat3(1)", err: true},
{stmt: "_ = mat4(0) > mat4(1)", err: true},
{stmt: "_ = false >= true", err: true},
{stmt: "_ = int(0) >= int(1)", err: false},
{stmt: "_ = float(0) >= float(1)", err: false},
{stmt: "_ = vec2(0) >= vec2(1)", err: true},
{stmt: "_ = vec3(0) >= vec3(1)", err: true},
{stmt: "_ = vec4(0) >= vec4(1)", err: true},
{stmt: "_ = ivec2(0) >= ivec2(1)", err: true},
{stmt: "_ = ivec3(0) >= ivec3(1)", err: true},
{stmt: "_ = ivec4(0) >= ivec4(1)", err: true},
{stmt: "_ = mat2(0) >= mat2(1)", err: true},
{stmt: "_ = mat3(0) >= mat3(1)", err: true},
{stmt: "_ = mat4(0) >= mat4(1)", err: true},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}

View File

@ -100,8 +100,12 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
return (lhst.IsFloatVector() || lhst.IsIntVector()) && (rhst.IsFloatVector() || lhst.IsIntVector()) && lhst.Equal(&rhst)
}
if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return (lhst.Main == Int && rhst.Main == Int) || (lhst.Main == Float && rhst.Main == Float)
}
// Comparing matrices are forbidden (#2187).
if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp || op == EqualOp || op == NotEqualOp {
if op == EqualOp || op == NotEqualOp {
if lhst.IsMatrix() || rhst.IsMatrix() {
return false
}