diff --git a/internal/shader/expr.go b/internal/shader/expr.go index f1f50656e..a2693812f 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -21,6 +21,7 @@ import ( "go/token" "regexp" "strconv" + "strings" "github.com/hajimehoshi/ebiten/v2/internal/shaderir" ) @@ -869,7 +870,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar return cs.parseExpr(block, fname, e.X, markLocalVariableUsed) case *ast.SelectorExpr: - exprs, _, stmts, ok := cs.parseExpr(block, fname, e.X, true) + exprs, types, stmts, ok := cs.parseExpr(block, fname, e.X, true) if !ok { return nil, nil, nil, false } @@ -877,6 +878,12 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a selector: %s", e.X)) return nil, nil, nil, false } + + if !isValidSwizzling(e.Sel.Name, types[0]) { + cs.addError(e.Pos(), fmt.Sprintf("unexpected swizzling: %s", e.Sel.Name)) + return nil, nil, nil, false + } + var t shaderir.Type switch len(e.Sel.Name) { case 1: @@ -1069,3 +1076,20 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } return nil, nil, nil, false } + +func isValidSwizzling(swizzling string, t shaderir.Type) bool { + if !shaderir.IsValidSwizzling(swizzling) { + return false + } + + switch t.Main { + case shaderir.Vec2: + return !strings.ContainsAny(swizzling, "zwbarq") + case shaderir.Vec3: + return !strings.ContainsAny(swizzling, "waq") + case shaderir.Vec4: + return true + default: + return false + } +} diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 38eda8788..c9b193126 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -2409,3 +2409,59 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +func TestSwizzling(t *testing.T) { + cases := []struct { + stmt string + err bool + }{ + {stmt: "var a vec2; var b float = a.x; _ = b", err: false}, + {stmt: "var a vec2; var b float = a.y; _ = b", err: false}, + {stmt: "var a vec2; var b float = a.z; _ = b", err: true}, + {stmt: "var a vec2; var b float = a.w; _ = b", err: true}, + {stmt: "var a vec2; var b vec2 = a.xy; _ = b", err: false}, + {stmt: "var a vec2; var b vec3 = a.xyz; _ = b", err: true}, + {stmt: "var a vec2; var b vec3 = a.xyw; _ = b", err: true}, + {stmt: "var a vec2; var b vec3 = a.xyy; _ = b", err: false}, + {stmt: "var a vec2; var b vec3 = a.xyz; _ = b", err: true}, + {stmt: "var a vec2; var b vec4 = a.xyzw; _ = b", err: true}, + + {stmt: "var a vec3; var b float = a.x; _ = b", err: false}, + {stmt: "var a vec3; var b float = a.y; _ = b", err: false}, + {stmt: "var a vec3; var b float = a.z; _ = b", err: false}, + {stmt: "var a vec3; var b float = a.w; _ = b", err: true}, + {stmt: "var a vec3; var b vec2 = a.xy; _ = b", err: false}, + {stmt: "var a vec3; var b vec3 = a.xyz; _ = b", err: false}, + {stmt: "var a vec3; var b vec3 = a.xyw; _ = b", err: true}, + {stmt: "var a vec3; var b vec3 = a.xyy; _ = b", err: false}, + {stmt: "var a vec3; var b vec3 = a.xyz; _ = b", err: false}, + {stmt: "var a vec3; var b vec4 = a.xyzw; _ = b", err: true}, + + {stmt: "var a vec4; var b float = a.x; _ = b", err: false}, + {stmt: "var a vec4; var b float = a.y; _ = b", err: false}, + {stmt: "var a vec4; var b float = a.z; _ = b", err: false}, + {stmt: "var a vec4; var b float = a.w; _ = b", err: false}, + {stmt: "var a vec4; var b vec2 = a.xy; _ = b", err: false}, + {stmt: "var a vec4; var b vec3 = a.xyz; _ = b", err: false}, + {stmt: "var a vec4; var b vec3 = a.xyw; _ = b", err: false}, + {stmt: "var a vec4; var b vec3 = a.xyy; _ = b", err: false}, + {stmt: "var a vec4; var b vec3 = a.xyz; _ = b", err: false}, + {stmt: "var a vec4; var b vec4 = a.xyzw; _ = b", err: false}, + } + + for _, c := range cases { + stmt := c.stmt + src := fmt.Sprintf(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + %s + return position +}`, stmt) + _, err := compileToIR([]byte(src)) + if err == nil && c.err { + t.Errorf("%s must return an error but does not", stmt) + } else if err != nil && !c.err { + t.Errorf("%s must not return nil but returned %v", stmt, err) + } + } +}