shader: Make parseExpr return types

Updates #1190
This commit is contained in:
Hajime Hoshi 2020-06-19 02:37:03 +09:00
parent afb4e6dc3d
commit 3da0af5de2

View File

@ -95,20 +95,20 @@ type block struct {
ir shaderir.Block ir shaderir.Block
} }
func (b *block) findLocalVariable(name string) (int, bool) { func (b *block) findLocalVariable(name string) (int, shaderir.Type, bool) {
idx := 0 idx := 0
for outer := b.outer; outer != nil; outer = outer.outer { for outer := b.outer; outer != nil; outer = outer.outer {
idx += len(outer.vars) idx += len(outer.vars)
} }
for i, v := range b.vars { for i, v := range b.vars {
if v.name == name { if v.name == name {
return idx + i, true return idx + i, v.typ, true
} }
} }
if b.outer != nil { if b.outer != nil {
return b.outer.findLocalVariable(name) return b.outer.findLocalVariable(name)
} }
return 0, false return 0, shaderir.Type{}, false
} }
type ParseError struct { type ParseError struct {
@ -357,7 +357,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
}) })
if len(vs.Values) > 1 || (len(vs.Values) == 1 && len(inits) == 0) { if len(vs.Values) > 1 || (len(vs.Values) == 1 && len(inits) == 0) {
es, ss := s.parseExpr(block, init) es, _, ss := s.parseExpr(block, init)
inits = append(inits, es...) inits = append(inits, es...)
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
} }
@ -594,9 +594,9 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
case token.REM_ASSIGN: case token.REM_ASSIGN:
op = shaderir.ModOp op = shaderir.ModOp
} }
rhs, stmts := cs.parseExpr(block, l.Rhs[0]) rhs, _, stmts := cs.parseExpr(block, l.Rhs[0])
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
lhs, stmts := cs.parseExpr(block, l.Lhs[0]) lhs, _, stmts := cs.parseExpr(block, l.Lhs[0])
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{
Type: shaderir.Assign, Type: shaderir.Assign,
@ -627,7 +627,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
cs.parseDecl(block, l.Decl) cs.parseDecl(block, l.Decl)
case *ast.ReturnStmt: case *ast.ReturnStmt:
for i, r := range l.Results { for i, r := range l.Results {
exprs, stmts := cs.parseExpr(block, r) exprs, _, stmts := cs.parseExpr(block, r)
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
if len(exprs) == 0 { if len(exprs) == 0 {
continue continue
@ -651,7 +651,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
Type: shaderir.Return, Type: shaderir.Return,
}) })
case *ast.ExprStmt: case *ast.ExprStmt:
exprs, stmts := cs.parseExpr(block, l.X) exprs, _, stmts := cs.parseExpr(block, l.X)
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
for _, expr := range exprs { for _, expr := range exprs {
if expr.Type != shaderir.Call { if expr.Type != shaderir.Call {
@ -675,13 +675,13 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr)
for i := range lhs { for i := range lhs {
// Prase RHS first for the order of the statements. // Prase RHS first for the order of the statements.
if len(lhs) == len(rhs) { if len(lhs) == len(rhs) {
rhs, stmts := cs.parseExpr(block, rhs[i]) rhs, _, stmts := cs.parseExpr(block, rhs[i])
if len(rhs) > 1 { if len(rhs) > 1 {
cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed"))
} }
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
lhs, stmts := cs.parseExpr(block, lhs[i]) lhs, _, stmts := cs.parseExpr(block, lhs[i])
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{
@ -691,14 +691,14 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr)
} else { } else {
if i == 0 { if i == 0 {
var stmts []shaderir.Stmt var stmts []shaderir.Stmt
rhsExprs, stmts = cs.parseExpr(block, rhs[0]) rhsExprs, _, stmts = cs.parseExpr(block, rhs[0])
if len(rhsExprs) != len(lhs) { if len(rhsExprs) != len(lhs) {
cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed"))
} }
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
} }
lhs, stmts := cs.parseExpr(block, lhs[i]) lhs, _, stmts := cs.parseExpr(block, lhs[i])
block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, stmts...)
block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{
@ -709,7 +709,7 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr)
} }
} }
func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, []shaderir.Stmt) { func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt) {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
@ -717,26 +717,26 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
v, err := strconv.ParseInt(e.Value, 10, 32) v, err := strconv.ParseInt(e.Value, 10, 32)
if err != nil { if err != nil {
cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value))
return nil, nil return nil, nil, nil
} }
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.IntExpr, Type: shaderir.IntExpr,
Int: int32(v), Int: int32(v),
}, },
}, nil }, []shaderir.Type{{Main: shaderir.Int}}, nil
case token.FLOAT: case token.FLOAT:
v, err := strconv.ParseFloat(e.Value, 32) v, err := strconv.ParseFloat(e.Value, 32)
if err != nil { if err != nil {
cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value))
return nil, nil return nil, nil, nil
} }
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.FloatExpr, Type: shaderir.FloatExpr,
Float: float32(v), Float: float32(v),
}, },
}, nil }, []shaderir.Type{{Main: shaderir.Float}}, nil
default: default:
cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e)) cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e))
} }
@ -783,33 +783,36 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
op = shaderir.OrOr op = shaderir.OrOr
default: default:
cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op))
return nil, nil return nil, nil, nil
} }
var stmts []shaderir.Stmt var stmts []shaderir.Stmt
// Prase RHS first for the order of the statements. // Prase RHS first for the order of the statements.
rhs, ss := cs.parseExpr(block, e.Y) rhs, t0, ss := cs.parseExpr(block, e.Y)
if len(rhs) != 1 { if len(rhs) != 1 {
cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a binary operator: %s", e.Y)) cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a binary operator: %s", e.Y))
return nil, nil return nil, nil, nil
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
lhs, ss := cs.parseExpr(block, e.X) lhs, t1, ss := cs.parseExpr(block, e.X)
if len(lhs) != 1 { if len(lhs) != 1 {
cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a binary operator: %s", e.X)) cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a binary operator: %s", e.X))
return nil, nil return nil, nil, nil
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
// TODO: Check the compatibility of t0 and t1
_ = t1
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.Binary, Type: shaderir.Binary,
Op: op, Op: op,
Exprs: []shaderir.Expr{lhs[0], rhs[0]}, Exprs: []shaderir.Expr{lhs[0], rhs[0]},
}, },
}, stmts }, t0, stmts
case *ast.CallExpr: case *ast.CallExpr:
var ( var (
callee shaderir.Expr callee shaderir.Expr
@ -819,10 +822,10 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
// Parse the argument first for the order of the statements. // Parse the argument first for the order of the statements.
for _, a := range e.Args { for _, a := range e.Args {
es, ss := cs.parseExpr(block, a) es, _, ss := cs.parseExpr(block, a)
if len(es) > 1 && len(e.Args) > 1 { if len(es) > 1 && len(e.Args) > 1 {
cs.addError(e.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed: %s", e.Fun)) cs.addError(e.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed: %s", e.Fun))
return nil, nil return nil, nil, nil
} }
// TODO: Convert integer literals to float literals if necessary. // TODO: Convert integer literals to float literals if necessary.
args = append(args, es...) args = append(args, es...)
@ -830,10 +833,10 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
} }
// TODO: When len(ss) is not 0? // TODO: When len(ss) is not 0?
es, ss := cs.parseExpr(block, e.Fun) es, _, ss := cs.parseExpr(block, e.Fun)
if len(es) != 1 { if len(es) != 1 {
cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a callee: %s", e.Fun)) cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a callee: %s", e.Fun))
return nil, nil return nil, nil, nil
} }
callee = es[0] callee = es[0]
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
@ -841,17 +844,19 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
// For built-in functions, we can call this in this position. Return an expression for the function // For built-in functions, we can call this in this position. Return an expression for the function
// call. // call.
if callee.Type == shaderir.BuiltinFuncExpr { if callee.Type == shaderir.BuiltinFuncExpr {
var t shaderir.Type
// TODO: Decude the type based on the arguments.
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.Call, Type: shaderir.Call,
Exprs: append([]shaderir.Expr{callee}, args...), Exprs: append([]shaderir.Expr{callee}, args...),
}, },
}, stmts }, []shaderir.Type{t}, stmts
} }
if callee.Type != shaderir.FunctionExpr { if callee.Type != shaderir.FunctionExpr {
cs.addError(e.Pos(), fmt.Sprintf("function callee must be a funciton name but %s", e.Fun)) cs.addError(e.Pos(), fmt.Sprintf("function callee must be a funciton name but %s", e.Fun))
return nil, nil return nil, nil, nil
} }
f := cs.funcs[callee.Index] f := cs.funcs[callee.Index]
@ -872,7 +877,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
if t := f.ir.Return; t.Main != shaderir.None { if t := f.ir.Return; t.Main != shaderir.None {
if len(outParams) == 0 { if len(outParams) == 0 {
cs.addError(e.Pos(), fmt.Sprintf("a function returning value cannot have out-params so far: %s", e.Fun)) cs.addError(e.Pos(), fmt.Sprintf("a function returning value cannot have out-params so far: %s", e.Fun))
return nil, nil return nil, nil, nil
} }
idx := len(block.vars) idx := len(block.vars)
@ -902,7 +907,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
Type: shaderir.LocalVariable, Type: shaderir.LocalVariable,
Index: idx, Index: idx,
}, },
}, stmts }, []shaderir.Type{t}, stmts
} }
// Even if the function doesn't return anything, calling the function should be done eariler to keep // Even if the function doesn't return anything, calling the function should be done eariler to keep
@ -928,15 +933,15 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
Index: p, Index: p,
}) })
} }
return exprs, stmts return exprs, nil, stmts
case *ast.Ident: case *ast.Ident:
if i, ok := block.findLocalVariable(e.Name); ok { if i, t, ok := block.findLocalVariable(e.Name); ok {
return []shaderir.Expr{ return []shaderir.Expr{
{ {
Type: shaderir.LocalVariable, Type: shaderir.LocalVariable,
Index: i, Index: i,
}, },
}, nil }, []shaderir.Type{t}, nil
} }
if i, ok := cs.findFunction(e.Name); ok { if i, ok := cs.findFunction(e.Name); ok {
return []shaderir.Expr{ return []shaderir.Expr{
@ -944,7 +949,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
Type: shaderir.FunctionExpr, Type: shaderir.FunctionExpr,
Index: i, Index: i,
}, },
}, nil }, nil, nil
} }
if i, ok := cs.findUniformVariable(e.Name); ok { if i, ok := cs.findUniformVariable(e.Name); ok {
return []shaderir.Expr{ return []shaderir.Expr{
@ -952,7 +957,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
Type: shaderir.UniformVariable, Type: shaderir.UniformVariable,
Index: i, Index: i,
}, },
}, nil }, []shaderir.Type{cs.ir.Uniforms[i]}, nil
} }
if f, ok := shaderir.ParseBuiltinFunc(e.Name); ok { if f, ok := shaderir.ParseBuiltinFunc(e.Name); ok {
return []shaderir.Expr{ return []shaderir.Expr{
@ -960,14 +965,28 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
Type: shaderir.BuiltinFuncExpr, Type: shaderir.BuiltinFuncExpr,
BuiltinFunc: f, BuiltinFunc: f,
}, },
}, nil }, nil, nil
} }
cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name)) cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name))
case *ast.SelectorExpr: case *ast.SelectorExpr:
exprs, stmts := cs.parseExpr(block, e.X) exprs, _, stmts := cs.parseExpr(block, e.X)
if len(exprs) != 1 { if len(exprs) != 1 {
cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a selector: %s", e.X)) cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a selector: %s", e.X))
return nil, nil return nil, nil, nil
}
var t shaderir.Type
switch len(e.Sel.Name) {
case 1:
t.Main = shaderir.Float
case 2:
t.Main = shaderir.Vec2
case 3:
t.Main = shaderir.Vec3
case 4:
t.Main = shaderir.Vec4
default:
cs.addError(e.Pos(), fmt.Sprintf("unexpected swizzling: %s", e.Sel.Name))
return nil, nil, nil
} }
return []shaderir.Expr{ return []shaderir.Expr{
{ {
@ -980,7 +999,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
}, },
}, },
}, },
}, stmts }, []shaderir.Type{t}, stmts
case *ast.UnaryExpr: case *ast.UnaryExpr:
var op shaderir.Op var op shaderir.Op
switch e.Op { switch e.Op {
@ -992,12 +1011,12 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
op = shaderir.NotOp op = shaderir.NotOp
default: default:
cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op))
return nil, nil return nil, nil, nil
} }
exprs, stmts := cs.parseExpr(block, e.X) exprs, t, stmts := cs.parseExpr(block, e.X)
if len(exprs) != 1 { if len(exprs) != 1 {
cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a unary operator: %s", e.X)) cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a unary operator: %s", e.X))
return nil, nil return nil, nil, nil
} }
return []shaderir.Expr{ return []shaderir.Expr{
{ {
@ -1005,9 +1024,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
Op: op, Op: op,
Exprs: exprs, Exprs: exprs,
}, },
}, stmts }, t, stmts
default: default:
cs.addError(e.Pos(), fmt.Sprintf("expression not implemented: %#v", e)) cs.addError(e.Pos(), fmt.Sprintf("expression not implemented: %#v", e))
} }
return nil, nil return nil, nil, nil
} }