internal/graphicsdriver/metal, internal/shaderir/msl: use one integrated struct for uniforms

Closes #3164
This commit is contained in:
Hajime Hoshi 2024-11-25 22:51:31 +09:00
parent 81bb5044ea
commit 1ea14d5076
4 changed files with 144 additions and 82 deletions

View File

@ -473,7 +473,7 @@ func (g *Graphics) flushRenderCommandEncoderIfNeeded() {
g.lastDst = nil g.lastDst = nil
} }
func (g *Graphics) draw(dst *Image, dstRegions []graphicsdriver.DstRegion, srcs [graphics.ShaderSrcImageCount]*Image, indexOffset int, shader *Shader, uniforms [][]uint32, blend graphicsdriver.Blend, fillRule graphicsdriver.FillRule) error { func (g *Graphics) draw(dst *Image, dstRegions []graphicsdriver.DstRegion, srcs [graphics.ShaderSrcImageCount]*Image, indexOffset int, shader *Shader, uniforms []uint32, blend graphicsdriver.Blend, fillRule graphicsdriver.FillRule) error {
// When preparing a stencil buffer, flush the current render command encoder // When preparing a stencil buffer, flush the current render command encoder
// to make sure the stencil buffer is cleared when loading. // to make sure the stencil buffer is cleared when loading.
// TODO: What about clearing the stencil buffer by vertices? // TODO: What about clearing the stencil buffer by vertices?
@ -527,12 +527,11 @@ func (g *Graphics) draw(dst *Image, dstRegions []graphicsdriver.DstRegion, srcs
}) })
g.rce.SetVertexBuffer(g.vb, 0, 0) g.rce.SetVertexBuffer(g.vb, 0, 0)
for i, u := range uniforms { if len(uniforms) > 0 {
if u == nil { uniforms := adjustUniformVariablesLayout(shader.ir.Uniforms, uniforms)
continue head := unsafe.SliceData(uniforms)
} g.rce.SetVertexBytes(unsafe.Pointer(head), unsafe.Sizeof(uniforms[0])*uintptr(len(uniforms)), 1)
g.rce.SetVertexBytes(unsafe.Pointer(&u[0]), unsafe.Sizeof(u[0])*uintptr(len(u)), i+1) g.rce.SetFragmentBytes(unsafe.Pointer(head), unsafe.Sizeof(uniforms[0])*uintptr(len(uniforms)), 0)
g.rce.SetFragmentBytes(unsafe.Pointer(&u[0]), unsafe.Sizeof(u[0])*uintptr(len(u)), i+1)
} }
for i, src := range srcs { for i, src := range srcs {
@ -627,67 +626,7 @@ func (g *Graphics) DrawTriangles(dstID graphicsdriver.ImageID, srcIDs [graphics.
srcs[i] = g.images[srcID] srcs[i] = g.images[srcID]
} }
uniformVars := make([][]uint32, len(g.shaders[shaderID].ir.Uniforms)) if err := g.draw(dst, dstRegions, srcs, indexOffset, g.shaders[shaderID], uniforms, blend, fillRule); err != nil {
// Set the additional uniform variables.
var idx int
for i, t := range g.shaders[shaderID].ir.Uniforms {
if i == graphics.ProjectionMatrixUniformVariableIndex {
// In Metal, the NDC's Y direction (upward) and the framebuffer's Y direction (downward) don't
// match. Then, the Y direction must be inverted.
// Invert the sign bits as float32 values.
uniforms[idx+1] ^= 1 << 31
uniforms[idx+5] ^= 1 << 31
uniforms[idx+9] ^= 1 << 31
uniforms[idx+13] ^= 1 << 31
}
n := t.Uint32Count()
switch t.Main {
case shaderir.Vec3, shaderir.IVec3:
// float3 requires 16-byte alignment (#2463).
v1 := make([]uint32, 4)
copy(v1[0:3], uniforms[idx:idx+3])
uniformVars[i] = v1
case shaderir.Mat3:
// float3x3 requires 16-byte alignment (#2036).
v1 := make([]uint32, 12)
copy(v1[0:3], uniforms[idx:idx+3])
copy(v1[4:7], uniforms[idx+3:idx+6])
copy(v1[8:11], uniforms[idx+6:idx+9])
uniformVars[i] = v1
case shaderir.Array:
switch t.Sub[0].Main {
case shaderir.Vec3, shaderir.IVec3:
v1 := make([]uint32, t.Length*4)
for j := 0; j < t.Length; j++ {
offset0 := j * 3
offset1 := j * 4
copy(v1[offset1:offset1+3], uniforms[idx+offset0:idx+offset0+3])
}
uniformVars[i] = v1
case shaderir.Mat3:
v1 := make([]uint32, t.Length*12)
for j := 0; j < t.Length; j++ {
offset0 := j * 9
offset1 := j * 12
copy(v1[offset1:offset1+3], uniforms[idx+offset0:idx+offset0+3])
copy(v1[offset1+4:offset1+7], uniforms[idx+offset0+3:idx+offset0+6])
copy(v1[offset1+8:offset1+11], uniforms[idx+offset0+6:idx+offset0+9])
}
uniformVars[i] = v1
default:
uniformVars[i] = uniforms[idx : idx+n]
}
default:
uniformVars[i] = uniforms[idx : idx+n]
}
idx += n
}
if err := g.draw(dst, dstRegions, srcs, indexOffset, g.shaders[shaderID], uniformVars, blend, fillRule); err != nil {
return err return err
} }
@ -922,3 +861,109 @@ func (i *Image) ensureStencil() {
} }
i.stencil = i.graphics.view.getMTLDevice().NewTextureWithDescriptor(td) i.stencil = i.graphics.view.getMTLDevice().NewTextureWithDescriptor(td)
} }
// adjustUniformVariablesLayout returns adjusted uniform variables to match the Metal's memory layout.
func adjustUniformVariablesLayout(uniformTypes []shaderir.Type, uniforms []uint32) []uint32 {
// Each type's alignment is defined by the specification.
// See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
var values []uint32
fillZerosToFitAlignment := func(values []uint32, align int) []uint32 {
if len(values) == 0 {
return values
}
n0 := len(values)
n1 := ((len(values)-1)/align + 1) * align
if n0 == n1 {
return values
}
return append(values, make([]uint32, n1-n0)...)
}
var idx int
for i, typ := range uniformTypes {
n := typ.Uint32Count()
switch typ.Main {
case shaderir.Float, shaderir.Int:
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Vec2, shaderir.IVec2:
values = fillZerosToFitAlignment(values, 2)
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Vec3, shaderir.IVec3:
values = fillZerosToFitAlignment(values, 4)
values = append(values, uniforms[idx:idx+n]...)
values = append(values, 0)
case shaderir.Vec4, shaderir.IVec4:
values = fillZerosToFitAlignment(values, 4)
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Mat2:
values = fillZerosToFitAlignment(values, 2)
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Mat3:
values = fillZerosToFitAlignment(values, 4)
values = append(values, uniforms[idx:idx+3]...)
values = append(values, 0)
values = append(values, uniforms[idx+3:idx+6]...)
values = append(values, 0)
values = append(values, uniforms[idx+6:idx+9]...)
values = append(values, 0)
case shaderir.Mat4:
values = fillZerosToFitAlignment(values, 4)
u := uniforms[idx : idx+16]
if i == graphics.ProjectionMatrixUniformVariableIndex {
// In Metal, the NDC's Y direction (upward) and the framebuffer's Y direction (downward) don't
// match. Then, the Y direction must be inverted.
// Invert the sign bits as float32 values.
values = append(values,
u[0], u[1]^uint32(1<<31), u[2], u[3],
u[4], u[5]^uint32(1<<31), u[6], u[7],
u[8], u[9]^uint32(1<<31), u[10], u[11],
u[12], u[13]^uint32(1<<31), u[14], u[15],
)
} else {
values = append(values, uniforms[idx:idx+n]...)
}
case shaderir.Array:
switch typ.Sub[0].Main {
case shaderir.Float, shaderir.Int:
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Vec2, shaderir.IVec2:
values = fillZerosToFitAlignment(values, 2)
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Vec3, shaderir.IVec3:
values = fillZerosToFitAlignment(values, 4)
for j := 0; j < typ.Length; j++ {
values = append(values, uniforms[idx+3*j:idx+3*(j+1)]...)
values = append(values, 0)
}
case shaderir.Vec4, shaderir.IVec4:
values = fillZerosToFitAlignment(values, 4)
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Mat2:
values = fillZerosToFitAlignment(values, 2)
values = append(values, uniforms[idx:idx+n]...)
case shaderir.Mat3:
values = fillZerosToFitAlignment(values, 4)
for j := 0; j < typ.Length; j++ {
values = append(values, uniforms[idx+9*j:idx+9*j+3]...)
values = append(values, 0)
values = append(values, uniforms[idx+9*j+3:idx+9*j+6]...)
values = append(values, 0)
values = append(values, uniforms[idx+9*j+6:idx+9*j+9]...)
values = append(values, 0)
}
case shaderir.Mat4:
values = fillZerosToFitAlignment(values, 4)
values = append(values, uniforms[idx:idx+n]...)
default:
panic(fmt.Sprintf("metal: not implemented type for uniform variables: %s", typ.String()))
}
default:
panic(fmt.Sprintf("metal: not implemented type for uniform variables: %s", typ.String()))
}
idx += n
}
return values
}

