mirror of
https://github.com/hajimehoshi/ebiten.git
synced 2024-12-26 03:38:55 +01:00
shader: Check returning value types and the number
This commit is contained in:
parent
36179636d1
commit
e0b8b9945f
@ -111,7 +111,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok))
|
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok))
|
||||||
}
|
}
|
||||||
case *ast.BlockStmt:
|
case *ast.BlockStmt:
|
||||||
b, ok := cs.parseBlock(block, fname, stmt.List, inParams, nil)
|
b, ok := cs.parseBlock(block, fname, stmt.List, inParams, outParams)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -146,7 +146,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
// Create a new pseudo block for the initial statement, so that the counter variable belongs to the
|
// Create a new pseudo block for the initial statement, so that the counter variable belongs to the
|
||||||
// new pseudo block for each for-loop. Without this, the samely named counter variables in different
|
// new pseudo block for each for-loop. Without this, the samely named counter variables in different
|
||||||
// for-loops confuses the parser.
|
// for-loops confuses the parser.
|
||||||
pseudoBlock, ok := cs.parseBlock(block, fname, []ast.Stmt{stmt.Init}, inParams, nil)
|
pseudoBlock, ok := cs.parseBlock(block, fname, []ast.Stmt{stmt.Init}, inParams, outParams)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -258,7 +258,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
b, ok := cs.parseBlock(pseudoBlock, fname, []ast.Stmt{stmt.Body}, inParams, nil)
|
b, ok := cs.parseBlock(pseudoBlock, fname, []ast.Stmt{stmt.Body}, inParams, outParams)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -289,7 +289,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
if stmt.Init != nil {
|
if stmt.Init != nil {
|
||||||
init := stmt.Init
|
init := stmt.Init
|
||||||
stmt.Init = nil
|
stmt.Init = nil
|
||||||
b, ok := cs.parseBlock(block, fname, []ast.Stmt{init, stmt}, inParams, nil)
|
b, ok := cs.parseBlock(block, fname, []ast.Stmt{init, stmt}, inParams, outParams)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -316,7 +316,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
stmts = append(stmts, ss...)
|
stmts = append(stmts, ss...)
|
||||||
|
|
||||||
var bs []*shaderir.Block
|
var bs []*shaderir.Block
|
||||||
b, ok := cs.parseBlock(block, fname, stmt.Body.List, inParams, nil)
|
b, ok := cs.parseBlock(block, fname, stmt.Body.List, inParams, outParams)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -325,13 +325,13 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
if stmt.Else != nil {
|
if stmt.Else != nil {
|
||||||
switch s := stmt.Else.(type) {
|
switch s := stmt.Else.(type) {
|
||||||
case *ast.BlockStmt:
|
case *ast.BlockStmt:
|
||||||
b, ok := cs.parseBlock(block, fname, s.List, inParams, nil)
|
b, ok := cs.parseBlock(block, fname, s.List, inParams, outParams)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
bs = append(bs, b.ir)
|
bs = append(bs, b.ir)
|
||||||
default:
|
default:
|
||||||
b, ok := cs.parseBlock(block, fname, []ast.Stmt{s}, inParams, nil)
|
b, ok := cs.parseBlock(block, fname, []ast.Stmt{s}, inParams, outParams)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -378,8 +378,14 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
})
|
})
|
||||||
|
|
||||||
case *ast.ReturnStmt:
|
case *ast.ReturnStmt:
|
||||||
|
if len(stmt.Results) != len(outParams) {
|
||||||
|
// TODO: Implenet multiple-context.
|
||||||
|
cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results)))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
for i, r := range stmt.Results {
|
for i, r := range stmt.Results {
|
||||||
exprs, _, ss, ok := cs.parseExpr(block, r)
|
exprs, ts, ss, ok := cs.parseExpr(block, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -392,11 +398,23 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t := ts[0]
|
||||||
expr := exprs[0]
|
expr := exprs[0]
|
||||||
if expr.Type == shaderir.NumberExpr && outParams[i].typ.Main == shaderir.Int {
|
if expr.Type == shaderir.NumberExpr {
|
||||||
|
switch outParams[i].typ.Main {
|
||||||
|
case shaderir.Int:
|
||||||
if !cs.forceToInt(stmt, &expr) {
|
if !cs.forceToInt(stmt, &expr) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
t = shaderir.Type{Main: shaderir.Int}
|
||||||
|
case shaderir.Float:
|
||||||
|
t = shaderir.Type{Main: shaderir.Float}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.Equal(&outParams[i].typ) {
|
||||||
|
cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", &t, &outParams[i].typ))
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
stmts = append(stmts, shaderir.Stmt{
|
stmts = append(stmts, shaderir.Stmt{
|
||||||
|
@ -239,3 +239,14 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
|
|||||||
t.Errorf("error must be nil but non-nil: %v", err)
|
t.Errorf("error must be nil but non-nil: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShaderWrongReturn(t *testing.T) {
|
||||||
|
if _, err := NewShader([]byte(`package main
|
||||||
|
|
||||||
|
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
`)); err == nil {
|
||||||
|
t.Errorf("error must be non-nil but was nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user