internal/shader: bug fix: wrong type checkings for operator *

Updates #1971
This commit is contained in:
Hajime Hoshi 2022-01-21 00:50:21 +09:00
parent 8321cecfdd
commit 8d2bf6525c
2 changed files with 70 additions and 12 deletions

View File

@ -147,6 +147,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var t shaderir.Type
switch {
case op == shaderir.LessThanOp || op == shaderir.LessThanEqualOp || op == shaderir.GreaterThanOp || op == shaderir.GreaterThanEqualOp || op == shaderir.EqualOp || op == shaderir.NotEqualOp || op == shaderir.AndAnd || op == shaderir.OrOr:
// TODO: Check types of the operands.
t = shaderir.Type{Main: shaderir.Bool}
case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr:
if rhst.Main == shaderir.Int {
@ -168,34 +169,30 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
t = lhst
case lhst.Equal(&rhst):
t = lhst
case lhst.Main == shaderir.Float || lhst.Main == shaderir.Int:
case lhst.Main == shaderir.Float:
switch rhst.Main {
case shaderir.Int:
t = lhst
case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
t = rhst
default:
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
case rhst.Main == shaderir.Float || rhst.Main == shaderir.Int:
case rhst.Main == shaderir.Float:
switch lhst.Main {
case shaderir.Int:
t = rhst
case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
t = lhst
default:
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
case lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 ||
lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2:
case op == shaderir.Mul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 ||
lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2):
t = shaderir.Type{Main: shaderir.Vec2}
case lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 ||
lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3:
case op == shaderir.Mul && (lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 ||
lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3):
t = shaderir.Type{Main: shaderir.Vec3}
case lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 ||
lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4:
case op == shaderir.Mul && (lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 ||
lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4):
t = shaderir.Type{Main: shaderir.Vec4}
default:
cs.addError(e.Pos(), fmt.Sprintf("invalid expression: %s %s %s", lhst.String(), e.Op, rhst.String()))

View File

@ -15,6 +15,7 @@
package ebiten_test
import (
"fmt"
"image"
"image/color"
"testing"
@ -1538,3 +1539,63 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
t.Errorf("error must be non-nil but was nil")
}
}
// Issue #1971
func TestShaderOperatorMultiply(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := 1 * vec2(2); _ = a", err: false},
{stmt: "a := int(1) * vec2(2); _ = a", err: true},
{stmt: "a := 1.0 * vec2(2); _ = a", err: false},
{stmt: "a := 1 + vec2(2); _ = a", err: false},
{stmt: "a := int(1) + vec2(2); _ = a", err: true},
{stmt: "a := 1.0 + vec2(2); _ = a", err: false},
{stmt: "a := 1 * vec3(2); _ = a", err: false},
{stmt: "a := 1.0 * vec3(2); _ = a", err: false},
{stmt: "a := 1 * vec4(2); _ = a", err: false},
{stmt: "a := 1.0 * vec4(2); _ = a", err: false},
{stmt: "a := 1 * mat2(2); _ = a", err: false},
{stmt: "a := 1.0 * mat2(2); _ = a", err: false},
{stmt: "a := 1 * mat3(2); _ = a", err: false},
{stmt: "a := 1.0 * mat3(2); _ = a", err: false},
{stmt: "a := 1 * mat4(2); _ = a", err: false},
{stmt: "a := 1.0 * mat4(2); _ = a", err: false},
{stmt: "a := vec2(1) * 2; _ = a", err: false},
{stmt: "a := vec2(1) * 2.0; _ = a", err: false},
{stmt: "a := vec2(1) * int(2); _ = a", err: true},
{stmt: "a := vec2(1) * vec2(2); _ = a", err: false},
{stmt: "a := vec2(1) + vec2(2); _ = a", err: false},
{stmt: "a := vec2(1) * vec3(2); _ = a", err: true},
{stmt: "a := vec2(1) * vec4(2); _ = a", err: true},
{stmt: "a := vec2(1) * mat2(2); _ = a", err: false},
{stmt: "a := vec2(1) * mat3(2); _ = a", err: true},
{stmt: "a := vec2(1) * mat4(2); _ = a", err: true},
{stmt: "a := mat2(1) * 2; _ = a", err: false},
{stmt: "a := mat2(1) * 2.0; _ = a", err: false},
{stmt: "a := mat2(1) * int(2); _ = a", err: true},
{stmt: "a := mat2(1) + 2.0; _ = a", err: false},
{stmt: "a := mat2(1) * vec2(2); _ = a", err: false},
{stmt: "a := mat2(1) + vec2(2); _ = a", err: true},
{stmt: "a := mat2(1) * vec3(2); _ = a", err: true},
{stmt: "a := mat2(1) * vec4(2); _ = a", err: true},
{stmt: "a := mat2(1) * mat2(2); _ = a", err: false},
{stmt: "a := mat2(1) * mat3(2); _ = a", err: true},
{stmt: "a := mat2(1) * mat4(2); _ = a", err: true},
}
for _, c := range cases {
_, err := ebiten.NewShader([]byte(fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}