From d8630f940d7a4bcfda10a7cad7830062c68246e3 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Tue, 1 Aug 2023 12:32:13 +0900 Subject: [PATCH] internal/shader: bug fix: forbide comparing non-scalar values Closes #2718 --- internal/shader/syntax_test.go | 102 +++++++++++++++++++++++++++++++++ internal/shaderir/check.go | 6 +- 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 0128b07ca..ae8fe733c 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -3419,3 +3419,105 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #2718 +func TestSyntaxCompare(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "_ = false == true", err: false}, + {stmt: "_ = int(0) == int(1)", err: false}, + {stmt: "_ = float(0) == float(1)", err: false}, + {stmt: "_ = vec2(0) == vec2(1)", err: false}, + {stmt: "_ = vec3(0) == vec3(1)", err: false}, + {stmt: "_ = vec4(0) == vec4(1)", err: false}, + {stmt: "_ = ivec2(0) == ivec2(1)", err: false}, + {stmt: "_ = ivec3(0) == ivec3(1)", err: false}, + {stmt: "_ = ivec4(0) == ivec4(1)", err: false}, + {stmt: "_ = mat2(0) == mat2(1)", err: true}, + {stmt: "_ = mat3(0) == mat3(1)", err: true}, + {stmt: "_ = mat4(0) == mat4(1)", err: true}, + + {stmt: "_ = false != true", err: false}, + {stmt: "_ = int(0) != int(1)", err: false}, + {stmt: "_ = float(0) != float(1)", err: false}, + {stmt: "_ = vec2(0) != vec2(1)", err: false}, + {stmt: "_ = vec3(0) != vec3(1)", err: false}, + {stmt: "_ = vec4(0) != vec4(1)", err: false}, + {stmt: "_ = ivec2(0) != ivec2(1)", err: false}, + {stmt: "_ = ivec3(0) != ivec3(1)", err: false}, + {stmt: "_ = ivec4(0) != ivec4(1)", err: false}, + {stmt: "_ = mat2(0) != mat2(1)", err: true}, + {stmt: "_ = mat3(0) != mat3(1)", err: true}, + {stmt: "_ = mat4(0) != mat4(1)", err: true}, + + {stmt: "_ = false < true", err: true}, + {stmt: "_ = int(0) < int(1)", err: false}, + {stmt: "_ = float(0) < float(1)", err: false}, + {stmt: "_ = vec2(0) < vec2(1)", err: true}, + {stmt: "_ = vec3(0) < vec3(1)", err: true}, + {stmt: "_ = vec4(0) < vec4(1)", err: true}, + {stmt: "_ = ivec2(0) < ivec2(1)", err: true}, + {stmt: "_ = ivec3(0) < ivec3(1)", err: true}, + {stmt: "_ = ivec4(0) < ivec4(1)", err: true}, + {stmt: "_ = mat2(0) < mat2(1)", err: true}, + {stmt: "_ = mat3(0) < mat3(1)", err: true}, + {stmt: "_ = mat4(0) < mat4(1)", err: true}, + + {stmt: "_ = false <= true", err: true}, + {stmt: "_ = int(0) <= int(1)", err: false}, + {stmt: "_ = float(0) <= float(1)", err: false}, + {stmt: "_ = vec2(0) <= vec2(1)", err: true}, + {stmt: "_ = vec3(0) <= vec3(1)", err: true}, + {stmt: "_ = vec4(0) <= vec4(1)", err: true}, + {stmt: "_ = ivec2(0) <= ivec2(1)", err: true}, + {stmt: "_ = ivec3(0) <= ivec3(1)", err: true}, + {stmt: "_ = ivec4(0) <= ivec4(1)", err: true}, + {stmt: "_ = mat2(0) <= mat2(1)", err: true}, + {stmt: "_ = mat3(0) <= mat3(1)", err: true}, + {stmt: "_ = mat4(0) <= mat4(1)", err: true}, + + {stmt: "_ = false > true", err: true}, + {stmt: "_ = int(0) > int(1)", err: false}, + {stmt: "_ = float(0) > float(1)", err: false}, + {stmt: "_ = vec2(0) > vec2(1)", err: true}, + {stmt: "_ = vec3(0) > vec3(1)", err: true}, + {stmt: "_ = vec4(0) > vec4(1)", err: true}, + {stmt: "_ = ivec2(0) > ivec2(1)", err: true}, + {stmt: "_ = ivec3(0) > ivec3(1)", err: true}, + {stmt: "_ = ivec4(0) > ivec4(1)", err: true}, + {stmt: "_ = mat2(0) > mat2(1)", err: true}, + {stmt: "_ = mat3(0) > mat3(1)", err: true}, + {stmt: "_ = mat4(0) > mat4(1)", err: true}, + + {stmt: "_ = false >= true", err: true}, + {stmt: "_ = int(0) >= int(1)", err: false}, + {stmt: "_ = float(0) >= float(1)", err: false}, + {stmt: "_ = vec2(0) >= vec2(1)", err: true}, + {stmt: "_ = vec3(0) >= vec3(1)", err: true}, + {stmt: "_ = vec4(0) >= vec4(1)", err: true}, + {stmt: "_ = ivec2(0) >= ivec2(1)", err: true}, + {stmt: "_ = ivec3(0) >= ivec3(1)", err: true}, + {stmt: "_ = ivec4(0) >= ivec4(1)", err: true}, + {stmt: "_ = mat2(0) >= mat2(1)", err: true}, + {stmt: "_ = mat3(0) >= mat3(1)", err: true}, + {stmt: "_ = mat4(0) >= mat4(1)", err: true}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +} diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index dca67f7b9..949b90b65 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -100,8 +100,12 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { return (lhst.IsFloatVector() || lhst.IsIntVector()) && (rhst.IsFloatVector() || lhst.IsIntVector()) && lhst.Equal(&rhst) } + if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp { + return (lhst.Main == Int && rhst.Main == Int) || (lhst.Main == Float && rhst.Main == Float) + } + // Comparing matrices are forbidden (#2187). - if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp || op == EqualOp || op == NotEqualOp { + if op == EqualOp || op == NotEqualOp { if lhst.IsMatrix() || rhst.IsMatrix() { return false }