diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index ebcc5375d..80fae51da 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -378,9 +378,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP }) case *ast.ReturnStmt: - if len(stmt.Results) != len(outParams) { + if len(stmt.Results) != len(outParams) && len(stmt.Results) != 1 { if !(len(stmt.Results) == 0 && len(outParams) > 0 && outParams[0].name != "") { - // TODO: Implenet multiple-value context. + // TODO: Check variable shadowings. + // https://golang.org/ref/spec#Return_statements cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results))) return nil, false } @@ -393,42 +394,51 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } stmts = append(stmts, ss...) if len(exprs) == 0 { + if len(exprs) != len(outParams) { + cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results))) + } continue } if len(exprs) > 1 { - cs.addError(r.Pos(), "multiple-value context with return is not implemented yet") - continue - } - - t := ts[0] - expr := exprs[0] - if expr.Type == shaderir.NumberExpr { - switch outParams[i].typ.Main { - case shaderir.Int: - if !cs.forceToInt(stmt, &expr) { - return nil, false - } - t = shaderir.Type{Main: shaderir.Int} - case shaderir.Float: - t = shaderir.Type{Main: shaderir.Float} + if len(stmt.Results) > 1 || len(outParams) == 1 { + cs.addError(r.Pos(), "single-value context and multiple-value context cannot be mixed") + return nil, false + } + if len(exprs) != len(outParams) { + cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results))) } } - if !t.Equal(&outParams[i].typ) { - cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", &t, &outParams[i].typ)) - return nil, false - } + for j, t := range ts { + expr := exprs[j] + if expr.Type == shaderir.NumberExpr { + switch outParams[i+j].typ.Main { + case shaderir.Int: + if !cs.forceToInt(stmt, &expr) { + return nil, false + } + t = shaderir.Type{Main: shaderir.Int} + case shaderir.Float: + t = shaderir.Type{Main: shaderir.Float} + } + } - stmts = append(stmts, shaderir.Stmt{ - Type: shaderir.Assign, - Exprs: []shaderir.Expr{ - { - Type: shaderir.LocalVariable, - Index: len(inParams) + i, + if !t.Equal(&outParams[i+j].typ) { + cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", &t, &outParams[i].typ)) + return nil, false + } + + stmts = append(stmts, shaderir.Stmt{ + Type: shaderir.Assign, + Exprs: []shaderir.Expr{ + { + Type: shaderir.LocalVariable, + Index: len(inParams) + i + j, + }, + expr, }, - expr, - }, - }) + }) + } } stmts = append(stmts, shaderir.Stmt{ Type: shaderir.Return, diff --git a/internal/shader/testdata/out2.expected.vs b/internal/shader/testdata/out2.expected.vs new file mode 100644 index 000000000..a2b096998 --- /dev/null +++ b/internal/shader/testdata/out2.expected.vs @@ -0,0 +1,36 @@ +void F0(out float l0, out float l1[4], out vec4 l2); +void F1(out float l0, out float l1[4], out vec4 l2); + +void F0(out float l0, out float l1[4], out vec4 l2) { + l0 = float(0); + l1[0] = float(0); + l1[1] = float(0); + l1[2] = float(0); + l1[3] = float(0); + l2 = vec4(0); + return; +} + +void F1(out float l0, out float l1[4], out vec4 l2) { + float l3 = float(0); + float l4[4]; + l4[0] = float(0); + l4[1] = float(0); + l4[2] = float(0); + l4[3] = float(0); + vec4 l5 = vec4(0); + l0 = float(0); + l1[0] = float(0); + l1[1] = float(0); + l1[2] = float(0); + l1[3] = float(0); + l2 = vec4(0); + F0(l3, l4, l5); + l0 = l3; + l1[0] = l4[0]; + l1[1] = l4[1]; + l1[2] = l4[2]; + l1[3] = l4[3]; + l2 = l5; + return; +} diff --git a/internal/shader/testdata/out2.go b/internal/shader/testdata/out2.go new file mode 100644 index 000000000..ae0ce078c --- /dev/null +++ b/internal/shader/testdata/out2.go @@ -0,0 +1,9 @@ +package main + +func Foo() (a float, b [4]float, c vec4) { + return +} + +func Foo2() (a float, b [4]float, c vec4) { + return Foo() +}