diff --git a/internal/shader/shader.go b/internal/shader/shader.go index f7d705f54..f141ee1a1 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -100,32 +100,40 @@ func (s *Shader) addError(pos token.Pos, str string) { s.errs = append(s.errs, fmt.Sprintf("%s: %s", p, str)) } -func (s *Shader) parse(f *ast.File) { - for name, obj := range f.Scope.Objects { - switch name { - case varyingStructName: - s.parseVaryingStruct(obj) - default: - switch obj.Kind { - case ast.Con: - s.parsePackageLevelConstant(name, obj) - case ast.Var: - s.parsePackageLevelVariable(name, obj) +func (sh *Shader) parse(f *ast.File) { + for _, d := range f.Decls { + switch d := d.(type) { + case *ast.GenDecl: + switch d.Tok { + case token.TYPE: + // TODO: Parse regular structs or other types + for _, s := range d.Specs { + s := s.(*ast.TypeSpec) + if s.Name.Name == varyingStructName { + sh.parseVaryingStruct(s) + } + } + case token.CONST: + for _, s := range d.Specs { + s := s.(*ast.ValueSpec) + sh.parsePackageLevelConstant(s) + } + case token.VAR: + for _, s := range d.Specs { + s := s.(*ast.ValueSpec) + sh.parsePackageLevelVariable(s) + } } + default: + // TODO: Parse functions } } } -func (sh *Shader) parseVaryingStruct(obj *ast.Object) { - name := obj.Name - if obj.Kind != ast.Typ { - sh.addError(obj.Pos(), fmt.Sprintf("%s must be a type but %s", name, obj.Kind)) - return - } - t := obj.Decl.(*ast.TypeSpec).Type - s, ok := t.(*ast.StructType) +func (sh *Shader) parseVaryingStruct(t *ast.TypeSpec) { + s, ok := t.Type.(*ast.StructType) if !ok { - sh.addError(t.Pos(), fmt.Sprintf("%s must be a struct but not", name)) + sh.addError(t.Type.Pos(), fmt.Sprintf("%s must be a struct but not", t.Name)) return } @@ -178,45 +186,35 @@ func (sh *Shader) parseVaryingStruct(obj *ast.Object) { } } -func (s *Shader) parsePackageLevelVariable(name string, obj *ast.Object) { - v, ok := obj.Decl.(*ast.ValueSpec) - if !ok { - s.addError(obj.Pos(), "value spec expected") - return - } - t, err := parseType(v.Type) +func (s *Shader) parsePackageLevelVariable(vs *ast.ValueSpec) { + t, err := parseType(vs.Type) if err != nil { - s.addError(v.Type.Pos(), err.Error()) + s.addError(vs.Type.Pos(), err.Error()) return } - val := variable{ - name: name, - typ: t, - } - // TODO: Parse initial value. - if 'A' <= name[0] && name[0] <= 'Z' { - s.uniforms = append(s.uniforms, val) - } else { - s.globals = append(s.globals, val) + for _, n := range vs.Names { + name := n.Name + val := variable{ + name: name, + typ: t, + } + // TODO: Parse initial value. + if 'A' <= name[0] && name[0] <= 'Z' { + s.uniforms = append(s.uniforms, val) + } else { + s.globals = append(s.globals, val) + } } } -func (s *Shader) parsePackageLevelConstant(name string, obj *ast.Object) { - vs, ok := obj.Decl.(*ast.ValueSpec) - if !ok { - s.addError(obj.Pos(), "value spec expected") - return - } +func (s *Shader) parsePackageLevelConstant(vs *ast.ValueSpec) { t, err := parseType(vs.Type) if err != nil { - s.addError(vs.Pos(), err.Error()) + s.addError(vs.Type.Pos(), err.Error()) return } - for i, v := range vs.Values { - if vs.Names[i].Name != name { - continue - } - + for i, n := range vs.Names { + v := vs.Values[i] var init string switch v := v.(type) { case *ast.BasicLit: @@ -228,6 +226,7 @@ func (s *Shader) parsePackageLevelConstant(name string, obj *ast.Object) { default: // TODO: Parse the expression. } + name := n.Name val := variable{ name: name, typ: t, // TODO: Treat consts without types