internal/shader: add type checks to vec2/vec3/vec4

Updates #2184
This commit is contained in:
Hajime Hoshi 2022-07-06 01:25:40 +09:00
parent a866fe7391
commit f89277fd85
3 changed files with 225 additions and 0 deletions

View File

@ -379,34 +379,58 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var t shaderir.Type var t shaderir.Type
switch callee.BuiltinFunc { switch callee.BuiltinFunc {
case shaderir.BoolF: case shaderir.BoolF:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Bool} t = shaderir.Type{Main: shaderir.Bool}
case shaderir.IntF: case shaderir.IntF:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Int} t = shaderir.Type{Main: shaderir.Int}
case shaderir.FloatF: case shaderir.FloatF:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Float} t = shaderir.Type{Main: shaderir.Float}
case shaderir.Vec2F: case shaderir.Vec2F:
if err := checkArgsForVec2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Vec2} t = shaderir.Type{Main: shaderir.Vec2}
case shaderir.Vec3F: case shaderir.Vec3F:
if err := checkArgsForVec3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Vec3} t = shaderir.Type{Main: shaderir.Vec3}
case shaderir.Vec4F: case shaderir.Vec4F:
if err := checkArgsForVec4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Vec4} t = shaderir.Type{Main: shaderir.Vec4}
case shaderir.Mat2F: case shaderir.Mat2F:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Mat2} t = shaderir.Type{Main: shaderir.Mat2}
case shaderir.Mat3F: case shaderir.Mat3F:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Mat3} t = shaderir.Type{Main: shaderir.Mat3}
case shaderir.Mat4F: case shaderir.Mat4F:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Mat4} t = shaderir.Type{Main: shaderir.Mat4}
case shaderir.Step: case shaderir.Step:
// TODO: Check arg types.
t = argts[1] t = argts[1]
case shaderir.Smoothstep: case shaderir.Smoothstep:
// TODO: Check arg types.
t = argts[2] t = argts[2]
case shaderir.Length, shaderir.Distance, shaderir.Dot: case shaderir.Length, shaderir.Distance, shaderir.Dot:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Float} t = shaderir.Type{Main: shaderir.Float}
case shaderir.Cross: case shaderir.Cross:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Vec3} t = shaderir.Type{Main: shaderir.Vec3}
case shaderir.Texture2DF: case shaderir.Texture2DF:
// TODO: Check arg types.
t = shaderir.Type{Main: shaderir.Vec4} t = shaderir.Type{Main: shaderir.Vec4}
default: default:
// TODO: Check arg types?
// If the argument is a non-typed constant value, treat is as a float value (#1874). // If the argument is a non-typed constant value, treat is as a float value (#1874).
if args[0].Type == shaderir.NumberExpr && args[0].ConstType == shaderir.ConstTypeNone { if args[0].Type == shaderir.NumberExpr && args[0].ConstType == shaderir.ConstTypeNone {
args[0].ConstType = shaderir.ConstTypeFloat args[0].ConstType = shaderir.ConstTypeFloat

View File

@ -1351,3 +1351,83 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
t.Errorf("error must be non-nil but was nil") t.Errorf("error must be non-nil but was nil")
} }
} }
// Issue #2184
func TestSyntaxBuiltinFuncType(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := vec2(1); _ = a", err: false},
{stmt: "a := vec2(1.0); _ = a", err: false},
{stmt: "i := 1; a := vec2(i); _ = a", err: false},
{stmt: "i := 1.0; a := vec2(i); _ = a", err: false},
{stmt: "a := vec2(vec2(1)); _ = a", err: false},
{stmt: "a := vec2(vec3(1)); _ = a", err: true},
{stmt: "a := vec2(1, 1); _ = a", err: false},
{stmt: "a := vec2(1.0, 1.0); _ = a", err: false},
{stmt: "i := 1; a := vec2(i, i); _ = a", err: false},
{stmt: "i := 1.0; a := vec2(i, i); _ = a", err: false},
{stmt: "a := vec2(vec2(1), 1); _ = a", err: true},
{stmt: "a := vec2(1, vec2(1)); _ = a", err: true},
{stmt: "a := vec2(vec2(1), vec2(1)); _ = a", err: true},
{stmt: "a := vec2(1, 1, 1); _ = a", err: true},
{stmt: "a := vec3(1); _ = a", err: false},
{stmt: "a := vec3(1.0); _ = a", err: false},
{stmt: "i := 1; a := vec3(i); _ = a", err: false},
{stmt: "i := 1.0; a := vec3(i); _ = a", err: false},
{stmt: "a := vec3(vec3(1)); _ = a", err: false},
{stmt: "a := vec3(vec2(1)); _ = a", err: true},
{stmt: "a := vec3(vec4(1)); _ = a", err: true},
{stmt: "a := vec3(1, 1, 1); _ = a", err: false},
{stmt: "a := vec3(1.0, 1.0, 1.0); _ = a", err: false},
{stmt: "i := 1; a := vec3(i, i, i); _ = a", err: false},
{stmt: "i := 1.0; a := vec3(i, i, i); _ = a", err: false},
{stmt: "a := vec3(vec2(1), 1); _ = a", err: false},
{stmt: "a := vec3(1, vec2(1)); _ = a", err: false},
{stmt: "a := vec3(vec3(1), 1); _ = a", err: true},
{stmt: "a := vec3(1, vec3(1)); _ = a", err: true},
{stmt: "a := vec3(vec3(1), vec3(1), vec3(1)); _ = a", err: true},
{stmt: "a := vec3(1, 1, 1, 1); _ = a", err: true},
{stmt: "a := vec4(1); _ = a", err: false},
{stmt: "a := vec4(1.0); _ = a", err: false},
{stmt: "i := 1; a := vec4(i); _ = a", err: false},
{stmt: "i := 1.0; a := vec4(i); _ = a", err: false},
{stmt: "a := vec4(vec4(1)); _ = a", err: false},
{stmt: "a := vec4(vec2(1)); _ = a", err: true},
{stmt: "a := vec4(vec3(1)); _ = a", err: true},
{stmt: "a := vec4(1, 1, 1, 1); _ = a", err: false},
{stmt: "a := vec4(1.0, 1.0, 1.0, 1.0); _ = a", err: false},
{stmt: "i := 1; a := vec4(i, i, i, i); _ = a", err: false},
{stmt: "i := 1.0; a := vec4(i, i, i, i); _ = a", err: false},
{stmt: "a := vec4(vec2(1), 1, 1); _ = a", err: false},
{stmt: "a := vec4(1, vec2(1), 1); _ = a", err: false},
{stmt: "a := vec4(1, 1, vec2(1)); _ = a", err: false},
{stmt: "a := vec4(vec2(1), vec2(1)); _ = a", err: false},
{stmt: "a := vec4(vec3(1), 1); _ = a", err: false},
{stmt: "a := vec4(1, vec3(1)); _ = a", err: false},
{stmt: "a := vec4(vec4(1), 1); _ = a", err: true},
{stmt: "a := vec4(1, vec4(1)); _ = a", err: true},
{stmt: "a := vec4(vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: true},
{stmt: "a := vec4(1, 1, 1, 1, 1); _ = a", err: true},
}
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}

