From fb775d806c4aaae0aa6dda5a0e1280be21c10f91 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Wed, 17 Aug 2022 23:13:21 +0900 Subject: [PATCH] internal/shader: disallow 'discard' in other functions than the fragment entry point Closes #2248 --- internal/shader/expr.go | 27 +++++++++++---------- internal/shader/shader.go | 44 ++++++++++++++++------------------ internal/shader/stmt.go | 24 +++++++++---------- internal/shader/syntax_test.go | 39 ++++++++++++++++++++++++++++++ internal/shader/type.go | 6 ++--- 5 files changed, 90 insertions(+), 50 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 40399cf29..c9cadcaaa 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -67,7 +67,7 @@ func goConstantKindString(k gconstant.Kind) string { var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) -func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { +func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -93,7 +93,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable var stmts []shaderir.Stmt // Prase LHS first for the order of the statements. - lhs, ts, ss, ok := cs.parseExpr(block, e.X, markLocalVariableUsed) + lhs, ts, ss, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed) if !ok { return nil, nil, nil, false } @@ -104,7 +104,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable stmts = append(stmts, ss...) lhst := ts[0] - rhs, ts, ss, ok := cs.parseExpr(block, e.Y, markLocalVariableUsed) + rhs, ts, ss, ok := cs.parseExpr(block, fname, e.Y, markLocalVariableUsed) if !ok { return nil, nil, nil, false } @@ -299,7 +299,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable // Parse the argument first for the order of the statements. for _, a := range e.Args { - es, ts, ss, ok := cs.parseExpr(block, a, markLocalVariableUsed) + es, ts, ss, ok := cs.parseExpr(block, fname, a, markLocalVariableUsed) if !ok { return nil, nil, nil, false } @@ -313,7 +313,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable } // TODO: When len(ss) is not 0? - es, _, ss, ok := cs.parseExpr(block, e.Fun, markLocalVariableUsed) + es, _, ss, ok := cs.parseExpr(block, fname, e.Fun, markLocalVariableUsed) if !ok { return nil, nil, nil, false } @@ -473,6 +473,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable if len(args) != 0 { cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args))) } + if fname != cs.fragmentEntry { + cs.addError(e.Pos(), fmt.Sprintf("discard is available only in %s", cs.fragmentEntry)) + } stmts = append(stmts, shaderir.Stmt{ Type: shaderir.Discard, }) @@ -653,10 +656,10 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name)) case *ast.ParenExpr: - return cs.parseExpr(block, e.X, markLocalVariableUsed) + return cs.parseExpr(block, fname, e.X, markLocalVariableUsed) case *ast.SelectorExpr: - exprs, _, stmts, ok := cs.parseExpr(block, e.X, true) + exprs, _, stmts, ok := cs.parseExpr(block, fname, e.X, true) if !ok { return nil, nil, nil, false } @@ -692,7 +695,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable }, []shaderir.Type{t}, stmts, true case *ast.UnaryExpr: - exprs, t, stmts, ok := cs.parseExpr(block, e.X, markLocalVariableUsed) + exprs, t, stmts, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed) if !ok { return nil, nil, nil, false } @@ -736,7 +739,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable }, t, stmts, true case *ast.CompositeLit: - t, ok := cs.parseType(block, e.Type) + t, ok := cs.parseType(block, fname, e.Type) if !ok { return nil, nil, nil, false } @@ -751,7 +754,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable var stmts []shaderir.Stmt for i, e := range e.Elts { - exprs, _, ss, ok := cs.parseExpr(block, e, markLocalVariableUsed) + exprs, _, ss, ok := cs.parseExpr(block, fname, e, markLocalVariableUsed) if !ok { return nil, nil, nil, false } @@ -793,7 +796,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable var stmts []shaderir.Stmt // Parse the index first - exprs, _, ss, ok := cs.parseExpr(block, e.Index, markLocalVariableUsed) + exprs, _, ss, ok := cs.parseExpr(block, fname, e.Index, markLocalVariableUsed) if !ok { return nil, nil, nil, false } @@ -812,7 +815,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable idx.ConstType = shaderir.ConstTypeInt } - exprs, ts, ss, ok := cs.parseExpr(block, e.X, markLocalVariableUsed) + exprs, ts, ss, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed) if !ok { return nil, nil, nil, false } diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 60d9e76ad..b6783c446 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -38,8 +38,7 @@ type constant struct { } type function struct { - name string - block *block + name string ir shaderir.Func } @@ -212,7 +211,7 @@ func (cs *compileState) parse(f *ast.File) { // Parse GenDecl for global variables, and then parse functions. for _, d := range f.Decls { if _, ok := d.(*ast.FuncDecl); !ok { - ss, ok := cs.parseDecl(&cs.global, d) + ss, ok := cs.parseDecl(&cs.global, "", d) if !ok { return } @@ -261,7 +260,7 @@ func (cs *compileState) parse(f *ast.File) { } } - inParams, outParams, ret := cs.parseFuncParams(&cs.global, fd) + inParams, outParams, ret := cs.parseFuncParams(&cs.global, n, fd) var inT, outT []shaderir.Type for _, v := range inParams { inT = append(inT, v.typ) @@ -284,8 +283,8 @@ func (cs *compileState) parse(f *ast.File) { // Parse functions. for _, d := range f.Decls { - if _, ok := d.(*ast.FuncDecl); ok { - ss, ok := cs.parseDecl(&cs.global, d) + if f, ok := d.(*ast.FuncDecl); ok { + ss, ok := cs.parseDecl(&cs.global, f.Name.Name, d) if !ok { return } @@ -302,7 +301,7 @@ func (cs *compileState) parse(f *ast.File) { } } -func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) { +func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) { var stmts []shaderir.Stmt switch d := d.(type) { @@ -312,7 +311,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) // TODO: Parse other types for _, s := range d.Specs { s := s.(*ast.TypeSpec) - t, ok := cs.parseType(b, s.Type) + t, ok := cs.parseType(b, fname, s.Type) if !ok { return nil, false } @@ -324,7 +323,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) case token.CONST: for _, s := range d.Specs { s := s.(*ast.ValueSpec) - cs, ok := cs.parseConstant(b, s) + cs, ok := cs.parseConstant(b, fname, s) if !ok { return nil, false } @@ -333,7 +332,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) case token.VAR: for _, s := range d.Specs { s := s.(*ast.ValueSpec) - vs, inits, ss, ok := cs.parseVariable(b, s) + vs, inits, ss, ok := cs.parseVariable(b, fname, s) if !ok { return nil, false } @@ -439,7 +438,7 @@ func (cs *compileState) functionReturnTypes(block *block, expr ast.Expr) ([]shad return nil, false } -func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt, bool) { +func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt, bool) { if len(vs.Names) != len(vs.Values) && len(vs.Values) != 1 && len(vs.Values) != 0 { s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) return nil, nil, nil, false @@ -448,7 +447,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl var declt shaderir.Type if vs.Type != nil { var ok bool - declt, ok = s.parseType(block, vs.Type) + declt, ok = s.parseType(block, fname, vs.Type) if !ok { return nil, nil, nil, false } @@ -475,7 +474,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl init := vs.Values[i] - es, rts, ss, ok := s.parseExpr(block, init, true) + es, rts, ss, ok := s.parseExpr(block, fname, init, true) if !ok { return nil, nil, nil, false } @@ -517,7 +516,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl var ss []shaderir.Stmt var ok bool - initexprs, inittypes, ss, ok = s.parseExpr(block, init, true) + initexprs, inittypes, ss, ok = s.parseExpr(block, fname, init, true) if !ok { return nil, nil, nil, false } @@ -569,11 +568,11 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl return vars, inits, stmts, true } -func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constant, bool) { +func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSpec) ([]constant, bool) { var t shaderir.Type if vs.Type != nil { var ok bool - t, ok = s.parseType(block, vs.Type) + t, ok = s.parseType(block, fname, vs.Type) if !ok { return nil, false } @@ -595,7 +594,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan } } - es, ts, ss, ok := s.parseExpr(block, vs.Values[i], false) + es, ts, ss, ok := s.parseExpr(block, fname, vs.Values[i], false) if !ok { return nil, false } @@ -621,9 +620,9 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan return cs, true } -func (cs *compileState) parseFuncParams(block *block, d *ast.FuncDecl) (in, out []variable, ret shaderir.Type) { +func (cs *compileState) parseFuncParams(block *block, fname string, d *ast.FuncDecl) (in, out []variable, ret shaderir.Type) { for _, f := range d.Type.Params.List { - t, ok := cs.parseType(block, f.Type) + t, ok := cs.parseType(block, fname, f.Type) if !ok { return } @@ -640,7 +639,7 @@ func (cs *compileState) parseFuncParams(block *block, d *ast.FuncDecl) (in, out } for _, f := range d.Type.Results.List { - t, ok := cs.parseType(block, f.Type) + t, ok := cs.parseType(block, fname, f.Type) if !ok { return } @@ -681,7 +680,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool return function{}, false } - inParams, outParams, returnType := cs.parseFuncParams(block, d) + inParams, outParams, returnType := cs.parseFuncParams(block, d.Name.Name, d) checkVaryings := func(vs []variable) { if len(cs.ir.Varyings) != len(vs) { @@ -789,8 +788,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool } return function{ - name: d.Name.Name, - block: b, + name: d.Name.Name, ir: shaderir.Func{ InParams: inT, OutParams: outT, diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 47a965a8f..ff3a3c714 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -62,13 +62,13 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } stmts = append(stmts, ss...) case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN: - rhs, rts, ss, ok := cs.parseExpr(block, stmt.Rhs[0], true) + rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true) if !ok { return nil, false } stmts = append(stmts, ss...) - lhs, lts, ss, ok := cs.parseExpr(block, stmt.Lhs[0], true) + lhs, lts, ss, ok := cs.parseExpr(block, fname, stmt.Lhs[0], true) if !ok { return nil, false } @@ -177,7 +177,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP }, }) case *ast.DeclStmt: - ss, ok := cs.parseDecl(block, stmt.Decl) + ss, ok := cs.parseDecl(block, fname, stmt.Decl) if !ok { return nil, false } @@ -228,7 +228,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP vartype := pseudoBlock.vars[0].typ init := ss[0].Exprs[1].Const - exprs, ts, ss, ok := cs.parseExpr(pseudoBlock, stmt.Cond, true) + exprs, ts, ss, ok := cs.parseExpr(pseudoBlock, fname, stmt.Cond, true) if !ok { return nil, false } @@ -356,7 +356,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return stmts, true } - exprs, ts, ss, ok := cs.parseExpr(block, stmt.Cond, true) + exprs, ts, ss, ok := cs.parseExpr(block, fname, stmt.Cond, true) if !ok { return nil, false } @@ -401,7 +401,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP }) case *ast.IncDecStmt: - exprs, _, ss, ok := cs.parseExpr(block, stmt.X, true) + exprs, _, ss, ok := cs.parseExpr(block, fname, stmt.X, true) if !ok { return nil, false } @@ -445,7 +445,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP var exprs []shaderir.Expr var types []shaderir.Type for _, r := range stmt.Results { - es, ts, ss, ok := cs.parseExpr(block, r, true) + es, ts, ss, ok := cs.parseExpr(block, fname, r, true) if !ok { return nil, false } @@ -544,7 +544,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } - exprs, _, ss, ok := cs.parseExpr(block, stmt.X, true) + exprs, _, ss, ok := cs.parseExpr(block, fname, stmt.X, true) if !ok { return nil, false } @@ -582,7 +582,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r for i, e := range lhs { if len(lhs) == len(rhs) { // Prase RHS first for the order of the statements. - r, rts, ss, ok := cs.parseExpr(block, rhs[i], true) + r, rts, ss, ok := cs.parseExpr(block, fname, rhs[i], true) if !ok { return nil, false } @@ -619,7 +619,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r return nil, false } - l, lts, ss, ok := cs.parseExpr(block, lhs[i], false) + l, lts, ss, ok := cs.parseExpr(block, fname, lhs[i], false) if !ok { return nil, false } @@ -714,7 +714,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r if i == 0 { var ss []shaderir.Stmt var ok bool - rhsExprs, rhsTypes, ss, ok = cs.parseExpr(block, rhs[0], true) + rhsExprs, rhsTypes, ss, ok = cs.parseExpr(block, fname, rhs[0], true) if !ok { return nil, false } @@ -743,7 +743,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r block.addNamedLocalVariable(name, t, e.Pos()) } - l, lts, ss, ok := cs.parseExpr(block, lhs[i], false) + l, lts, ss, ok := cs.parseExpr(block, fname, lhs[i], false) if !ok { return nil, false } diff --git a/internal/shader/syntax_test.go b/internal/shader/syntax_test.go index 759367b0d..56e6ad13a 100644 --- a/internal/shader/syntax_test.go +++ b/internal/shader/syntax_test.go @@ -1575,3 +1575,42 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { } } } + +// Issue #2248 +func TestSyntaxDiscard(t *testing.T) { + if _, err := compileToIR([]byte(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + if true { + discard() + } + return vec4(0) +} +`)); err != nil { + t.Error(err) + } + // discard without return doesn't work so far. + // TODO: Allow discard without return. + if _, err := compileToIR([]byte(`package main + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + discard() + return vec4(0) +} +`)); err != nil { + t.Error(err) + } + if _, err := compileToIR([]byte(`package main + +func foo() { + discard() +} + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + foo() + return vec4(0) +} +`)); err == nil { + t.Errorf("error must be non-nil but was nil") + } +} diff --git a/internal/shader/type.go b/internal/shader/type.go index e32a61237..da8c108df 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -23,7 +23,7 @@ import ( "github.com/hajimehoshi/ebiten/v2/internal/shaderir" ) -func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, bool) { +func (cs *compileState) parseType(block *block, fname string, expr ast.Expr) (shaderir.Type, bool) { switch t := expr.(type) { case *ast.Ident: switch t.Name { @@ -58,7 +58,7 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b if _, ok := t.Len.(*ast.Ellipsis); ok { length = -1 // Determine the length later. } else { - exprs, _, _, ok := cs.parseExpr(block, t.Len, true) + exprs, _, _, ok := cs.parseExpr(block, fname, t.Len, true) if !ok { return shaderir.Type{}, false } @@ -78,7 +78,7 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b length = int(l) } - elm, ok := cs.parseType(block, t.Elt) + elm, ok := cs.parseType(block, fname, t.Elt) if !ok { return shaderir.Type{}, false }