View File

@ -1,3 +1,7 @@
struct Uniforms {
float2 U0;
};
struct Attributes { struct Attributes {
float2 M0; float2 M0;
float2 M1; float2 M1;
@ -13,10 +17,10 @@ struct Varyings {
vertex Varyings Vertex( vertex Varyings Vertex(
uint vid [[vertex_id]], uint vid [[vertex_id]],
const device Attributes* attributes [[buffer(0)]], const device Attributes* attributes [[buffer(0)]],
constant float2& U0 [[buffer(1)]]) { constant Uniforms& uniforms [[buffer(1)]]) {
Varyings varyings = {}; Varyings varyings = {};
float4x4 l0 = float4x4(0); float4x4 l0 = float4x4(0);
l0 = float4x4((2.0) / ((U0).x), 0.0, 0.0, 0.0, 0.0, (2.0) / ((U0).y), 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0); l0 = float4x4((2.0) / ((uniforms.U0).x), 0.0, 0.0, 0.0, 0.0, (2.0) / ((uniforms.U0).y), 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0);
varyings.Position = (l0) * (float4(attributes[vid].M0, 0.0, 1.0)); varyings.Position = (l0) * (float4(attributes[vid].M0, 0.0, 1.0));
varyings.M0 = attributes[vid].M1; varyings.M0 = attributes[vid].M1;
varyings.M1 = attributes[vid].M2; varyings.M1 = attributes[vid].M2;

View File

@ -1,3 +1,7 @@
struct Uniforms {
float2 U0;
};
struct Attributes { struct Attributes {
float2 M0; float2 M0;
float2 M1; float2 M1;
@ -13,10 +17,10 @@ struct Varyings {
vertex Varyings Vertex( vertex Varyings Vertex(
uint vid [[vertex_id]], uint vid [[vertex_id]],
const device Attributes* attributes [[buffer(0)]], const device Attributes* attributes [[buffer(0)]],
constant float2& U0 [[buffer(1)]]) { constant Uniforms& uniforms [[buffer(1)]]) {
Varyings varyings = {}; Varyings varyings = {};
float4x4 l0 = float4x4(0); float4x4 l0 = float4x4(0);
l0 = float4x4((2.0) / ((U0).x), 0.0, 0.0, 0.0, 0.0, (2.0) / ((U0).y), 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0); l0 = float4x4((2.0) / ((uniforms.U0).x), 0.0, 0.0, 0.0, 0.0, (2.0) / ((uniforms.U0).y), 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0);
varyings.Position = (l0) * (float4(attributes[vid].M0, 0.0, 1.0)); varyings.Position = (l0) * (float4(attributes[vid].M0, 0.0, 1.0));
varyings.M0 = attributes[vid].M1; varyings.M0 = attributes[vid].M1;
varyings.M1 = attributes[vid].M2; varyings.M1 = attributes[vid].M2;
@ -25,6 +29,6 @@ vertex Varyings Vertex(
fragment float4 Fragment( fragment float4 Fragment(
Varyings varyings [[stage_in]], Varyings varyings [[stage_in]],
constant float2& U0 [[buffer(1)]]) { constant Uniforms& uniforms [[buffer(0)]]) {
return float4((varyings.Position).x, (varyings.M0).y, (varyings.M1).z, 1.0); return float4((varyings.Position).x, (varyings.M0).y, (varyings.M1).z, 1.0);
} }

View File

@ -79,6 +79,15 @@ func Compile(p *shaderir.Program) (shader string) {
lines = append(lines, strings.Split(Prelude(p.Unit), "\n")...) lines = append(lines, strings.Split(Prelude(p.Unit), "\n")...)
lines = append(lines, "", "{{.Structs}}") lines = append(lines, "", "{{.Structs}}")
if len(p.Uniforms) > 0 {
lines = append(lines, "")
lines = append(lines, "struct Uniforms {")
for i, u := range p.Uniforms {
lines = append(lines, fmt.Sprintf("\t%s;", c.varDecl(p, &u, fmt.Sprintf("U%d", i), false)))
}
lines = append(lines, "};")
}
if len(p.Attributes) > 0 { if len(p.Attributes) > 0 {
lines = append(lines, "") lines = append(lines, "")
lines = append(lines, "struct Attributes {") lines = append(lines, "struct Attributes {")
@ -117,9 +126,9 @@ func Compile(p *shaderir.Program) (shader string) {
fmt.Sprintf("vertex Varyings %s(", VertexName), fmt.Sprintf("vertex Varyings %s(", VertexName),
"\tuint vid [[vertex_id]],", "\tuint vid [[vertex_id]],",
"\tconst device Attributes* attributes [[buffer(0)]]") "\tconst device Attributes* attributes [[buffer(0)]]")
for i, u := range p.Uniforms { if len(p.Uniforms) > 0 {
lines[len(lines)-1] += "," lines[len(lines)-1] += ","
lines = append(lines, fmt.Sprintf("\tconstant %s [[buffer(%d)]]", c.varDecl(p, &u, fmt.Sprintf("U%d", i), true), i+1)) lines = append(lines, "\tconstant Uniforms& uniforms [[buffer(1)]]")
} }
for i := 0; i < p.TextureCount; i++ { for i := 0; i < p.TextureCount; i++ {
lines[len(lines)-1] += "," lines[len(lines)-1] += ","
@ -139,9 +148,9 @@ func Compile(p *shaderir.Program) (shader string) {
lines = append(lines, lines = append(lines,
fmt.Sprintf("fragment float4 %s(", FragmentName), fmt.Sprintf("fragment float4 %s(", FragmentName),
"\tVaryings varyings [[stage_in]]") "\tVaryings varyings [[stage_in]]")
for i, u := range p.Uniforms { if len(p.Uniforms) > 0 {
lines[len(lines)-1] += "," lines[len(lines)-1] += ","
lines = append(lines, fmt.Sprintf("\tconstant %s [[buffer(%d)]]", c.varDecl(p, &u, fmt.Sprintf("U%d", i), true), i+1)) lines = append(lines, "\tconstant Uniforms& uniforms [[buffer(0)]]")
} }
for i := 0; i < p.TextureCount; i++ { for i := 0; i < p.TextureCount; i++ {
lines[len(lines)-1] += "," lines[len(lines)-1] += ","
@ -229,8 +238,8 @@ func (c *compileContext) function(p *shaderir.Program, f *shaderir.Func, prototy
var args []string var args []string
// Uniform variables and texture variables. In Metal, non-const global variables are not available. // Uniform variables and texture variables. In Metal, non-const global variables are not available.
for i, u := range p.Uniforms { if len(p.Uniforms) > 0 {
args = append(args, "constant "+c.varDecl(p, &u, fmt.Sprintf("U%d", i), true)) args = append(args, "constant Uniforms& uniforms")
} }
for i := 0; i < p.TextureCount; i++ { for i := 0; i < p.TextureCount; i++ {
args = append(args, fmt.Sprintf("texture2d<float> T%d", i)) args = append(args, fmt.Sprintf("texture2d<float> T%d", i))
@ -350,7 +359,7 @@ func (c *compileContext) block(p *shaderir.Program, topBlock, block *shaderir.Bl
case shaderir.NumberExpr: case shaderir.NumberExpr:
return constantToNumberLiteral(e.Const) return constantToNumberLiteral(e.Const)
case shaderir.UniformVariable: case shaderir.UniformVariable:
return fmt.Sprintf("U%d", e.Index) return fmt.Sprintf("uniforms.U%d", e.Index)
case shaderir.TextureVariable: case shaderir.TextureVariable:
return fmt.Sprintf("T%d", e.Index) return fmt.Sprintf("T%d", e.Index)
case shaderir.LocalVariable: case shaderir.LocalVariable:
@ -389,8 +398,8 @@ func (c *compileContext) block(p *shaderir.Program, topBlock, block *shaderir.Bl
callee := e.Exprs[0] callee := e.Exprs[0]
var args []string var args []string
if callee.Type != shaderir.BuiltinFuncExpr { if callee.Type != shaderir.BuiltinFuncExpr {
for i := range p.Uniforms { if len(p.Uniforms) > 0 {
args = append(args, fmt.Sprintf("U%d", i)) args = append(args, "uniforms")
} }
for i := 0; i < p.TextureCount; i++ { for i := 0; i < p.TextureCount; i++ {
args = append(args, fmt.Sprintf("T%d", i)) args = append(args, fmt.Sprintf("T%d", i))