internal/shader: add type checks for mat2/mat3/mat4

Updates #2184
This commit is contained in:
Hajime Hoshi 2022-07-07 21:25:02 +09:00
parent f89277fd85
commit faa2ad5c6f
3 changed files with 226 additions and 12 deletions

View File

@ -406,13 +406,22 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
} }
t = shaderir.Type{Main: shaderir.Vec4} t = shaderir.Type{Main: shaderir.Vec4}
case shaderir.Mat2F: case shaderir.Mat2F:
// TODO: Check arg types. if err := checkArgsForMat2BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Mat2} t = shaderir.Type{Main: shaderir.Mat2}
case shaderir.Mat3F: case shaderir.Mat3F:
// TODO: Check arg types. if err := checkArgsForMat3BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Mat3} t = shaderir.Type{Main: shaderir.Mat3}
case shaderir.Mat4F: case shaderir.Mat4F:
// TODO: Check arg types. if err := checkArgsForMat4BuiltinFunc(args, argts); err != nil {
cs.addError(e.Pos(), err.Error())
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Mat4} t = shaderir.Type{Main: shaderir.Mat4}
case shaderir.Step: case shaderir.Step:
// TODO: Check arg types. // TODO: Check arg types.

View File

@ -1415,6 +1415,84 @@ func TestSyntaxBuiltinFuncType(t *testing.T) {
{stmt: "a := vec4(1, vec4(1)); _ = a", err: true}, {stmt: "a := vec4(1, vec4(1)); _ = a", err: true},
{stmt: "a := vec4(vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: true}, {stmt: "a := vec4(vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: true},
{stmt: "a := vec4(1, 1, 1, 1, 1); _ = a", err: true}, {stmt: "a := vec4(1, 1, 1, 1, 1); _ = a", err: true},
{stmt: "a := mat2(1); _ = a", err: false},
{stmt: "a := mat2(1.0); _ = a", err: false},
{stmt: "i := 1; a := mat2(i); _ = a", err: false},
{stmt: "i := 1.0; a := mat2(i); _ = a", err: false},
{stmt: "a := mat2(mat2(1)); _ = a", err: false},
{stmt: "a := mat2(vec2(1)); _ = a", err: true},
{stmt: "a := mat2(mat3(1)); _ = a", err: true},
{stmt: "a := mat2(mat4(1)); _ = a", err: true},
{stmt: "a := mat2(vec2(1), vec2(1)); _ = a", err: false},
{stmt: "a := mat2(1, 1); _ = a", err: true},
{stmt: "a := mat2(1, vec2(1)); _ = a", err: true},
{stmt: "a := mat2(vec2(1), vec3(1)); _ = a", err: true},
{stmt: "a := mat2(mat2(1), mat2(1)); _ = a", err: true},
{stmt: "a := mat2(1, 1, 1, 1); _ = a", err: false},
{stmt: "a := mat2(1.0, 1.0, 1.0, 1.0); _ = a", err: false},
{stmt: "i := 1; a := mat2(i, i, i, i); _ = a", err: false},
{stmt: "i := 1.0; a := mat2(i, i, i, i); _ = a", err: false},
{stmt: "a := mat2(vec2(1), vec2(1), vec2(1), vec2(1)); _ = a", err: true},
{stmt: "a := mat2(1, 1, 1, vec2(1)); _ = a", err: true},
{stmt: "a := mat2(1, 1, 1, vec3(1)); _ = a", err: true},
{stmt: "a := mat2(mat2(1), mat2(1), mat2(1), mat2(1)); _ = a", err: true},
{stmt: "a := mat2(1, 1, 1); _ = a", err: true},
{stmt: "a := mat2(1, 1, 1, 1, 1); _ = a", err: true},
{stmt: "a := mat3(1); _ = a", err: false},
{stmt: "a := mat3(1.0); _ = a", err: false},
{stmt: "i := 1; a := mat3(i); _ = a", err: false},
{stmt: "i := 1.0; a := mat3(i); _ = a", err: false},
{stmt: "a := mat3(mat3(1)); _ = a", err: false},
{stmt: "a := mat3(vec2(1)); _ = a", err: true},
{stmt: "a := mat3(mat2(1)); _ = a", err: true},
{stmt: "a := mat3(mat4(1)); _ = a", err: true},
{stmt: "a := mat3(vec3(1), vec3(1), vec3(1)); _ = a", err: false},
{stmt: "a := mat3(1, 1, 1); _ = a", err: true},
{stmt: "a := mat3(1, 1, vec3(1)); _ = a", err: true},
{stmt: "a := mat3(vec3(1), vec3(1), vec4(1)); _ = a", err: true},
{stmt: "a := mat3(mat3(1), mat3(1), mat3(1)); _ = a", err: true},
{stmt: "a := mat3(1, 1, 1, 1, 1, 1, 1, 1, 1); _ = a", err: false},
{stmt: "a := mat3(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0); _ = a", err: false},
{stmt: "i := 1; a := mat3(i, i, i, i, i, i, i, i, i); _ = a", err: false},
{stmt: "i := 1.0; a := mat3(i, i, i, i, i, i, i, i, i); _ = a", err: false},
{stmt: "a := mat3(vec3(1), vec3(1), vec3(1), vec3(1), vec3(1), vec3(1), vec3(1), vec3(1), vec3(1)); _ = a", err: true},
{stmt: "a := mat3(1, 1, 1, 1, 1, 1, 1, 1, vec2(1)); _ = a", err: true},
{stmt: "a := mat3(1, 1, 1, 1, 1, 1, 1, 1, vec3(1)); _ = a", err: true},
{stmt: "a := mat3(mat3(1), mat3(1), mat3(1), mat3(1), mat3(1), mat3(1), mat3(1), mat3(1), mat3(1)); _ = a", err: true},
{stmt: "a := mat3(1, 1, 1, 1, 1, 1, 1, 1); _ = a", err: true},
{stmt: "a := mat3(1, 1, 1, 1, 1, 1, 1, 1, 1, 1); _ = a", err: true},
{stmt: "a := mat4(1); _ = a", err: false},
{stmt: "a := mat4(1.0); _ = a", err: false},
{stmt: "i := 1; a := mat4(i); _ = a", err: false},
{stmt: "i := 1.0; a := mat4(i); _ = a", err: false},
{stmt: "a := mat4(mat4(1)); _ = a", err: false},
{stmt: "a := mat4(vec2(1)); _ = a", err: true},
{stmt: "a := mat4(mat2(1)); _ = a", err: true},
{stmt: "a := mat4(mat3(1)); _ = a", err: true},
{stmt: "a := mat4(vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: false},
{stmt: "a := mat4(1, 1, 1, 1); _ = a", err: true},
{stmt: "a := mat4(1, 1, 1, vec4(1)); _ = a", err: true},
{stmt: "a := mat4(vec4(1), vec4(1), vec4(1), vec2(1)); _ = a", err: true},
{stmt: "a := mat4(mat4(1), mat4(1), mat4(1), mat4(1)); _ = a", err: true},
{stmt: "a := mat4(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); _ = a", err: false},
{stmt: "a := mat4(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0); _ = a", err: false},
{stmt: "i := 1; a := mat4(i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i); _ = a", err: false},
{stmt: "i := 1.0; a := mat4(i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i); _ = a", err: false},
{stmt: "a := mat4(vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1), vec4(1)); _ = a", err: true},
{stmt: "a := mat4(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, vec2(1)); _ = a", err: true},
{stmt: "a := mat4(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, vec3(1)); _ = a", err: true},
{stmt: "a := mat4(mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1), mat4(1)); _ = a", err: true},
{stmt: "a := mat4(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); _ = a", err: true},
{stmt: "a := mat4(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); _ = a", err: true},
} }
for _, c := range cases { for _, c := range cases {

View File

@ -18,6 +18,7 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
gconstant "go/constant" gconstant "go/constant"
"strings"
"github.com/hajimehoshi/ebiten/v2/internal/shaderir" "github.com/hajimehoshi/ebiten/v2/internal/shaderir"
) )
@ -131,15 +132,19 @@ func checkArgsForVec2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er
if argts[0].Main == shaderir.Vec2 { if argts[0].Main == shaderir.Vec2 {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec2: (%s)", argts[0].String())
case 2: case 2:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) { if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec2: (%s, %s)", argts[0].String(), argts[1].String())
default: default:
return fmt.Errorf("too many arguments for vec2") return fmt.Errorf("too many arguments for vec2")
} }
var str []string
for _, t := range argts {
str = append(str, t.String())
}
return fmt.Errorf("invalid arguments for vec2: (%s)", strings.Join(str, ", "))
} }
func checkArgsForVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { func checkArgsForVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
@ -155,7 +160,6 @@ func checkArgsForVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er
if argts[0].Main == shaderir.Vec3 { if argts[0].Main == shaderir.Vec3 {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec3: (%s)", argts[0].String())
case 2: case 2:
if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec2 { if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec2 {
return nil return nil
@ -163,15 +167,19 @@ func checkArgsForVec3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er
if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) { if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec3: (%s, %s)", argts[0].String(), argts[1].String())
case 3: case 3:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) { if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec3: (%s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String())
default: default:
return fmt.Errorf("too many arguments for vec3") return fmt.Errorf("too many arguments for vec3")
} }
var str []string
for _, t := range argts {
str = append(str, t.String())
}
return fmt.Errorf("invalid arguments for vec3: (%s)", strings.Join(str, ", "))
} }
func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error { func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
@ -187,7 +195,6 @@ func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er
if argts[0].Main == shaderir.Vec4 { if argts[0].Main == shaderir.Vec4 {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec4: (%s)", argts[0].String())
case 2: case 2:
if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec3 { if canBeFloatImplicitly(args[0], argts[0]) && argts[1].Main == shaderir.Vec3 {
return nil return nil
@ -198,7 +205,6 @@ func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er
if argts[0].Main == shaderir.Vec3 && canBeFloatImplicitly(args[1], argts[1]) { if argts[0].Main == shaderir.Vec3 && canBeFloatImplicitly(args[1], argts[1]) {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec4: (%s, %s)", argts[0].String(), argts[1].String())
case 3: case 3:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && argts[2].Main == shaderir.Vec2 { if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && argts[2].Main == shaderir.Vec2 {
return nil return nil
@ -209,13 +215,134 @@ func checkArgsForVec4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) er
if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) { if argts[0].Main == shaderir.Vec2 && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec4: (%s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String())
case 4: case 4:
if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) && canBeFloatImplicitly(args[3], argts[3]) { if canBeFloatImplicitly(args[0], argts[0]) && canBeFloatImplicitly(args[1], argts[1]) && canBeFloatImplicitly(args[2], argts[2]) && canBeFloatImplicitly(args[3], argts[3]) {
return nil return nil
} }
return fmt.Errorf("invalid arguments for vec4: (%s, %s, %s, %s)", argts[0].String(), argts[1].String(), argts[2].String(), argts[3].String())
default: default:
return fmt.Errorf("too many arguments for vec4") return fmt.Errorf("too many arguments for vec4")
} }
var str []string
for _, t := range argts {
str = append(str, t.String())
}
return fmt.Errorf("invalid arguments for vec4: (%s)", strings.Join(str, ", "))
}
func checkArgsForMat2BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
if len(args) != len(argts) {
return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts))
}
switch len(args) {
case 1:
if canBeFloatImplicitly(args[0], argts[0]) {
return nil
}
if argts[0].Main == shaderir.Mat2 {
return nil
}
case 2:
if argts[0].Main == shaderir.Vec2 && argts[1].Main == shaderir.Vec2 {
return nil
}
case 4:
ok := true
for i := range argts {
if !canBeFloatImplicitly(args[i], argts[i]) {
ok = false
break
}
}
if ok {
return nil
}
default:
return fmt.Errorf("invalid number of arguments for mat2")
}
var str []string
for _, t := range argts {
str = append(str, t.String())
}
return fmt.Errorf("invalid arguments for mat2: (%s)", strings.Join(str, ", "))
}
func checkArgsForMat3BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
if len(args) != len(argts) {
return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts))
}
switch len(args) {
case 1:
if canBeFloatImplicitly(args[0], argts[0]) {
return nil
}
if argts[0].Main == shaderir.Mat3 {
return nil
}
case 3:
if argts[0].Main == shaderir.Vec3 && argts[1].Main == shaderir.Vec3 && argts[2].Main == shaderir.Vec3 {
return nil
}
case 9:
ok := true
for i := range argts {
if !canBeFloatImplicitly(args[i], argts[i]) {
ok = false
break
}
}
if ok {
return nil
}
default:
return fmt.Errorf("invalid number of arguments for mat3")
}
var str []string
for _, t := range argts {
str = append(str, t.String())
}
return fmt.Errorf("invalid arguments for mat3: (%s)", strings.Join(str, ", "))
}
func checkArgsForMat4BuiltinFunc(args []shaderir.Expr, argts []shaderir.Type) error {
if len(args) != len(argts) {
return fmt.Errorf("the number of arguments and types doesn't match: %d vs %d", len(args), len(argts))
}
switch len(args) {
case 1:
if canBeFloatImplicitly(args[0], argts[0]) {
return nil
}
if argts[0].Main == shaderir.Mat4 {
return nil
}
case 4:
if argts[0].Main == shaderir.Vec4 && argts[1].Main == shaderir.Vec4 && argts[2].Main == shaderir.Vec4 && argts[3].Main == shaderir.Vec4 {
return nil
}
case 16:
ok := true
for i := range argts {
if !canBeFloatImplicitly(args[i], argts[i]) {
ok = false
break
}
}
if ok {
return nil
}
default:
return fmt.Errorf("invalid number of arguments for mat4")
}
var str []string
for _, t := range argts {
str = append(str, t.String())
}
return fmt.Errorf("invalid arguments for mat4: (%s)", strings.Join(str, ", "))
} }