shader: Check returning value types and the number

This commit is contained in:
Hajime Hoshi 2020-09-06 22:08:57 +09:00
parent 36179636d1
commit e0b8b9945f
2 changed files with 40 additions and 11 deletions

View File

@ -111,7 +111,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok))
} }
case *ast.BlockStmt: case *ast.BlockStmt:
b, ok := cs.parseBlock(block, fname, stmt.List, inParams, nil) b, ok := cs.parseBlock(block, fname, stmt.List, inParams, outParams)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -146,7 +146,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
// Create a new pseudo block for the initial statement, so that the counter variable belongs to the // Create a new pseudo block for the initial statement, so that the counter variable belongs to the
// new pseudo block for each for-loop. Without this, the samely named counter variables in different // new pseudo block for each for-loop. Without this, the samely named counter variables in different
// for-loops confuses the parser. // for-loops confuses the parser.
pseudoBlock, ok := cs.parseBlock(block, fname, []ast.Stmt{stmt.Init}, inParams, nil) pseudoBlock, ok := cs.parseBlock(block, fname, []ast.Stmt{stmt.Init}, inParams, outParams)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -258,7 +258,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
b, ok := cs.parseBlock(pseudoBlock, fname, []ast.Stmt{stmt.Body}, inParams, nil) b, ok := cs.parseBlock(pseudoBlock, fname, []ast.Stmt{stmt.Body}, inParams, outParams)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -289,7 +289,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
if stmt.Init != nil { if stmt.Init != nil {
init := stmt.Init init := stmt.Init
stmt.Init = nil stmt.Init = nil
b, ok := cs.parseBlock(block, fname, []ast.Stmt{init, stmt}, inParams, nil) b, ok := cs.parseBlock(block, fname, []ast.Stmt{init, stmt}, inParams, outParams)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -316,7 +316,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
var bs []*shaderir.Block var bs []*shaderir.Block
b, ok := cs.parseBlock(block, fname, stmt.Body.List, inParams, nil) b, ok := cs.parseBlock(block, fname, stmt.Body.List, inParams, outParams)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -325,13 +325,13 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
if stmt.Else != nil { if stmt.Else != nil {
switch s := stmt.Else.(type) { switch s := stmt.Else.(type) {
case *ast.BlockStmt: case *ast.BlockStmt:
b, ok := cs.parseBlock(block, fname, s.List, inParams, nil) b, ok := cs.parseBlock(block, fname, s.List, inParams, outParams)
if !ok { if !ok {
return nil, false return nil, false
} }
bs = append(bs, b.ir) bs = append(bs, b.ir)
default: default:
b, ok := cs.parseBlock(block, fname, []ast.Stmt{s}, inParams, nil) b, ok := cs.parseBlock(block, fname, []ast.Stmt{s}, inParams, outParams)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -378,8 +378,14 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
}) })
case *ast.ReturnStmt: case *ast.ReturnStmt:
if len(stmt.Results) != len(outParams) {
// TODO: Implenet multiple-context.
cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results)))
return nil, false
}
for i, r := range stmt.Results { for i, r := range stmt.Results {
exprs, _, ss, ok := cs.parseExpr(block, r) exprs, ts, ss, ok := cs.parseExpr(block, r)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -392,13 +398,25 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
continue continue
} }
t := ts[0]
expr := exprs[0] expr := exprs[0]
if expr.Type == shaderir.NumberExpr && outParams[i].typ.Main == shaderir.Int { if expr.Type == shaderir.NumberExpr {
if !cs.forceToInt(stmt, &expr) { switch outParams[i].typ.Main {
return nil, false case shaderir.Int:
if !cs.forceToInt(stmt, &expr) {
return nil, false
}
t = shaderir.Type{Main: shaderir.Int}
case shaderir.Float:
t = shaderir.Type{Main: shaderir.Float}
} }
} }
if !t.Equal(&outParams[i].typ) {
cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", &t, &outParams[i].typ))
return nil, false
}
stmts = append(stmts, shaderir.Stmt{ stmts = append(stmts, shaderir.Stmt{
Type: shaderir.Assign, Type: shaderir.Assign,
Exprs: []shaderir.Expr{ Exprs: []shaderir.Expr{

View File

@ -239,3 +239,14 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
t.Errorf("error must be nil but non-nil: %v", err) t.Errorf("error must be nil but non-nil: %v", err)
} }
} }
func TestShaderWrongReturn(t *testing.T) {
if _, err := NewShader([]byte(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
return 0.0;
}
`)); err == nil {
t.Errorf("error must be non-nil but was nil")
}
}