internal/graphicsdriver/directx: reuse a compiled vertex shader if possible

D3DCompile can be a very slow function, and let's skip this if possible.
This commit is contained in:
Hajime Hoshi 2022-12-24 16:12:35 +09:00
parent 88fe3d1287
commit fda0b1cbcb
5 changed files with 128 additions and 54 deletions

View File

@ -2309,6 +2309,11 @@ type _ID3DBlob_Vtbl struct {
GetBufferSize uintptr GetBufferSize uintptr
} }
func (i *_ID3DBlob) AddRef() uint32 {
r, _, _ := syscall.Syscall(i.vtbl.AddRef, 1, uintptr(unsafe.Pointer(i)), 0, 0)
return uint32(r)
}
func (i *_ID3DBlob) GetBufferPointer() uintptr { func (i *_ID3DBlob) GetBufferPointer() uintptr {
r, _, _ := syscall.Syscall(i.vtbl.GetBufferPointer, 1, uintptr(unsafe.Pointer(i)), r, _, _ := syscall.Syscall(i.vtbl.GetBufferPointer, 1, uintptr(unsafe.Pointer(i)),
0, 0) 0, 0)

View File

@ -1214,8 +1214,8 @@ func (g *Graphics) MaxImageSize() int {
} }
func (g *Graphics) NewShader(program *shaderir.Program) (graphicsdriver.Shader, error) { func (g *Graphics) NewShader(program *shaderir.Program) (graphicsdriver.Shader, error) {
src, offsets := hlsl.Compile(program) vs, ps, offsets := hlsl.Compile(program)
vsh, psh, err := newShader([]byte(src), nil) vsh, psh, err := newShader(vs, ps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1742,7 +1742,14 @@ func (s *Shader) disposeImpl() {
s.pixelShader = nil s.pixelShader = nil
} }
if s.vertexShader != nil { if s.vertexShader != nil {
s.vertexShader.Release() count := s.vertexShader.Release()
if count == 0 {
for k, v := range vertexShaderCache {
if v == s.vertexShader {
delete(vertexShaderCache, k)
}
}
}
s.vertexShader = nil s.vertexShader = nil
} }
} }

View File

@ -376,7 +376,9 @@ func (p *pipelineStates) ensureRootSignature(device *_ID3D12Device) (rootSignatu
return p.rootSignature, nil return p.rootSignature, nil
} }
func newShader(source []byte, defs []_D3D_SHADER_MACRO) (vsh, psh *_ID3DBlob, ferr error) { var vertexShaderCache = map[string]*_ID3DBlob{}
func newShader(vs, ps string) (vsh, psh *_ID3DBlob, ferr error) {
var flag uint32 = uint32(_D3DCOMPILE_OPTIMIZATION_LEVEL3) var flag uint32 = uint32(_D3DCOMPILE_OPTIMIZATION_LEVEL3)
defer func() { defer func() {
@ -392,18 +394,28 @@ func newShader(source []byte, defs []_D3D_SHADER_MACRO) (vsh, psh *_ID3DBlob, fe
}() }()
var wg errgroup.Group var wg errgroup.Group
wg.Go(func() error {
v, err := _D3DCompile(source, "shader", defs, nil, "VSMain", "vs_5_0", flag, 0) // Vertex shaders are likely the same. If so, reuse the same _ID3DBlob.
if err != nil { if v, ok := vertexShaderCache[vs]; ok {
return fmt.Errorf("directx: D3DCompile for VSMain failed, original source: %s, %w", string(source), err) // Increment the reference count not to release this object unexpectedly.
} // The value will be removed when the count reached 0.
// See (*Shader).disposeImpl.
v.AddRef()
vsh = v vsh = v
return nil } else {
}) wg.Go(func() error {
v, err := _D3DCompile([]byte(vs), "shader", nil, nil, "VSMain", "vs_5_0", flag, 0)
if err != nil {
return fmt.Errorf("directx: D3DCompile for VSMain failed, original source: %s, %w", vs, err)
}
vsh = v
return nil
})
}
wg.Go(func() error { wg.Go(func() error {
p, err := _D3DCompile(source, "shader", defs, nil, "PSMain", "ps_5_0", flag, 0) p, err := _D3DCompile([]byte(ps), "shader", nil, nil, "PSMain", "ps_5_0", flag, 0)
if err != nil { if err != nil {
return fmt.Errorf("directx: D3DCompile for PSMain failed, original source: %s, %w", string(source), err) return fmt.Errorf("directx: D3DCompile for PSMain failed, original source: %s, %w", ps, err)
} }
psh = p psh = p
return nil return nil
@ -413,6 +425,8 @@ func newShader(source []byte, defs []_D3D_SHADER_MACRO) (vsh, psh *_ID3DBlob, fe
return nil, nil, err return nil, nil, err
} }
vertexShaderCache[vs] = vsh
return return
} }

View File

@ -194,8 +194,8 @@ func TestCompile(t *testing.T) {
} }
if tc.HLSL != nil { if tc.HLSL != nil {
h, _ := hlsl.Compile(s) vs, _, _ := hlsl.Compile(s)
if got, want := hlslNormalize(h), hlslNormalize(string(tc.HLSL)); got != want { if got, want := hlslNormalize(vs), hlslNormalize(string(tc.HLSL)); got != want {
compare(t, "HLSL", got, want) compare(t, "HLSL", got, want)
} }
} }

View File

@ -84,16 +84,16 @@ float4x4 float4x4FromScalar(float x) {
return float4x4(x, 0, 0, 0, 0, x, 0, 0, 0, 0, x, 0, 0, 0, 0, x); return float4x4(x, 0, 0, 0, 0, x, 0, 0, 0, 0, x, 0, 0, 0, 0, x);
}` }`
func Compile(p *shaderir.Program) (string, []int) { func Compile(p *shaderir.Program) (vertexShader, pixelShader string, offsets []int) {
offsets = calculateMemoryOffsets(p.Uniforms)
c := &compileContext{} c := &compileContext{}
var lines []string var lines []string
lines = append(lines, strings.Split(Prelude, "\n")...) lines = append(lines, strings.Split(Prelude, "\n")...)
lines = append(lines, "", "{{.Structs}}") lines = append(lines, "", "{{.Structs}}")
var offsets []int
if len(p.Uniforms) > 0 { if len(p.Uniforms) > 0 {
offsets = calculateMemoryOffsets(p.Uniforms)
lines = append(lines, "") lines = append(lines, "")
lines = append(lines, "cbuffer Uniforms : register(b0) {") lines = append(lines, "cbuffer Uniforms : register(b0) {")
for i, t := range p.Uniforms { for i, t := range p.Uniforms {
@ -111,6 +111,7 @@ func Compile(p *shaderir.Program) (string, []int) {
} }
lines = append(lines, "}") lines = append(lines, "}")
} }
if p.TextureCount > 0 { if p.TextureCount > 0 {
lines = append(lines, "") lines = append(lines, "")
for i := 0; i < p.TextureCount; i++ { for i := 0; i < p.TextureCount; i++ {
@ -119,61 +120,108 @@ func Compile(p *shaderir.Program) (string, []int) {
lines = append(lines, "SamplerState samp : register(s0);") lines = append(lines, "SamplerState samp : register(s0);")
} }
if len(p.Funcs) > 0 { vslines := make([]string, len(lines))
lines = append(lines, "") copy(vslines, lines)
pslines := make([]string, len(lines))
copy(pslines, lines)
var vsfuncs []*shaderir.Func
if p.VertexFunc.Block != nil {
vsfuncs = p.ReachableFuncsFromBlock(p.VertexFunc.Block)
} else {
// Use all the functions for testing.
vsfuncs = make([]*shaderir.Func, 0, len(p.Funcs))
for _, f := range p.Funcs { for _, f := range p.Funcs {
lines = append(lines, c.function(p, &f, true)...) f := f
vsfuncs = append(vsfuncs, &f)
} }
for _, f := range p.Funcs { }
if len(lines) > 0 && lines[len(lines)-1] != "" { if len(vsfuncs) > 0 {
lines = append(lines, "") vslines = append(vslines, "")
for _, f := range vsfuncs {
vslines = append(vslines, c.function(p, f, true)...)
}
for _, f := range vsfuncs {
if len(vslines) > 0 && vslines[len(vslines)-1] != "" {
vslines = append(vslines, "")
} }
lines = append(lines, c.function(p, &f, false)...) vslines = append(vslines, c.function(p, f, false)...)
} }
} }
if p.VertexFunc.Block != nil && len(p.VertexFunc.Block.Stmts) > 0 { if p.VertexFunc.Block != nil && len(p.VertexFunc.Block.Stmts) > 0 {
lines = append(lines, "") vslines = append(vslines, "")
lines = append(lines, "Varyings VSMain(float2 A0 : POSITION, float2 A1 : TEXCOORD, float4 A2 : COLOR) {") vslines = append(vslines, "Varyings VSMain(float2 A0 : POSITION, float2 A1 : TEXCOORD, float4 A2 : COLOR) {")
lines = append(lines, fmt.Sprintf("\tVaryings %s;", vsOut)) vslines = append(vslines, fmt.Sprintf("\tVaryings %s;", vsOut))
lines = append(lines, c.block(p, p.VertexFunc.Block, p.VertexFunc.Block, 0)...) vslines = append(vslines, c.block(p, p.VertexFunc.Block, p.VertexFunc.Block, 0)...)
if last := fmt.Sprintf("\treturn %s;", vsOut); lines[len(lines)-1] != last { if last := fmt.Sprintf("\treturn %s;", vsOut); vslines[len(vslines)-1] != last {
lines = append(lines, last) vslines = append(vslines, last)
} }
lines = append(lines, "}") vslines = append(vslines, "}")
} }
var psfuncs []*shaderir.Func
if p.FragmentFunc.Block != nil {
psfuncs = p.ReachableFuncsFromBlock(p.FragmentFunc.Block)
} else {
// Use all the functions for testing.
psfuncs = make([]*shaderir.Func, 0, len(p.Funcs))
for _, f := range p.Funcs {
f := f
psfuncs = append(psfuncs, &f)
}
}
if len(psfuncs) > 0 {
pslines = append(pslines, "")
for _, f := range psfuncs {
pslines = append(pslines, c.function(p, f, true)...)
}
for _, f := range psfuncs {
if len(pslines) > 0 && pslines[len(pslines)-1] != "" {
pslines = append(pslines, "")
}
pslines = append(pslines, c.function(p, f, false)...)
}
}
if p.FragmentFunc.Block != nil && len(p.FragmentFunc.Block.Stmts) > 0 { if p.FragmentFunc.Block != nil && len(p.FragmentFunc.Block.Stmts) > 0 {
lines = append(lines, "") pslines = append(pslines, "")
lines = append(lines, fmt.Sprintf("float4 PSMain(Varyings %s) : SV_TARGET {", vsOut)) pslines = append(pslines, fmt.Sprintf("float4 PSMain(Varyings %s) : SV_TARGET {", vsOut))
lines = append(lines, c.block(p, p.FragmentFunc.Block, p.FragmentFunc.Block, 0)...) pslines = append(pslines, c.block(p, p.FragmentFunc.Block, p.FragmentFunc.Block, 0)...)
lines = append(lines, "}") pslines = append(pslines, "}")
} }
ls := strings.Join(lines, "\n") vertexShader = strings.Join(vslines, "\n")
pixelShader = strings.Join(pslines, "\n")
// Struct types are determined after converting the program. // Struct types are determined after converting the program.
if len(c.structTypes) > 0 { shaders := []string{vertexShader, pixelShader}
var stlines []string for i, shader := range shaders {
for i, t := range c.structTypes { if len(c.structTypes) > 0 {
stlines = append(stlines, fmt.Sprintf("struct S%d {", i)) var stlines []string
for j, st := range t.Sub { for i, t := range c.structTypes {
stlines = append(stlines, fmt.Sprintf("\t%s;", c.varDecl(p, &st, fmt.Sprintf("M%d", j)))) stlines = append(stlines, fmt.Sprintf("struct S%d {", i))
for j, st := range t.Sub {
stlines = append(stlines, fmt.Sprintf("\t%s;", c.varDecl(p, &st, fmt.Sprintf("M%d", j))))
}
stlines = append(stlines, "};")
} }
stlines = append(stlines, "};") st := strings.Join(stlines, "\n")
shader = strings.ReplaceAll(shader, "{{.Structs}}", st)
} else {
shader = strings.ReplaceAll(shader, "{{.Structs}}", "")
} }
st := strings.Join(stlines, "\n")
ls = strings.ReplaceAll(ls, "{{.Structs}}", st) nls := regexp.MustCompile(`\n\n+`)
} else { shader = nls.ReplaceAllString(shader, "\n\n")
ls = strings.ReplaceAll(ls, "{{.Structs}}", "")
shader = strings.TrimSpace(shader) + "\n"
shaders[i] = shader
} }
nls := regexp.MustCompile(`\n\n+`) vertexShader = shaders[0]
ls = nls.ReplaceAllString(ls, "\n\n") pixelShader = shaders[1]
ls = strings.TrimSpace(ls) + "\n" return
return ls, offsets
} }
func (c *compileContext) typ(p *shaderir.Program, t *shaderir.Type) (string, string) { func (c *compileContext) typ(p *shaderir.Program, t *shaderir.Type) (string, string) {