diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 5c03d19dd..6196930c0 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -19,6 +19,7 @@ import ( "go/ast" "go/parser" "go/token" + "strconv" "strings" "github.com/hajimehoshi/ebiten/internal/shaderir" @@ -38,8 +39,6 @@ type constant struct { type function struct { name string - in []string - out []string block *block ir shaderir.Func @@ -218,39 +217,46 @@ func (cs *compileState) parseFunc(d *ast.FuncDecl, block *block) function { return function{} } - var in []string + var vars []variable + var inT []shaderir.Type for _, f := range d.Type.Params.List { t := cs.parseType(f.Type) for _, n := range f.Names { - in = append(in, n.Name) + vars = append(vars, variable{ + name: n.Name, + typ: t, + }) inT = append(inT, t.ir) } } - var out []string var outT []shaderir.Type if d.Type.Results != nil { for _, f := range d.Type.Results.List { t := cs.parseType(f.Type) if len(f.Names) == 0 { - out = append(out, "") + vars = append(vars, variable{ + name: "", + typ: t, + }) outT = append(outT, t.ir) } else { for _, n := range f.Names { - out = append(out, n.Name) + vars = append(vars, variable{ + name: n.Name, + typ: t, + }) outT = append(outT, t.ir) } } } } - b := cs.parseBlock(block, d.Body) + b := cs.parseBlock(block, d.Body, vars) return function{ name: d.Name.Name, - in: in, - out: out, block: b, ir: shaderir.Func{ Index: len(cs.ir.Funcs), @@ -261,8 +267,9 @@ func (cs *compileState) parseFunc(d *ast.FuncDecl, block *block) function { } } -func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt) *block { +func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, locals []variable) *block { block := &block{ + vars: locals, outer: outer, } @@ -304,7 +311,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt) *block { cs.parseDecl(block, l.Decl, false) case *ast.ReturnStmt: for _, r := range l.Results { - e := cs.parseExpr(r) + e := cs.parseExpr(block, r) block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ Type: shaderir.Assign, Exprs: []shaderir.Expr{ @@ -369,22 +376,74 @@ func (s *compileState) detectType(b *block, expr ast.Expr) typ { } } -func (cs *compileState) parseExpr(expr ast.Expr) shaderir.Expr { +func (b *block) findLocalVariable(name string) (int, bool) { + for i, v := range b.vars { + if v.name == name { + return i, true + } + } + if b.outer != nil { + return b.outer.findLocalVariable(name) + } + return 0, false +} + +func (cs *compileState) parseExpr(block *block, expr ast.Expr) shaderir.Expr { switch e := expr.(type) { + case *ast.BasicLit: + switch e.Kind { + case token.INT: + v, err := strconv.ParseInt(e.Value, 10, 32) + if err != nil { + cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) + return shaderir.Expr{} + } + return shaderir.Expr{ + Type: shaderir.IntExpr, + Int: int32(v), + } + case token.FLOAT: + v, err := strconv.ParseFloat(e.Value, 32) + if err != nil { + cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) + return shaderir.Expr{} + } + return shaderir.Expr{ + Type: shaderir.FloatExpr, + Float: float32(v), + } + default: + cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e)) + } case *ast.CallExpr: + exprs := []shaderir.Expr{ + cs.parseExpr(block, e.Fun), + } + for _, a := range e.Args { + exprs = append(exprs, cs.parseExpr(block, a)) + } return shaderir.Expr{ - Type: shaderir.Call, - Exprs: []shaderir.Expr{ - cs.parseExpr(e.Fun), - }, + Type: shaderir.Call, + Exprs: exprs, } case *ast.Ident: - return shaderir.Expr{ - Type: shaderir.BuiltinFuncExpr, - BuiltinFunc: shaderir.BuiltinFunc(e.Name), + i, ok := block.findLocalVariable(e.Name) + if ok { + return shaderir.Expr{ + Type: shaderir.LocalVariable, + Index: i, + } } + f, ok := shaderir.ParseBuiltinFunc(e.Name) + if ok { + return shaderir.Expr{ + Type: shaderir.BuiltinFuncExpr, + BuiltinFunc: f, + } + } + cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name)) default: - cs.addError(expr.Pos(), fmt.Sprintf("detecting expression not implemented: %#v", e)) + cs.addError(e.Pos(), fmt.Sprintf("expression not implemented: %#v", e)) } return shaderir.Expr{} } diff --git a/internal/shaderir/program.go b/internal/shaderir/program.go index e9fbc62b0..3b1177eee 100644 --- a/internal/shaderir/program.go +++ b/internal/shaderir/program.go @@ -192,3 +192,63 @@ const ( Not BuiltinFunc = "not" Texture2DF BuiltinFunc = "texture2D" ) + +func ParseBuiltinFunc(str string) (BuiltinFunc, bool) { + switch BuiltinFunc(str) { + case Vec2F, + Vec3F, + Vec4F, + Mat2F, + Mat3F, + Mat4F, + Radians, + Degrees, + Sin, + Cos, + Tan, + Asin, + Acos, + Atan, + Pow, + Exp, + Log, + Exp2, + Log2, + Sqrt, + Inversesqrt, + Abs, + Sign, + Floor, + Ceil, + Fract, + Mod, + Min, + Max, + Clamp, + Mix, + Step, + Smoothstep, + Length, + Distance, + Dot, + Cross, + Normalize, + Faceforward, + Reflect, + MatrixCompMult, + OuterProduct, + Transpose, + LessThan, + LessThanEqual, + GreaterThan, + GreaterThanEqual, + Equal, + NotEqual, + Any, + All, + Not, + Texture2DF: + return BuiltinFunc(str), true + } + return "", false +}