diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 4b2a7df10..0e7866484 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -582,68 +582,6 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out return block } -func (s *compileState) detectType(b *block, expr ast.Expr) []shaderir.Type { - switch e := expr.(type) { - case *ast.BasicLit: - switch e.Kind { - case token.FLOAT: - return []shaderir.Type{{Main: shaderir.Float}} - case token.INT: - return []shaderir.Type{{Main: shaderir.Int}} - } - s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) - return nil - case *ast.CallExpr: - n := e.Fun.(*ast.Ident).Name - f, ok := shaderir.ParseBuiltinFunc(n) - if ok { - switch f { - case shaderir.Vec2F: - return []shaderir.Type{{Main: shaderir.Vec2}} - case shaderir.Vec3F: - return []shaderir.Type{{Main: shaderir.Vec3}} - case shaderir.Vec4F: - return []shaderir.Type{{Main: shaderir.Vec4}} - case shaderir.Mat2F: - return []shaderir.Type{{Main: shaderir.Mat2}} - case shaderir.Mat3F: - return []shaderir.Type{{Main: shaderir.Mat3}} - case shaderir.Mat4F: - return []shaderir.Type{{Main: shaderir.Mat4}} - // TODO: Add more functions - } - } - s.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n)) - return nil - case *ast.CompositeLit: - return []shaderir.Type{s.parseType(e.Type)} - case *ast.Ident: - n := e.Name - for _, v := range b.vars { - if v.name == n { - return []shaderir.Type{v.typ} - } - } - if b == &s.global { - for i, v := range s.uniforms { - if v == n { - return []shaderir.Type{s.ir.Uniforms[i]} - } - } - } - if b.outer != nil { - return s.detectType(b.outer, e) - } - s.addError(expr.Pos(), fmt.Sprintf("unexpected identifier: %s", n)) - return nil - //case *ast.SelectorExpr: - //return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel)) - default: - s.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr)) - return nil - } -} - func (cs *compileState) parseExpr(block *block, expr ast.Expr) (shaderir.Expr, []shaderir.Stmt) { switch e := expr.(type) { case *ast.BasicLit: diff --git a/internal/shader/type.go b/internal/shader/type.go index 2adab2edd..7084ba59b 100644 --- a/internal/shader/type.go +++ b/internal/shader/type.go @@ -17,6 +17,7 @@ package shader import ( "fmt" "go/ast" + "go/token" "github.com/hajimehoshi/ebiten/internal/shaderir" ) @@ -59,3 +60,65 @@ func (cs *compileState) parseType(expr ast.Expr) shaderir.Type { return shaderir.Type{} } } + +func (cs *compileState) detectType(b *block, expr ast.Expr) []shaderir.Type { + switch e := expr.(type) { + case *ast.BasicLit: + switch e.Kind { + case token.FLOAT: + return []shaderir.Type{{Main: shaderir.Float}} + case token.INT: + return []shaderir.Type{{Main: shaderir.Int}} + } + cs.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) + return nil + case *ast.CallExpr: + n := e.Fun.(*ast.Ident).Name + f, ok := shaderir.ParseBuiltinFunc(n) + if ok { + switch f { + case shaderir.Vec2F: + return []shaderir.Type{{Main: shaderir.Vec2}} + case shaderir.Vec3F: + return []shaderir.Type{{Main: shaderir.Vec3}} + case shaderir.Vec4F: + return []shaderir.Type{{Main: shaderir.Vec4}} + case shaderir.Mat2F: + return []shaderir.Type{{Main: shaderir.Mat2}} + case shaderir.Mat3F: + return []shaderir.Type{{Main: shaderir.Mat3}} + case shaderir.Mat4F: + return []shaderir.Type{{Main: shaderir.Mat4}} + // TODO: Add more functions + } + } + cs.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n)) + return nil + case *ast.CompositeLit: + return []shaderir.Type{cs.parseType(e.Type)} + case *ast.Ident: + n := e.Name + for _, v := range b.vars { + if v.name == n { + return []shaderir.Type{v.typ} + } + } + if b == &cs.global { + for i, v := range cs.uniforms { + if v == n { + return []shaderir.Type{cs.ir.Uniforms[i]} + } + } + } + if b.outer != nil { + return cs.detectType(b.outer, e) + } + cs.addError(expr.Pos(), fmt.Sprintf("unexpected identifier: %s", n)) + return nil + //case *ast.SelectorExpr: + //return fmt.Sprintf("%cs.%s", dumpExpr(e.X), dumpExpr(e.Sel)) + default: + cs.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr)) + return nil + } +}