internal/shader: disallow 'discard' in other functions than the fragment entry point

Closes #2248
This commit is contained in:
Hajime Hoshi 2022-08-17 23:13:21 +09:00
parent 9d303e8dc5
commit fb775d806c
5 changed files with 90 additions and 50 deletions

View File

@ -67,7 +67,7 @@ func goConstantKindString(k gconstant.Kind) string {
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) {
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) {
switch e := expr.(type) {
case *ast.BasicLit:
switch e.Kind {
@ -93,7 +93,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var stmts []shaderir.Stmt
// Prase LHS first for the order of the statements.
lhs, ts, ss, ok := cs.parseExpr(block, e.X, markLocalVariableUsed)
lhs, ts, ss, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}
@ -104,7 +104,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
stmts = append(stmts, ss...)
lhst := ts[0]
rhs, ts, ss, ok := cs.parseExpr(block, e.Y, markLocalVariableUsed)
rhs, ts, ss, ok := cs.parseExpr(block, fname, e.Y, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}
@ -299,7 +299,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
// Parse the argument first for the order of the statements.
for _, a := range e.Args {
es, ts, ss, ok := cs.parseExpr(block, a, markLocalVariableUsed)
es, ts, ss, ok := cs.parseExpr(block, fname, a, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}
@ -313,7 +313,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
}
// TODO: When len(ss) is not 0?
es, _, ss, ok := cs.parseExpr(block, e.Fun, markLocalVariableUsed)
es, _, ss, ok := cs.parseExpr(block, fname, e.Fun, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}
@ -473,6 +473,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
if len(args) != 0 {
cs.addError(e.Pos(), fmt.Sprintf("number of %s's arguments must be 0 but %d", callee.BuiltinFunc, len(args)))
}
if fname != cs.fragmentEntry {
cs.addError(e.Pos(), fmt.Sprintf("discard is available only in %s", cs.fragmentEntry))
}
stmts = append(stmts, shaderir.Stmt{
Type: shaderir.Discard,
})
@ -653,10 +656,10 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name))
case *ast.ParenExpr:
return cs.parseExpr(block, e.X, markLocalVariableUsed)
return cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
case *ast.SelectorExpr:
exprs, _, stmts, ok := cs.parseExpr(block, e.X, true)
exprs, _, stmts, ok := cs.parseExpr(block, fname, e.X, true)
if !ok {
return nil, nil, nil, false
}
@ -692,7 +695,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
}, []shaderir.Type{t}, stmts, true
case *ast.UnaryExpr:
exprs, t, stmts, ok := cs.parseExpr(block, e.X, markLocalVariableUsed)
exprs, t, stmts, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}
@ -736,7 +739,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
}, t, stmts, true
case *ast.CompositeLit:
t, ok := cs.parseType(block, e.Type)
t, ok := cs.parseType(block, fname, e.Type)
if !ok {
return nil, nil, nil, false
}
@ -751,7 +754,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var stmts []shaderir.Stmt
for i, e := range e.Elts {
exprs, _, ss, ok := cs.parseExpr(block, e, markLocalVariableUsed)
exprs, _, ss, ok := cs.parseExpr(block, fname, e, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}
@ -793,7 +796,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
var stmts []shaderir.Stmt
// Parse the index first
exprs, _, ss, ok := cs.parseExpr(block, e.Index, markLocalVariableUsed)
exprs, _, ss, ok := cs.parseExpr(block, fname, e.Index, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}
@ -812,7 +815,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
idx.ConstType = shaderir.ConstTypeInt
}
exprs, ts, ss, ok := cs.parseExpr(block, e.X, markLocalVariableUsed)
exprs, ts, ss, ok := cs.parseExpr(block, fname, e.X, markLocalVariableUsed)
if !ok {
return nil, nil, nil, false
}

View File

