diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 918413227..986964e9f 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -724,6 +724,15 @@ func (cs *compileState) parseBlock(outer *block, fname string, stmts []ast.Stmt, } }() + if outer.outer == nil && len(outParams) > 0 && outParams[0].name != "" { + for i := range outParams { + block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ + Type: shaderir.Init, + InitIndex: len(inParams) + i, + }) + } + } + for _, stmt := range stmts { ss, ok := cs.parseStmt(block, fname, stmt, inParams, outParams) if !ok { diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 0e050da5e..ebcc5375d 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -379,9 +379,11 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP case *ast.ReturnStmt: if len(stmt.Results) != len(outParams) { - // TODO: Implenet multiple-value context. - cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results))) - return nil, false + if !(len(stmt.Results) == 0 && len(outParams) > 0 && outParams[0].name != "") { + // TODO: Implenet multiple-value context. + cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results))) + return nil, false + } } for i, r := range stmt.Results { diff --git a/internal/shader/testdata/out.expected.metal b/internal/shader/testdata/out.expected.metal new file mode 100644 index 000000000..70c3652ff --- /dev/null +++ b/internal/shader/testdata/out.expected.metal @@ -0,0 +1,8 @@ +void F0(thread float& l0, thread array& l1, thread float4& l2); + +void F0(thread float& l0, thread array& l1, thread float4& l2) { + l0 = float(0); + l1 = {}; + l2 = float4(0); + return; +} diff --git a/internal/shader/testdata/out.expected.vs b/internal/shader/testdata/out.expected.vs new file mode 100644 index 000000000..2a13481ba --- /dev/null +++ b/internal/shader/testdata/out.expected.vs @@ -0,0 +1,11 @@ +void F0(out float l0, out float l1[4], out vec4 l2); + +void F0(out float l0, out float l1[4], out vec4 l2) { + l0 = float(0); + l1[0] = float(0); + l1[1] = float(0); + l1[2] = float(0); + l1[3] = float(0); + l2 = vec4(0); + return; +} diff --git a/internal/shader/testdata/out.go b/internal/shader/testdata/out.go new file mode 100644 index 000000000..1e8a8ba85 --- /dev/null +++ b/internal/shader/testdata/out.go @@ -0,0 +1,5 @@ +package main + +func Foo() (a float, b [4]float, c vec4) { + return +} diff --git a/internal/shader/testdata/return.expected.vs b/internal/shader/testdata/return.expected.vs new file mode 100644 index 000000000..10d6a37bf --- /dev/null +++ b/internal/shader/testdata/return.expected.vs @@ -0,0 +1,8 @@ +void F0(in vec2 l0, out vec3 l1, out vec4 l2); + +void F0(in vec2 l0, out vec3 l1, out vec4 l2) { + l1 = vec3(0); + l2 = vec4(0); + l1 = vec3(1.0); + return; +} diff --git a/internal/shader/testdata/return.go b/internal/shader/testdata/return.go new file mode 100644 index 000000000..4a78cc926 --- /dev/null +++ b/internal/shader/testdata/return.go @@ -0,0 +1,6 @@ +package main + +func Foo(a vec2) (b vec3, c vec4) { + b = vec3(1) + return +} diff --git a/internal/shader/testdata/vertex.expected.vs b/internal/shader/testdata/vertex.expected.vs index da841bca3..7ef00ba3c 100644 --- a/internal/shader/testdata/vertex.expected.vs +++ b/internal/shader/testdata/vertex.expected.vs @@ -7,6 +7,9 @@ varying vec4 V1; void main(void) { mat4 l0 = mat4(0); + gl_Position = vec4(0); + V0 = vec2(0); + V1 = vec4(0); l0 = mat4((2.0) / ((U0).x), 0.0, 0.0, 0.0, 0.0, (2.0) / ((U0).y), 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0); gl_Position = (l0) * (vec4(A0, 0.0, 1.0)); V0 = A1; diff --git a/internal/shader/testdata/vertex_fragment.expected.vs b/internal/shader/testdata/vertex_fragment.expected.vs index da841bca3..7ef00ba3c 100644 --- a/internal/shader/testdata/vertex_fragment.expected.vs +++ b/internal/shader/testdata/vertex_fragment.expected.vs @@ -7,6 +7,9 @@ varying vec4 V1; void main(void) { mat4 l0 = mat4(0); + gl_Position = vec4(0); + V0 = vec2(0); + V1 = vec4(0); l0 = mat4((2.0) / ((U0).x), 0.0, 0.0, 0.0, 0.0, (2.0) / ((U0).y), 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0); gl_Position = (l0) * (vec4(A0, 0.0, 1.0)); V0 = A1; diff --git a/internal/shaderir/glsl/glsl.go b/internal/shaderir/glsl/glsl.go index bd779d33e..2a998c318 100644 --- a/internal/shaderir/glsl/glsl.go +++ b/internal/shaderir/glsl/glsl.go @@ -414,6 +414,8 @@ func (c *compileContext) glslBlock(p *shaderir.Program, topBlock, block *shaderi } } lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, glslExpr(&lhs), glslExpr(&rhs))) + case shaderir.Init: + lines = append(lines, c.initVariable(p, topBlock, block, s.InitIndex, false, level)...) case shaderir.If: lines = append(lines, fmt.Sprintf("%sif (%s) {", idt, glslExpr(&s.Exprs[0]))) lines = append(lines, c.glslBlock(p, topBlock, s.Blocks[0], level+1)...) diff --git a/internal/shaderir/metal/metal.go b/internal/shaderir/metal/metal.go index 46cf210dd..69d21826b 100644 --- a/internal/shaderir/metal/metal.go +++ b/internal/shaderir/metal/metal.go @@ -313,10 +313,10 @@ func localVariableName(p *shaderir.Program, topBlock *shaderir.Block, idx int) s func (c *compileContext) initVariable(p *shaderir.Program, topBlock, block *shaderir.Block, index int, decl bool, level int) []string { idt := strings.Repeat("\t", level+1) + name := localVariableName(p, topBlock, index) t := p.LocalVariableType(topBlock, block, index) var lines []string - name := localVariableName(p, topBlock, index) if decl { lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, c.metalVarDecl(p, &t, name, false, false), c.metalVarInit(p, &t))) } else { @@ -412,6 +412,20 @@ func (c *compileContext) metalBlock(p *shaderir.Program, topBlock, block *shader lines = append(lines, idt+"}") case shaderir.Assign: lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, metalExpr(&s.Exprs[0]), metalExpr(&s.Exprs[1]))) + case shaderir.Init: + init := true + if topBlock == p.VertexFunc.Block { + // In the vertex function, varying values are the output parameters. + // These values are represented as a struct and not needed to be initialized. + na := len(p.Attributes) + nv := len(p.Varyings) + if s.InitIndex < na+nv+1 { + init = false + } + } + if init { + lines = append(lines, c.initVariable(p, topBlock, block, s.InitIndex, false, level)...) + } case shaderir.If: lines = append(lines, fmt.Sprintf("%sif (%s) {", idt, metalExpr(&s.Exprs[0]))) lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[0], level+1)...) diff --git a/internal/shaderir/program.go b/internal/shaderir/program.go index b30867385..8fe9ec241 100644 --- a/internal/shaderir/program.go +++ b/internal/shaderir/program.go @@ -73,6 +73,7 @@ type Stmt struct { ForEnd constant.Value ForOp Op ForDelta constant.Value + InitIndex int } type StmtType int @@ -81,6 +82,7 @@ const ( ExprStmt StmtType = iota BlockStmt Assign + Init If For Continue