diff --git a/internal/shaderir/glsl/glsl.go b/internal/shaderir/glsl/glsl.go index 0d1ef1cd6..31b906e84 100644 --- a/internal/shaderir/glsl/glsl.go +++ b/internal/shaderir/glsl/glsl.go @@ -273,68 +273,35 @@ func constantToNumberLiteral(t shaderir.ConstType, v constant.Value) string { return fmt.Sprintf("?(unexpected literal: %s)", v) } -func descendantLocalVars(block, target *shaderir.Block) ([]shaderir.Type, bool) { - if block == target { - return block.LocalVars, true - } - - var ts []shaderir.Type - for _, s := range block.Stmts { - for _, b := range s.Blocks { - if ts2, found := descendantLocalVars(b, target); found { - n := b.LocalVarIndexOffset - block.LocalVarIndexOffset - ts = append(ts, block.LocalVars[:n]...) - ts = append(ts, ts2...) - return ts, true - } - } - } - return nil, false -} - -func localVariableType(p *shaderir.Program, topBlock, block *shaderir.Block, absidx int) shaderir.Type { - var ts []shaderir.Type - for _, f := range p.Funcs { - if f.Block == topBlock { - ts = append(f.InParams, f.OutParams...) - break - } - } - - ts2, _ := descendantLocalVars(topBlock, block) - ts = append(ts, ts2...) - return ts[absidx] -} - -func localVariable(p *shaderir.Program, topBlock, block *shaderir.Block, idx int) (string, shaderir.Type) { +func localVariableName(p *shaderir.Program, topBlock, block *shaderir.Block, idx int) string { switch topBlock { case p.VertexFunc.Block: na := len(p.Attributes) nv := len(p.Varyings) switch { case idx < na: - return fmt.Sprintf("A%d", idx), p.Attributes[idx] + return fmt.Sprintf("A%d", idx) case idx == na: - return "gl_Position", shaderir.Type{Main: shaderir.Vec4} + return "gl_Position" case idx < na+nv+1: - return fmt.Sprintf("V%d", idx-na-1), p.Varyings[idx-na-1] + return fmt.Sprintf("V%d", idx-na-1) default: - return fmt.Sprintf("l%d", idx-(na+nv+1)), localVariableType(p, topBlock, block, idx-(na+nv+1)) + return fmt.Sprintf("l%d", idx-(na+nv+1)) } case p.FragmentFunc.Block: nv := len(p.Varyings) switch { case idx == 0: - return "gl_FragCoord", shaderir.Type{Main: shaderir.Vec4} + return "gl_FragCoord" case idx < nv+1: - return fmt.Sprintf("V%d", idx-1), p.Varyings[idx-1] + return fmt.Sprintf("V%d", idx-1) case idx == nv+1: - return "gl_FragColor", shaderir.Type{Main: shaderir.Vec4} + return "gl_FragColor" default: - return fmt.Sprintf("l%d", idx-(nv+2)), localVariableType(p, topBlock, block, idx-(nv+2)) + return fmt.Sprintf("l%d", idx-(nv+2)) } default: - return fmt.Sprintf("l%d", idx), localVariableType(p, topBlock, block, idx) + return fmt.Sprintf("l%d", idx) } } @@ -347,7 +314,8 @@ func (c *compileContext) glslBlock(p *shaderir.Program, topBlock, block *shaderi var lines []string for i := range block.LocalVars { - name, t := localVariable(p, topBlock, block, block.LocalVarIndexOffset+i) + name := localVariableName(p, topBlock, block, block.LocalVarIndexOffset+i) + t := p.LocalVariableType(topBlock, block, block.LocalVarIndexOffset+i) switch t.Main { case shaderir.Array: lines = append(lines, fmt.Sprintf("%s%s;", idt, c.glslVarDecl(p, &t, name))) @@ -372,8 +340,7 @@ func (c *compileContext) glslBlock(p *shaderir.Program, topBlock, block *shaderi case shaderir.TextureVariable: return fmt.Sprintf("T%d", e.Index) case shaderir.LocalVariable: - n, _ := localVariable(p, topBlock, block, e.Index) - return n + return localVariableName(p, topBlock, block, e.Index) case shaderir.StructMember: return fmt.Sprintf("M%d", e.Index) case shaderir.BuiltinFuncExpr: @@ -426,7 +393,7 @@ func (c *compileContext) glslBlock(p *shaderir.Program, topBlock, block *shaderi lhs := s.Exprs[0] rhs := s.Exprs[1] if lhs.Type == shaderir.LocalVariable { - if _, t := localVariable(p, topBlock, block, lhs.Index); t.Main == shaderir.Array { + if t := p.LocalVariableType(topBlock, block, lhs.Index); t.Main == shaderir.Array { for i := 0; i < t.Length; i++ { lines = append(lines, fmt.Sprintf("%[1]s%[2]s[%[3]d] = %[4]s[%[3]d];", idt, glslExpr(&lhs), i, glslExpr(&rhs))) } @@ -451,7 +418,7 @@ func (c *compileContext) glslBlock(p *shaderir.Program, topBlock, block *shaderi ct = shaderir.ConstTypeFloat } - v, _ := localVariable(p, topBlock, block, s.ForVarIndex) + v := localVariableName(p, topBlock, block, s.ForVarIndex) var delta string switch val, _ := constant.Float64Val(s.ForDelta); val { case 0: diff --git a/internal/shaderir/metal/metal.go b/internal/shaderir/metal/metal.go index 75c970214..46cf210dd 100644 --- a/internal/shaderir/metal/metal.go +++ b/internal/shaderir/metal/metal.go @@ -311,6 +311,20 @@ func localVariableName(p *shaderir.Program, topBlock *shaderir.Block, idx int) s } } +func (c *compileContext) initVariable(p *shaderir.Program, topBlock, block *shaderir.Block, index int, decl bool, level int) []string { + idt := strings.Repeat("\t", level+1) + t := p.LocalVariableType(topBlock, block, index) + + var lines []string + name := localVariableName(p, topBlock, index) + if decl { + lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, c.metalVarDecl(p, &t, name, false, false), c.metalVarInit(p, &t))) + } else { + lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, name, c.metalVarInit(p, &t))) + } + return lines +} + func (c *compileContext) metalBlock(p *shaderir.Program, topBlock, block *shaderir.Block, level int) []string { if block == nil { return nil @@ -322,8 +336,7 @@ func (c *compileContext) metalBlock(p *shaderir.Program, topBlock, block *shader for i, t := range block.LocalVars { // The type is None e.g., when the variable is a for-loop counter. if t.Main != shaderir.None { - name := localVariableName(p, topBlock, block.LocalVarIndexOffset+i) - lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, c.metalVarDecl(p, &t, name, false, false), c.metalVarInit(p, &t))) + lines = append(lines, c.initVariable(p, topBlock, block, block.LocalVarIndexOffset+i, true, level)...) } } diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index a3255b4d7..05760d557 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -120,3 +120,69 @@ const ( Array Struct ) + +func descendantLocalVars(block, target *Block) ([]Type, bool) { + if block == target { + return block.LocalVars, true + } + + var ts []Type + for _, s := range block.Stmts { + for _, b := range s.Blocks { + if ts2, found := descendantLocalVars(b, target); found { + n := b.LocalVarIndexOffset - block.LocalVarIndexOffset + ts = append(ts, block.LocalVars[:n]...) + ts = append(ts, ts2...) + return ts, true + } + } + } + return nil, false +} + +func localVariableType(p *Program, topBlock, block *Block, absidx int) Type { + // TODO: Rename this function (truly-local variable?) + var ts []Type + for _, f := range p.Funcs { + if f.Block == topBlock { + ts = append(f.InParams, f.OutParams...) + break + } + } + + ts2, _ := descendantLocalVars(topBlock, block) + ts = append(ts, ts2...) + return ts[absidx] +} + +func (p *Program) LocalVariableType(topBlock, block *Block, idx int) Type { + switch topBlock { + case p.VertexFunc.Block: + na := len(p.Attributes) + nv := len(p.Varyings) + switch { + case idx < na: + return p.Attributes[idx] + case idx == na: + return Type{Main: Vec4} + case idx < na+nv+1: + return p.Varyings[idx-na-1] + default: + return localVariableType(p, topBlock, block, idx-(na+nv+1)) + } + case p.FragmentFunc.Block: + nv := len(p.Varyings) + switch { + case idx == 0: + return Type{Main: Vec4} + case idx < nv+1: + return p.Varyings[idx-1] + case idx == nv+1: + return Type{Main: Vec4} + default: + return localVariableType(p, topBlock, block, idx-(nv+2)) + } + default: + return localVariableType(p, topBlock, block, idx) + } +}