shader: Implement 'for' statement

Fixes #1230
This commit is contained in:
Hajime Hoshi 2020-07-11 21:08:19 +09:00
parent e15ee77e8e
commit 2ca551cdc6
5 changed files with 201 additions and 7 deletions

View File

@ -26,6 +26,7 @@ import (
type variable struct { type variable struct {
name string name string
typ shaderir.Type typ shaderir.Type
forLoopCounter bool
} }
type constant struct { type constant struct {
@ -645,6 +646,10 @@ func (cs *compileState) parseBlock(outer *block, stmts []ast.Stmt, inParams, out
offset = len(inParams) + len(outParams) offset = len(inParams) + len(outParams)
} }
for _, v := range block.vars[offset:] { for _, v := range block.vars[offset:] {
if v.forLoopCounter {
block.ir.LocalVars = append(block.ir.LocalVars, shaderir.Type{})
continue
}
block.ir.LocalVars = append(block.ir.LocalVars, v.typ) block.ir.LocalVars = append(block.ir.LocalVars, v.typ)
} }
}() }()

View File

@ -113,6 +113,163 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
return nil, false return nil, false
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case *ast.ForStmt:
msg := "for-statement must follow this format: for (varname) := (constant); (varname) (op) (constant); (varname) (op) (constant) { ..."
if stmt.Init == nil {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if stmt.Cond == nil {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if stmt.Post == nil {
cs.addError(stmt.Pos(), msg)
return nil, false
}
// Create a new pseudo block for the initial statement, so that the counter variable belongs to the
// new pseudo block for each for-loop. Without this, the samely named counter variables in different
// for-loops confuses the parser.
pseudoBlock, ok := cs.parseBlock(block, []ast.Stmt{stmt.Init}, inParams, nil)
if !ok {
return nil, false
}
ss := pseudoBlock.ir.Stmts
if len(ss) != 1 {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if ss[0].Type != shaderir.Assign {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if ss[0].Exprs[0].Type != shaderir.LocalVariable {
cs.addError(stmt.Pos(), msg)
return nil, false
}
varidx := ss[0].Exprs[0].Index
if ss[0].Exprs[1].Type != shaderir.NumberExpr {
cs.addError(stmt.Pos(), msg)
return nil, false
}
vartype := pseudoBlock.vars[0].typ
init := ss[0].Exprs[1].Const
exprs, ts, ss, ok := cs.parseExpr(pseudoBlock, stmt.Cond)
if !ok {
return nil, false
}
if len(exprs) != 1 {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if len(ts) != 1 || ts[0].Main != shaderir.Bool {
cs.addError(stmt.Pos(), "for-statement's condition must be bool")
return nil, false
}
if len(ss) != 0 {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if exprs[0].Type != shaderir.Binary {
cs.addError(stmt.Pos(), msg)
return nil, false
}
op := exprs[0].Op
if op != shaderir.LessThanOp && op != shaderir.LessThanEqualOp && op != shaderir.GreaterThanOp && op != shaderir.GreaterThanEqualOp && op != shaderir.EqualOp && op != shaderir.NotEqualOp {
cs.addError(stmt.Pos(), "for-statement's condition must have one of these operators: <, <=, >, >=, ==, !=")
return nil, false
}
if exprs[0].Exprs[0].Type != shaderir.LocalVariable {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if exprs[0].Exprs[0].Index != varidx {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if exprs[0].Exprs[1].Type != shaderir.NumberExpr {
cs.addError(stmt.Pos(), msg)
return nil, false
}
end := exprs[0].Exprs[1].Const
postSs, ok := cs.parseStmt(pseudoBlock, stmt.Post, inParams)
if !ok {
return nil, false
}
if len(postSs) != 1 {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if postSs[0].Type != shaderir.Assign {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if postSs[0].Exprs[0].Type != shaderir.LocalVariable {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if postSs[0].Exprs[0].Index != varidx {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if postSs[0].Exprs[1].Type != shaderir.Binary {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if postSs[0].Exprs[1].Exprs[0].Type != shaderir.LocalVariable {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if postSs[0].Exprs[1].Exprs[0].Index != varidx {
cs.addError(stmt.Pos(), msg)
return nil, false
}
if postSs[0].Exprs[1].Exprs[1].Type != shaderir.NumberExpr {
cs.addError(stmt.Pos(), msg)
return nil, false
}
delta := postSs[0].Exprs[1].Exprs[1].Const
switch postSs[0].Exprs[1].Op {
case shaderir.Add:
case shaderir.Sub:
delta = gconstant.UnaryOp(token.SUB, delta, 0)
default:
cs.addError(stmt.Pos(), "for-statement's post statement must have one of these operators: +=, -=, ++, --")
return nil, false
}
b, ok := cs.parseBlock(pseudoBlock, []ast.Stmt{stmt.Body}, inParams, nil)
if !ok {
return nil, false
}
bodyir := b.ir
for len(bodyir.Stmts) == 1 && bodyir.Stmts[0].Type == shaderir.BlockStmt {
bodyir = bodyir.Stmts[0].Blocks[0]
}
// As the pseudo block is not actually used, copy the variable part to the actual block.
// This must be done after parsing the for-loop is done, or the duplicated variables confuses the
// parsing.
block.vars = append(block.vars, pseudoBlock.vars[0])
block.vars[len(block.vars)-1].forLoopCounter = true
stmts = append(stmts, shaderir.Stmt{
Type: shaderir.For,
Blocks: []shaderir.Block{bodyir},
ForVarType: vartype,
ForVarIndex: varidx,
ForInit: init,
ForEnd: end,
ForOp: op,
ForDelta: delta,
})
case *ast.IfStmt: case *ast.IfStmt:
if stmt.Init != nil { if stmt.Init != nil {
init := stmt.Init init := stmt.Init
@ -124,9 +281,7 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
stmts = append(stmts, shaderir.Stmt{ stmts = append(stmts, shaderir.Stmt{
Type: shaderir.BlockStmt, Type: shaderir.BlockStmt,
Blocks: []shaderir.Block{ Blocks: []shaderir.Block{b.ir},
b.ir,
},
}) })
return stmts, true return stmts, true
} }
@ -174,6 +329,7 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
Exprs: exprs, Exprs: exprs,
Blocks: bs, Blocks: bs,
}) })
case *ast.IncDecStmt: case *ast.IncDecStmt:
exprs, _, ss, ok := cs.parseExpr(block, stmt.X) exprs, _, ss, ok := cs.parseExpr(block, stmt.X)
if !ok { if !ok {
@ -205,6 +361,7 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
}, },
}, },
}) })
case *ast.ReturnStmt: case *ast.ReturnStmt:
for i, r := range stmt.Results { for i, r := range stmt.Results {
exprs, _, ss, ok := cs.parseExpr(block, r) exprs, _, ss, ok := cs.parseExpr(block, r)
@ -233,6 +390,7 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
stmts = append(stmts, shaderir.Stmt{ stmts = append(stmts, shaderir.Stmt{
Type: shaderir.Return, Type: shaderir.Return,
}) })
case *ast.ExprStmt: case *ast.ExprStmt:
exprs, _, ss, ok := cs.parseExpr(block, stmt.X) exprs, _, ss, ok := cs.parseExpr(block, stmt.X)
if !ok { if !ok {
@ -249,6 +407,7 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
Exprs: []shaderir.Expr{expr}, Exprs: []shaderir.Expr{expr},
}) })
} }
default: default:
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt))
return nil, false return nil, false

