shader: Implement function call

This commit is contained in:
Hajime Hoshi 2020-06-07 00:49:28 +09:00
parent 2ffbd49602
commit c986da8970
2 changed files with 236 additions and 74 deletions

View File

@ -53,6 +53,8 @@ type compileState struct {
// uniforms is a collection of uniform variable names. // uniforms is a collection of uniform variable names.
uniforms []string uniforms []string
funcs []function
global block global block
varyingParsed bool varyingParsed bool
@ -60,6 +62,15 @@ type compileState struct {
errs []string errs []string
} }
func (cs *compileState) findFunction(name string) (int, bool) {
for i, f := range cs.funcs {
if f.name == name {
return i, true
}
}
return 0, false
}
func (cs *compileState) findUniformVariable(name string) (int, bool) { func (cs *compileState) findUniformVariable(name string) (int, bool) {
for i, u := range cs.uniforms { for i, u := range cs.uniforms {
if u == name { if u == name {
@ -78,7 +89,6 @@ type block struct {
types []typ types []typ
vars []variable vars []variable
consts []constant consts []constant
funcs []function
pos token.Pos pos token.Pos
outer *block outer *block
@ -159,6 +169,40 @@ func (cs *compileState) parse(f *ast.File) {
cs.uniforms = unames cs.uniforms = unames
cs.ir.Uniforms = utypes cs.ir.Uniforms = utypes
// Parse function names so that any other function call the others.
// The function data is provisional and will be updated soon.
for _, d := range f.Decls {
fd, ok := d.(*ast.FuncDecl)
if !ok {
continue
}
n := fd.Name.Name
if n == cs.vertexEntry {
continue
}
if n == cs.fragmentEntry {
continue
}
inParams, outParams := cs.parseFuncParams(fd)
var inT, outT []shaderir.Type
for _, v := range inParams {
inT = append(inT, v.typ)
}
for _, v := range outParams {
outT = append(outT, v.typ)
}
cs.funcs = append(cs.funcs, function{
name: n,
ir: shaderir.Func{
Index: len(cs.funcs),
InParams: inT,
OutParams: outT,
},
})
}
// Parse functions. // Parse functions.
for _, d := range f.Decls { for _, d := range f.Decls {
if _, ok := d.(*ast.FuncDecl); ok { if _, ok := d.(*ast.FuncDecl); ok {
@ -170,7 +214,7 @@ func (cs *compileState) parse(f *ast.File) {
return return
} }
for _, f := range cs.global.funcs { for _, f := range cs.funcs {
cs.ir.Funcs = append(cs.ir.Funcs, f.ir) cs.ir.Funcs = append(cs.ir.Funcs, f.ir)
} }
} }
@ -237,17 +281,25 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) {
} }
case *ast.FuncDecl: case *ast.FuncDecl:
f := cs.parseFunc(b, d) f := cs.parseFunc(b, d)
if b == &cs.global { if b != &cs.global {
cs.addError(d.Pos(), "non-global function is not implemented")
return
}
switch d.Name.Name { switch d.Name.Name {
case cs.vertexEntry: case cs.vertexEntry:
cs.ir.VertexFunc.Block = f.ir.Block cs.ir.VertexFunc.Block = f.ir.Block
case cs.fragmentEntry: case cs.fragmentEntry:
cs.ir.FragmentFunc.Block = f.ir.Block cs.ir.FragmentFunc.Block = f.ir.Block
default: default:
b.funcs = append(b.funcs, f) // The function is already registered for their names.
for i := range cs.funcs {
if cs.funcs[i].name == d.Name.Name {
// Index is already determined by the provisional parsing.
f.ir.Index = cs.funcs[i].ir.Index
cs.funcs[i] = f
break
}
} }
} else {
b.funcs = append(b.funcs, f)
} }
default: default:
cs.addError(d.Pos(), "unexpected decl") cs.addError(d.Pos(), "unexpected decl")
@ -305,6 +357,40 @@ func (s *compileState) parseConstant(vs *ast.ValueSpec) []constant {
return cs return cs
} }
func (cs *compileState) parseFuncParams(d *ast.FuncDecl) (in, out []variable) {
for _, f := range d.Type.Params.List {
t := cs.parseType(f.Type)
for _, n := range f.Names {
in = append(in, variable{
name: n.Name,
typ: t,
})
}
}
if d.Type.Results == nil {
return
}
for _, f := range d.Type.Results.List {
t := cs.parseType(f.Type)
if len(f.Names) == 0 {
out = append(out, variable{
name: "",
typ: t,
})
} else {
for _, n := range f.Names {
out = append(out, variable{
name: n.Name,
typ: t,
})
}
}
}
return
}
func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function { func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
if d.Name == nil { if d.Name == nil {
cs.addError(d.Pos(), "function must have a name") cs.addError(d.Pos(), "function must have a name")
@ -315,51 +401,15 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
return function{} return function{}
} }
var inT []shaderir.Type inParams, outParams := cs.parseFuncParams(d)
var inParams []variable
for _, f := range d.Type.Params.List { checkVaryings := func(vs []variable) {
t := cs.parseType(f.Type) if len(cs.ir.Varyings) != len(vs) {
for _, n := range f.Names {
inParams = append(inParams, variable{
name: n.Name,
typ: t,
})
inT = append(inT, t)
}
}
var outT []shaderir.Type
var outParams []variable
if d.Type.Results != nil {
for _, f := range d.Type.Results.List {
t := cs.parseType(f.Type)
if len(f.Names) == 0 {
outParams = append(outParams, variable{
name: "",
typ: t,
})
outT = append(outT, t)
} else {
for _, n := range f.Names {
outParams = append(outParams, variable{
name: n.Name,
typ: t,
})
outT = append(outT, t)
}
}
}
}
checkVaryings := func(types []shaderir.Type) {
if len(cs.ir.Varyings) != len(types) {
cs.addError(d.Pos(), fmt.Sprintf("the number of vertex entry point's returning values and the number of framgent entry point's params must be the same")) cs.addError(d.Pos(), fmt.Sprintf("the number of vertex entry point's returning values and the number of framgent entry point's params must be the same"))
return return
} }
for i, t := range cs.ir.Varyings { for i, t := range cs.ir.Varyings {
if t.Main != types[i].Main { if t.Main != vs[i].typ.Main {
cs.addError(d.Pos(), fmt.Sprintf("vertex entry point's returning value types and framgent entry point's param types must match")) cs.addError(d.Pos(), fmt.Sprintf("vertex entry point's returning value types and framgent entry point's param types must match"))
} }
} }
@ -368,8 +418,8 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
if block == &cs.global { if block == &cs.global {
switch d.Name.Name { switch d.Name.Name {
case cs.vertexEntry: case cs.vertexEntry:
for _, t := range inT { for _, v := range inParams {
cs.ir.Attributes = append(cs.ir.Attributes, t) cs.ir.Attributes = append(cs.ir.Attributes, v.typ)
} }
// The first out-param is treated as gl_Position in GLSL. // The first out-param is treated as gl_Position in GLSL.
@ -377,17 +427,17 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position")) cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position"))
return function{} return function{}
} }
if outT[0].Main != shaderir.Vec4 { if outParams[0].typ.Main != shaderir.Vec4 {
cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position")) cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position"))
return function{} return function{}
} }
if cs.varyingParsed { if cs.varyingParsed {
checkVaryings(outT[1:]) checkVaryings(outParams[1:])
} else { } else {
for _, t := range outT[1:] { for _, v := range outParams[1:] {
// TODO: Check that these params are not arrays or structs // TODO: Check that these params are not arrays or structs
cs.ir.Varyings = append(cs.ir.Varyings, t) cs.ir.Varyings = append(cs.ir.Varyings, v.typ)
} }
} }
cs.varyingParsed = true cs.varyingParsed = true
@ -396,7 +446,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position")) cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position"))
return function{} return function{}
} }
if inT[0].Main != shaderir.Vec4 { if inParams[0].typ.Main != shaderir.Vec4 {
cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position")) cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position"))
return function{} return function{}
} }
@ -405,16 +455,16 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color")) cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color"))
return function{} return function{}
} }
if outT[0].Main != shaderir.Vec4 { if outParams[0].typ.Main != shaderir.Vec4 {
cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color")) cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color"))
return function{} return function{}
} }
if cs.varyingParsed { if cs.varyingParsed {
checkVaryings(inT[1:]) checkVaryings(inParams[1:])
} else { } else {
for _, t := range inT[1:] { for _, v := range inParams[1:] {
cs.ir.Varyings = append(cs.ir.Varyings, t) cs.ir.Varyings = append(cs.ir.Varyings, v.typ)
} }
} }
cs.varyingParsed = true cs.varyingParsed = true
@ -423,11 +473,18 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) function {
b := cs.parseBlock(block, d.Body, inParams, outParams) b := cs.parseBlock(block, d.Body, inParams, outParams)
var inT, outT []shaderir.Type
for _, v := range inParams {
inT = append(inT, v.typ)
}
for _, v := range outParams {
outT = append(outT, v.typ)
}
return function{ return function{
name: d.Name.Name, name: d.Name.Name,
block: b, block: b,
ir: shaderir.Func{ ir: shaderir.Func{
Index: len(cs.ir.Funcs),
InParams: inT, InParams: inT,
OutParams: outT, OutParams: outT,
Block: b.ir, Block: b.ir,
@ -666,27 +723,104 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) (shaderir.Expr, [
Exprs: []shaderir.Expr{lhs, rhs}, Exprs: []shaderir.Expr{lhs, rhs},
}, stmts }, stmts
case *ast.CallExpr: case *ast.CallExpr:
var exprs []shaderir.Expr var (
var stmts []shaderir.Stmt callee shaderir.Expr
args []shaderir.Expr
stmts []shaderir.Stmt
)
// Parse the argument first for the order of the statements. // Parse the argument first for the order of the statements.
for _, a := range e.Args { for _, a := range e.Args {
e, ss := cs.parseExpr(block, a) e, ss := cs.parseExpr(block, a)
// TODO: Convert integer literals to float literals if necessary. // TODO: Convert integer literals to float literals if necessary.
exprs = append(exprs, e) args = append(args, e)
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
} }
// TODO: When len(stmts) is not 0? // TODO: When len(ss) is not 0?
expr, ss := cs.parseExpr(block, e.Fun) expr, ss := cs.parseExpr(block, e.Fun)
exprs = append([]shaderir.Expr{expr}, exprs...) callee = expr
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
// TODO: Return statements to call the function separately. // For built-in functions, we can call this in this position. Return an expression for the function
// call.
if expr.Type == shaderir.BuiltinFuncExpr {
return shaderir.Expr{ return shaderir.Expr{
Type: shaderir.Call, Type: shaderir.Call,
Exprs: exprs, Exprs: append([]shaderir.Expr{callee}, args...),
}, stmts }, stmts
}
if expr.Type != shaderir.FunctionExpr {
cs.addError(e.Pos(), fmt.Sprintf("function callee must be a funciton name but %s", e.Fun))
}
f := cs.funcs[expr.Index]
var outParams []int
for _, p := range f.ir.OutParams {
idx := len(block.vars)
block.vars = append(block.vars, variable{
typ: p,
})
block.ir.LocalVars = append(block.ir.LocalVars, p)
args = append(args, shaderir.Expr{
Type: shaderir.LocalVariable,
Index: idx,
})
outParams = append(outParams, idx)
}
if t := f.ir.Return; t.Main != shaderir.None {
idx := len(block.vars)
block.vars = append(block.vars, variable{
typ: t,
})
// Calling the function should be done eariler to treat out-params correctly.
stmts = append(stmts, shaderir.Stmt{
Type: shaderir.Assign,
Exprs: []shaderir.Expr{
{
Type: shaderir.LocalVariable,
Index: idx,
},
{
Type: shaderir.Call,
Exprs: append([]shaderir.Expr{callee}, args...),
},
},
})
// The actual expression here is just a local variable that includes the result of the
// function call.
return shaderir.Expr{
Type: shaderir.LocalVariable,
Index: idx,
}, stmts
}
// Even if the function doesn't return anything, calling the function should be done eariler to keep
// the evaluation order.
stmts = append(stmts, shaderir.Stmt{
Type: shaderir.ExprStmt,
Exprs: []shaderir.Expr{
{
Type: shaderir.Call,
Exprs: append([]shaderir.Expr{callee}, args...),
},
},
})
// TODO: What about the other params?
if len(outParams) > 0 {
return shaderir.Expr{
Type: shaderir.LocalVariable,
Index: outParams[0],
}, stmts
}
// TODO: Is an empty expression work?
return shaderir.Expr{}, stmts
case *ast.Ident: case *ast.Ident:
if i, ok := block.findLocalVariable(e.Name); ok { if i, ok := block.findLocalVariable(e.Name); ok {
return shaderir.Expr{ return shaderir.Expr{
@ -694,6 +828,12 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) (shaderir.Expr, [
Index: i, Index: i,
}, nil }, nil
} }
if i, ok := cs.findFunction(e.Name); ok {
return shaderir.Expr{
Type: shaderir.FunctionExpr,
Index: i,
}, nil
}
if i, ok := cs.findUniformVariable(e.Name); ok { if i, ok := cs.findUniformVariable(e.Name); ok {
return shaderir.Expr{ return shaderir.Expr{
Type: shaderir.UniformVariable, Type: shaderir.UniformVariable,

View File

@ -183,6 +183,28 @@ func Foo(foo vec2) vec4 {
l2 = vec4(l0, 0.0, 1.0); l2 = vec4(l0, 0.0, 1.0);
l1 = l2; l1 = l2;
return; return;
}`,
},
{
Name: "call",
Src: `package main
func Foo(x vec2) vec2 {
return Bar(x)
}
func Bar(x vec2) vec2 {
return x
}`,
VS: `void F0(in vec2 l0, out vec2 l1) {
vec2 l2 = vec2(0.0);
F1(l0, l2);
l1 = l2;
return;
}
void F1(in vec2 l0, out vec2 l1) {
l1 = l0;
return;
}`, }`,
}, },
{ {