shader: Refactoring

This commit is contained in:
Hajime Hoshi 2020-05-10 23:56:56 +09:00
parent ba9d27b8ba
commit 90fdececa0
3 changed files with 11 additions and 35 deletions

View File

@ -125,12 +125,11 @@ func dumpExpr(e ast.Expr) string {
case *ast.BasicLit: case *ast.BasicLit:
return e.Value return e.Value
case *ast.CompositeLit: case *ast.CompositeLit:
t := parseType(e.Type)
var vals []string var vals []string
for _, e := range e.Elts { for _, e := range e.Elts {
vals = append(vals, dumpExpr(e)) vals = append(vals, dumpExpr(e))
} }
return fmt.Sprintf("%s{%s}", t, strings.Join(vals, ", ")) return fmt.Sprintf("%s{%s}", e.Type, strings.Join(vals, ", "))
case *ast.Ident: case *ast.Ident:
return e.Name return e.Name
case *ast.SelectorExpr: case *ast.SelectorExpr:

View File

@ -192,11 +192,7 @@ func (sh *Shader) parseStruct(t *ast.TypeSpec) {
sh.addError(f.Pos(), fmt.Sprintf("position members must be one")) sh.addError(f.Pos(), fmt.Sprintf("position members must be one"))
continue continue
} }
t := parseType(f.Type) t := sh.parseType(f.Type)
if t == basicTypeNone {
sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type))
continue
}
if t != basicTypeVec4 { if t != basicTypeVec4 {
sh.addError(f.Type.Pos(), fmt.Sprintf("position must be vec4 but %s", t)) sh.addError(f.Type.Pos(), fmt.Sprintf("position must be vec4 but %s", t))
continue continue
@ -207,11 +203,7 @@ func (sh *Shader) parseStruct(t *ast.TypeSpec) {
} }
continue continue
} }
t := parseType(f.Type) t := sh.parseType(f.Type)
if t == basicTypeNone {
sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type))
continue
}
if !t.numeric() { if !t.numeric() {
sh.addError(f.Type.Pos(), fmt.Sprintf("members in %s must be numeric but %s", varyingStructName, t)) sh.addError(f.Type.Pos(), fmt.Sprintf("members in %s must be numeric but %s", varyingStructName, t))
continue continue
@ -228,11 +220,7 @@ func (sh *Shader) parseStruct(t *ast.TypeSpec) {
func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable { func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable {
var t basicType var t basicType
if vs.Type != nil { if vs.Type != nil {
t = parseType(vs.Type) t = s.parseType(vs.Type)
if t == basicTypeNone {
s.addError(vs.Type.Pos(), fmt.Sprintf("unexpected type: %s", vs.Type))
return nil
}
} }
var vars []variable var vars []variable
@ -257,11 +245,7 @@ func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable {
func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant { func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant {
var t basicType var t basicType
if vs.Type != nil { if vs.Type != nil {
t = parseType(vs.Type) t = s.parseType(vs.Type)
if t == basicTypeNone {
s.addError(vs.Type.Pos(), fmt.Sprintf("unexpected type: %s", vs.Type))
return nil
}
} }
var cs []constant var cs []constant
@ -287,11 +271,7 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl, block *block) function {
var args []variable var args []variable
for _, f := range d.Type.Params.List { for _, f := range d.Type.Params.List {
t := parseType(f.Type) t := sh.parseType(f.Type)
if t == basicTypeNone {
sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type))
continue
}
for _, n := range f.Names { for _, n := range f.Names {
args = append(args, variable{ args = append(args, variable{
name: n.Name, name: n.Name,
@ -303,11 +283,7 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl, block *block) function {
var rets []variable var rets []variable
if d.Type.Results != nil { if d.Type.Results != nil {
for _, f := range d.Type.Results.List { for _, f := range d.Type.Results.List {
t := parseType(f.Type) t := sh.parseType(f.Type)
if t == basicTypeNone {
sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type))
continue
}
if len(f.Names) == 0 { if len(f.Names) == 0 {
rets = append(rets, variable{ rets = append(rets, variable{
name: "", name: "",
@ -408,7 +384,7 @@ func (s *Shader) detectType(b *block, expr ast.Expr) basicType {
} }
return basicTypeNone return basicTypeNone
case *ast.CompositeLit: case *ast.CompositeLit:
return parseType(e.Type) return s.parseType(e.Type)
case *ast.Ident: case *ast.Ident:
n := e.Name n := e.Name
for _, v := range b.vars { for _, v := range b.vars {

View File

@ -35,7 +35,7 @@ const (
basicTypeSampler2d basicTypeSampler2d
) )
func parseType(expr ast.Expr) basicType { func (s *Shader) parseType(expr ast.Expr) basicType {
switch t := expr.(type) { switch t := expr.(type) {
case *ast.Ident: case *ast.Ident:
switch t.Name { switch t.Name {
@ -55,8 +55,9 @@ func parseType(expr ast.Expr) basicType {
return basicTypeMat4 return basicTypeMat4
case "sampler2d": case "sampler2d":
return basicTypeSampler2d return basicTypeSampler2d
default:
s.addError(t.Pos(), fmt.Sprintf("unexpected type: %s", t.Name))
} }
// TODO: Parse array types
} }
return basicTypeNone return basicTypeNone
} }