diff --git a/internal/shader/shader.go b/internal/shader/shader.go index dfe472906..bec8e363a 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -41,9 +41,15 @@ type Shader struct { // position is the field name of VertexOut that represents a vertex position (gl_Position in GLSL). position nameAndType - // varyings is a set of varying variables. + // varyings is a collection of varying variables. varyings []nameAndType + // uniforms is a collection of uniform variables. + uniforms []nameAndType + + // globals is a collection of global variables. + globals []nameAndType + errs []string } @@ -68,16 +74,35 @@ func NewShader(src []byte) (*Shader, error) { return nil, &ParseError{s.errs} } + sort.Slice(s.varyings, func(a, b int) bool { + return s.varyings[a].name < s.varyings[b].name + }) + sort.Slice(s.uniforms, func(a, b int) bool { + return s.uniforms[a].name < s.uniforms[b].name + }) + sort.Slice(s.globals, func(a, b int) bool { + return s.globals[a].name < s.globals[b].name + }) + // TODO: Make a call graph and reorder the elements. return s, nil } +func (s *Shader) addError(str string) { + s.errs = append(s.errs, str) +} + func (s *Shader) parse(f *ast.File) { // TODO: Accumulate errors for name, obj := range f.Scope.Objects { switch name { case varyingStructName: s.parseVaryingStruct(obj) + default: + switch obj.Kind { + case ast.Var: + s.parsePackageLevelVariable(name, obj) + } } } } @@ -85,13 +110,13 @@ func (s *Shader) parse(f *ast.File) { func (sh *Shader) parseVaryingStruct(obj *ast.Object) { name := obj.Name if obj.Kind != ast.Typ { - sh.errs = append(sh.errs, fmt.Sprintf("%s must be a type but %s", name, obj.Kind)) + sh.addError(fmt.Sprintf("%s must be a type but %s", name, obj.Kind)) return } t := obj.Decl.(*ast.TypeSpec).Type s, ok := t.(*ast.StructType) if !ok { - sh.errs = append(sh.errs, fmt.Sprintf("%s must be a struct but not", name)) + sh.addError(fmt.Sprintf("%s must be a struct but not", name)) return } @@ -100,24 +125,24 @@ func (sh *Shader) parseVaryingStruct(obj *ast.Object) { tag := f.Tag.Value m := kageTagRe.FindStringSubmatch(tag) if m == nil { - sh.errs = append(sh.errs, fmt.Sprintf("invalid struct tag: %s", tag)) + sh.addError(fmt.Sprintf("invalid struct tag: %s", tag)) continue } if m[1] != "position" { - sh.errs = append(sh.errs, fmt.Sprintf("struct tag value must be position in %s but %s", varyingStructName, m[1])) + sh.addError(fmt.Sprintf("struct tag value must be position in %s but %s", varyingStructName, m[1])) continue } if len(f.Names) != 1 { - sh.errs = append(sh.errs, fmt.Sprintf("position members must be one")) + sh.addError(fmt.Sprintf("position members must be one")) continue } t, err := parseType(f.Type) if err != nil { - sh.errs = append(sh.errs, err.Error()) + sh.addError(err.Error()) continue } if t != typVec4 { - sh.errs = append(sh.errs, fmt.Sprintf("position must be vec4 but %s", t)) + sh.addError(fmt.Sprintf("position must be vec4 but %s", t)) continue } sh.position = nameAndType{ @@ -128,11 +153,11 @@ func (sh *Shader) parseVaryingStruct(obj *ast.Object) { } t, err := parseType(f.Type) if err != nil { - sh.errs = append(sh.errs, err.Error()) + sh.addError(err.Error()) continue } if !t.numeric() { - sh.errs = append(sh.errs, fmt.Sprintf("members in %s must be numeric but %s", varyingStructName, t)) + sh.addError(fmt.Sprintf("members in %s must be numeric but %s", varyingStructName, t)) continue } for _, n := range f.Names { @@ -142,9 +167,28 @@ func (sh *Shader) parseVaryingStruct(obj *ast.Object) { }) } } - sort.Slice(sh.varyings, func(a, b int) bool { - return sh.varyings[a].name < sh.varyings[b].name - }) +} + +func (s *Shader) parsePackageLevelVariable(name string, obj *ast.Object) { + v, ok := obj.Decl.(*ast.ValueSpec) + if !ok { + s.addError("value spec expected") + return + } + t, err := parseType(v.Type) + if err != nil { + s.addError(err.Error()) + return + } + nt := nameAndType{ + name: name, + typ: t, + } + if 'A' <= name[0] && name[0] <= 'Z' { + s.uniforms = append(s.uniforms, nt) + } else { + s.globals = append(s.globals, nt) + } } // Dump dumps the shader state in an intermediate language. @@ -155,6 +199,15 @@ func (s *Shader) Dump() string { for _, v := range s.varyings { lines = append(lines, fmt.Sprintf("var %s varying %s", v.name, v.typ)) } + + for _, u := range s.uniforms { + lines = append(lines, fmt.Sprintf("var %s uniform %s", u.name, u.typ)) + } + + for _, g := range s.globals { + lines = append(lines, fmt.Sprintf("var %s %s", g.name, g.typ)) + } + return strings.Join(lines, "\n") + "\n" } diff --git a/internal/shader/shader_test.go b/internal/shader/shader_test.go index 466b941ef..75e7b3cde 100644 --- a/internal/shader/shader_test.go +++ b/internal/shader/shader_test.go @@ -33,10 +33,22 @@ type VertexOut struct { TexCoord vec2 Color vec4 } + +var Foo float +var ( + Bar vec2 + Baz, Quux vec3 + qux vec4 +) `, Dump: `var Position varying vec4 // position var Color varying vec4 var TexCoord varying vec2 +var Bar uniform vec2 +var Baz uniform vec3 +var Foo uniform float +var Quux uniform vec3 +var qux vec4 `, }, }