shader: Let detectType return multiple types

This commit is contained in:
Hajime Hoshi 2020-06-07 23:32:50 +09:00
parent c986da8970
commit 3fd8062fbe

View File

@ -320,7 +320,11 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
if len(vs.Values) > 0 { if len(vs.Values) > 0 {
init = vs.Values[i] init = vs.Values[i]
if t.Main == shaderir.None { if t.Main == shaderir.None {
t = s.detectType(block, init) ts := s.detectType(block, init)
if len(ts) > 1 {
s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match"))
}
t = ts[0]
} }
} }
name := n.Name name := n.Name
@ -510,7 +514,11 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
v := variable{ v := variable{
name: e.(*ast.Ident).Name, name: e.(*ast.Ident).Name,
} }
v.typ = cs.detectType(block, l.Rhs[i]) ts := cs.detectType(block, l.Rhs[i])
if len(ts) > 1 {
cs.addError(l.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match"))
}
v.typ = ts[0]
block.vars = append(block.vars, v) block.vars = append(block.vars, v)
block.ir.LocalVars = append(block.ir.LocalVars, v.typ) block.ir.LocalVars = append(block.ir.LocalVars, v.typ)
@ -574,52 +582,52 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
return block return block
} }
func (s *compileState) detectType(b *block, expr ast.Expr) shaderir.Type { func (s *compileState) detectType(b *block, expr ast.Expr) []shaderir.Type {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
case token.FLOAT: case token.FLOAT:
return shaderir.Type{Main: shaderir.Float} return []shaderir.Type{{Main: shaderir.Float}}
case token.INT: case token.INT:
return shaderir.Type{Main: shaderir.Int} return []shaderir.Type{{Main: shaderir.Int}}
} }
s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value))
return shaderir.Type{} return nil
case *ast.CallExpr: case *ast.CallExpr:
n := e.Fun.(*ast.Ident).Name n := e.Fun.(*ast.Ident).Name
f, ok := shaderir.ParseBuiltinFunc(n) f, ok := shaderir.ParseBuiltinFunc(n)
if ok { if ok {
switch f { switch f {
case shaderir.Vec2F: case shaderir.Vec2F:
return shaderir.Type{Main: shaderir.Vec2} return []shaderir.Type{{Main: shaderir.Vec2}}
case shaderir.Vec3F: case shaderir.Vec3F:
return shaderir.Type{Main: shaderir.Vec3} return []shaderir.Type{{Main: shaderir.Vec3}}
case shaderir.Vec4F: case shaderir.Vec4F:
return shaderir.Type{Main: shaderir.Vec4} return []shaderir.Type{{Main: shaderir.Vec4}}
case shaderir.Mat2F: case shaderir.Mat2F:
return shaderir.Type{Main: shaderir.Mat2} return []shaderir.Type{{Main: shaderir.Mat2}}
case shaderir.Mat3F: case shaderir.Mat3F:
return shaderir.Type{Main: shaderir.Mat3} return []shaderir.Type{{Main: shaderir.Mat3}}
case shaderir.Mat4F: case shaderir.Mat4F:
return shaderir.Type{Main: shaderir.Mat4} return []shaderir.Type{{Main: shaderir.Mat4}}
// TODO: Add more functions // TODO: Add more functions
} }
} }
s.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n)) s.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n))
return shaderir.Type{} return nil
case *ast.CompositeLit: case *ast.CompositeLit:
return s.parseType(e.Type) return []shaderir.Type{s.parseType(e.Type)}
case *ast.Ident: case *ast.Ident:
n := e.Name n := e.Name
for _, v := range b.vars { for _, v := range b.vars {
if v.name == n { if v.name == n {
return v.typ return []shaderir.Type{v.typ}
} }
} }
if b == &s.global { if b == &s.global {
for i, v := range s.uniforms { for i, v := range s.uniforms {
if v == n { if v == n {
return s.ir.Uniforms[i] return []shaderir.Type{s.ir.Uniforms[i]}
} }
} }
} }
@ -627,12 +635,12 @@ func (s *compileState) detectType(b *block, expr ast.Expr) shaderir.Type {
return s.detectType(b.outer, e) return s.detectType(b.outer, e)
} }
s.addError(expr.Pos(), fmt.Sprintf("unexpected identifier: %s", n)) s.addError(expr.Pos(), fmt.Sprintf("unexpected identifier: %s", n))
return shaderir.Type{} return nil
//case *ast.SelectorExpr: //case *ast.SelectorExpr:
//return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel)) //return fmt.Sprintf("%s.%s", dumpExpr(e.X), dumpExpr(e.Sel))
default: default:
s.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr)) s.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr))
return shaderir.Type{} return nil
} }
} }