diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 46fb01667..420bf2a35 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -26,12 +26,12 @@ import ( type variable struct { name string - typ typ + typ shaderir.Type } type constant struct { name string - typ typ + typ shaderir.Type init ast.Expr } @@ -69,6 +69,11 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) { return 0, false } +type typ struct { + name string + ir shaderir.Type +} + type block struct { types []typ vars []variable @@ -179,8 +184,10 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) { for _, s := range d.Specs { s := s.(*ast.TypeSpec) t := cs.parseType(s.Type) - t.name = s.Name.Name - b.types = append(b.types, t) + b.types = append(b.types, typ{ + name: s.Name.Name, + ir: t, + }) } case token.CONST: for _, s := range d.Specs { @@ -202,13 +209,13 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) { } } cs.uniforms = append(cs.uniforms, v.name) - cs.ir.Uniforms = append(cs.ir.Uniforms, v.typ.ir) + cs.ir.Uniforms = append(cs.ir.Uniforms, v.typ) } continue } for i, v := range vs { b.vars = append(b.vars, v) - b.ir.LocalVars = append(b.ir.LocalVars, v.typ.ir) + b.ir.LocalVars = append(b.ir.LocalVars, v.typ) if inits[i] != nil { b.ir.Stmts = append(b.ir.Stmts, shaderir.Stmt{ Type: shaderir.Assign, @@ -248,7 +255,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) { } func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []*shaderir.Expr, []shaderir.Stmt) { - var t typ + var t shaderir.Type if vs.Type != nil { t = s.parseType(vs.Type) } @@ -260,7 +267,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl var init ast.Expr if len(vs.Values) > 0 { init = vs.Values[i] - if t.ir.Main == shaderir.None { + if t.Main == shaderir.None { t = s.detectType(block, init) } } @@ -282,7 +289,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl } func (s *compileState) parseConstant(vs *ast.ValueSpec) []constant { - var t typ + var t shaderir.Type if vs.Type != nil { t = s.parseType(vs.Type) } @@ -318,7 +325,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { name: n.Name, typ: t, }) - inT = append(inT, t.ir) + inT = append(inT, t) } } @@ -333,14 +340,14 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { name: "", typ: t, }) - outT = append(outT, t.ir) + outT = append(outT, t) } else { for _, n := range f.Names { outParams = append(outParams, variable{ name: n.Name, typ: t, }) - outT = append(outT, t.ir) + outT = append(outT, t) } } } @@ -448,7 +455,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out } v.typ = cs.detectType(block, l.Rhs[i]) block.vars = append(block.vars, v) - block.ir.LocalVars = append(block.ir.LocalVars, v.typ.ir) + block.ir.LocalVars = append(block.ir.LocalVars, v.typ) // Prase RHS first for the order of the statements. rhs, stmts := cs.parseExpr(block, l.Rhs[i]) @@ -510,43 +517,39 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out return block } -func (s *compileState) detectType(b *block, expr ast.Expr) typ { +func (s *compileState) detectType(b *block, expr ast.Expr) shaderir.Type { switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { case token.FLOAT: - return typ{ - ir: shaderir.Type{Main: shaderir.Float}, - } + return shaderir.Type{Main: shaderir.Float} case token.INT: - return typ{ - ir: shaderir.Type{Main: shaderir.Int}, - } + return shaderir.Type{Main: shaderir.Int} } s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) - return typ{} + return shaderir.Type{} case *ast.CallExpr: n := e.Fun.(*ast.Ident).Name f, ok := shaderir.ParseBuiltinFunc(n) if ok { switch f { case shaderir.Vec2F: - return typ{ir: shaderir.Type{Main: shaderir.Vec2}} + return shaderir.Type{Main: shaderir.Vec2} case shaderir.Vec3F: - return typ{ir: shaderir.Type{Main: shaderir.Vec3}} + return shaderir.Type{Main: shaderir.Vec3} case shaderir.Vec4F: - return typ{ir: shaderir.Type{Main: shaderir.Vec4}} + return shaderir.Type{Main: shaderir.Vec4} case shaderir.Mat2F: - return typ{ir: shaderir.Type{Main: shaderir.Mat2}} + return shaderir.Type{Main: shaderir.Mat2} case shaderir.Mat3F: - return typ{ir: shaderir.Type{Main: shaderir.Mat3}} + return shaderir.Type{Main: shaderir.Mat3} case shaderir.Mat4F: - return typ{ir: shaderir.Type{Main: shaderir.Mat4}} + return shaderir.Type{Main: shaderir.Mat4} // TODO: Add more functions } } s.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n)) - return typ{} + return shaderir.Type{} case *ast.CompositeLit: return s.parseType(e.Type) case *ast.Ident: @@ -559,7 +562,7 @@ func (s *compileState) detectType(b *block, expr ast.Expr) typ { if b == &s.global { for i, v := range s.uniforms { if v == n { - return typ{ir: s.ir.Uniforms[i]} + return s.ir.Uniforms[i] } } } @@ -567,12 +570,12 @@ func (s *compileState) detectType(b *block, expr ast.Expr) typ { return s.detectType(b.outer, e) } s.addError(expr.Pos(), fmt.Sprintf("unexpected identifier: %s", n)) - return typ{} + return shaderir.Type{} //case *ast.SelectorExpr: //return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel)) default: s.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr)) - return typ{} + return shaderir.Type{} } } diff --git a/internal/shader/type.go b/internal/shader/type.go index 93e82147e..2adab2edd 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -23,64 +23,39 @@ import ( // TODO: What about array types? -type typ struct { - ir shaderir.Type - name string -} - -func (cs *compileState) parseType(expr ast.Expr) typ { +func (cs *compileState) parseType(expr ast.Expr) shaderir.Type { switch t := expr.(type) { case *ast.Ident: switch t.Name { case "bool": - return typ{ - ir: shaderir.Type{Main: shaderir.Bool}, - } + return shaderir.Type{Main: shaderir.Bool} case "int": - return typ{ - ir: shaderir.Type{Main: shaderir.Int}, - } + return shaderir.Type{Main: shaderir.Int} case "float": - return typ{ - ir: shaderir.Type{Main: shaderir.Float}, - } + return shaderir.Type{Main: shaderir.Float} case "vec2": - return typ{ - ir: shaderir.Type{Main: shaderir.Vec2}, - } + return shaderir.Type{Main: shaderir.Vec2} case "vec3": - return typ{ - ir: shaderir.Type{Main: shaderir.Vec3}, - } + return shaderir.Type{Main: shaderir.Vec3} case "vec4": - return typ{ - ir: shaderir.Type{Main: shaderir.Vec4}, - } + return shaderir.Type{Main: shaderir.Vec4} case "mat2": - return typ{ - ir: shaderir.Type{Main: shaderir.Mat2}, - } + return shaderir.Type{Main: shaderir.Mat2} case "mat3": - return typ{ - ir: shaderir.Type{Main: shaderir.Mat3}, - } + return shaderir.Type{Main: shaderir.Mat3} case "mat4": - return typ{ - ir: shaderir.Type{Main: shaderir.Mat4}, - } + return shaderir.Type{Main: shaderir.Mat4} case "texture2d": - return typ{ - ir: shaderir.Type{Main: shaderir.Texture2D}, - } + return shaderir.Type{Main: shaderir.Texture2D} default: cs.addError(t.Pos(), fmt.Sprintf("unexpected type: %s", t.Name)) - return typ{} + return shaderir.Type{} } case *ast.StructType: cs.addError(t.Pos(), "struct is not implemented") - return typ{} + return shaderir.Type{} default: cs.addError(t.Pos(), fmt.Sprintf("unepxected type: %v", t)) - return typ{} + return shaderir.Type{} } }