@ -39,7 +39,6 @@ type constant struct {
type function struct {
name string
block *block
ir shaderir.Func
}
@ -212,7 +211,7 @@ func (cs *compileState) parse(f *ast.File) {
// Parse GenDecl for global variables, and then parse functions.
for _, d := range f.Decls {
if _, ok := d.(*ast.FuncDecl); !ok {
ss, ok := cs.parseDecl(&cs.global, d)
ss, ok := cs.parseDecl(&cs.global, "", d)
if !ok {
return
}
@ -261,7 +260,7 @@ func (cs *compileState) parse(f *ast.File) {
}
}
inParams, outParams, ret := cs.parseFuncParams(&cs.global, fd)
inParams, outParams, ret := cs.parseFuncParams(&cs.global, n, fd)
var inT, outT []shaderir.Type
for _, v := range inParams {
inT = append(inT, v.typ)
@ -284,8 +283,8 @@ func (cs *compileState) parse(f *ast.File) {
// Parse functions.
for _, d := range f.Decls {
if _, ok := d.(*ast.FuncDecl); ok {
ss, ok := cs.parseDecl(&cs.global, d)
if f, ok := d.(*ast.FuncDecl); ok {
ss, ok := cs.parseDecl(&cs.global, f.Name.Name, d)
if !ok {
return
}
@ -302,7 +301,7 @@ func (cs *compileState) parse(f *ast.File) {
}
}
func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) {
func (cs *compileState) parseDecl(b *block, fname string, d ast.Decl) ([]shaderir.Stmt, bool) {
var stmts []shaderir.Stmt
switch d := d.(type) {
@ -312,7 +311,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool)
// TODO: Parse other types
for _, s := range d.Specs {
s := s.(*ast.TypeSpec)
t, ok := cs.parseType(b, s.Type)
t, ok := cs.parseType(b, fname, s.Type)
if !ok {
return nil, false
}
@ -324,7 +323,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool)
case token.CONST:
for _, s := range d.Specs {
s := s.(*ast.ValueSpec)
cs, ok := cs.parseConstant(b, s)
cs, ok := cs.parseConstant(b, fname, s)
if !ok {
return nil, false
}
@ -333,7 +332,7 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool)
case token.VAR:
for _, s := range d.Specs {
s := s.(*ast.ValueSpec)
vs, inits, ss, ok := cs.parseVariable(b, s)
vs, inits, ss, ok := cs.parseVariable(b, fname, s)
if !ok {
return nil, false
}
@ -439,7 +438,7 @@ func (cs *compileState) functionReturnTypes(block *block, expr ast.Expr) ([]shad
return nil, false
}
func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt, bool) {
func (s *compileState) parseVariable(block *block, fname string, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt, bool) {
if len(vs.Names) != len(vs.Values) && len(vs.Values) != 1 && len(vs.Values) != 0 {
s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match"))
return nil, nil, nil, false
@ -448,7 +447,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
var declt shaderir.Type
if vs.Type != nil {
var ok bool
declt, ok = s.parseType(block, vs.Type)
declt, ok = s.parseType(block, fname, vs.Type)
if !ok {
return nil, nil, nil, false
}
@ -475,7 +474,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
init := vs.Values[i]
es, rts, ss, ok := s.parseExpr(block, init, true)
es, rts, ss, ok := s.parseExpr(block, fname, init, true)
if !ok {
return nil, nil, nil, false
}
@ -517,7 +516,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
var ss []shaderir.Stmt
var ok bool
initexprs, inittypes, ss, ok = s.parseExpr(block, init, true)
initexprs, inittypes, ss, ok = s.parseExpr(block, fname, init, true)
if !ok {
return nil, nil, nil, false
}
@ -569,11 +568,11 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
return vars, inits, stmts, true
}
func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constant, bool) {
func (s *compileState) parseConstant(block *block, fname string, vs *ast.ValueSpec) ([]constant, bool) {
var t shaderir.Type
if vs.Type != nil {
var ok bool
t, ok = s.parseType(block, vs.Type)
t, ok = s.parseType(block, fname, vs.Type)
if !ok {
return nil, false
}
@ -595,7 +594,7 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan
}
}
es, ts, ss, ok := s.parseExpr(block, vs.Values[i], false)
es, ts, ss, ok := s.parseExpr(block, fname, vs.Values[i], false)
if !ok {
return nil, false
}
@ -621,9 +620,9 @@ func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) ([]constan
return cs, true
}
func (cs *compileState) parseFuncParams(block *block, d *ast.FuncDecl) (in, out []variable, ret shaderir.Type) {
func (cs *compileState) parseFuncParams(block *block, fname string, d *ast.FuncDecl) (in, out []variable, ret shaderir.Type) {
for _, f := range d.Type.Params.List {
t, ok := cs.parseType(block, f.Type)
t, ok := cs.parseType(block, fname, f.Type)
if !ok {
return
}
@ -640,7 +639,7 @@ func (cs *compileState) parseFuncParams(block *block, d *ast.FuncDecl) (in, out
}
for _, f := range d.Type.Results.List {
t, ok := cs.parseType(block, f.Type)
t, ok := cs.parseType(block, fname, f.Type)
if !ok {
return
}
@ -681,7 +680,7 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool
return function{}, false
}
inParams, outParams, returnType := cs.parseFuncParams(block, d)
inParams, outParams, returnType := cs.parseFuncParams(block, d.Name.Name, d)
checkVaryings := func(vs []variable) {
if len(cs.ir.Varyings) != len(vs) {
@ -790,7 +789,6 @@ func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool
return function{
name: d.Name.Name,
block: b,
ir: shaderir.Func{
InParams: inT,
OutParams: outT,

View File

@ -62,13 +62,13 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
}
stmts = append(stmts, ss...)
case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN:
rhs, rts, ss, ok := cs.parseExpr(block, stmt.Rhs[0], true)
rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true)
if !ok {
return nil, false
}
stmts = append(stmts, ss...)
lhs, lts, ss, ok := cs.parseExpr(block, stmt.Lhs[0], true)
lhs, lts, ss, ok := cs.parseExpr(block, fname, stmt.Lhs[0], true)
if !ok {
return nil, false
}
@ -177,7 +177,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
},
})
case *ast.DeclStmt:
ss, ok := cs.parseDecl(block, stmt.Decl)
ss, ok := cs.parseDecl(block, fname, stmt.Decl)
if !ok {
return nil, false
}
@ -228,7 +228,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
vartype := pseudoBlock.vars[0].typ
init := ss[0].Exprs[1].Const
exprs, ts, ss, ok := cs.parseExpr(pseudoBlock, stmt.Cond, true)
exprs, ts, ss, ok := cs.parseExpr(pseudoBlock, fname, stmt.Cond, true)
if !ok {
return nil, false
}
@ -356,7 +356,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return stmts, true
}
exprs, ts, ss, ok := cs.parseExpr(block, stmt.Cond, true)
exprs, ts, ss, ok := cs.parseExpr(block, fname, stmt.Cond, true)
if !ok {
return nil, false
}
@ -401,7 +401,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
})
case *ast.IncDecStmt:
exprs, _, ss, ok := cs.parseExpr(block, stmt.X, true)
exprs, _, ss, ok := cs.parseExpr(block, fname, stmt.X, true)
if !ok {
return nil, false
}
@ -445,7 +445,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
var exprs []shaderir.Expr
var types []shaderir.Type
for _, r := range stmt.Results {
es, ts, ss, ok := cs.parseExpr(block, r, true)
es, ts, ss, ok := cs.parseExpr(block, fname, r, true)
if !ok {
return nil, false
}
@ -544,7 +544,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false
}
exprs, _, ss, ok := cs.parseExpr(block, stmt.X, true)
exprs, _, ss, ok := cs.parseExpr(block, fname, stmt.X, true)
if !ok {
return nil, false
}
@ -582,7 +582,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
for i, e := range lhs {
if len(lhs) == len(rhs) {
// Prase RHS first for the order of the statements.
r, rts, ss, ok := cs.parseExpr(block, rhs[i], true)
r, rts, ss, ok := cs.parseExpr(block, fname, rhs[i], true)
if !ok {
return nil, false
}
@ -619,7 +619,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
return nil, false
}
l, lts, ss, ok := cs.parseExpr(block, lhs[i], false)
l, lts, ss, ok := cs.parseExpr(block, fname, lhs[i], false)
if !ok {
return nil, false
}
@ -714,7 +714,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
if i == 0 {
var ss []shaderir.Stmt
var ok bool
rhsExprs, rhsTypes, ss, ok = cs.parseExpr(block, rhs[0], true)
rhsExprs, rhsTypes, ss, ok = cs.parseExpr(block, fname, rhs[0], true)
if !ok {
return nil, false
}
@ -743,7 +743,7 @@ func (cs *compileState) assign(block *block, fname string, pos token.Pos, lhs, r
block.addNamedLocalVariable(name, t, e.Pos())
}
l, lts, ss, ok := cs.parseExpr(block, lhs[i], false)
l, lts, ss, ok := cs.parseExpr(block, fname, lhs[i], false)
if !ok {
return nil, false
}

View File

@ -1575,3 +1575,42 @@ func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
}
}
}
// Issue #2248
func TestSyntaxDiscard(t *testing.T) {
if _, err := compileToIR([]byte(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
if true {
discard()
}
return vec4(0)
}
`)); err != nil {
t.Error(err)
}
// discard without return doesn't work so far.
// TODO: Allow discard without return.
if _, err := compileToIR([]byte(`package main
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
discard()
return vec4(0)
}
`)); err != nil {
t.Error(err)
}
if _, err := compileToIR([]byte(`package main
func foo() {
discard()
}
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
foo()
return vec4(0)
}
`)); err == nil {
t.Errorf("error must be non-nil but was nil")
}
}

View File

@ -23,7 +23,7 @@ import (
"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
)
func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, bool) {
func (cs *compileState) parseType(block *block, fname string, expr ast.Expr) (shaderir.Type, bool) {
switch t := expr.(type) {
case *ast.Ident:
switch t.Name {
@ -58,7 +58,7 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b
if _, ok := t.Len.(*ast.Ellipsis); ok {
length = -1 // Determine the length later.
} else {
exprs, _, _, ok := cs.parseExpr(block, t.Len, true)
exprs, _, _, ok := cs.parseExpr(block, fname, t.Len, true)
if !ok {
return shaderir.Type{}, false
}
@ -78,7 +78,7 @@ func (cs *compileState) parseType(block *block, expr ast.Expr) (shaderir.Type, b
length = int(l)
}
elm, ok := cs.parseType(block, t.Elt)
elm, ok := cs.parseType(block, fname, t.Elt)
if !ok {
return shaderir.Type{}, false
}