diff --git a/internal/shader/expr.go b/internal/shader/expr.go index cf1dea277..8aa4fd21c 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -172,7 +172,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable var t shaderir.Type switch { - case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr: + case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.VectorEqualOp || op == shaderir.VectorNotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr: // TODO: Check types of the operands. t = shaderir.Type{Main: shaderir.Bool} case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr: diff --git a/internal/shaderir/hlsl/hlsl.go b/internal/shaderir/hlsl/hlsl.go index b2bfde782..a8c31d38b 100644 --- a/internal/shaderir/hlsl/hlsl.go +++ b/internal/shaderir/hlsl/hlsl.go @@ -390,7 +390,12 @@ func (c *compileContext) block(p *shaderir.Program, topBlock, block *shaderir.Bl } return fmt.Sprintf("%s(%s)", op, expr(&e.Exprs[0])) case shaderir.Binary: - if e.Op == shaderir.MatrixMul { + switch e.Op { + case shaderir.VectorEqualOp: + return fmt.Sprintf("all(%s == %s)", expr(&e.Exprs[0]), expr(&e.Exprs[1])) + case shaderir.VectorNotEqualOp: + return fmt.Sprintf("!all(%s == %s)", expr(&e.Exprs[0]), expr(&e.Exprs[1])) + case shaderir.MatrixMul: // If either is a matrix, use the mul function. // Swap the order of the lhs and the rhs since matrices are row-major in HLSL. return fmt.Sprintf("mul(%s, %s)", expr(&e.Exprs[1]), expr(&e.Exprs[0])) diff --git a/internal/shaderir/msl/msl.go b/internal/shaderir/msl/msl.go index c4afdaeb5..22acd3789 100644 --- a/internal/shaderir/msl/msl.go +++ b/internal/shaderir/msl/msl.go @@ -377,6 +377,12 @@ func (c *compileContext) block(p *shaderir.Program, topBlock, block *shaderir.Bl } return fmt.Sprintf("%s(%s)", op, expr(&e.Exprs[0])) case shaderir.Binary: + switch e.Op { + case shaderir.VectorEqualOp: + return fmt.Sprintf("all((%s) == (%s))", expr(&e.Exprs[0]), expr(&e.Exprs[1])) + case shaderir.VectorNotEqualOp: + return fmt.Sprintf("!all((%s) == (%s))", expr(&e.Exprs[0]), expr(&e.Exprs[1])) + } return fmt.Sprintf("(%s) %s (%s)", expr(&e.Exprs[0]), opString(e.Op), expr(&e.Exprs[1])) case shaderir.Selection: return fmt.Sprintf("(%s) ? (%s) : (%s)", expr(&e.Exprs[0]), expr(&e.Exprs[1]), expr(&e.Exprs[2])) diff --git a/internal/shaderir/program.go b/internal/shaderir/program.go index afb50af3d..0a75ccebe 100644 --- a/internal/shaderir/program.go +++ b/internal/shaderir/program.go @@ -149,6 +149,8 @@ const ( GreaterThanEqualOp EqualOp NotEqualOp + VectorEqualOp + VectorNotEqualOp And Xor Or @@ -186,8 +188,14 @@ func OpFromToken(t token.Token, lhs, rhs Type) (Op, bool) { case token.GEQ: return GreaterThanEqualOp, true case token.EQL: + if lhs.IsVector() || rhs.IsVector() { + return VectorEqualOp, true + } return EqualOp, true case token.NEQ: + if lhs.IsVector() || rhs.IsVector() { + return VectorNotEqualOp, true + } return NotEqualOp, true case token.AND: return And, true diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index 1f4b19467..be699409f 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -104,6 +104,14 @@ func (t *Type) FloatNum() int { } } +func (t *Type) IsVector() bool { + switch t.Main { + case Vec2, Vec3, Vec4: + return true + } + return false +} + func (t *Type) IsMatrix() bool { switch t.Main { case Mat2, Mat3, Mat4: diff --git a/shader_test.go b/shader_test.go index c9e12124f..5697c6d92 100644 --- a/shader_test.go +++ b/shader_test.go @@ -1095,3 +1095,37 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { testPixels("DrawTrianglesShader", dst) }) } + +// Issue #2186 +func TestShaderVectorEqual(t *testing.T) { + const w, h = 16, 16 + + dst := ebiten.NewImage(w, h) + s, err := ebiten.NewShader([]byte(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + a := vec3(1) + b := vec3(1) + if a == b { + return vec4(1, 0, 0, 1) + } else { + return vec4(0, 1, 0, 1) + } +} +`)) + if err != nil { + t.Fatal(err) + } + + dst.DrawRectShader(w, h, s, nil) + + for j := 0; j < h; j++ { + for i := 0; i < w; i++ { + got := dst.At(i, j).(color.RGBA) + want := color.RGBA{0xff, 0, 0x00, 0xff} + if !sameColors(got, want, 2) { + t.Errorf("dst.At(%d, %d): got: %v, want: %v", i, j, got, want) + } + } + } +}