diff --git a/internal/shader/block.go b/internal/shader/block.go index a35d6388c..99263faaf 100644 --- a/internal/shader/block.go +++ b/internal/shader/block.go @@ -15,10 +15,8 @@ package shader import ( - "fmt" "go/ast" "go/token" - "strings" ) type block struct { @@ -31,56 +29,6 @@ type block struct { outer *block } -func (b *block) dump(indent int) []string { - idt := strings.Repeat("\t", indent) - - var lines []string - - for _, t := range b.types { - ls := t.dump(indent) - ls[0] = fmt.Sprintf("type %s %s", t.name, ls[0]) - lines = append(lines, ls...) - } - for _, v := range b.vars { - init := "" - if v.init != nil { - init = " = " + dumpExpr(v.init) - } - lines = append(lines, fmt.Sprintf("%svar %s %s%s", idt, v.name, v.typ, init)) - } - for _, c := range b.consts { - lines = append(lines, fmt.Sprintf("%sconst %s %s = %s", idt, c.name, c.typ, dumpExpr(c.init))) - } - for _, f := range b.funcs { - var args []string - for _, a := range f.args { - args = append(args, fmt.Sprintf("%s %s", a.name, a.typ)) - } - var rets []string - for _, r := range f.rets { - name := r.name - if name == "" { - name = "_" - } - rets = append(rets, fmt.Sprintf("%s %s", name, r.typ)) - } - l := fmt.Sprintf("func %s(%s)", f.name, strings.Join(args, ", ")) - if len(rets) > 0 { - l += " (" + strings.Join(rets, ", ") + ")" - } - l += " {" - lines = append(lines, l) - lines = append(lines, f.body.dump(indent+1)...) - lines = append(lines, "}") - } - - for _, s := range b.stmts { - lines = append(lines, s.dump(indent)...) - } - - return lines -} - type stmtType int const ( @@ -95,52 +43,3 @@ type stmt struct { exprs []ast.Expr block *block } - -func (s *stmt) dump(indent int) []string { - idt := strings.Repeat("\t", indent) - - var lines []string - switch s.stmtType { - case stmtNone: - lines = append(lines, "%s(none)", idt) - case stmtAssign: - lines = append(lines, fmt.Sprintf("%s%s = %s", idt, dumpExpr(s.exprs[0]), dumpExpr(s.exprs[1]))) - case stmtBlock: - lines = append(lines, fmt.Sprintf("%s{", idt)) - lines = append(lines, s.block.dump(indent+1)...) - lines = append(lines, fmt.Sprintf("%s}", idt)) - case stmtReturn: - var expr string - if len(s.exprs) > 0 { - var strs []string - for _, e := range s.exprs { - strs = append(strs, dumpExpr(e)) - } - expr = " " + strings.Join(strs, ", ") - } - lines = append(lines, fmt.Sprintf("%sreturn%s", idt, expr)) - default: - lines = append(lines, fmt.Sprintf("%s(unknown stmt: %d)", idt, s.stmtType)) - } - - return lines -} - -func dumpExpr(e ast.Expr) string { - switch e := e.(type) { - case *ast.BasicLit: - return e.Value - case *ast.CompositeLit: - var vals []string - for _, e := range e.Elts { - vals = append(vals, dumpExpr(e)) - } - return fmt.Sprintf("%s{%s}", e.Type, strings.Join(vals, ", ")) - case *ast.Ident: - return e.Name - case *ast.SelectorExpr: - return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel)) - default: - return fmt.Sprintf("(unkown expr: %#v)", e) - } -} diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 0015fcb43..6ed6f2a3d 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -19,18 +19,10 @@ import ( "go/ast" "go/parser" "go/token" - "regexp" "sort" "strings" -) -const ( - // TODO: Remove this. This struct is a user-defined struct. - varyingStructName = "VertexOut" -) - -var ( - kageTagRe = regexp.MustCompile("^`" + `kage:\"(.+)\"` + "`$") + "github.com/hajimehoshi/ebiten/internal/shaderir" ) type variable struct { @@ -52,9 +44,11 @@ type function struct { body *block } -type Shader struct { +type compileState struct { fs *token.FileSet + result *shaderir.Program + // position is the field name of VertexOut that represents a vertex position (gl_Position in GLSL). position variable @@ -77,14 +71,14 @@ func (p *ParseError) Error() string { return strings.Join(p.errs, "\n") } -func NewShader(src []byte) (*Shader, error) { +func Compile(src []byte) (*shaderir.Program, error) { fs := token.NewFileSet() f, err := parser.ParseFile(fs, "", src, parser.AllErrors) if err != nil { return nil, err } - s := &Shader{ + s := &compileState{ fs: fs, } s.parse(f) @@ -97,35 +91,35 @@ func NewShader(src []byte) (*Shader, error) { // TODO: Resolve constants // TODO: Make a call graph and reorder the elements. - return s, nil + return s.result, nil } -func (s *Shader) addError(pos token.Pos, str string) { +func (s *compileState) addError(pos token.Pos, str string) { p := s.fs.Position(pos) s.errs = append(s.errs, fmt.Sprintf("%s: %s", p, str)) } -func (sh *Shader) parse(f *ast.File) { +func (cs *compileState) parse(f *ast.File) { for _, d := range f.Decls { - sh.parseDecl(&sh.global, d, true) + cs.parseDecl(&cs.global, d, true) } // TODO: This is duplicated with parseBlock. - sort.Slice(sh.global.consts, func(a, b int) bool { - return sh.global.consts[a].name < sh.global.consts[b].name + sort.Slice(cs.global.consts, func(a, b int) bool { + return cs.global.consts[a].name < cs.global.consts[b].name }) - sort.Slice(sh.global.funcs, func(a, b int) bool { - return sh.global.funcs[a].name < sh.global.funcs[b].name + sort.Slice(cs.global.funcs, func(a, b int) bool { + return cs.global.funcs[a].name < cs.global.funcs[b].name }) - sort.Slice(sh.varyings, func(a, b int) bool { - return sh.varyings[a].name < sh.varyings[b].name + sort.Slice(cs.varyings, func(a, b int) bool { + return cs.varyings[a].name < cs.varyings[b].name }) - sort.Slice(sh.uniforms, func(a, b int) bool { - return sh.uniforms[a].name < sh.uniforms[b].name + sort.Slice(cs.uniforms, func(a, b int) bool { + return cs.uniforms[a].name < cs.uniforms[b].name }) } -func (sh *Shader) parseDecl(b *block, d ast.Decl, global bool) { +func (cs *compileState) parseDecl(b *block, d ast.Decl, global bool) { switch d := d.(type) { case *ast.GenDecl: switch d.Tok { @@ -133,81 +127,55 @@ func (sh *Shader) parseDecl(b *block, d ast.Decl, global bool) { // TODO: Parse other types for _, s := range d.Specs { s := s.(*ast.TypeSpec) - t := sh.parseType(s.Type) + t := cs.parseType(s.Type) t.name = s.Name.Name b.types = append(b.types, t) } case token.CONST: for _, s := range d.Specs { s := s.(*ast.ValueSpec) - cs := sh.parseConstant(s) + cs := cs.parseConstant(s) b.consts = append(b.consts, cs...) } case token.VAR: for _, s := range d.Specs { s := s.(*ast.ValueSpec) - vs := sh.parseVariable(b, s) + vs := cs.parseVariable(b, s) if !global { b.vars = append(b.vars, vs...) continue } for i, v := range vs { if v.name[0] < 'A' || 'Z' < v.name[0] { - sh.addError(s.Names[i].Pos(), fmt.Sprintf("global variables must be exposed: %s", v.name)) + cs.addError(s.Names[i].Pos(), fmt.Sprintf("global variables must be exposed: %s", v.name)) } // TODO: Check RHS - sh.uniforms = append(sh.uniforms, v) + cs.uniforms = append(cs.uniforms, v) } } case token.IMPORT: - sh.addError(d.Pos(), "import is forbidden") + cs.addError(d.Pos(), "import is forbidden") default: - sh.addError(d.Pos(), "unexpected token") + cs.addError(d.Pos(), "unexpected token") } case *ast.FuncDecl: - b.funcs = append(b.funcs, sh.parseFunc(d, b)) + b.funcs = append(b.funcs, cs.parseFunc(d, b)) default: - sh.addError(d.Pos(), "unexpected decl") + cs.addError(d.Pos(), "unexpected decl") } } -func (sh *Shader) parseStruct(t *ast.TypeSpec) { +func (cs *compileState) parseStruct(t *ast.TypeSpec) { s, ok := t.Type.(*ast.StructType) if !ok { - sh.addError(t.Type.Pos(), fmt.Sprintf("%s must be a struct but not", t.Name)) + cs.addError(t.Type.Pos(), fmt.Sprintf("%s must be a struct but not", t.Name)) return } for _, f := range s.Fields.List { - if f.Tag != nil { - tag := f.Tag.Value - m := kageTagRe.FindStringSubmatch(tag) - if m == nil { - sh.addError(f.Tag.Pos(), fmt.Sprintf("invalid struct tag: %s", tag)) - continue - } - if m[1] != "position" { - sh.addError(f.Tag.Pos(), fmt.Sprintf("struct tag value must be position in %s but %s", varyingStructName, m[1])) - continue - } - if len(f.Names) != 1 { - sh.addError(f.Pos(), fmt.Sprintf("position members must be one")) - continue - } - t := sh.parseType(f.Type) - if t.basic != basicTypeVec4 { - sh.addError(f.Type.Pos(), fmt.Sprintf("position must be vec4 but %s", t)) - continue - } - sh.position = variable{ - name: f.Names[0].Name, - typ: t, - } - continue - } - t := sh.parseType(f.Type) + t := cs.parseType(f.Type) for _, n := range f.Names { - sh.varyings = append(sh.varyings, variable{ + cs.varyings = append(cs.varyings, variable{ name: n.Name, typ: t, }) @@ -215,7 +183,7 @@ func (sh *Shader) parseStruct(t *ast.TypeSpec) { } } -func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable { +func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) []variable { var t typ if vs.Type != nil { t = s.parseType(vs.Type) @@ -226,7 +194,7 @@ func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable { var init ast.Expr if len(vs.Values) > 0 { init = vs.Values[i] - if t.isNone() { + if t.ir.Main == shaderir.None { t = s.detectType(block, init) } } @@ -240,7 +208,7 @@ func (s *Shader) parseVariable(block *block, vs *ast.ValueSpec) []variable { return vars } -func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant { +func (s *compileState) parseConstant(vs *ast.ValueSpec) []constant { var t typ if vs.Type != nil { t = s.parseType(vs.Type) @@ -257,19 +225,19 @@ func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant { return cs } -func (sh *Shader) parseFunc(d *ast.FuncDecl, block *block) function { +func (cs *compileState) parseFunc(d *ast.FuncDecl, block *block) function { if d.Name == nil { - sh.addError(d.Pos(), "function must have a name") + cs.addError(d.Pos(), "function must have a name") return function{} } if d.Body == nil { - sh.addError(d.Pos(), "function must have a body") + cs.addError(d.Pos(), "function must have a body") return function{} } var args []variable for _, f := range d.Type.Params.List { - t := sh.parseType(f.Type) + t := cs.parseType(f.Type) for _, n := range f.Names { args = append(args, variable{ name: n.Name, @@ -281,7 +249,7 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl, block *block) function { var rets []variable if d.Type.Results != nil { for _, f := range d.Type.Results.List { - t := sh.parseType(f.Type) + t := cs.parseType(f.Type) if len(f.Names) == 0 { rets = append(rets, variable{ name: "", @@ -302,11 +270,11 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl, block *block) function { name: d.Name.Name, args: args, rets: rets, - body: sh.parseBlock(block, d.Body), + body: cs.parseBlock(block, d.Body), } } -func (sh *Shader) parseBlock(outer *block, b *ast.BlockStmt) *block { +func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt) *block { block := &block{ outer: outer, } @@ -321,7 +289,7 @@ func (sh *Shader) parseBlock(outer *block, b *ast.BlockStmt) *block { name: s.(*ast.Ident).Name, } if len(l.Rhs) > 0 { - v.typ = sh.detectType(block, l.Rhs[i]) + v.typ = cs.detectType(block, l.Rhs[i]) } block.vars = append(block.vars, v) } @@ -343,10 +311,10 @@ func (sh *Shader) parseBlock(outer *block, b *ast.BlockStmt) *block { case *ast.BlockStmt: block.stmts = append(block.stmts, stmt{ stmtType: stmtBlock, - block: sh.parseBlock(block, l), + block: cs.parseBlock(block, l), }) case *ast.DeclStmt: - sh.parseDecl(block, l.Decl, false) + cs.parseDecl(block, l.Decl, false) case *ast.ReturnStmt: var exprs []ast.Expr for _, r := range l.Results { @@ -369,12 +337,12 @@ func (sh *Shader) parseBlock(outer *block, b *ast.BlockStmt) *block { return block } -func (s *Shader) detectType(b *block, expr ast.Expr) typ { +func (s *compileState) detectType(b *block, expr ast.Expr) typ { switch e := expr.(type) { case *ast.BasicLit: if e.Kind == token.FLOAT { return typ{ - basic: basicTypeFloat, + ir: shaderir.Type{Main: shaderir.Float}, } } if e.Kind == token.INT { @@ -411,22 +379,3 @@ func (s *Shader) detectType(b *block, expr ast.Expr) typ { return typ{} } } - -// Dump dumps the shader state in an intermediate language. -func (s *Shader) Dump() string { - var lines []string - - if s.position.name != "" { - lines = append(lines, fmt.Sprintf("var %s varying %s // position", s.position.name, s.position.typ)) - } - for _, v := range s.varyings { - lines = append(lines, fmt.Sprintf("var %s varying %s", v.name, v.typ)) - } - for _, u := range s.uniforms { - lines = append(lines, fmt.Sprintf("var %s uniform %s", u.name, u.typ)) - } - - lines = append(lines, s.global.dump(0)...) - - return strings.Join(lines, "\n") + "\n" -} diff --git a/internal/shader/shader_test.go b/internal/shader/shader_test.go index 250c83f26..1b37ea77b 100644 --- a/internal/shader/shader_test.go +++ b/internal/shader/shader_test.go @@ -24,97 +24,98 @@ func TestDump(t *testing.T) { tests := []struct { Name string Src string - Dump string + VS string + FS string }{ - { - Name: "general", - Src: `package main + /*{ + Name: "general", + Src: `package main -type VertexOut struct { - Position vec4 ` + "`kage:\"position\"`" + ` - TexCoord vec2 - Color vec4 -} + type VertexOut struct { + Position vec4 ` + "`kage:\"position\"`" + ` + TexCoord vec2 + Color vec4 + } -var Foo float -var ( - Bar vec2 - Baz, Quux vec3 -) + var Foo float + var ( + Bar vec2 + Baz, Quux vec3 + ) -const C1 float = 1 -const C2, C3 float = 2, 3 + const C1 float = 1 + const C2, C3 float = 2, 3 -func F1(a, b vec2) vec4 { - var c0 vec2 = a - var c1, c2 = c0, 1.0 - c1.x = c2.x - c3 := vec4{c0, c1} - return c2 -} -`, - Dump: `var Bar uniform vec2 -var Baz uniform vec3 -var Foo uniform float -var Quux uniform vec3 -type VertexOut struct { - Position vec4 - TexCoord vec2 - Color vec4 -} -const C1 float = 1 -const C2 float = 2 -const C3 float = 3 -func F1(a vec2, b vec2) (_ vec4) { - var c0 vec2 = a - var c1 vec2 = c0 - var c2 vec2 = 1.0 - var c3 vec4 - c1.x = c2.x - c3 = vec4{c0, c1} - return c2 -} -`, - }, - { - Name: "AutoType", - Src: `package main + func F1(a, b vec2) vec4 { + var c0 vec2 = a + var c1, c2 = c0, 1.0 + c1.x = c2.x + c3 := vec4{c0, c1} + return c2 + } + `, + Dump: `var Bar uniform vec2 + var Baz uniform vec3 + var Foo uniform float + var Quux uniform vec3 + type VertexOut struct { + Position vec4 + TexCoord vec2 + Color vec4 + } + const C1 float = 1 + const C2 float = 2 + const C3 float = 3 + func F1(a vec2, b vec2) (_ vec4) { + var c0 vec2 = a + var c1 vec2 = c0 + var c2 vec2 = 1.0 + var c3 vec4 + c1.x = c2.x + c3 = vec4{c0, c1} + return c2 + } + `, + }, + { + Name: "AutoType", + Src: `package main -var V0 = 0.0 -func F() { - v1 := V0 -} -`, - Dump: `var V0 uniform float -func F() { - var v1 float - v1 = V0 -} -`, - }, - { - Name: "AutoType2", - Src: `package main + var V0 = 0.0 + func F() { + v1 := V0 + } + `, + Dump: `var V0 uniform float + func F() { + var v1 float + v1 = V0 + } + `, + }, + { + Name: "AutoType2", + Src: `package main -var V0 = 0.0 -func F() { - v1 := V0 - { - v2 := v1 - } -} -`, - Dump: `var V0 uniform float -func F() { - var v1 float - v1 = V0 - { - var v2 float - v2 = v1 - } -} -`, - }, + var V0 = 0.0 + func F() { + v1 := V0 + { + v2 := v1 + } + } + `, + Dump: `var V0 uniform float + func F() { + var v1 float + v1 = V0 + { + var v2 float + v2 = v1 + } + } + `, + },*/ /*{ Name: "Struct", Src: `package main @@ -138,12 +139,16 @@ func F() { },*/ } for _, tc := range tests { - s, err := NewShader([]byte(tc.Src)) + s, err := Compile([]byte(tc.Src)) if err != nil { t.Error(err) continue } - if got, want := s.Dump(), tc.Dump; got != want { + vs, fs := s.Glsl() + if got, want := vs, tc.VS; got != want { + t.Errorf("%s: got: %v, want: %v", tc.Name, got, want) + } + if got, want := fs, tc.FS; got != want { t.Errorf("%s: got: %v, want: %v", tc.Name, got, want) } } diff --git a/internal/shader/type.go b/internal/shader/type.go index 2cc14d77b..93e82147e 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -17,184 +17,70 @@ package shader import ( "fmt" "go/ast" - "strings" -) -type basicType int + "github.com/hajimehoshi/ebiten/internal/shaderir" +) // TODO: What about array types? -const ( - basicTypeNone basicType = iota - basicTypeFloat - basicTypeVec2 - basicTypeVec3 - basicTypeVec4 - basicTypeMat2 - basicTypeMat3 - basicTypeMat4 - basicTypeSampler2d - basicTypeStruct -) - -type structMember struct { - name string - typ typ - tag string -} - type typ struct { - basic basicType - name string - structMembers []structMember + ir shaderir.Type + name string } -func (t *typ) isNone() bool { - return t.basic == basicTypeNone -} - -func (sh *Shader) parseType(expr ast.Expr) typ { +func (cs *compileState) parseType(expr ast.Expr) typ { switch t := expr.(type) { case *ast.Ident: switch t.Name { + case "bool": + return typ{ + ir: shaderir.Type{Main: shaderir.Bool}, + } + case "int": + return typ{ + ir: shaderir.Type{Main: shaderir.Int}, + } case "float": return typ{ - basic: basicTypeFloat, + ir: shaderir.Type{Main: shaderir.Float}, } case "vec2": return typ{ - basic: basicTypeVec2, + ir: shaderir.Type{Main: shaderir.Vec2}, } case "vec3": return typ{ - basic: basicTypeVec3, + ir: shaderir.Type{Main: shaderir.Vec3}, } case "vec4": return typ{ - basic: basicTypeVec4, + ir: shaderir.Type{Main: shaderir.Vec4}, } case "mat2": return typ{ - basic: basicTypeMat2, + ir: shaderir.Type{Main: shaderir.Mat2}, } case "mat3": return typ{ - basic: basicTypeMat3, + ir: shaderir.Type{Main: shaderir.Mat3}, } case "mat4": return typ{ - basic: basicTypeMat4, + ir: shaderir.Type{Main: shaderir.Mat4}, } - case "sampler2d": + case "texture2d": return typ{ - basic: basicTypeSampler2d, + ir: shaderir.Type{Main: shaderir.Texture2D}, } default: - sh.addError(t.Pos(), fmt.Sprintf("unexpected type: %s", t.Name)) + cs.addError(t.Pos(), fmt.Sprintf("unexpected type: %s", t.Name)) return typ{} } case *ast.StructType: - str := typ{ - basic: basicTypeStruct, - } - for _, f := range t.Fields.List { - typ := sh.parseType(f.Type) - var tag string - if f.Tag != nil { - tag = f.Tag.Value - } - for _, n := range f.Names { - str.structMembers = append(str.structMembers, structMember{ - name: n.Name, - typ: typ, - tag: tag, - }) - } - } - return str + cs.addError(t.Pos(), "struct is not implemented") + return typ{} default: - sh.addError(t.Pos(), fmt.Sprintf("unepxected type: %v", t)) + cs.addError(t.Pos(), fmt.Sprintf("unepxected type: %v", t)) return typ{} } } - -func (t typ) dump(indent int) []string { - idt := strings.Repeat("\t", indent) - - switch t.basic { - case basicTypeStruct: - ls := []string{ - fmt.Sprintf("%sstruct {", idt), - } - for _, m := range t.structMembers { - ls = append(ls, fmt.Sprintf("%s\t%s %s", idt, m.name, m.typ)) - } - ls = append(ls, fmt.Sprintf("%s}", idt)) - return ls - default: - return []string{t.basic.String()} - } -} - -func (t typ) String() string { - if t.name != "" { - return t.name - } - return t.basic.String() -} - -func (t basicType) String() string { - switch t { - case basicTypeNone: - return "(none)" - case basicTypeFloat: - return "float" - case basicTypeVec2: - return "vec2" - case basicTypeVec3: - return "vec3" - case basicTypeVec4: - return "vec4" - case basicTypeMat2: - return "mat2" - case basicTypeMat3: - return "mat3" - case basicTypeMat4: - return "mat4" - case basicTypeSampler2d: - return "sampler2d" - case basicTypeStruct: - return "(struct)" - default: - return fmt.Sprintf("unknown(%d)", t) - } -} - -func (t basicType) numeric() bool { - return t != basicTypeNone && t != basicTypeSampler2d -} - -func (t basicType) glslString() string { - switch t { - case basicTypeNone: - return "?(none)" - case basicTypeFloat: - return "float" - case basicTypeVec2: - return "vec2" - case basicTypeVec3: - return "vec3" - case basicTypeVec4: - return "vec4" - case basicTypeMat2: - return "mat2" - case basicTypeMat3: - return "mat3" - case basicTypeMat4: - return "mat4" - case basicTypeSampler2d: - return "?(sampler2d)" - default: - return fmt.Sprintf("?(%d)", t) - } -}