View File

@ -98,3 +98,124 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b
return shaderir.Type{}, false return shaderir.Type{}, false
} }
} }
func canBeFloatImplicitly(expr shaderir.Expr, t shaderir.Type) bool {
// TODO: For integers, should only constants be allowed?
if t.Main == shaderir.Int {
return true
}
if t.Main == shaderir.Float {
return true
}
if expr.Const != nil {
if expr.Const.Kind() == gconstant.Int {
return true
}
if expr.Const.Kind() == gconstant.Float {
return true
}
}
return false
}
func checkArgsForVec2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
if len(args) != len(argts) {
return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts))
}
switch len(args) {
case 1:
if canBeFloatImplicitly(args[0], argts[0]) {
return nil
}
if argts[0].Main == shaderir.Vec2 {
return nil
}
return fmt.Errorf("invalid arguments for vec2: (%s)", argts[0].String())
case 2:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) {
return nil
}
return fmt.Errorf("invalid arguments for vec2: (%s, %s)", argts[0].String(), argts[1].String())
default:
return fmt.Errorf("too many arguments for vec2")
}
}
func checkArgsForVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
if len(args) != len(argts) {
return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts))
}
switch len(args) {
case 1:
if canBeFloatImplicitly(args[0], argts[0]) {
return nil
}
if argts[0].Main == shaderir.Vec3 {
return nil
}
return fmt.Errorf("invalid arguments for vec3: (%s)", argts[0].String())
case 2:
if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec2 {
return nil
}
if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) {
return nil
}
return fmt.Errorf("invalid arguments for vec3: (%s, %s)", argts[0].String(), argts[1].String())
case 3:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) {
return nil
}
return fmt.Errorf("invalid arguments for vec3: (%s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String())
default:
return fmt.Errorf("too many arguments for vec3")
}
}
func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
if len(args) != len(argts) {
return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts))
}
switch len(args) {
case 1:
if canBeFloatImplicitly(args[0], argts[0]) {
return nil
}
if argts[0].Main == shaderir.Vec4 {
return nil
}
return fmt.Errorf("invalid arguments for vec4: (%s)", argts[0].String())
case 2:
if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec3 {
return nil
}
if argts[0].Main == shaderir.Vec2 && argts[1].Main == shaderir.Vec2 {
return nil
}
if argts[0].Main == shaderir.Vec3 && canBeFloatImplicitly(args[1], argts[1]) {
return nil
}
return fmt.Errorf("invalid arguments for vec4: (%s, %s)", argts[0].String(), argts[1].String())
case 3:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && argts[2].Main == shaderir.Vec2 {
return nil
}
if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec2 && canBeFloatImplicitly(args[2], argts[2]) {
return nil
}
if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) {
return nil
}
return fmt.Errorf("invalid arguments for vec4: (%s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String())
case 4:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) && canBeFloatImplicitly(args[3], argts[3]) {
return nil
}
return fmt.Errorf("invalid arguments for vec4: (%s, %s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String(), argts[3].String())
default:
return fmt.Errorf("too many arguments for vec4")
}
}