internal/shader: disallow 'discard' in other functions than the fragment entry point

Closes #2248
This commit is contained in:
Hajime Hoshi 2022-08-17 23:13:21 +09:00
parent 9d303e8dc5
commit fb775d806c
5 changed files with 90 additions and 50 deletions

View File

@ -67,7 +67,7 @@ func goConstantKindString(k gconstant.Kind) string {
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
@ -93,7 +93,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var stmts []shaderir.Stmt var stmts []shaderir.Stmt
// Prase LHS first for the order of the statements. // Prase LHS first for the order of the statements.
lhs, ts, ss, ok := cs.parseExpr(block, e.X, markLocalVariableUsed) lhs, ts, ss, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -104,7 +104,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
lhst := ts[0] lhst := ts[0]
rhs, ts, ss, ok := cs.parseExpr(block, e.Y, markLocalVariableUsed) rhs, ts, ss, ok := cs.parseExpr(block, fname, e.Y, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -299,7 +299,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
// Parse the argument first for the order of the statements. // Parse the argument first for the order of the statements.
for _, a := range e.Args { for _, a := range e.Args {
es, ts, ss, ok := cs.parseExpr(block, a, markLocalVariableUsed) es, ts, ss, ok := cs.parseExpr(block, fname, a, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -313,7 +313,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
} }
// TODO: When len(ss) is not 0? // TODO: When len(ss) is not 0?
es, _, ss, ok := cs.parseExpr(block, e.Fun, markLocalVariableUsed) es, _, ss, ok := cs.parseExpr(block, fname, e.Fun, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -473,6 +473,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
if len(args) != 0 { if len(args) != 0 {
cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args))) cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args)))
} }
if fname != cs.fragmentEntry {
cs.addError(e.Pos(), fmt.Sprintf("discard is available only in %s", cs.fragmentEntry))
}
stmts = append(stmts, shaderir.Stmt{ stmts = append(stmts, shaderir.Stmt{
Type: shaderir.Discard, Type: shaderir.Discard,
}) })
@ -653,10 +656,10 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name)) cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name))
case *ast.ParenExpr: case *ast.ParenExpr:
return cs.parseExpr(block, e.X, markLocalVariableUsed) return cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
case *ast.SelectorExpr: case *ast.SelectorExpr:
exprs, _, stmts, ok := cs.parseExpr(block, e.X, true) exprs, _, stmts, ok := cs.parseExpr(block, fname, e.X, true)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -692,7 +695,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{t}, stmts, true
case *ast.UnaryExpr: case *ast.UnaryExpr:
exprs, t, stmts, ok := cs.parseExpr(block, e.X, markLocalVariableUsed) exprs, t, stmts, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -736,7 +739,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
}, t, stmts, true }, t, stmts, true
case *ast.CompositeLit: case *ast.CompositeLit:
t, ok := cs.parseType(block, e.Type) t, ok := cs.parseType(block, fname, e.Type)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -751,7 +754,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var stmts []shaderir.Stmt var stmts []shaderir.Stmt
for i, e := range e.Elts { for i, e := range e.Elts {
exprs, _, ss, ok := cs.parseExpr(block, e, markLocalVariableUsed) exprs, _, ss, ok := cs.parseExpr(block, fname, e, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -793,7 +796,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var stmts []shaderir.Stmt var stmts []shaderir.Stmt
// Parse the index first // Parse the index first
exprs, _, ss, ok := cs.parseExpr(block, e.Index, markLocalVariableUsed) exprs, _, ss, ok := cs.parseExpr(block, fname, e.Index, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -812,7 +815,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
idx.ConstType = shaderir.ConstTypeInt idx.ConstType = shaderir.ConstTypeInt
} }
exprs, ts, ss, ok := cs.parseExpr(block, e.X, markLocalVariableUsed) exprs, ts, ss, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }

View File

@ -39,7 +39,6 @@ type constant struct {
type function struct { type function struct {
name string name string
block *block
ir shaderir.Func ir shaderir.Func
} }
@ -212,7 +211,7 @@ func (cs *compileState) parse(f *ast.File) {
// Parse GenDecl for global variables, and then parse functions. // Parse GenDecl for global variables, and then parse functions.
for _, d := range f.Decls { for _, d := range f.Decls {
if _, ok := d.(*ast.FuncDecl); !ok { if _, ok := d.(*ast.FuncDecl); !ok {
ss, ok := cs.parseDecl(&cs.global, d) ss, ok := cs.parseDecl(&cs.global, "", d)
if !ok { if !ok {
return return
} }
@ -261,7 +260,7 @@ func (cs *compileState) parse(f *ast.File) {
} }
} }
inParams, outParams, ret := cs.parseFuncParams(&cs.global, fd) inParams, outParams, ret := cs.parseFuncParams(&cs.global, n, fd)
var inT, outT []shaderir.Type var inT, outT []shaderir.Type
for _, v := range inParams { for _, v := range inParams {
inT = append(inT, v.typ) inT = append(inT, v.typ)
@ -284,8 +283,8 @@ func (cs *compileState) parse(f *ast.File) {
// Parse functions. // Parse functions.
for _, d := range f.Decls { for _, d := range f.Decls {
if _, ok := d.(*ast.FuncDecl); ok { if f, ok := d.(*ast.FuncDecl); ok {
ss, ok := cs.parseDecl(&cs.global, d) ss, ok := cs.parseDecl(&cs.global, f.Name.Name, d)
if !ok { if !ok {
return return
} }
@ -302,7 +301,7 @@ func (cs *compileState) parse(f *ast.File) {
} }
} }
func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) { func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) {
var stmts []shaderir.Stmt var stmts []shaderir.Stmt
switch d := d.(type) { switch d := d.(type) {
@ -312,7 +311,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool)
// TODO: Parse other types // TODO: Parse other types
for _, s := range d.Specs { for _, s := range d.Specs {
s := s.(*ast.TypeSpec) s := s.(*ast.TypeSpec)
t, ok := cs.parseType(b, s.Type) t, ok := cs.parseType(b, fname, s.Type)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -324,7 +323,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool)
case token.CONST: case token.CONST:
for _, s := range d.Specs { for _, s := range d.Specs {
s := s.(*ast.ValueSpec) s := s.(*ast.ValueSpec)
cs, ok := cs.parseConstant(b, s) cs, ok := cs.parseConstant(b, fname, s)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -333,7 +332,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool)
case token.VAR: case token.VAR:
for _, s := range d.Specs { for _, s := range d.Specs {
s := s.(*ast.ValueSpec) s := s.(*ast.ValueSpec)
vs, inits, ss, ok := cs.parseVariable(b, s) vs, inits, ss, ok := cs.parseVariable(b, fname, s)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -439,7 +438,7 @@ func (cs *compileState) functionReturnTypes(block *block, expr ast.Expr) ([]shad
return nil, false return nil, false
} }
func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt, bool) { func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt, bool) {
if len(vs.Names) != len(vs.Values) && len(vs.Values) != 1 && len(vs.Values) != 0 { if len(vs.Names) != len(vs.Values) && len(vs.Values) != 1 && len(vs.Values) != 0 {
s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match"))
return nil, nil, nil, false return nil, nil, nil, false
@ -448,7 +447,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
var declt shaderir.Type var declt shaderir.Type
if vs.Type != nil { if vs.Type != nil {
var ok bool var ok bool
declt, ok = s.parseType(block, vs.Type) declt, ok = s.parseType(block, fname, vs.Type)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -475,7 +474,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
init := vs.Values[i] init := vs.Values[i]
es, rts, ss, ok := s.parseExpr(block, init, true) es, rts, ss, ok := s.parseExpr(block, fname, init, true)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -517,7 +516,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
var ss []shaderir.Stmt var ss []shaderir.Stmt
var ok bool var ok bool
initexprs, inittypes, ss, ok = s.parseExpr(block, init, true) initexprs, inittypes, ss, ok = s.parseExpr(block, fname, init, true)
if !ok { if !ok {
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -569,11 +568,11 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
return vars, inits, stmts, true return vars, inits, stmts, true
} }
func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constant, bool) { func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSpec) ([]constant, bool) {
var t shaderir.Type var t shaderir.Type
if vs.Type != nil { if vs.Type != nil {
var ok bool var ok bool
t, ok = s.parseType(block, vs.Type) t, ok = s.parseType(block, fname, vs.Type)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -595,7 +594,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan
} }
} }
es, ts, ss, ok := s.parseExpr(block, vs.Values[i], false) es, ts, ss, ok := s.parseExpr(block, fname, vs.Values[i], false)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -621,9 +620,9 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan
return cs, true return cs, true
} }
func (cs *compileState) parseFuncParams(block *block, d *ast.FuncDecl) (in, out []variable, ret shaderir.Type) { func (cs *compileState) parseFuncParams(block *block, fname string, d *ast.FuncDecl) (in, out []variable, ret shaderir.Type) {
for _, f := range d.Type.Params.List { for _, f := range d.Type.Params.List {
t, ok := cs.parseType(block, f.Type) t, ok := cs.parseType(block, fname, f.Type)
if !ok { if !ok {
return return
} }
@ -640,7 +639,7 @@ func (cs *compileState) parseFuncParams(block *block, d *ast.FuncDecl) (in, out
} }
for _, f := range d.Type.Results.List { for _, f := range d.Type.Results.List {
t, ok := cs.parseType(block, f.Type) t, ok := cs.parseType(block, fname, f.Type)
if !ok { if !ok {
return return
} }
@ -681,7 +680,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool
return function{}, false return function{}, false
} }
inParams, outParams, returnType := cs.parseFuncParams(block, d) inParams, outParams, returnType := cs.parseFuncParams(block, d.Name.Name, d)
checkVaryings := func(vs []variable) { checkVaryings := func(vs []variable) {
if len(cs.ir.Varyings) != len(vs) { if len(cs.ir.Varyings) != len(vs) {
@ -790,7 +789,6 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool
return function{ return function{
name: d.Name.Name, name: d.Name.Name,
block: b,
ir: shaderir.Func{ ir: shaderir.Func{
InParams: inT, InParams: inT,
OutParams: outT, OutParams: outT,

View File

@ -62,13 +62,13 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN: case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN:
rhs, rts, ss, ok := cs.parseExpr(block, stmt.Rhs[0], true) rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true)
if !ok { if !ok {
return nil, false return nil, false
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
lhs, lts, ss, ok := cs.parseExpr(block, stmt.Lhs[0], true) lhs, lts, ss, ok := cs.parseExpr(block, fname, stmt.Lhs[0], true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -177,7 +177,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
}, },
}) })
case *ast.DeclStmt: case *ast.DeclStmt:
ss, ok := cs.parseDecl(block, stmt.Decl) ss, ok := cs.parseDecl(block, fname, stmt.Decl)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -228,7 +228,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
vartype := pseudoBlock.vars[0].typ vartype := pseudoBlock.vars[0].typ
init := ss[0].Exprs[1].Const init := ss[0].Exprs[1].Const
exprs, ts, ss, ok := cs.parseExpr(pseudoBlock, stmt.Cond, true) exprs, ts, ss, ok := cs.parseExpr(pseudoBlock, fname, stmt.Cond, true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -356,7 +356,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return stmts, true return stmts, true
} }
exprs, ts, ss, ok := cs.parseExpr(block, stmt.Cond, true) exprs, ts, ss, ok := cs.parseExpr(block, fname, stmt.Cond, true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -401,7 +401,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
}) })
case *ast.IncDecStmt: case *ast.IncDecStmt:
exprs, _, ss, ok := cs.parseExpr(block, stmt.X, true) exprs, _, ss, ok := cs.parseExpr(block, fname, stmt.X, true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -445,7 +445,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
var exprs []shaderir.Expr var exprs []shaderir.Expr
var types []shaderir.Type var types []shaderir.Type
for _, r := range stmt.Results { for _, r := range stmt.Results {
es, ts, ss, ok := cs.parseExpr(block, r, true) es, ts, ss, ok := cs.parseExpr(block, fname, r, true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -544,7 +544,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
exprs, _, ss, ok := cs.parseExpr(block, stmt.X, true) exprs, _, ss, ok := cs.parseExpr(block, fname, stmt.X, true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -582,7 +582,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
for i, e := range lhs { for i, e := range lhs {
if len(lhs) == len(rhs) { if len(lhs) == len(rhs) {
// Prase RHS first for the order of the statements. // Prase RHS first for the order of the statements.
r, rts, ss, ok := cs.parseExpr(block, rhs[i], true) r, rts, ss, ok := cs.parseExpr(block, fname, rhs[i], true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -619,7 +619,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
return nil, false return nil, false
} }
l, lts, ss, ok := cs.parseExpr(block, lhs[i], false) l, lts, ss, ok := cs.parseExpr(block, fname, lhs[i], false)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -714,7 +714,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
if i == 0 { if i == 0 {
var ss []shaderir.Stmt var ss []shaderir.Stmt
var ok bool var ok bool
rhsExprs, rhsTypes, ss, ok = cs.parseExpr(block, rhs[0], true) rhsExprs, rhsTypes, ss, ok = cs.parseExpr(block, fname, rhs[0], true)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -743,7 +743,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
block.addNamedLocalVariable(name, t, e.Pos()) block.addNamedLocalVariable(name, t, e.Pos())
} }
l, lts, ss, ok := cs.parseExpr(block, lhs[i], false) l, lts, ss, ok := cs.parseExpr(block, fname, lhs[i], false)
if !ok { if !ok {
return nil, false return nil, false
} }

View File

@ -1575,3 +1575,42 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
} }
} }
} }
// Issue #2248
func TestSyntaxDiscard(t *testing.T) {
if _, err := compileToIR([]byte(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
if true {
discard()
}
return vec4(0)
}
`)); err != nil {
t.Error(err)
}
// discard without return doesn't work so far.
// TODO: Allow discard without return.
if _, err := compileToIR([]byte(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
discard()
return vec4(0)
}
`)); err != nil {
t.Error(err)
}
if _, err := compileToIR([]byte(`package main
func foo() {
discard()
}
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
foo()
return vec4(0)
}
`)); err == nil {
t.Errorf("error must be non-nil but was nil")
}
}

View File

@ -23,7 +23,7 @@ import (
"github.com/hajimehoshi/ebiten/v2/internal/shaderir" "github.com/hajimehoshi/ebiten/v2/internal/shaderir"
) )
func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, bool) { func (cs *compileState) parseType(block *block, fname string, expr ast.Expr) (shaderir.Type, bool) {
switch t := expr.(type) { switch t := expr.(type) {
case *ast.Ident: case *ast.Ident:
switch t.Name { switch t.Name {
@ -58,7 +58,7 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b
if _, ok := t.Len.(*ast.Ellipsis); ok { if _, ok := t.Len.(*ast.Ellipsis); ok {
length = -1 // Determine the length later. length = -1 // Determine the length later.
} else { } else {
exprs, _, _, ok := cs.parseExpr(block, t.Len, true) exprs, _, _, ok := cs.parseExpr(block, fname, t.Len, true)
if !ok { if !ok {
return shaderir.Type{}, false return shaderir.Type{}, false
} }
@ -78,7 +78,7 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b
length = int(l) length = int(l)
} }
elm, ok := cs.parseType(block, t.Elt) elm, ok := cs.parseType(block, fname, t.Elt)
if !ok { if !ok {
return shaderir.Type{}, false return shaderir.Type{}, false
} }