diff --git a/shader.go b/shader.go index 79a869994..0ee67dfa3 100644 --- a/shader.go +++ b/shader.go @@ -24,6 +24,7 @@ import ( "github.com/hajimehoshi/ebiten/internal/graphics" "github.com/hajimehoshi/ebiten/internal/mipmap" "github.com/hajimehoshi/ebiten/internal/shader" + "github.com/hajimehoshi/ebiten/internal/shaderir" ) var shaderSuffix string @@ -105,6 +106,7 @@ func __vertex(position vec2, texCoord vec2, color vec4) (vec4, vec2, vec4) { type Shader struct { shader *mipmap.Shader uniformNames []string + uniformTypes []shaderir.Type } func NewShader(src []byte) (*Shader, error) { @@ -137,6 +139,7 @@ func NewShader(src []byte) (*Shader, error) { return &Shader{ shader: mipmap.NewShader(s), uniformNames: s.UniformNames, + uniformTypes: s.Uniforms, }, nil } @@ -157,15 +160,67 @@ func (s *Shader) convertUniforms(uniforms map[string]interface{}) []interface{} } us := make([]interface{}, len(names)) - for n, u := range uniforms { - idx, ok := names[n] - if !ok { - // TODO: Panic here? + for name, idx := range names { + if v, ok := uniforms[name]; ok { + // TODO: Check the uniform variable types? + us[idx] = v continue } - us[idx] = u + + t := s.uniformTypes[idx] + v := zeroUniformValue(t) + if v == nil { + panic(fmt.Sprintf("ebiten: unexpected uniform variable type: %s", t.String())) + } + us[idx] = v } - // TODO: Check the uniform variable types? + // TODO: Panic if uniforms include an invalid name + return us } + +func zeroUniformValue(t shaderir.Type) interface{} { + switch t.Main { + case shaderir.Bool: + return false + case shaderir.Int: + return 0 + case shaderir.Float: + return float32(0) + case shaderir.Vec2: + return make([]float32, 2) + case shaderir.Vec3: + return make([]float32, 3) + case shaderir.Vec4: + return make([]float32, 4) + case shaderir.Mat2: + return make([]float32, 4) + case shaderir.Mat3: + return make([]float32, 9) + case shaderir.Mat4: + return make([]float32, 16) + case shaderir.Array: + switch t.Sub[0].Main { + case shaderir.Bool: + return make([]bool, t.Length) + case shaderir.Int: + return make([]int, t.Length) + case shaderir.Float: + return make([]float32, t.Length) + case shaderir.Vec2: + return make([]float32, t.Length*2) + case shaderir.Vec3: + return make([]float32, t.Length*3) + case shaderir.Vec4: + return make([]float32, t.Length*4) + case shaderir.Mat2: + return make([]float32, t.Length*4) + case shaderir.Mat3: + return make([]float32, t.Length*9) + case shaderir.Mat4: + return make([]float32, t.Length*16) + } + } + return nil +} diff --git a/shader_test.go b/shader_test.go index c2b8f4400..139047975 100644 --- a/shader_test.go +++ b/shader_test.go @@ -419,3 +419,32 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 { t.Errorf("error must be non-nil but was nil") } } + +func TestShaderUninitializedUniformVariables(t *testing.T) { + const w, h = 16, 16 + + dst, _ := NewImage(w, h, FilterDefault) + s, err := NewShader([]byte(`package main + +var U vec4 + +func Fragment(position vec4, texCoord vec2, color vec4) vec4 { + return U +} +`)) + if err != nil { + t.Fatal(err) + } + + dst.DrawRectShader(w, h, s, nil) + + for j := 0; j < h; j++ { + for i := 0; i < w; i++ { + got := dst.At(i, j).(color.RGBA) + var want color.RGBA + if got != want { + t.Errorf("dst.At(%d, %d): got: %v, want: %v", i, j, got, want) + } + } + } +}