shader: Add 'if'

Updates #1230
This commit is contained in:
Hajime Hoshi 2020-07-04 22:49:29 +09:00
parent 3ca6e41194
commit 380b7382ac
8 changed files with 105 additions and 7 deletions

View File

@ -297,6 +297,14 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
return exprs, f.ir.OutParams, stmts, true
case *ast.Ident:
if e.Name == "true" || e.Name == "false" {
return []shaderir.Expr{
{
Type: shaderir.NumberExpr,
Const: gconstant.MakeBool(e.Name == "true"),
},
}, []shaderir.Type{{Main: shaderir.Bool}}, nil, true
}
if i, t, ok := block.findLocalVariable(e.Name); ok {
return []shaderir.Expr{
{

View File

@ -582,7 +582,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool
}
}
b, ok := cs.parseBlock(block, d.Body, inParams, outParams)
b, ok := cs.parseBlock(block, d.Body.List, inParams, outParams)
if !ok {
return function{}, false
}
@ -606,7 +606,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool
}, true
}
func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, outParams []variable) (*block, bool) {
func (cs *compileState) parseBlock(outer *block, stmts []ast.Stmt, inParams, outParams []variable) (*block, bool) {
vars := make([]variable, 0, len(inParams)+len(outParams))
vars = append(vars, inParams...)
vars = append(vars, outParams...)
@ -620,12 +620,12 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
}
}()
for _, l := range b.List {
stmts, ok := cs.parseStmt(block, l, inParams)
for _, stmt := range stmts {
ss, ok := cs.parseStmt(block, stmt, inParams)
if !ok {
return nil, false
}
block.ir.Stmts = append(block.ir.Stmts, stmts...)
block.ir.Stmts = append(block.ir.Stmts, ss...)
}
return block, true

View File

@ -95,7 +95,7 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok))
}
case *ast.BlockStmt:
b, ok := cs.parseBlock(block, stmt, nil, nil)
b, ok := cs.parseBlock(block, stmt.List, nil, nil)
if !ok {
return nil, false
}
@ -109,6 +109,48 @@ func (cs *compileState) parseStmt(block *block, stmt ast.Stmt, inParams []variab
if !cs.parseDecl(block, stmt.Decl) {
return nil, false
}
case *ast.IfStmt:
// TODO: Parse stmt.Init
exprs, ts, ss, ok := cs.parseExpr(block, stmt.Cond)
if !ok {
return nil, false
}
if len(ts) != 1 || ts[0].Main != shaderir.Bool {
cs.addError(stmt.Pos(), fmt.Sprintf("if-condition must be bool but: %#v", ts))
return nil, false
}
stmts = append(stmts, ss...)
var bs []shaderir.Block
b, ok := cs.parseBlock(block, stmt.Body.List, nil, nil)
if !ok {
return nil, false
}
bs = append(bs, b.ir)
if stmt.Else != nil {
switch s := stmt.Else.(type) {
case *ast.BlockStmt:
b, ok := cs.parseBlock(block, s.List, nil, nil)
if !ok {
return nil, false
}
bs = append(bs, b.ir)
default:
b, ok := cs.parseBlock(block, []ast.Stmt{s}, nil, nil)
if !ok {
return nil, false
}
bs = append(bs, b.ir)
}
}
stmts = append(stmts, shaderir.Stmt{
Type: shaderir.If,
Exprs: exprs,
Blocks: bs,
})
case *ast.ReturnStmt:
for i, r := range stmt.Results {
exprs, _, ss, ok := cs.parseExpr(block, r)

10
internal/shader/testdata/if.expected.vs vendored Normal file
View File

@ -0,0 +1,10 @@
void F0(out vec2 l0) {
bool l1 = false;
l1 = true;
if (l1) {
l0 = vec2(0.0);
return;
}
l0 = vec2(1.0);
return;
}

9
internal/shader/testdata/if.go vendored Normal file
View File

@ -0,0 +1,9 @@
package main
func Foo() vec2 {
x := true
if x {
return vec2(0)
}
return vec2(1)
}

View File

@ -0,0 +1,11 @@
void F0(out vec2 l0) {
bool l1 = false;
l1 = true;
if (l1) {
l0 = vec2(0.0);
return;
} else {
l0 = vec2(1.0);
return;
}
}

10
internal/shader/testdata/if_else.go vendored Normal file
View File

@ -0,0 +1,10 @@
package main
func Foo() vec2 {
x := true
if x {
return vec2(0)
} else {
return vec2(1)
}
}

View File

@ -237,7 +237,15 @@ func (p *Program) glslBlock(b *Block, level int, localVarIndex int) []string {
switch e.Type {
case NumberExpr:
switch e.ConstType {
case ConstTypeNone, ConstTypeFloat:
case ConstTypeNone:
if e.Const.Kind() == constant.Bool {
if constant.BoolVal(e.Const) {
return "true"
}
return "false"
}
fallthrough
case ConstTypeFloat:
if i := constant.ToInt(e.Const); i.Kind() == constant.Int {
x, _ := constant.Int64Val(i)
return fmt.Sprintf("%d.0", x)