shader: Refactoring: Let parseDecl return statements

This commit is contained in:
Hajime Hoshi 2020-07-12 15:51:06 +09:00
parent d6bce5b9d2
commit dcb693460e
2 changed files with 20 additions and 12 deletions

View File

@ -148,9 +148,11 @@ func (cs *compileState) parse(f *ast.File) {
// Parse GenDecl for global variables, and then parse functions. // Parse GenDecl for global variables, and then parse functions.
for _, d := range f.Decls { for _, d := range f.Decls {
if _, ok := d.(*ast.FuncDecl); !ok { if _, ok := d.(*ast.FuncDecl); !ok {
if !cs.parseDecl(&cs.global, d) { ss, ok := cs.parseDecl(&cs.global, d)
if !ok {
return return
} }
cs.global.ir.Stmts = append(cs.global.ir.Stmts, ss...)
} }
} }
@ -209,9 +211,11 @@ func (cs *compileState) parse(f *ast.File) {
// Parse functions. // Parse functions.
for _, d := range f.Decls { for _, d := range f.Decls {
if _, ok := d.(*ast.FuncDecl); ok { if _, ok := d.(*ast.FuncDecl); ok {
if !cs.parseDecl(&cs.global, d) { ss, ok := cs.parseDecl(&cs.global, d)
if !ok {
return return
} }
cs.global.ir.Stmts = append(cs.global.ir.Stmts, ss...)
} }
} }
@ -224,7 +228,9 @@ func (cs *compileState) parse(f *ast.File) {
} }
} }
func (cs *compileState) parseDecl(b *block, d ast.Decl) bool { func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) {
var stmts []shaderir.Stmt
switch d := d.(type) { switch d := d.(type) {
case *ast.GenDecl: case *ast.GenDecl:
switch d.Tok { switch d.Tok {
@ -247,11 +253,11 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool {
case token.VAR: case token.VAR:
for _, s := range d.Specs { for _, s := range d.Specs {
s := s.(*ast.ValueSpec) s := s.(*ast.ValueSpec)
vs, inits, stmts, ok := cs.parseVariable(b, s) vs, inits, ss, ok := cs.parseVariable(b, s)
if !ok { if !ok {
return false return nil, false
} }
b.ir.Stmts = append(b.ir.Stmts, stmts...) stmts = append(stmts, ss...)
if b == &cs.global { if b == &cs.global {
// TODO: Should rhs be ignored? // TODO: Should rhs be ignored?
for i, v := range vs { for i, v := range vs {
@ -271,7 +277,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool {
if len(inits) > 0 { if len(inits) > 0 {
for i := range vs { for i := range vs {
b.ir.Stmts = append(b.ir.Stmts, shaderir.Stmt{ stmts = append(stmts, shaderir.Stmt{
Type: shaderir.Assign, Type: shaderir.Assign,
Exprs: []shaderir.Expr{ Exprs: []shaderir.Expr{
{ {
@ -292,11 +298,11 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool {
case *ast.FuncDecl: case *ast.FuncDecl:
f, ok := cs.parseFunc(b, d) f, ok := cs.parseFunc(b, d)
if !ok { if !ok {
return false return nil, false
} }
if b != &cs.global { if b != &cs.global {
cs.addError(d.Pos(), "non-global function is not implemented") cs.addError(d.Pos(), "non-global function is not implemented")
return false return nil, false
} }
switch d.Name.Name { switch d.Name.Name {
case cs.vertexEntry: case cs.vertexEntry:
@ -316,10 +322,10 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) bool {
} }
default: default:
cs.addError(d.Pos(), "unexpected decl") cs.addError(d.Pos(), "unexpected decl")
return false return nil, false
} }
return true return stmts, true
} }
// functionReturnTypes returns the original returning value types, if the given expression is call. // functionReturnTypes returns the original returning value types, if the given expression is call.

View File

@ -107,9 +107,11 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
}, },
}) })
case *ast.DeclStmt: case *ast.DeclStmt:
if !cs.parseDecl(block, stmt.Decl) { ss, ok := cs.parseDecl(block, stmt.Decl)
if !ok {
return nil, false return nil, false
} }
stmts = append(stmts, ss...)
case *ast.IfStmt: case *ast.IfStmt:
if stmt.Init != nil { if stmt.Init != nil {
init := stmt.Init init := stmt.Init