internal/shader: use all functions for vector comparisons

Updates #2186
This commit is contained in:
Hajime Hoshi 2022-07-09 02:26:33 +09:00
parent 3d5031571d
commit c01821ca5c
6 changed files with 63 additions and 2 deletions

View File

@ -172,7 +172,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var t shaderir.Type var t shaderir.Type
switch { switch {
case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr: case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.VectorEqualOp || op == shaderir.VectorNotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr:
// TODO: Check types of the operands. // TODO: Check types of the operands.
t = shaderir.Type{Main: shaderir.Bool} t = shaderir.Type{Main: shaderir.Bool}
case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr: case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr:

View File

@ -390,7 +390,12 @@ func (c *compileContext) block(p *shaderir.Program, topBlock, block *shaderir.Bl
} }
return fmt.Sprintf("%s(%s)", op, expr(&e.Exprs[0])) return fmt.Sprintf("%s(%s)", op, expr(&e.Exprs[0]))
case shaderir.Binary: case shaderir.Binary:
if e.Op == shaderir.MatrixMul { switch e.Op {
case shaderir.VectorEqualOp:
return fmt.Sprintf("all(%s == %s)", expr(&e.Exprs[0]), expr(&e.Exprs[1]))
case shaderir.VectorNotEqualOp:
return fmt.Sprintf("!all(%s == %s)", expr(&e.Exprs[0]), expr(&e.Exprs[1]))
case shaderir.MatrixMul:
// If either is a matrix, use the mul function. // If either is a matrix, use the mul function.
// Swap the order of the lhs and the rhs since matrices are row-major in HLSL. // Swap the order of the lhs and the rhs since matrices are row-major in HLSL.
return fmt.Sprintf("mul(%s, %s)", expr(&e.Exprs[1]), expr(&e.Exprs[0])) return fmt.Sprintf("mul(%s, %s)", expr(&e.Exprs[1]), expr(&e.Exprs[0]))

View File

@ -377,6 +377,12 @@ func (c *compileContext) block(p *shaderir.Program, topBlock, block *shaderir.Bl
} }
return fmt.Sprintf("%s(%s)", op, expr(&e.Exprs[0])) return fmt.Sprintf("%s(%s)", op, expr(&e.Exprs[0]))
case shaderir.Binary: case shaderir.Binary:
switch e.Op {
case shaderir.VectorEqualOp:
return fmt.Sprintf("all((%s) == (%s))", expr(&e.Exprs[0]), expr(&e.Exprs[1]))
case shaderir.VectorNotEqualOp:
return fmt.Sprintf("!all((%s) == (%s))", expr(&e.Exprs[0]), expr(&e.Exprs[1]))
}
return fmt.Sprintf("(%s) %s (%s)", expr(&e.Exprs[0]), opString(e.Op), expr(&e.Exprs[1])) return fmt.Sprintf("(%s) %s (%s)", expr(&e.Exprs[0]), opString(e.Op), expr(&e.Exprs[1]))
case shaderir.Selection: case shaderir.Selection:
return fmt.Sprintf("(%s) ? (%s) : (%s)", expr(&e.Exprs[0]), expr(&e.Exprs[1]), expr(&e.Exprs[2])) return fmt.Sprintf("(%s) ? (%s) : (%s)", expr(&e.Exprs[0]), expr(&e.Exprs[1]), expr(&e.Exprs[2]))

View File

@ -149,6 +149,8 @@ const (
GreaterThanEqualOp GreaterThanEqualOp
EqualOp EqualOp
NotEqualOp NotEqualOp
VectorEqualOp
VectorNotEqualOp
And And
Xor Xor
Or Or
@ -186,8 +188,14 @@ func OpFromToken(t token.Token, lhs, rhs Type) (Op, bool) {
case token.GEQ: case token.GEQ:
return GreaterThanEqualOp, true return GreaterThanEqualOp, true
case token.EQL: case token.EQL:
if lhs.IsVector() || rhs.IsVector() {
return VectorEqualOp, true
}
return EqualOp, true return EqualOp, true
case token.NEQ: case token.NEQ:
if lhs.IsVector() || rhs.IsVector() {
return VectorNotEqualOp, true
}
return NotEqualOp, true return NotEqualOp, true
case token.AND: case token.AND:
return And, true return And, true

View File

@ -104,6 +104,14 @@ func (t *Type) FloatNum() int {
} }
} }
func (t *Type) IsVector() bool {
switch t.Main {
case Vec2, Vec3, Vec4:
return true
}
return false
}
func (t *Type) IsMatrix() bool { func (t *Type) IsMatrix() bool {
switch t.Main { switch t.Main {
case Mat2, Mat3, Mat4: case Mat2, Mat3, Mat4:

View File

@ -1095,3 +1095,37 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
testPixels("DrawTrianglesShader", dst) testPixels("DrawTrianglesShader", dst)
}) })
} }
// Issue #2186
func TestShaderVectorEqual(t *testing.T) {
const w, h = 16, 16
dst := ebiten.NewImage(w, h)
s, err := ebiten.NewShader([]byte(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
a := vec3(1)
b := vec3(1)
if a == b {
return vec4(1, 0, 0, 1)
} else {
return vec4(0, 1, 0, 1)
}
}
`))
if err != nil {
t.Fatal(err)
}
dst.DrawRectShader(w, h, s, nil)
for j := 0; j < h; j++ {
for i := 0; i < w; i++ {
got := dst.At(i, j).(color.RGBA)
want := color.RGBA{0xff, 0, 0x00, 0xff}
if !sameColors(got, want, 2) {
t.Errorf("dst.At(%d, %d): got: %v, want: %v", i, j, got, want)
}
}
}
}