internal/shader: add type checks for the builtin function transpose

Updates #2184
This commit is contained in:
Hajime Hoshi 2022-08-19 01:57:13 +09:00
parent 7a94cbbd62
commit 63eee0600e
2 changed files with 48 additions and 3 deletions

View File

@ -572,9 +572,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
args[0].ConstType = shaderir.ConstTypeFloat
argts[0] = shaderir.Type{Main: shaderir.Float}
}
if argts[0].Main != shaderir.Float && argts[0].Main != shaderir.Vec2 && argts[0].Main != shaderir.Vec3 && argts[0].Main != shaderir.Vec4 {
cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[0].String(), callee.BuiltinFunc))
return nil, nil, nil, false
switch callee.BuiltinFunc {
case shaderir.Transpose:
if argts[0].Main != shaderir.Mat2 && argts[0].Main != shaderir.Mat3 && argts[0].Main != shaderir.Mat4 {
cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as mat2, mat3, or mat4 value in argument to %s", argts[0].String(), callee.BuiltinFunc))
return nil, nil, nil, false
}
default:
if argts[0].Main != shaderir.Float && argts[0].Main != shaderir.Vec2 && argts[0].Main != shaderir.Vec3 && argts[0].Main != shaderir.Vec4 {
cs.addError(e.Pos(), fmt.Sprintf("cannot use %s as float, vec2, vec3, or vec4 value in argument to %s", argts[0].String(), callee.BuiltinFunc))
return nil, nil, nil, false
}
}
if callee.BuiltinFunc == shaderir.Length {
t = shaderir.Type{Main: shaderir.Float}

View File

@ -2152,3 +2152,40 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
}
// Issue #2184
func TestSyntaxBuiltinFuncTransposeType(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := transpose(); _ = a", err: true},
{stmt: "a := transpose(false); _ = a", err: true},
{stmt: "a := transpose(1); _ = a", err: true},
{stmt: "a := transpose(1.0); _ = a", err: true},
{stmt: "a := transpose(int(1)); _ = a", err: true},
{stmt: "a := transpose(vec2(1)); _ = a", err: true},
{stmt: "a := transpose(vec3(1)); _ = a", err: true},
{stmt: "a := transpose(vec4(1)); _ = a", err: true},
{stmt: "a := transpose(mat2(1)); _ = a", err: false},
{stmt: "a := transpose(mat3(1)); _ = a", err: false},
{stmt: "a := transpose(mat4(1)); _ = a", err: false},
{stmt: "a := transpose(1, 1); _ = a", err: true},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}