diff --git a/internal/graphicsdriver/directx/graphics12_windows.go b/internal/graphicsdriver/directx/graphics12_windows.go index ec8c795dd..114119b2b 100644 --- a/internal/graphicsdriver/directx/graphics12_windows.go +++ b/internal/graphicsdriver/directx/graphics12_windows.go @@ -1068,7 +1068,7 @@ func (g *graphics12) MaxImageSize() int { func (g *graphics12) NewShader(program *shaderir.Program) (graphicsdriver.Shader, error) { vs, ps, offsets := hlsl.Compile(program) - vsh, psh, err := newShader(vs, ps) + vsh, psh, err := compileShader(vs, ps) if err != nil { return nil, err } diff --git a/internal/graphicsdriver/directx/pipeline12_windows.go b/internal/graphicsdriver/directx/pipeline12_windows.go index f599a2eb6..b2aa5a31e 100644 --- a/internal/graphicsdriver/directx/pipeline12_windows.go +++ b/internal/graphicsdriver/directx/pipeline12_windows.go @@ -19,8 +19,6 @@ import ( "math" "unsafe" - "golang.org/x/sync/errgroup" - "github.com/hajimehoshi/ebiten/v2/internal/graphics" "github.com/hajimehoshi/ebiten/v2/internal/graphicsdriver" ) @@ -406,63 +404,6 @@ func (p *pipelineStates) ensureRootSignature(device *_ID3D12Device) (rootSignatu return p.rootSignature, nil } -var vertexShaderCache = map[string]*_ID3DBlob{} - -func newShader(vs, ps string) (vsh, psh *_ID3DBlob, ferr error) { - var flag uint32 = uint32(_D3DCOMPILE_OPTIMIZATION_LEVEL3) - - defer func() { - if ferr == nil { - return - } - if vsh != nil { - vsh.Release() - } - if psh != nil { - psh.Release() - } - }() - - var wg errgroup.Group - - // Vertex shaders are likely the same. If so, reuse the same _ID3DBlob. - if v, ok := vertexShaderCache[vs]; ok { - // Increment the reference count not to release this object unexpectedly. - // The value will be removed when the count reached 0. - // See (*Shader).disposeImpl. - v.AddRef() - vsh = v - } else { - defer func() { - if ferr == nil { - vertexShaderCache[vs] = vsh - } - }() - wg.Go(func() error { - v, err := _D3DCompile([]byte(vs), "shader", nil, nil, "VSMain", "vs_5_0", flag, 0) - if err != nil { - return fmt.Errorf("directx: D3DCompile for VSMain failed, original source: %s, %w", vs, err) - } - vsh = v - return nil - }) - } - wg.Go(func() error { - p, err := _D3DCompile([]byte(ps), "shader", nil, nil, "PSMain", "ps_5_0", flag, 0) - if err != nil { - return fmt.Errorf("directx: D3DCompile for PSMain failed, original source: %s, %w", ps, err) - } - psh = p - return nil - }) - - if err := wg.Wait(); err != nil { - return nil, nil, err - } - - return -} - func (p *pipelineStates) newPipelineState(device *_ID3D12Device, vsh, psh *_ID3DBlob, blend graphicsdriver.Blend, stencilMode stencilMode, screen bool) (state *_ID3D12PipelineState, ferr error) { rootSignature, err := p.ensureRootSignature(device) if err != nil { diff --git a/internal/graphicsdriver/directx/shader_windows.go b/internal/graphicsdriver/directx/shader_windows.go index 8623ed137..79d1d9dd2 100644 --- a/internal/graphicsdriver/directx/shader_windows.go +++ b/internal/graphicsdriver/directx/shader_windows.go @@ -17,6 +17,8 @@ package directx import ( "fmt" + "golang.org/x/sync/errgroup" + "github.com/hajimehoshi/ebiten/v2/internal/graphics" "github.com/hajimehoshi/ebiten/v2/internal/graphicsdriver" "github.com/hajimehoshi/ebiten/v2/internal/shaderir" @@ -46,6 +48,63 @@ type Shader struct { pipelineStates map[pipelineStateKey]*_ID3D12PipelineState } +var vertexShaderCache = map[string]*_ID3DBlob{} + +func compileShader(vs, ps string) (vsh, psh *_ID3DBlob, ferr error) { + var flag uint32 = uint32(_D3DCOMPILE_OPTIMIZATION_LEVEL3) + + defer func() { + if ferr == nil { + return + } + if vsh != nil { + vsh.Release() + } + if psh != nil { + psh.Release() + } + }() + + var wg errgroup.Group + + // Vertex shaders are likely the same. If so, reuse the same _ID3DBlob. + if v, ok := vertexShaderCache[vs]; ok { + // Increment the reference count not to release this object unexpectedly. + // The value will be removed when the count reached 0. + // See (*Shader).disposeImpl. + v.AddRef() + vsh = v + } else { + defer func() { + if ferr == nil { + vertexShaderCache[vs] = vsh + } + }() + wg.Go(func() error { + v, err := _D3DCompile([]byte(vs), "shader", nil, nil, "VSMain", "vs_5_0", flag, 0) + if err != nil { + return fmt.Errorf("directx: D3DCompile for VSMain failed, original source: %s, %w", vs, err) + } + vsh = v + return nil + }) + } + wg.Go(func() error { + p, err := _D3DCompile([]byte(ps), "shader", nil, nil, "PSMain", "ps_5_0", flag, 0) + if err != nil { + return fmt.Errorf("directx: D3DCompile for PSMain failed, original source: %s, %w", ps, err) + } + psh = p + return nil + }) + + if err := wg.Wait(); err != nil { + return nil, nil, err + } + + return +} + func (s *Shader) ID() graphicsdriver.ShaderID { return s.id }