From c986da897082362fdf4287de6b7c1e9e9d9e4000 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Sun, 7 Jun 2020 00:49:28 +0900 Subject: [PATCH] shader: Implement function call --- internal/shader/shader.go | 288 ++++++++++++++++++++++++--------- internal/shader/shader_test.go | 22 +++ 2 files changed, 236 insertions(+), 74 deletions(-) diff --git a/internal/shader/shader.go b/internal/shader/shader.go index 420bf2a35..e5475b2b9 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -53,6 +53,8 @@ type compileState struct { // uniforms is a collection of uniform variable names. uniforms []string + funcs []function + global block varyingParsed bool @@ -60,6 +62,15 @@ type compileState struct { errs []string } +func (cs *compileState) findFunction(name string) (int, bool) { + for i, f := range cs.funcs { + if f.name == name { + return i, true + } + } + return 0, false +} + func (cs *compileState) findUniformVariable(name string) (int, bool) { for i, u := range cs.uniforms { if u == name { @@ -78,7 +89,6 @@ type block struct { types []typ vars []variable consts []constant - funcs []function pos token.Pos outer *block @@ -159,6 +169,40 @@ func (cs *compileState) parse(f *ast.File) { cs.uniforms = unames cs.ir.Uniforms = utypes + // Parse function names so that any other function call the others. + // The function data is provisional and will be updated soon. + for _, d := range f.Decls { + fd, ok := d.(*ast.FuncDecl) + if !ok { + continue + } + n := fd.Name.Name + if n == cs.vertexEntry { + continue + } + if n == cs.fragmentEntry { + continue + } + + inParams, outParams := cs.parseFuncParams(fd) + var inT, outT []shaderir.Type + for _, v := range inParams { + inT = append(inT, v.typ) + } + for _, v := range outParams { + outT = append(outT, v.typ) + } + + cs.funcs = append(cs.funcs, function{ + name: n, + ir: shaderir.Func{ + Index: len(cs.funcs), + InParams: inT, + OutParams: outT, + }, + }) + } + // Parse functions. for _, d := range f.Decls { if _, ok := d.(*ast.FuncDecl); ok { @@ -170,7 +214,7 @@ func (cs *compileState) parse(f *ast.File) { return } - for _, f := range cs.global.funcs { + for _, f := range cs.funcs { cs.ir.Funcs = append(cs.ir.Funcs, f.ir) } } @@ -237,17 +281,25 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) { } case *ast.FuncDecl: f := cs.parseFunc(b, d) - if b == &cs.global { - switch d.Name.Name { - case cs.vertexEntry: - cs.ir.VertexFunc.Block = f.ir.Block - case cs.fragmentEntry: - cs.ir.FragmentFunc.Block = f.ir.Block - default: - b.funcs = append(b.funcs, f) + if b != &cs.global { + cs.addError(d.Pos(), "non-global function is not implemented") + return + } + switch d.Name.Name { + case cs.vertexEntry: + cs.ir.VertexFunc.Block = f.ir.Block + case cs.fragmentEntry: + cs.ir.FragmentFunc.Block = f.ir.Block + default: + // The function is already registered for their names. + for i := range cs.funcs { + if cs.funcs[i].name == d.Name.Name { + // Index is already determined by the provisional parsing. + f.ir.Index = cs.funcs[i].ir.Index + cs.funcs[i] = f + break + } } - } else { - b.funcs = append(b.funcs, f) } default: cs.addError(d.Pos(), "unexpected decl") @@ -305,6 +357,40 @@ func (s *compileState) parseConstant(vs *ast.ValueSpec) []constant { return cs } +func (cs *compileState) parseFuncParams(d *ast.FuncDecl) (in, out []variable) { + for _, f := range d.Type.Params.List { + t := cs.parseType(f.Type) + for _, n := range f.Names { + in = append(in, variable{ + name: n.Name, + typ: t, + }) + } + } + + if d.Type.Results == nil { + return + } + + for _, f := range d.Type.Results.List { + t := cs.parseType(f.Type) + if len(f.Names) == 0 { + out = append(out, variable{ + name: "", + typ: t, + }) + } else { + for _, n := range f.Names { + out = append(out, variable{ + name: n.Name, + typ: t, + }) + } + } + } + return +} + func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { if d.Name == nil { cs.addError(d.Pos(), "function must have a name") @@ -315,51 +401,15 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { return function{} } - var inT []shaderir.Type - var inParams []variable + inParams, outParams := cs.parseFuncParams(d) - for _, f := range d.Type.Params.List { - t := cs.parseType(f.Type) - for _, n := range f.Names { - inParams = append(inParams, variable{ - name: n.Name, - typ: t, - }) - inT = append(inT, t) - } - } - - var outT []shaderir.Type - var outParams []variable - - if d.Type.Results != nil { - for _, f := range d.Type.Results.List { - t := cs.parseType(f.Type) - if len(f.Names) == 0 { - outParams = append(outParams, variable{ - name: "", - typ: t, - }) - outT = append(outT, t) - } else { - for _, n := range f.Names { - outParams = append(outParams, variable{ - name: n.Name, - typ: t, - }) - outT = append(outT, t) - } - } - } - } - - checkVaryings := func(types []shaderir.Type) { - if len(cs.ir.Varyings) != len(types) { + checkVaryings := func(vs []variable) { + if len(cs.ir.Varyings) != len(vs) { cs.addError(d.Pos(), fmt.Sprintf("the number of vertex entry point's returning values and the number of framgent entry point's params must be the same")) return } for i, t := range cs.ir.Varyings { - if t.Main != types[i].Main { + if t.Main != vs[i].typ.Main { cs.addError(d.Pos(), fmt.Sprintf("vertex entry point's returning value types and framgent entry point's param types must match")) } } @@ -368,8 +418,8 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { if block == &cs.global { switch d.Name.Name { case cs.vertexEntry: - for _, t := range inT { - cs.ir.Attributes = append(cs.ir.Attributes, t) + for _, v := range inParams { + cs.ir.Attributes = append(cs.ir.Attributes, v.typ) } // The first out-param is treated as gl_Position in GLSL. @@ -377,17 +427,17 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position")) return function{} } - if outT[0].Main != shaderir.Vec4 { + if outParams[0].typ.Main != shaderir.Vec4 { cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position")) return function{} } if cs.varyingParsed { - checkVaryings(outT[1:]) + checkVaryings(outParams[1:]) } else { - for _, t := range outT[1:] { + for _, v := range outParams[1:] { // TODO: Check that these params are not arrays or structs - cs.ir.Varyings = append(cs.ir.Varyings, t) + cs.ir.Varyings = append(cs.ir.Varyings, v.typ) } } cs.varyingParsed = true @@ -396,7 +446,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position")) return function{} } - if inT[0].Main != shaderir.Vec4 { + if inParams[0].typ.Main != shaderir.Vec4 { cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position")) return function{} } @@ -405,16 +455,16 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color")) return function{} } - if outT[0].Main != shaderir.Vec4 { + if outParams[0].typ.Main != shaderir.Vec4 { cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color")) return function{} } if cs.varyingParsed { - checkVaryings(inT[1:]) + checkVaryings(inParams[1:]) } else { - for _, t := range inT[1:] { - cs.ir.Varyings = append(cs.ir.Varyings, t) + for _, v := range inParams[1:] { + cs.ir.Varyings = append(cs.ir.Varyings, v.typ) } } cs.varyingParsed = true @@ -423,11 +473,18 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { b := cs.parseBlock(block, d.Body, inParams, outParams) + var inT, outT []shaderir.Type + for _, v := range inParams { + inT = append(inT, v.typ) + } + for _, v := range outParams { + outT = append(outT, v.typ) + } + return function{ name: d.Name.Name, block: b, ir: shaderir.Func{ - Index: len(cs.ir.Funcs), InParams: inT, OutParams: outT, Block: b.ir, @@ -666,27 +723,104 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) (shaderir.Expr, [ Exprs: []shaderir.Expr{lhs, rhs}, }, stmts case *ast.CallExpr: - var exprs []shaderir.Expr - var stmts []shaderir.Stmt + var ( + callee shaderir.Expr + args []shaderir.Expr + stmts []shaderir.Stmt + ) // Parse the argument first for the order of the statements. for _, a := range e.Args { e, ss := cs.parseExpr(block, a) // TODO: Convert integer literals to float literals if necessary. - exprs = append(exprs, e) + args = append(args, e) stmts = append(stmts, ss...) } - // TODO: When len(stmts) is not 0? + // TODO: When len(ss) is not 0? expr, ss := cs.parseExpr(block, e.Fun) - exprs = append([]shaderir.Expr{expr}, exprs...) + callee = expr stmts = append(stmts, ss...) - // TODO: Return statements to call the function separately. - return shaderir.Expr{ - Type: shaderir.Call, - Exprs: exprs, - }, stmts + // For built-in functions, we can call this in this position. Return an expression for the function + // call. + if expr.Type == shaderir.BuiltinFuncExpr { + return shaderir.Expr{ + Type: shaderir.Call, + Exprs: append([]shaderir.Expr{callee}, args...), + }, stmts + } + + if expr.Type != shaderir.FunctionExpr { + cs.addError(e.Pos(), fmt.Sprintf("function callee must be a funciton name but %s", e.Fun)) + } + f := cs.funcs[expr.Index] + + var outParams []int + for _, p := range f.ir.OutParams { + idx := len(block.vars) + block.vars = append(block.vars, variable{ + typ: p, + }) + block.ir.LocalVars = append(block.ir.LocalVars, p) + args = append(args, shaderir.Expr{ + Type: shaderir.LocalVariable, + Index: idx, + }) + outParams = append(outParams, idx) + } + + if t := f.ir.Return; t.Main != shaderir.None { + idx := len(block.vars) + block.vars = append(block.vars, variable{ + typ: t, + }) + + // Calling the function should be done eariler to treat out-params correctly. + stmts = append(stmts, shaderir.Stmt{ + Type: shaderir.Assign, + Exprs: []shaderir.Expr{ + { + Type: shaderir.LocalVariable, + Index: idx, + }, + { + Type: shaderir.Call, + Exprs: append([]shaderir.Expr{callee}, args...), + }, + }, + }) + + // The actual expression here is just a local variable that includes the result of the + // function call. + return shaderir.Expr{ + Type: shaderir.LocalVariable, + Index: idx, + }, stmts + } + + // Even if the function doesn't return anything, calling the function should be done eariler to keep + // the evaluation order. + stmts = append(stmts, shaderir.Stmt{ + Type: shaderir.ExprStmt, + Exprs: []shaderir.Expr{ + { + Type: shaderir.Call, + Exprs: append([]shaderir.Expr{callee}, args...), + }, + }, + }) + + // TODO: What about the other params? + if len(outParams) > 0 { + return shaderir.Expr{ + Type: shaderir.LocalVariable, + Index: outParams[0], + }, stmts + } + + // TODO: Is an empty expression work? + return shaderir.Expr{}, stmts case *ast.Ident: if i, ok := block.findLocalVariable(e.Name); ok { return shaderir.Expr{ @@ -694,6 +828,12 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) (shaderir.Expr, [ Index: i, }, nil } + if i, ok := cs.findFunction(e.Name); ok { + return shaderir.Expr{ + Type: shaderir.FunctionExpr, + Index: i, + }, nil + } if i, ok := cs.findUniformVariable(e.Name); ok { return shaderir.Expr{ Type: shaderir.UniformVariable, diff --git a/internal/shader/shader_test.go b/internal/shader/shader_test.go index 2cea24b4c..a6cf233bd 100644 --- a/internal/shader/shader_test.go +++ b/internal/shader/shader_test.go @@ -183,6 +183,28 @@ func Foo(foo vec2) vec4 { l2 = vec4(l0, 0.0, 1.0); l1 = l2; return; +}`, + }, + { + Name: "call", + Src: `package main + +func Foo(x vec2) vec2 { + return Bar(x) +} + +func Bar(x vec2) vec2 { + return x +}`, + VS: `void F0(in vec2 l0, out vec2 l1) { + vec2 l2 = vec2(0.0); + F1(l0, l2); + l1 = l2; + return; +} +void F1(in vec2 l0, out vec2 l1) { + l1 = l0; + return; }`, }, {