View File

@ -0,0 +1,14 @@
void F0(out vec2 l0) {
vec2 l1 = vec2(0);
vec2 l3 = vec2(0);
l1 = vec2(0.0);
for (int l2 = 0; l2 < 100; l2++) {
(l1).x = ((l1).x) + (1);
}
l3 = vec2(0.0);
for (float l4 = 10.0; l4 >= 0.0; l4 -= 2.0) {
(l3).x = ((l3).x) - (1);
}
l0 = l1;
return;
}

13
internal/shader/testdata/for.go vendored Normal file
View File

@ -0,0 +1,13 @@
package main
func Foo() vec2 {
v := vec2(0)
for i := 0; i < 100; i++ {
v.x++
}
v2 := vec2(0)
for i := 10.0; i >= 0; i -= 2 {
v2.x--
}
return v
}

View File

@ -265,7 +265,10 @@ func (p *Program) glslBlock(topBlock, block *Block, level int, localVarIndex int
var lines []string var lines []string
for _, t := range block.LocalVars { for _, t := range block.LocalVars {
// The type is None e.g., when the variable is a for-loop counter.
if t.Main != None {
lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, p.glslVarDecl(&t, fmt.Sprintf("l%d", localVarIndex)), p.glslVarInit(&t))) lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, p.glslVarDecl(&t, fmt.Sprintf("l%d", localVarIndex)), p.glslVarInit(&t)))
}
localVarIndex++ localVarIndex++
} }