shader: Parse structs

This commit is contained in:
Hajime Hoshi 2020-05-11 00:13:08 +09:00
parent 90fdececa0
commit 6d182c4b55
4 changed files with 152 additions and 72 deletions

View File

@ -22,6 +22,7 @@ import (
) )
type block struct { type block struct {
types []typ
vars []variable vars []variable
consts []constant consts []constant
funcs []function funcs []function
@ -35,20 +36,25 @@ func (b *block) dump(indent int) []string {
var lines []string var lines []string
for _, t := range b.types {
ls := t.dump(indent)
ls[0] = fmt.Sprintf("type %s %s", t.name, ls[0])
lines = append(lines, ls...)
}
for _, v := range b.vars { for _, v := range b.vars {
init := "" init := ""
if v.init != nil { if v.init != nil {
init = " = " + dumpExpr(v.init) init = " = " + dumpExpr(v.init)
} }
lines = append(lines, fmt.Sprintf("%svar %s %s%s", idt, v.name, v.basicType, init)) lines = append(lines, fmt.Sprintf("%svar %s %s%s", idt, v.name, v.typ, init))
} }
for _, c := range b.consts { for _, c := range b.consts {
lines = append(lines, fmt.Sprintf("%sconst %s %s = %s", idt, c.name, c.basicType, dumpExpr(c.init))) lines = append(lines, fmt.Sprintf("%sconst %s %s = %s", idt, c.name, c.typ, dumpExpr(c.init)))
} }
for _, f := range b.funcs { for _, f := range b.funcs {
var args []string var args []string
for _, a := range f.args { for _, a := range f.args {
args = append(args, fmt.Sprintf("%s %s", a.name, a.basicType)) args = append(args, fmt.Sprintf("%s %s", a.name, a.typ))
} }
var rets []string var rets []string
for _, r := range f.rets { for _, r := range f.rets {
@ -56,7 +62,7 @@ func (b *block) dump(indent int) []string {
if name == "" { if name == "" {
name = "_" name = "_"
} }
rets = append(rets, fmt.Sprintf("%s %s", name, r.basicType)) rets = append(rets, fmt.Sprintf("%s %s", name, r.typ))
} }
l := fmt.Sprintf("func %s(%s)", f.name, strings.Join(args, ", ")) l := fmt.Sprintf("func %s(%s)", f.name, strings.Join(args, ", "))
if len(rets) > 0 { if len(rets) > 0 {

View File

@ -34,15 +34,15 @@ var (
) )
type variable struct { type variable struct {
name string name string
basicType basicType typ typ
init ast.Expr init ast.Expr
} }
type constant struct { type constant struct {
name string name string
basicType basicType typ typ
init ast.Expr init ast.Expr
} }
type function struct { type function struct {
@ -133,7 +133,9 @@ func (sh *Shader) parseDecl(b *block, d ast.Decl, global bool) {
// TODO: Parse other types // TODO: Parse other types
for _, s := range d.Specs { for _, s := range d.Specs {
s := s.(*ast.TypeSpec) s := s.(*ast.TypeSpec)
sh.parseStruct(s) t := sh.parseType(s.Type)
t.name = s.Name.Name
b.types = append(b.types, t)
} }
case token.CONST: case token.CONST:
for _, s := range d.Specs { for _, s := range d.Specs {
@ -193,32 +195,28 @@ func (sh *Shader) parseStruct(t *ast.TypeSpec) {
continue continue
} }
t := sh.parseType(f.Type) t := sh.parseType(f.Type)
if t != basicTypeVec4 { if t.basic != basicTypeVec4 {
sh.addError(f.Type.Pos(), fmt.Sprintf("position must be vec4 but %s", t)) sh.addError(f.Type.Pos(), fmt.Sprintf("position must be vec4 but %s", t))
continue continue
} }
sh.position = variable{ sh.position = variable{
name: f.Names[0].Name, name: f.Names[0].Name,
basicType: t, typ: t,
} }
continue continue
} }
t := sh.parseType(f.Type) t := sh.parseType(f.Type)
if !t.numeric() {
sh.addError(f.Type.Pos(), fmt.Sprintf("members in %s must be numeric but %s", varyingStructName, t))
continue
}
for _, n := range f.Names { for _, n := range f.Names {
sh.varyings = append(sh.varyings, variable{ sh.varyings = append(sh.varyings, variable{
name: n.Name, name: n.Name,
basicType: t, typ: t,
}) })
} }
} }
} }
func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable { func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable {
var t basicType var t typ
if vs.Type != nil { if vs.Type != nil {
t = s.parseType(vs.Type) t = s.parseType(vs.Type)
} }
@ -228,22 +226,22 @@ func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable {
var init ast.Expr var init ast.Expr
if len(vs.Values) > 0 { if len(vs.Values) > 0 {
init = vs.Values[i] init = vs.Values[i]
if t == basicTypeNone { if t.isNone() {
t = s.detectType(block, init) t = s.detectType(block, init)
} }
} }
name := n.Name name := n.Name
vars = append(vars, variable{ vars = append(vars, variable{
name: name, name: name,
basicType: t, typ: t,
init: init, init: init,
}) })
} }
return vars return vars
} }
func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant { func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant {
var t basicType var t typ
if vs.Type != nil { if vs.Type != nil {
t = s.parseType(vs.Type) t = s.parseType(vs.Type)
} }
@ -251,9 +249,9 @@ func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant {
var cs []constant var cs []constant
for i, n := range vs.Names { for i, n := range vs.Names {
cs = append(cs, constant{ cs = append(cs, constant{
name: n.Name, name: n.Name,
basicType: t, typ: t,
init: vs.Values[i], init: vs.Values[i],
}) })
} }
return cs return cs
@ -274,8 +272,8 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl, block *block) function {
t := sh.parseType(f.Type) t := sh.parseType(f.Type)
for _, n := range f.Names { for _, n := range f.Names {
args = append(args, variable{ args = append(args, variable{
name: n.Name, name: n.Name,
basicType: t, typ: t,
}) })
} }
} }
@ -286,14 +284,14 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl, block *block) function {
t := sh.parseType(f.Type) t := sh.parseType(f.Type)
if len(f.Names) == 0 { if len(f.Names) == 0 {
rets = append(rets, variable{ rets = append(rets, variable{
name: "", name: "",
basicType: t, typ: t,
}) })
} else { } else {
for _, n := range f.Names { for _, n := range f.Names {
rets = append(rets, variable{ rets = append(rets, variable{
name: n.Name, name: n.Name,
basicType: t, typ: t,
}) })
} }
} }
@ -323,7 +321,7 @@ func (sh *Shader) parseBlock(outer *block, b *ast.BlockStmt) *block {
name: s.(*ast.Ident).Name, name: s.(*ast.Ident).Name,
} }
if len(l.Rhs) > 0 { if len(l.Rhs) > 0 {
v.basicType = sh.detectType(block, l.Rhs[i]) v.typ = sh.detectType(block, l.Rhs[i])
} }
block.vars = append(block.vars, v) block.vars = append(block.vars, v)
} }
@ -371,31 +369,33 @@ func (sh *Shader) parseBlock(outer *block, b *ast.BlockStmt) *block {
return block return block
} }
func (s *Shader) detectType(b *block, expr ast.Expr) basicType { func (s *Shader) detectType(b *block, expr ast.Expr) typ {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
if e.Kind == token.FLOAT { if e.Kind == token.FLOAT {
return basicTypeFloat return typ{
basic: basicTypeFloat,
}
} }
if e.Kind == token.INT { if e.Kind == token.INT {
s.addError(expr.Pos(), fmt.Sprintf("integer literal is not implemented yet: %s", e.Value)) s.addError(expr.Pos(), fmt.Sprintf("integer literal is not implemented yet: %s", e.Value))
} else { } else {
s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value))
} }
return basicTypeNone return typ{}
case *ast.CompositeLit: case *ast.CompositeLit:
return s.parseType(e.Type) return s.parseType(e.Type)
case *ast.Ident: case *ast.Ident:
n := e.Name n := e.Name
for _, v := range b.vars { for _, v := range b.vars {
if v.name == n { if v.name == n {
return v.basicType return v.typ
} }
} }
if b == &s.global { if b == &s.global {
for _, v := range s.uniforms { for _, v := range s.uniforms {
if v.name == n { if v.name == n {
return v.basicType return v.typ
} }
} }
} }
@ -403,12 +403,12 @@ func (s *Shader) detectType(b *block, expr ast.Expr) basicType {
return s.detectType(b.outer, e) return s.detectType(b.outer, e)
} }
s.addError(expr.Pos(), fmt.Sprintf("unexpected identity: %s", n)) s.addError(expr.Pos(), fmt.Sprintf("unexpected identity: %s", n))
return basicTypeNone return typ{}
//case *ast.SelectorExpr: //case *ast.SelectorExpr:
//return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel)) //return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel))
default: default:
s.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr)) s.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr))
return basicTypeNone return typ{}
} }
} }
@ -417,27 +417,16 @@ func (s *Shader) Dump() string {
var lines []string var lines []string
if s.position.name != "" { if s.position.name != "" {
lines = append(lines, fmt.Sprintf("var %s varying %s // position", s.position.name, s.position.basicType)) lines = append(lines, fmt.Sprintf("var %s varying %s // position", s.position.name, s.position.typ))
} }
for _, v := range s.varyings { for _, v := range s.varyings {
lines = append(lines, fmt.Sprintf("var %s varying %s", v.name, v.basicType)) lines = append(lines, fmt.Sprintf("var %s varying %s", v.name, v.typ))
} }
for _, u := range s.uniforms { for _, u := range s.uniforms {
lines = append(lines, fmt.Sprintf("var %s uniform %s", u.name, u.basicType)) lines = append(lines, fmt.Sprintf("var %s uniform %s", u.name, u.typ))
} }
lines = append(lines, s.global.dump(0)...) lines = append(lines, s.global.dump(0)...)
return strings.Join(lines, "\n") + "\n" return strings.Join(lines, "\n") + "\n"
} }
func (s *Shader) GlslVertex() string {
var lines []string
for _, v := range s.varyings {
// TODO: variable names must be escaped not to conflict with keywords.
lines = append(lines, fmt.Sprintf("varying %s %s;", v.basicType.glslString(), v.name))
}
return strings.Join(lines, "\n") + "\n"
}

