internal/graphicsdriver/directx: reuse a compiled vertex shader if possible

D3DCompile can be a very slow function, and let's skip this if possible.
This commit is contained in:
Hajime Hoshi 2022-12-24 16:12:35 +09:00
parent 88fe3d1287
commit fda0b1cbcb
5 changed files with 128 additions and 54 deletions

View File

@ -2309,6 +2309,11 @@ type _ID3DBlob_Vtbl struct {
GetBufferSize uintptr
}
func (i *_ID3DBlob) AddRef() uint32 {
r, _, _ := syscall.Syscall(i.vtbl.AddRef, 1, uintptr(unsafe.Pointer(i)), 0, 0)
return uint32(r)
}
func (i *_ID3DBlob) GetBufferPointer() uintptr {
r, _, _ := syscall.Syscall(i.vtbl.GetBufferPointer, 1, uintptr(unsafe.Pointer(i)),
0, 0)

View File

@ -1214,8 +1214,8 @@ func (g *Graphics) MaxImageSize() int {
}
func (g *Graphics) NewShader(program *shaderir.Program) (graphicsdriver.Shader, error) {
src, offsets := hlsl.Compile(program)
vsh, psh, err := newShader([]byte(src), nil)
vs, ps, offsets := hlsl.Compile(program)
vsh, psh, err := newShader(vs, ps)
if err != nil {
return nil, err
}
@ -1742,7 +1742,14 @@ func (s *Shader) disposeImpl() {
s.pixelShader = nil
}
if s.vertexShader != nil {
s.vertexShader.Release()
count := s.vertexShader.Release()
if count == 0 {
for k, v := range vertexShaderCache {
if v == s.vertexShader {
delete(vertexShaderCache, k)
}
}
}
s.vertexShader = nil
}
}

View File

@ -376,7 +376,9 @@ func (p *pipelineStates) ensureRootSignature(device *_ID3D12Device) (rootSignatu
return p.rootSignature, nil
}
func newShader(source []byte, defs []_D3D_SHADER_MACRO) (vsh, psh *_ID3DBlob, ferr error) {
var vertexShaderCache = map[string]*_ID3DBlob{}
func newShader(vs, ps string) (vsh, psh *_ID3DBlob, ferr error) {
var flag uint32 = uint32(_D3DCOMPILE_OPTIMIZATION_LEVEL3)
defer func() {
@ -392,18 +394,28 @@ func newShader(source []byte, defs []_D3D_SHADER_MACRO) (vsh, psh *_ID3DBlob, fe
}()
var wg errgroup.Group
wg.Go(func() error {
v, err := _D3DCompile(source, "shader", defs, nil, "VSMain", "vs_5_0", flag, 0)
if err != nil {
return fmt.Errorf("directx: D3DCompile for VSMain failed, original source: %s, %w", string(source), err)
}
// Vertex shaders are likely the same. If so, reuse the same _ID3DBlob.
if v, ok := vertexShaderCache[vs]; ok {
// Increment the reference count not to release this object unexpectedly.
// The value will be removed when the count reached 0.
// See (*Shader).disposeImpl.
v.AddRef()
vsh = v
return nil
})
} else {
wg.Go(func() error {
v, err := _D3DCompile([]byte(vs), "shader", nil, nil, "VSMain", "vs_5_0", flag, 0)
if err != nil {
return fmt.Errorf("directx: D3DCompile for VSMain failed, original source: %s, %w", vs, err)
}
vsh = v
return nil
})
}
wg.Go(func() error {
p, err := _D3DCompile(source, "shader", defs, nil, "PSMain", "ps_5_0", flag, 0)
p, err := _D3DCompile([]byte(ps), "shader", nil, nil, "PSMain", "ps_5_0", flag, 0)
if err != nil {
return fmt.Errorf("directx: D3DCompile for PSMain failed, original source: %s, %w", string(source), err)
return fmt.Errorf("directx: D3DCompile for PSMain failed, original source: %s, %w", ps, err)
}
psh = p
return nil
@ -413,6 +425,8 @@ func newShader(source []byte, defs []_D3D_SHADER_MACRO) (vsh, psh *_ID3DBlob, fe
return nil, nil, err
}
vertexShaderCache[vs] = vsh
return
}

View File

@ -194,8 +194,8 @@ func TestCompile(t *testing.T) {
}
if tc.HLSL != nil {
h, _ := hlsl.Compile(s)
if got, want := hlslNormalize(h), hlslNormalize(string(tc.HLSL)); got != want {
vs, _, _ := hlsl.Compile(s)
if got, want := hlslNormalize(vs), hlslNormalize(string(tc.HLSL)); got != want {
compare(t, "HLSL", got, want)
}
}

View File

@ -84,16 +84,16 @@ float4x4 float4x4FromScalar(float x) {
return float4x4(x, 0, 0, 0, 0, x, 0, 0, 0, 0, x, 0, 0, 0, 0, x);
}`
func Compile(p *shaderir.Program) (string, []int) {
func Compile(p *shaderir.Program) (vertexShader, pixelShader string, offsets []int) {
offsets = calculateMemoryOffsets(p.Uniforms)
c := &compileContext{}
var lines []string
lines = append(lines, strings.Split(Prelude, "\n")...)
lines = append(lines, "", "{{.Structs}}")
var offsets []int
if len(p.Uniforms) > 0 {
offsets = calculateMemoryOffsets(p.Uniforms)
lines = append(lines, "")
lines = append(lines, "cbuffer Uniforms : register(b0) {")
for i, t := range p.Uniforms {
@ -111,6 +111,7 @@ func Compile(p *shaderir.Program) (string, []int) {
}
lines = append(lines, "}")
}
if p.TextureCount > 0 {
lines = append(lines, "")
for i := 0; i < p.TextureCount; i++ {
@ -119,61 +120,108 @@ func Compile(p *shaderir.Program) (string, []int) {
lines = append(lines, "SamplerState samp : register(s0);")
}
if len(p.Funcs) > 0 {
lines = append(lines, "")
vslines := make([]string, len(lines))
copy(vslines, lines)
pslines := make([]string, len(lines))
copy(pslines, lines)
var vsfuncs []*shaderir.Func
if p.VertexFunc.Block != nil {
vsfuncs = p.ReachableFuncsFromBlock(p.VertexFunc.Block)
} else {
// Use all the functions for testing.
vsfuncs = make([]*shaderir.Func, 0, len(p.Funcs))
for _, f := range p.Funcs {
lines = append(lines, c.function(p, &f, true)...)
f := f
vsfuncs = append(vsfuncs, &f)
}
for _, f := range p.Funcs {
if len(lines) > 0 && lines[len(lines)-1] != "" {
lines = append(lines, "")
}
if len(vsfuncs) > 0 {
vslines = append(vslines, "")
for _, f := range vsfuncs {
vslines = append(vslines, c.function(p, f, true)...)
}
for _, f := range vsfuncs {
if len(vslines) > 0 && vslines[len(vslines)-1] != "" {
vslines = append(vslines, "")
}
lines = append(lines, c.function(p, &f, false)...)
vslines = append(vslines, c.function(p, f, false)...)
}
}
if p.VertexFunc.Block != nil && len(p.VertexFunc.Block.Stmts) > 0 {
lines = append(lines, "")
lines = append(lines, "Varyings VSMain(float2 A0 : POSITION, float2 A1 : TEXCOORD, float4 A2 : COLOR) {")
lines = append(lines, fmt.Sprintf("\tVaryings %s;", vsOut))
lines = append(lines, c.block(p, p.VertexFunc.Block, p.VertexFunc.Block, 0)...)
if last := fmt.Sprintf("\treturn %s;", vsOut); lines[len(lines)-1] != last {
lines = append(lines, last)
vslines = append(vslines, "")
vslines = append(vslines, "Varyings VSMain(float2 A0 : POSITION, float2 A1 : TEXCOORD, float4 A2 : COLOR) {")
vslines = append(vslines, fmt.Sprintf("\tVaryings %s;", vsOut))
vslines = append(vslines, c.block(p, p.VertexFunc.Block, p.VertexFunc.Block, 0)...)
if last := fmt.Sprintf("\treturn %s;", vsOut); vslines[len(vslines)-1] != last {
vslines = append(vslines, last)
}
lines = append(lines, "}")
vslines = append(vslines, "}")
}
var psfuncs []*shaderir.Func
if p.FragmentFunc.Block != nil {
psfuncs = p.ReachableFuncsFromBlock(p.FragmentFunc.Block)
} else {
// Use all the functions for testing.
psfuncs = make([]*shaderir.Func, 0, len(p.Funcs))
for _, f := range p.Funcs {
f := f
psfuncs = append(psfuncs, &f)
}
}
if len(psfuncs) > 0 {
pslines = append(pslines, "")
for _, f := range psfuncs {
pslines = append(pslines, c.function(p, f, true)...)
}
for _, f := range psfuncs {
if len(pslines) > 0 && pslines[len(pslines)-1] != "" {
pslines = append(pslines, "")
}
pslines = append(pslines, c.function(p, f, false)...)
}
}
if p.FragmentFunc.Block != nil && len(p.FragmentFunc.Block.Stmts) > 0 {
lines = append(lines, "")
lines = append(lines, fmt.Sprintf("float4 PSMain(Varyings %s) : SV_TARGET {", vsOut))
lines = append(lines, c.block(p, p.FragmentFunc.Block, p.FragmentFunc.Block, 0)...)
lines = append(lines, "}")
pslines = append(pslines, "")
pslines = append(pslines, fmt.Sprintf("float4 PSMain(Varyings %s) : SV_TARGET {", vsOut))
pslines = append(pslines, c.block(p, p.FragmentFunc.Block, p.FragmentFunc.Block, 0)...)
pslines = append(pslines, "}")
}
ls := strings.Join(lines, "\n")
vertexShader = strings.Join(vslines, "\n")
pixelShader = strings.Join(pslines, "\n")
// Struct types are determined after converting the program.
if len(c.structTypes) > 0 {
var stlines []string
for i, t := range c.structTypes {
stlines = append(stlines, fmt.Sprintf("struct S%d {", i))
for j, st := range t.Sub {
stlines = append(stlines, fmt.Sprintf("\t%s;", c.varDecl(p, &st, fmt.Sprintf("M%d", j))))
shaders := []string{vertexShader, pixelShader}
for i, shader := range shaders {
if len(c.structTypes) > 0 {
var stlines []string
for i, t := range c.structTypes {
stlines = append(stlines, fmt.Sprintf("struct S%d {", i))
for j, st := range t.Sub {
stlines = append(stlines, fmt.Sprintf("\t%s;", c.varDecl(p, &st, fmt.Sprintf("M%d", j))))
}
stlines = append(stlines, "};")
}
stlines = append(stlines, "};")
st := strings.Join(stlines, "\n")
shader = strings.ReplaceAll(shader, "{{.Structs}}", st)
} else {
shader = strings.ReplaceAll(shader, "{{.Structs}}", "")
}
st := strings.Join(stlines, "\n")
ls = strings.ReplaceAll(ls, "{{.Structs}}", st)
} else {
ls = strings.ReplaceAll(ls, "{{.Structs}}", "")
nls := regexp.MustCompile(`\n\n+`)
shader = nls.ReplaceAllString(shader, "\n\n")
shader = strings.TrimSpace(shader) + "\n"
shaders[i] = shader
}
nls := regexp.MustCompile(`\n\n+`)
ls = nls.ReplaceAllString(ls, "\n\n")
vertexShader = shaders[0]
pixelShader = shaders[1]
ls = strings.TrimSpace(ls) + "\n"
return ls, offsets
return
}
func (c *compileContext) typ(p *shaderir.Program, t *shaderir.Type) (string, string) {