diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 7f7953275..c1f6c32ea 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -148,9 +148,11 @@ func (cs *compileState) parse(f *ast.File) { // Parse GenDecl for global variables, and then parse functions. for _, d := range f.Decls { if _, ok := d.(*ast.FuncDecl); !ok { - if !cs.parseDecl(&cs.global, d) { + ss, ok := cs.parseDecl(&cs.global, d) + if !ok { return } + cs.global.ir.Stmts = append(cs.global.ir.Stmts, ss...) } } @@ -209,9 +211,11 @@ func (cs *compileState) parse(f *ast.File) { // Parse functions. for _, d := range f.Decls { if _, ok := d.(*ast.FuncDecl); ok { - if !cs.parseDecl(&cs.global, d) { + ss, ok := cs.parseDecl(&cs.global, d) + if !ok { return } + cs.global.ir.Stmts = append(cs.global.ir.Stmts, ss...) } } @@ -224,7 +228,9 @@ func (cs *compileState) parse(f *ast.File) { } } -func (cs *compileState) parseDecl(b *block, d ast.Decl) bool { +func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) { + var stmts []shaderir.Stmt + switch d := d.(type) { case *ast.GenDecl: switch d.Tok { @@ -247,11 +253,11 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool { case token.VAR: for _, s := range d.Specs { s := s.(*ast.ValueSpec) - vs, inits, stmts, ok := cs.parseVariable(b, s) + vs, inits, ss, ok := cs.parseVariable(b, s) if !ok { - return false + return nil, false } - b.ir.Stmts = append(b.ir.Stmts, stmts...) + stmts = append(stmts, ss...) if b == &cs.global { // TODO: Should rhs be ignored? for i, v := range vs { @@ -271,7 +277,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool { if len(inits) > 0 { for i := range vs { - b.ir.Stmts = append(b.ir.Stmts, shaderir.Stmt{ + stmts = append(stmts, shaderir.Stmt{ Type: shaderir.Assign, Exprs: []shaderir.Expr{ { @@ -292,11 +298,11 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool { case *ast.FuncDecl: f, ok := cs.parseFunc(b, d) if !ok { - return false + return nil, false } if b != &cs.global { cs.addError(d.Pos(), "non-global function is not implemented") - return false + return nil, false } switch d.Name.Name { case cs.vertexEntry: @@ -316,10 +322,10 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool { } default: cs.addError(d.Pos(), "unexpected decl") - return false + return nil, false } - return true + return stmts, true } // functionReturnTypes returns the original returning value types, if the given expression is call. diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index d75eed9dc..37bd2a13e 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -107,9 +107,11 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab }, }) case *ast.DeclStmt: - if !cs.parseDecl(block, stmt.Decl) { + ss, ok := cs.parseDecl(block, stmt.Decl) + if !ok { return nil, false } + stmts = append(stmts, ss...) case *ast.IfStmt: if stmt.Init != nil { init := stmt.Init