View File

@ -53,13 +53,15 @@ func F1(a, b vec2) vec4 {
return c2 return c2
} }
`, `,
Dump: `var Position varying vec4 // position Dump: `var Bar uniform vec2
var Color varying vec4
var TexCoord varying vec2
var Bar uniform vec2
var Baz uniform vec3 var Baz uniform vec3
var Foo uniform float var Foo uniform float
var Quux uniform vec3 var Quux uniform vec3
type VertexOut struct {
Position vec4
TexCoord vec2
Color vec4
}
const C1 float = 1 const C1 float = 1
const C2 float = 2 const C2 float = 2
const C3 float = 3 const C3 float = 3

View File

@ -17,6 +17,7 @@ package shader
import ( import (
"fmt" "fmt"
"go/ast" "go/ast"
"strings"
) )
type basicType int type basicType int
@ -33,33 +34,113 @@ const (
basicTypeMat3 basicTypeMat3
basicTypeMat4 basicTypeMat4
basicTypeSampler2d basicTypeSampler2d
basicTypeStruct
) )
func (s *Shader) parseType(expr ast.Expr) basicType { type structMember struct {
name string
typ typ
tag string
}
type typ struct {
basic basicType
name string
structMembers []structMember
}
func (t *typ) isNone() bool {
return t.basic == basicTypeNone
}
func (sh *Shader) parseType(expr ast.Expr) typ {
switch t := expr.(type) { switch t := expr.(type) {
case *ast.Ident: case *ast.Ident:
switch t.Name { switch t.Name {
case "float": case "float":
return basicTypeFloat return typ{
basic: basicTypeFloat,
}
case "vec2": case "vec2":
return basicTypeVec2 return typ{
basic: basicTypeVec2,
}
case "vec3": case "vec3":
return basicTypeVec3 return typ{
basic: basicTypeVec3,
}
case "vec4": case "vec4":
return basicTypeVec4 return typ{
basic: basicTypeVec4,
}
case "mat2": case "mat2":
return basicTypeMat2 return typ{
basic: basicTypeMat2,
}
case "mat3": case "mat3":
return basicTypeMat3 return typ{
basic: basicTypeMat3,
}
case "mat4": case "mat4":
return basicTypeMat4 return typ{
basic: basicTypeMat4,
}
case "sampler2d": case "sampler2d":
return basicTypeSampler2d return typ{
basic: basicTypeSampler2d,
}
default: default:
s.addError(t.Pos(), fmt.Sprintf("unexpected type: %s", t.Name)) sh.addError(t.Pos(), fmt.Sprintf("unexpected type: %s", t.Name))
return typ{}
} }
case *ast.StructType:
str := typ{
basic: basicTypeStruct,
}
for _, f := range t.Fields.List {
typ := sh.parseType(f.Type)
var tag string
if f.Tag != nil {
tag = f.Tag.Value
}
for _, n := range f.Names {
str.structMembers = append(str.structMembers, structMember{
name: n.Name,
typ: typ,
tag: tag,
})
}
}
return str
default:
sh.addError(t.Pos(), fmt.Sprintf("unepxected type: %v", t))
return typ{}
} }
return basicTypeNone }
func (t typ) dump(indent int) []string {
idt := strings.Repeat("\t", indent)
switch t.basic {
case basicTypeStruct:
ls := []string{
fmt.Sprintf("%sstruct {", idt),
}
for _, m := range t.structMembers {
ls = append(ls, fmt.Sprintf("%s\t%s %s", idt, m.name, m.typ))
}
ls = append(ls, fmt.Sprintf("%s}", idt))
return ls
default:
return []string{t.basic.String()}
}
}
func (t typ) String() string {
if t.name != "" {
return t.name
}
return t.basic.String()
} }
func (t basicType) String() string { func (t basicType) String() string {
@ -82,6 +163,8 @@ func (t basicType) String() string {
return "mat4" return "mat4"
case basicTypeSampler2d: case basicTypeSampler2d:
return "sampler2d" return "sampler2d"
case basicTypeStruct:
return "(struct)"
default: default:
return fmt.Sprintf("unknown(%d)", t) return fmt.Sprintf("unknown(%d)", t)
} }