From 9430b6be376c641033c4f3981ab2ec6cb897efea Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Sun, 10 May 2020 02:01:28 +0900 Subject: [PATCH] shader: Parse assignments --- internal/shader/block.go | 49 ++++++++--------- internal/shader/shader.go | 97 ++++++++++++++++------------------ internal/shader/shader_test.go | 8 +-- internal/shader/type.go | 20 +++---- 4 files changed, 85 insertions(+), 89 deletions(-) diff --git a/internal/shader/block.go b/internal/shader/block.go index 23bd272d5..ff3b8bd00 100644 --- a/internal/shader/block.go +++ b/internal/shader/block.go @@ -16,6 +16,7 @@ package shader import ( "fmt" + "go/ast" "go/token" "strings" ) @@ -35,13 +36,13 @@ func (b *block) dump(indent int) []string { for _, v := range b.vars { init := "" - if v.init != "" { - init = " = " + v.init + if v.init != nil { + init = " = " + dumpExpr(v.init) } lines = append(lines, fmt.Sprintf("%svar %s %s%s", idt, v.name, v.typ, init)) } for _, c := range b.consts { - lines = append(lines, fmt.Sprintf("%sconst %s %s = %s", idt, c.name, c.typ, c.init)) + lines = append(lines, fmt.Sprintf("%sconst %s %s = %s", idt, c.name, c.typ, dumpExpr(c.init))) } for _, f := range b.funcs { var args []string @@ -77,12 +78,13 @@ type stmtType int const ( stmtNone stmtType = iota + stmtAssign stmtReturn ) type stmt struct { stmtType stmtType - exprs []expr + exprs []ast.Expr } func (s *stmt) dump(indent int) []string { @@ -92,12 +94,14 @@ func (s *stmt) dump(indent int) []string { switch s.stmtType { case stmtNone: lines = append(lines, "%s(none)", idt) + case stmtAssign: + lines = append(lines, fmt.Sprintf("%s%s = %s", idt, dumpExpr(s.exprs[0]), dumpExpr(s.exprs[1]))) case stmtReturn: var expr string if len(s.exprs) > 0 { var strs []string for _, e := range s.exprs { - strs = append(strs, e.dump()) + strs = append(strs, dumpExpr(e)) } expr = " " + strings.Join(strs, ", ") } @@ -109,25 +113,22 @@ func (s *stmt) dump(indent int) []string { return lines } -type exprType int - -const ( - exprNone exprType = iota - exprIdent -) - -type expr struct { - exprType exprType - value string -} - -func (e *expr) dump() string { - switch e.exprType { - case exprNone: - return "(none)" - case exprIdent: - return e.value +func dumpExpr(e ast.Expr) string { + switch e := e.(type) { + case *ast.BasicLit: + return e.Value + case *ast.CompositeLit: + t := parseType(e.Type) + var vals []string + for _, e := range e.Elts { + vals = append(vals, dumpExpr(e)) + } + return fmt.Sprintf("%s{%s}", t, strings.Join(vals, ", ")) + case *ast.Ident: + return e.Name + case *ast.SelectorExpr: + return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel)) default: - return fmt.Sprintf("(unkown expr: %d)", e.exprType) + return fmt.Sprintf("(unkown expr: %#v)", e) } } diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 6b6ecba30..c91ab030f 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -35,13 +35,13 @@ var ( type variable struct { name string typ typ - init string + init ast.Expr } type constant struct { name string typ typ - init string + init ast.Expr } type function struct { @@ -194,9 +194,9 @@ func (sh *Shader) parseVaryingStruct(t *ast.TypeSpec) { sh.addError(f.Pos(), fmt.Sprintf("position members must be one")) continue } - t, err := parseType(f.Type) - if err != nil { - sh.addError(f.Type.Pos(), err.Error()) + t := parseType(f.Type) + if t == typNone { + sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type)) continue } if t != typVec4 { @@ -209,9 +209,9 @@ func (sh *Shader) parseVaryingStruct(t *ast.TypeSpec) { } continue } - t, err := parseType(f.Type) - if err != nil { - sh.addError(f.Type.Pos(), err.Error()) + t := parseType(f.Type) + if t == typNone { + sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type)) continue } if !t.numeric() { @@ -230,20 +230,24 @@ func (sh *Shader) parseVaryingStruct(t *ast.TypeSpec) { func (s *Shader) parseVariable(vs *ast.ValueSpec) []variable { var t typ if vs.Type != nil { - var err error - t, err = parseType(vs.Type) - if err != nil { - s.addError(vs.Type.Pos(), err.Error()) + t = parseType(vs.Type) + if t == typNone { + s.addError(vs.Type.Pos(), fmt.Sprintf("unexpected type: %s", vs.Type)) return nil } } var vars []variable - for _, n := range vs.Names { + for i, n := range vs.Names { + var init ast.Expr + if len(vs.Values) > 0 { + init = vs.Values[i] + } name := n.Name vars = append(vars, variable{ name: name, typ: t, + init: init, }) } return vars @@ -252,32 +256,19 @@ func (s *Shader) parseVariable(vs *ast.ValueSpec) []variable { func (s *Shader) parseConstant(vs *ast.ValueSpec) []constant { var t typ if vs.Type != nil { - var err error - t, err = parseType(vs.Type) - if err != nil { - s.addError(vs.Type.Pos(), err.Error()) + t = parseType(vs.Type) + if t == typNone { + s.addError(vs.Type.Pos(), fmt.Sprintf("unexpected type: %s", vs.Type)) return nil } } var cs []constant for i, n := range vs.Names { - v := vs.Values[i] - var init string - switch v := v.(type) { - case *ast.BasicLit: - if v.Kind != token.INT && v.Kind != token.FLOAT { - s.addError(v.Pos(), fmt.Sprintf("literal must be int or float but %s", v.Kind)) - return cs - } - init = v.Value // TODO: This should be go/constant.Value - default: - // TODO: Parse the expression. - } cs = append(cs, constant{ name: n.Name, typ: t, - init: init, + init: vs.Values[i], }) } return cs @@ -295,9 +286,9 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl) function { var args []variable for _, f := range d.Type.Params.List { - t, err := parseType(f.Type) - if err != nil { - sh.addError(f.Type.Pos(), err.Error()) + t := parseType(f.Type) + if t == typNone { + sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type)) continue } for _, n := range f.Names { @@ -310,9 +301,9 @@ func (sh *Shader) parseFunc(d *ast.FuncDecl) function { var rets []variable for _, f := range d.Type.Results.List { - t, err := parseType(f.Type) - if err != nil { - sh.addError(f.Type.Pos(), err.Error()) + t := parseType(f.Type) + if t == typNone { + sh.addError(f.Type.Pos(), fmt.Sprintf("unexpected type: %s", f.Type)) continue } if len(f.Names) == 0 { @@ -344,22 +335,35 @@ func (sh *Shader) parseBlock(b *ast.BlockStmt) *block { for _, l := range b.List { switch l := l.(type) { case *ast.AssignStmt: - if l.Tok == token.DEFINE { + switch l.Tok { + case token.DEFINE: for _, s := range l.Lhs { ident := s.(*ast.Ident) block.vars = append(block.vars, variable{ name: ident.Name, }) } - } else { - // TODO + for i := range l.Rhs { + block.stmts = append(block.stmts, stmt{ + stmtType: stmtAssign, + exprs: []ast.Expr{l.Lhs[i], l.Rhs[i]}, + }) + } + case token.ASSIGN: + // TODO: What about the statement `a,b = b,a?` + for i := range l.Rhs { + block.stmts = append(block.stmts, stmt{ + stmtType: stmtAssign, + exprs: []ast.Expr{l.Lhs[i], l.Rhs[i]}, + }) + } } case *ast.DeclStmt: sh.parseDecl(block, l.Decl) case *ast.ReturnStmt: - var exprs []expr + var exprs []ast.Expr for _, r := range l.Results { - exprs = append(exprs, sh.parseExpr(r)) + exprs = append(exprs, r) } block.stmts = append(block.stmts, stmt{ stmtType: stmtReturn, @@ -379,17 +383,6 @@ func (sh *Shader) parseBlock(b *ast.BlockStmt) *block { return block } -func (sh *Shader) parseExpr(e ast.Expr) expr { - switch e := e.(type) { - case *ast.Ident: - return expr{ - exprType: exprIdent, - value: e.Name, - } - } - return expr{} -} - // Dump dumps the shader state in an intermediate language. func (s *Shader) Dump() string { var lines []string diff --git a/internal/shader/shader_test.go b/internal/shader/shader_test.go index 5be4a5ee5..2946bef44 100644 --- a/internal/shader/shader_test.go +++ b/internal/shader/shader_test.go @@ -64,10 +64,12 @@ const C1 float = 1 const C2 float = 2 const C3 float = 3 func F1(a vec2, b vec2) (_ vec4) { - var c0 vec2 - var c1 (none) - var c2 (none) + var c0 vec2 = a + var c1 (none) = b + var c2 (none) = 1.0 var c3 (none) + c1.x = c2.x + c3 = vec4{c0, c1} return c2 } `, diff --git a/internal/shader/type.go b/internal/shader/type.go index 85fe989eb..674df2228 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -35,30 +35,30 @@ const ( typSampler2d ) -func parseType(expr ast.Expr) (typ, error) { +func parseType(expr ast.Expr) typ { switch t := expr.(type) { case *ast.Ident: switch t.Name { case "float": - return typFloat, nil + return typFloat case "vec2": - return typVec2, nil + return typVec2 case "vec3": - return typVec3, nil + return typVec3 case "vec4": - return typVec4, nil + return typVec4 case "mat2": - return typMat2, nil + return typMat2 case "mat3": - return typMat3, nil + return typMat3 case "mat4": - return typMat4, nil + return typMat4 case "sampler2d": - return typSampler2d, nil + return typSampler2d } // TODO: Parse array types } - return 0, fmt.Errorf("invalid type: %s", expr) + return typNone } func (t typ) String() string {