diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 365ef636f..7c1ad1992 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -365,44 +365,70 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl stmts []shaderir.Stmt ) - for i, n := range vs.Names { - // TODO: Reduce calls of parseExpr + // These variables are used only in multiple-value context. + var inittypes []shaderir.Type + var initexprs []shaderir.Expr - var init ast.Expr + for i, n := range vs.Names { t := declt - switch len(vs.Values) { - case 0: - case 1: - init = vs.Values[0] - if t.Main == shaderir.None { - ts, ok := s.functionReturnTypes(block, init) - if !ok { - _, ts, _, ok = s.parseExpr(block, init) - if !ok { - return nil, nil, nil, false - } - } - if len(ts) != len(vs.Names) { - s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) - continue - } - t = ts[i] + switch { + case len(vs.Values) == 0: + // No initialization + + case len(vs.Names) == len(vs.Values): + // Single-value context + + init := vs.Values[i] + + es, origts, ss, ok := s.parseExpr(block, init) + if !ok { + return nil, nil, nil, false } - default: - init = vs.Values[i] + inits = append(inits, es...) + stmts = append(stmts, ss...) + if t.Main == shaderir.None { ts, ok := s.functionReturnTypes(block, init) if !ok { - _, ts, _, ok = s.parseExpr(block, init) - if !ok { - return nil, nil, nil, false - } + ts = origts } if len(ts) > 1 { s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) } t = ts[0] } + + default: + // Multiple-value context + + if i == 0 { + init := vs.Values[0] + + var ss []shaderir.Stmt + var ok bool + initexprs, inittypes, ss, ok = s.parseExpr(block, init) + if !ok { + return nil, nil, nil, false + } + stmts = append(stmts, ss...) + + if t.Main == shaderir.None { + ts, ok := s.functionReturnTypes(block, init) + if ok { + inittypes = ts + } + if len(ts) != len(vs.Names) { + s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) + continue + } + } + } + if len(inittypes) > 0 { + t = inittypes[i] + } + + // Add the same initexprs for each variable. + inits = append(inits, initexprs...) } name := n.Name @@ -410,20 +436,6 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl name: name, typ: t, }) - - if len(vs.Values) > 1 || (len(vs.Values) == 1 && len(inits) == 0) { - es, _, ss, ok := s.parseExpr(block, init) - if !ok { - return nil, nil, nil, false - } - inits = append(inits, es...) - stmts = append(stmts, ss...) - } - } - - if len(inits) > 0 && len(vars) != len(inits) { - s.addError(vs.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) - return nil, nil, nil, false } return vars, inits, stmts, true diff --git a/internal/shader/testdata/number.expected.vs b/internal/shader/testdata/number.expected.vs new file mode 100644 index 000000000..29c5561e6 --- /dev/null +++ b/internal/shader/testdata/number.expected.vs @@ -0,0 +1,43 @@ +void F0(out vec2 l0) { + float l1 = float(0); + int l2 = 0; + float l3 = float(0); + int l4 = 0; + float l5 = float(0); + float l6 = float(0); + F2(l1); + F3(l2); + l3 = (l1) * (l2); + F3(l4); + F2(l5); + l6 = (l4) * (l5); + l0 = vec2(l3, l6); + return; +} + +void F1(out vec2 l0) { + float l1 = float(0); + int l2 = 0; + float l3 = float(0); + int l4 = 0; + float l5 = float(0); + float l6 = float(0); + F2(l1); + F3(l2); + l3 = (l1) * (l2); + F3(l4); + F2(l5); + l6 = (l4) * (l5); + l0 = vec2(l3, l6); + return; +} + +void F2(out float l0) { + l0 = 1.0; + return; +} + +void F3(out int l0) { + l0 = 1.0; + return; +} diff --git a/internal/shader/testdata/number.go b/internal/shader/testdata/number.go new file mode 100644 index 000000000..c99279cd2 --- /dev/null +++ b/internal/shader/testdata/number.go @@ -0,0 +1,21 @@ +package main + +func Foo() vec2 { + x := Float() * Int() + y := Int() * Float() + return vec2(x, y) +} + +func Foo2() vec2 { + var x = Float() * Int() + var y = Int() * Float() + return vec2(x, y) +} + +func Float() float { + return 1.0 +} + +func Int() int { + return 1 +}