internal/shader: add swizzling check

This commit is contained in:
Hajime Hoshi 2022-11-19 22:37:31 +09:00
parent 6121856836
commit 5f4e3a0348
2 changed files with 81 additions and 1 deletions

View File

@ -21,6 +21,7 @@ import (
"go/token"
"regexp"
"strconv"
"strings"
"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
)
@ -869,7 +870,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
return cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
case *ast.SelectorExpr:
exprs, _, stmts, ok := cs.parseExpr(block, fname, e.X, true)
exprs, types, stmts, ok := cs.parseExpr(block, fname, e.X, true)
if !ok {
return nil, nil, nil, false
}
@ -877,6 +878,12 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a selector: %s", e.X))
return nil, nil, nil, false
}
if !isValidSwizzling(e.Sel.Name, types[0]) {
cs.addError(e.Pos(), fmt.Sprintf("unexpected swizzling: %s", e.Sel.Name))
return nil, nil, nil, false
}
var t shaderir.Type
switch len(e.Sel.Name) {
case 1:
@ -1069,3 +1076,20 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
return nil, nil, nil, false
}
func isValidSwizzling(swizzling string, t shaderir.Type) bool {
if !shaderir.IsValidSwizzling(swizzling) {
return false
}
switch t.Main {
case shaderir.Vec2:
return !strings.ContainsAny(swizzling, "zwbarq")
case shaderir.Vec3:
return !strings.ContainsAny(swizzling, "waq")
case shaderir.Vec4:
return true
default:
return false
}
}

View File

@ -2409,3 +2409,59 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
}
func TestSwizzling(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "var a vec2; var b float = a.x; _ = b", err: false},
{stmt: "var a vec2; var b float = a.y; _ = b", err: false},
{stmt: "var a vec2; var b float = a.z; _ = b", err: true},
{stmt: "var a vec2; var b float = a.w; _ = b", err: true},
{stmt: "var a vec2; var b vec2 = a.xy; _ = b", err: false},
{stmt: "var a vec2; var b vec3 = a.xyz; _ = b", err: true},
{stmt: "var a vec2; var b vec3 = a.xyw; _ = b", err: true},
{stmt: "var a vec2; var b vec3 = a.xyy; _ = b", err: false},
{stmt: "var a vec2; var b vec3 = a.xyz; _ = b", err: true},
{stmt: "var a vec2; var b vec4 = a.xyzw; _ = b", err: true},
{stmt: "var a vec3; var b float = a.x; _ = b", err: false},
{stmt: "var a vec3; var b float = a.y; _ = b", err: false},
{stmt: "var a vec3; var b float = a.z; _ = b", err: false},
{stmt: "var a vec3; var b float = a.w; _ = b", err: true},
{stmt: "var a vec3; var b vec2 = a.xy; _ = b", err: false},
{stmt: "var a vec3; var b vec3 = a.xyz; _ = b", err: false},
{stmt: "var a vec3; var b vec3 = a.xyw; _ = b", err: true},
{stmt: "var a vec3; var b vec3 = a.xyy; _ = b", err: false},
{stmt: "var a vec3; var b vec3 = a.xyz; _ = b", err: false},
{stmt: "var a vec3; var b vec4 = a.xyzw; _ = b", err: true},
{stmt: "var a vec4; var b float = a.x; _ = b", err: false},
{stmt: "var a vec4; var b float = a.y; _ = b", err: false},
{stmt: "var a vec4; var b float = a.z; _ = b", err: false},
{stmt: "var a vec4; var b float = a.w; _ = b", err: false},
{stmt: "var a vec4; var b vec2 = a.xy; _ = b", err: false},
{stmt: "var a vec4; var b vec3 = a.xyz; _ = b", err: false},
{stmt: "var a vec4; var b vec3 = a.xyw; _ = b", err: false},
{stmt: "var a vec4; var b vec3 = a.xyy; _ = b", err: false},
{stmt: "var a vec4; var b vec3 = a.xyz; _ = b", err: false},
{stmt: "var a vec4; var b vec4 = a.xyzw; _ = b", err: false},
}
for _, c := range cases {
stmt := c.stmt
src := fmt.Sprintf(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
%s
return position
}`, stmt)
_, err := compileToIR([]byte(src))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", stmt, err)
}
}
}