diff --git a/internal/shader/shader.go b/internal/shader/shader.go index da27f4316..365ef636f 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -608,47 +608,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out return nil, false } - // TODO: Reduce calls of parseExpr - - var rhsTypes []shaderir.Type - for i, e := range l.Lhs { - v := variable{ - name: e.(*ast.Ident).Name, - } - if len(l.Lhs) == len(l.Rhs) { - ts, ok := cs.functionReturnTypes(block, l.Rhs[i]) - if !ok { - _, ts, _, ok = cs.parseExpr(block, l.Rhs[i]) - if !ok { - return nil, false - } - } - if len(ts) > 1 { - cs.addError(l.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) - } - if len(ts) == 1 { - v.typ = ts[0] - } - } else { - if i == 0 { - var ok bool - rhsTypes, ok = cs.functionReturnTypes(block, l.Rhs[0]) - if !ok { - _, rhsTypes, _, ok = cs.parseExpr(block, l.Rhs[0]) - if !ok { - return nil, false - } - } - if len(rhsTypes) != len(l.Lhs) { - cs.addError(l.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) - } - } - v.typ = rhsTypes[i] - } - block.vars = append(block.vars, v) - } - - if !cs.assign(block, l.Pos(), l.Lhs, l.Rhs) { + if !cs.assign(block, l.Pos(), l.Lhs, l.Rhs, true) { return nil, false } case token.ASSIGN: @@ -657,7 +617,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out cs.addError(l.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) return nil, false } - if !cs.assign(block, l.Pos(), l.Lhs, l.Rhs) { + if !cs.assign(block, l.Pos(), l.Lhs, l.Rhs, false) { return nil, false } case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN: @@ -768,21 +728,43 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out return block, true } -func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr) bool { +func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr, define bool) bool { var rhsExprs []shaderir.Expr - for i := range lhs { + var rhsTypes []shaderir.Type + + for i, e := range lhs { // Prase RHS first for the order of the statements. if len(lhs) == len(rhs) { - rhs, _, stmts, ok := cs.parseExpr(block, rhs[i]) + r, origts, stmts, ok := cs.parseExpr(block, rhs[i]) if !ok { return false } - if len(rhs) > 1 { - cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) - } block.ir.Stmts = append(block.ir.Stmts, stmts...) - lhs, _, stmts, ok := cs.parseExpr(block, lhs[i]) + if define { + v := variable{ + name: e.(*ast.Ident).Name, + } + ts, ok := cs.functionReturnTypes(block, rhs[i]) + if !ok { + ts = origts + } + if len(ts) > 1 { + cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) + return false + } + if len(ts) == 1 { + v.typ = ts[0] + } + block.vars = append(block.vars, v) + } + + if len(r) > 1 { + cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) + return false + } + + l, _, stmts, ok := cs.parseExpr(block, lhs[i]) if !ok { return false } @@ -790,13 +772,13 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr) block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ Type: shaderir.Assign, - Exprs: []shaderir.Expr{lhs[0], rhs[0]}, + Exprs: []shaderir.Expr{l[0], r[0]}, }) } else { if i == 0 { var stmts []shaderir.Stmt var ok bool - rhsExprs, _, stmts, ok = cs.parseExpr(block, rhs[0]) + rhsExprs, rhsTypes, stmts, ok = cs.parseExpr(block, rhs[0]) if !ok { return false } @@ -806,7 +788,15 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr) block.ir.Stmts = append(block.ir.Stmts, stmts...) } - lhs, _, stmts, ok := cs.parseExpr(block, lhs[i]) + if define { + v := variable{ + name: e.(*ast.Ident).Name, + } + v.typ = rhsTypes[i] + block.vars = append(block.vars, v) + } + + l, _, stmts, ok := cs.parseExpr(block, lhs[i]) if !ok { return false } @@ -814,7 +804,7 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr) block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ Type: shaderir.Assign, - Exprs: []shaderir.Expr{lhs[0], rhsExprs[i]}, + Exprs: []shaderir.Expr{l[0], rhsExprs[i]}, }) } } diff --git a/internal/shader/testdata/define2.expected.vs b/internal/shader/testdata/define2.expected.vs index 4fb5c8e2a..e1f67f63a 100644 --- a/internal/shader/testdata/define2.expected.vs +++ b/internal/shader/testdata/define2.expected.vs @@ -3,12 +3,10 @@ void F0(out vec2 l0) { vec2 l2 = vec2(0); vec2 l3 = vec2(0); vec2 l4 = vec2(0); - vec2 l5 = vec2(0); - vec2 l6 = vec2(0); + F1(l1); + l2 = (1.0) * (l1); F1(l3); - l2 = (1.0) * (l3); - F1(l6); - l5 = (l6) * (1.0); + l4 = (l3) * (1.0); l0 = l2; return; } diff --git a/internal/shader/testdata/define_multiple.expected.vs b/internal/shader/testdata/define_multiple.expected.vs index dfeb5c92b..6676720bb 100644 --- a/internal/shader/testdata/define_multiple.expected.vs +++ b/internal/shader/testdata/define_multiple.expected.vs @@ -3,10 +3,10 @@ void F0(in vec2 l0, out vec4 l1) { float l3 = float(0); float l4 = float(0); float l5 = float(0); - F1(l4, l5); - l2 = l4; - l3 = l5; - l1 = vec4(l0, l2, l2); + F1(l2, l3); + l4 = l2; + l5 = l3; + l1 = vec4(l0, l4, l4); return; }