internal/shader: refactoring

This commit is contained in:
Hajime Hoshi 2022-07-10 16:02:50 +09:00
parent bf0f3d304b
commit 8c879c7bcf

View File

@ -442,35 +442,39 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
} }
} }
for i, r := range stmt.Results { var exprs []shaderir.Expr
exprs, ts, ss, ok := cs.parseExpr(block, r, true) var types []shaderir.Type
for _, r := range stmt.Results {
es, ts, ss, ok := cs.parseExpr(block, r, true)
if !ok { if !ok {
return nil, false return nil, false
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
if len(exprs) > 1 { if len(es) > 1 && (len(stmt.Results) > 1 || len(outParams) == 1) {
if len(stmt.Results) > 1 || len(outParams) == 1 {
cs.addError(r.Pos(), "single-value context and multiple-value context cannot be mixed") cs.addError(r.Pos(), "single-value context and multiple-value context cannot be mixed")
return nil, false return nil, false
} }
}
if len(outParams) > 1 && len(stmt.Results) == 1 { if len(outParams) > 1 && len(stmt.Results) == 1 {
if len(exprs) == 1 { if len(es) == 1 {
cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results))) cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(stmt.Results)))
return nil, false return nil, false
} }
if len(exprs) > 1 && len(exprs) != len(outParams) { if len(es) > 1 && len(es) != len(outParams) {
cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(exprs))) cs.addError(stmt.Pos(), fmt.Sprintf("the number of returning variables must be %d but %d", len(outParams), len(es)))
return nil, false return nil, false
} }
} }
for j, t := range ts { exprs = append(exprs, es...)
expr := exprs[j] types = append(types, ts...)
}
for i, t := range types {
expr := exprs[i]
if expr.Type == shaderir.NumberExpr { if expr.Type == shaderir.NumberExpr {
switch outParams[i+j].typ.Main { switch outParams[i].typ.Main {
case shaderir.Int: case shaderir.Int:
if !cs.forceToInt(stmt, &expr) { if !cs.forceToInt(stmt, &expr) {
return nil, false return nil, false
@ -481,7 +485,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
} }
} }
if !t.Equal(&outParams[i+j].typ) { if !t.Equal(&outParams[i].typ) {
cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", t.String(), &outParams[i].typ)) cs.addError(stmt.Pos(), fmt.Sprintf("cannot use type %s as type %s in return argument", t.String(), &outParams[i].typ))
return nil, false return nil, false
} }
@ -491,13 +495,13 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
Exprs: []shaderir.Expr{ Exprs: []shaderir.Expr{
{ {
Type: shaderir.LocalVariable, Type: shaderir.LocalVariable,
Index: len(inParams) + i + j, Index: len(inParams) + i,
}, },
expr, expr,
}, },
}) })
} }
}
stmts = append(stmts, shaderir.Stmt{ stmts = append(stmts, shaderir.Stmt{
Type: shaderir.Return, Type: shaderir.Return,
}) })