diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 045b65172..46fb01667 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -191,7 +191,8 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) { case token.VAR: for _, s := range d.Specs { s := s.(*ast.ValueSpec) - vs, inits := cs.parseVariable(b, s) + vs, inits, stmts := cs.parseVariable(b, s) + b.ir.Stmts = append(b.ir.Stmts, stmts...) if b == &cs.global { // TODO: Should rhs be ignored? for i, v := range vs { @@ -246,7 +247,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) { } } -func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []*shaderir.Expr) { +func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []*shaderir.Expr, []shaderir.Stmt) { var t typ if vs.Type != nil { t = s.parseType(vs.Type) @@ -254,6 +255,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl var vars []variable var inits []*shaderir.Expr + var stmts []shaderir.Stmt for i, n := range vs.Names { var init ast.Expr if len(vs.Values) > 0 { @@ -270,12 +272,13 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl var expr *shaderir.Expr if init != nil { - e := s.parseExpr(block, init) + e, ss := s.parseExpr(block, init) expr = &e + stmts = append(stmts, ss...) } inits = append(inits, expr) } - return vars, inits + return vars, inits, stmts } func (s *compileState) parseConstant(vs *ast.ValueSpec) []constant { @@ -446,23 +449,30 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out v.typ = cs.detectType(block, l.Rhs[i]) block.vars = append(block.vars, v) block.ir.LocalVars = append(block.ir.LocalVars, v.typ.ir) + + // Prase RHS first for the order of the statements. + rhs, stmts := cs.parseExpr(block, l.Rhs[i]) + block.ir.Stmts = append(block.ir.Stmts, stmts...) + lhs, stmts := cs.parseExpr(block, l.Lhs[i]) + block.ir.Stmts = append(block.ir.Stmts, stmts...) + block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ - Type: shaderir.Assign, - Exprs: []shaderir.Expr{ - cs.parseExpr(block, l.Lhs[i]), - cs.parseExpr(block, l.Rhs[i]), - }, + Type: shaderir.Assign, + Exprs: []shaderir.Expr{lhs, rhs}, }) } case token.ASSIGN: // TODO: What about the statement `a,b = b,a?` for i := range l.Rhs { + // Prase RHS first for the order of the statements. + rhs, stmts := cs.parseExpr(block, l.Rhs[i]) + block.ir.Stmts = append(block.ir.Stmts, stmts...) + lhs, stmts := cs.parseExpr(block, l.Lhs[i]) + block.ir.Stmts = append(block.ir.Stmts, stmts...) + block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ - Type: shaderir.Assign, - Exprs: []shaderir.Expr{ - cs.parseExpr(block, l.Lhs[i]), - cs.parseExpr(block, l.Rhs[i]), - }, + Type: shaderir.Assign, + Exprs: []shaderir.Expr{lhs, rhs}, }) } } @@ -478,7 +488,8 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out cs.parseDecl(block, l.Decl) case *ast.ReturnStmt: for i, r := range l.Results { - e := cs.parseExpr(block, r) + e, stmts := cs.parseExpr(block, r) + block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ Type: shaderir.Assign, Exprs: []shaderir.Expr{ @@ -565,7 +576,7 @@ func (s *compileState) detectType(b *block, expr ast.Expr) typ { } } -func (cs *compileState) parseExpr(block *block, expr ast.Expr) shaderir.Expr { +func (cs *compileState) parseExpr(block *block, expr ast.Expr) (shaderir.Expr, []shaderir.Stmt) { switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -573,22 +584,22 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) shaderir.Expr { v, err := strconv.ParseInt(e.Value, 10, 32) if err != nil { cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) - return shaderir.Expr{} + return shaderir.Expr{}, nil } return shaderir.Expr{ Type: shaderir.IntExpr, Int: int32(v), - } + }, nil case token.FLOAT: v, err := strconv.ParseFloat(e.Value, 32) if err != nil { cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) - return shaderir.Expr{} + return shaderir.Expr{}, nil } return shaderir.Expr{ Type: shaderir.FloatExpr, Float: float32(v), - } + }, nil default: cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e)) } @@ -635,60 +646,76 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) shaderir.Expr { op = shaderir.OrOr default: cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) - return shaderir.Expr{} + return shaderir.Expr{}, nil } + + var stmts []shaderir.Stmt + + // Prase RHS first for the order of the statements. + rhs, ss := cs.parseExpr(block, e.Y) + stmts = append(stmts, ss...) + lhs, ss := cs.parseExpr(block, e.X) + stmts = append(stmts, ss...) + return shaderir.Expr{ - Type: shaderir.Binary, - Op: op, - Exprs: []shaderir.Expr{ - cs.parseExpr(block, e.X), - cs.parseExpr(block, e.Y), - }, - } + Type: shaderir.Binary, + Op: op, + Exprs: []shaderir.Expr{lhs, rhs}, + }, stmts case *ast.CallExpr: - exprs := []shaderir.Expr{ - cs.parseExpr(block, e.Fun), - } + var exprs []shaderir.Expr + var stmts []shaderir.Stmt + + // Parse the argument first for the order of the statements. for _, a := range e.Args { - e := cs.parseExpr(block, a) + e, ss := cs.parseExpr(block, a) // TODO: Convert integer literals to float literals if necessary. exprs = append(exprs, e) + stmts = append(stmts, ss...) } + + // TODO: When len(stmts) is not 0? + expr, ss := cs.parseExpr(block, e.Fun) + exprs = append([]shaderir.Expr{expr}, exprs...) + stmts = append(stmts, ss...) + + // TODO: Return statements to call the function separately. return shaderir.Expr{ Type: shaderir.Call, Exprs: exprs, - } + }, stmts case *ast.Ident: if i, ok := block.findLocalVariable(e.Name); ok { return shaderir.Expr{ Type: shaderir.LocalVariable, Index: i, - } + }, nil } if i, ok := cs.findUniformVariable(e.Name); ok { return shaderir.Expr{ Type: shaderir.UniformVariable, Index: i, - } + }, nil } if f, ok := shaderir.ParseBuiltinFunc(e.Name); ok { return shaderir.Expr{ Type: shaderir.BuiltinFuncExpr, BuiltinFunc: f, - } + }, nil } cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name)) case *ast.SelectorExpr: + expr, stmts := cs.parseExpr(block, e.X) return shaderir.Expr{ Type: shaderir.FieldSelector, Exprs: []shaderir.Expr{ - cs.parseExpr(block, e.X), + expr, { Type: shaderir.SwizzlingExpr, Swizzling: e.Sel.Name, }, }, - } + }, stmts case *ast.UnaryExpr: var op shaderir.Op switch e.Op { @@ -700,17 +727,16 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) shaderir.Expr { op = shaderir.NotOp default: cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) - return shaderir.Expr{} + return shaderir.Expr{}, nil } + expr, stmts := cs.parseExpr(block, e.X) return shaderir.Expr{ - Type: shaderir.Unary, - Op: op, - Exprs: []shaderir.Expr{ - cs.parseExpr(block, e.X), - }, - } + Type: shaderir.Unary, + Op: op, + Exprs: []shaderir.Expr{expr}, + }, stmts default: cs.addError(e.Pos(), fmt.Sprintf("expression not implemented: %#v", e)) } - return shaderir.Expr{} + return shaderir.Expr{}, nil }