shader: Refactoring: Remove detectType

This commit is contained in:
Hajime Hoshi 2020-06-20 00:20:17 +09:00
parent cfc8b4505d
commit f36d6c02a9
8 changed files with 136 additions and 138 deletions

View File

@ -308,15 +308,42 @@ func (cs *compileState) parseDecl(b *block, d ast.Decl) {
}
}
// functionReturnTypes returns the original returning value types, if the given expression is call.
//
// Note that parseExpr returns the returning types for IR, not the original function.
func (cs *compileState) functionReturnTypes(block *block, expr ast.Expr) ([]shaderir.Type, bool) {
call, ok := expr.(*ast.CallExpr)
if !ok {
return nil, false
}
ident, ok := call.Fun.(*ast.Ident)
if !ok {
return nil, false
}
for _, f := range cs.funcs {
if f.name == ident.Name {
// TODO: Is it correct to combine out-params and return param?
ts := f.ir.OutParams
if f.ir.Return.Main != shaderir.None {
ts = append(ts, f.ir.Return)
}
return ts, true
}
}
return nil, false
}
func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt) {
if len(vs.Names) != len(vs.Values) && len(vs.Values) != 1 && len(vs.Values) != 0 {
s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match"))
return nil, nil, nil
}
var t shaderir.Type
var declt shaderir.Type
if vs.Type != nil {
t = s.parseType(vs.Type)
declt = s.parseType(vs.Type)
}
var (
@ -326,13 +353,19 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
)
for i, n := range vs.Names {
// TODO: Reduce calls of parseExpr
var init ast.Expr
t := declt
switch len(vs.Values) {
case 0:
case 1:
init = vs.Values[0]
if t.Main == shaderir.None {
ts := s.detectType(block, init)
ts, ok := s.functionReturnTypes(block, init)
if !ok {
_, ts, _ = s.parseExpr(block, init)
}
if len(ts) != len(vs.Names) {
s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match"))
continue
@ -342,7 +375,10 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl
default:
init = vs.Values[i]
if t.Main == shaderir.None {
ts := s.detectType(block, init)
ts, ok := s.functionReturnTypes(block, init)
if !ok {
_, ts, _ = s.parseExpr(block, init)
}
if len(ts) > 1 {
s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match"))
}
@ -547,13 +583,18 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
return nil
}
// TODO: Reduce calls of parseExpr
var rhsTypes []shaderir.Type
for i, e := range l.Lhs {
v := variable{
name: e.(*ast.Ident).Name,
}
if len(l.Lhs) == len(l.Rhs) {
ts := cs.detectType(block, l.Rhs[i])
ts, ok := cs.functionReturnTypes(block, l.Rhs[i])
if !ok {
_, ts, _ = cs.parseExpr(block, l.Rhs[i])
}
if len(ts) > 1 {
cs.addError(l.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed"))
}
@ -562,7 +603,11 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
}
} else {
if i == 0 {
rhsTypes = cs.detectType(block, l.Rhs[0])
var ok bool
rhsTypes, ok = cs.functionReturnTypes(block, l.Rhs[0])
if !ok {
_, rhsTypes, _ = cs.parseExpr(block, l.Rhs[0])
}
if len(rhsTypes) != len(l.Lhs) {
cs.addError(l.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed"))
}
@ -817,18 +862,19 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
var (
callee shaderir.Expr
args []shaderir.Expr
argts []shaderir.Type
stmts []shaderir.Stmt
)
// Parse the argument first for the order of the statements.
for _, a := range e.Args {
es, _, ss := cs.parseExpr(block, a)
es, ts, ss := cs.parseExpr(block, a)
if len(es) > 1 && len(e.Args) > 1 {
cs.addError(e.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed: %s", e.Fun))
return nil, nil, nil
}
// TODO: Convert integer literals to float literals if necessary.
args = append(args, es...)
argts = append(argts, ts...)
stmts = append(stmts, ss...)
}
@ -845,7 +891,32 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
// call.
if callee.Type == shaderir.BuiltinFuncExpr {
var t shaderir.Type
// TODO: Decude the type based on the arguments.
switch callee.BuiltinFunc {
case shaderir.Vec2F:
t = shaderir.Type{Main: shaderir.Vec2}
case shaderir.Vec3F:
t = shaderir.Type{Main: shaderir.Vec3}
case shaderir.Vec4F:
t = shaderir.Type{Main: shaderir.Vec4}
case shaderir.Mat2F:
t = shaderir.Type{Main: shaderir.Mat2}
case shaderir.Mat3F:
t = shaderir.Type{Main: shaderir.Mat3}
case shaderir.Mat4F:
t = shaderir.Type{Main: shaderir.Mat4}
case shaderir.Step:
t = argts[1]
case shaderir.Smoothstep:
t = argts[2]
case shaderir.Length, shaderir.Distance, shaderir.Dot:
t = shaderir.Type{Main: shaderir.Float}
case shaderir.Cross:
t = shaderir.Type{Main: shaderir.Vec3}
case shaderir.Texture2DF:
t = shaderir.Type{Main: shaderir.Vec4}
default:
t = argts[0]
}
return []shaderir.Expr{
{
Type: shaderir.Call,
@ -875,7 +946,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
}
if t := f.ir.Return; t.Main != shaderir.None {
if len(outParams) == 0 {
if len(outParams) != 0 {
cs.addError(e.Pos(), fmt.Sprintf("a function returning value cannot have out-params so far: %s", e.Fun))
return nil, nil, nil
}
@ -933,7 +1004,8 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr,
Index: p,
})
}
return exprs, nil, stmts
return exprs, f.ir.OutParams, stmts
case *ast.Ident:
if i, t, ok := block.findLocalVariable(e.Name); ok {
return []shaderir.Expr{

View File

@ -0,0 +1,14 @@
void F0(out vec2 l0) {
vec2 l1 = vec2(0);
vec2 l2 = vec2(0);
vec2 l3 = vec2(0);
F1(l3);
l2 = (1.0) * (l3);
l0 = l2;
return;
}
void F1(out vec2 l0) {
l0 = vec2(0.0);
return;
}

10
internal/shader/testdata/define2.go vendored Normal file
View File

@ -0,0 +1,10 @@
package main
func Foo() vec2 {
x := 1 * Bar()
return x
}
func Bar() vec2 {
return vec2(0)
}

View File

@ -6,3 +6,21 @@ void F0(in vec2 l0, out vec4 l1) {
l1 = vec4(l2, l3);
return;
}
void F1(in vec2 l0, out vec4 l1) {
vec2 l2 = vec2(0);
vec2 l3 = vec2(0);
vec2 l4 = vec2(0);
vec2 l5 = vec2(0);
F2(l2, l3);
l4 = l2;
l5 = l3;
l1 = vec4(l4, l5);
return;
}
void F2(out vec2 l0, out vec2 l1) {
l0 = vec2(0.0);
l1 = vec2(0.0);
return;
}

View File

@ -4,3 +4,12 @@ func Foo(foo vec2) vec4 {
var bar1, bar2 vec2 = foo, foo
return vec4(bar1, bar2)
}
func Foo2(foo vec2) vec4 {
var bar1, bar2 = Bar()
return vec4(bar1, bar2)
}
func Bar() (vec2, vec2) {
return vec2(0), vec2(0)
}

View File

@ -3,6 +3,6 @@ varying vec2 V0;
varying vec4 V1;
void main(void) {
gl_FragColor = vec4(1.0, 0.0, 0.0, 1.0);
gl_FragColor = vec4((gl_FragCoord).x, (V0).y, (V1).z, 1.0);
return;
}

View File

@ -11,7 +11,7 @@ func Vertex(position vec2, texCoord vec2, color vec4) (position vec4, texCoord v
}
func Fragment(position vec4, texCoord vec2, color vec4) vec4 {
return vec4(1, 0, 0, 1)
return vec4(position.x, texCoord.y, color.z, 1)
}
var ScreenSize vec2

View File

@ -17,7 +17,6 @@ package shader
import (
"fmt"
"go/ast"
"go/token"
"github.com/hajimehoshi/ebiten/internal/shaderir"
)
@ -60,127 +59,3 @@ func (cs *compileState) parseType(expr ast.Expr) shaderir.Type {
return shaderir.Type{}
}
}
func (cs *compileState) detectType(b *block, expr ast.Expr) []shaderir.Type {
switch e := expr.(type) {
case *ast.BasicLit:
switch e.Kind {
case token.FLOAT:
return []shaderir.Type{{Main: shaderir.Float}}
case token.INT:
return []shaderir.Type{{Main: shaderir.Int}}
}
cs.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value))
return nil
case *ast.BinaryExpr:
t1, t2 := cs.detectType(b, e.X), cs.detectType(b, e.Y)
if len(t1) != 1 || len(t2) != 1 {
cs.addError(expr.Pos(), fmt.Sprintf("binary operator cannot be used for multiple-value context: %v", expr))
return nil
}
if !t1[0].Equal(&t2[0]) {
// TODO: Move this checker to shaderir
if t1[0].Main == shaderir.Float {
switch t2[0].Main {
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
return t2
}
}
if t2[0].Main == shaderir.Float {
switch t1[0].Main {
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
return t1
}
}
cs.addError(expr.Pos(), fmt.Sprintf("types between a binary operator don't match"))
return nil
}
return t1
case *ast.CallExpr:
n := e.Fun.(*ast.Ident).Name
f, ok := shaderir.ParseBuiltinFunc(n)
if ok {
switch f {
case shaderir.Vec2F:
return []shaderir.Type{{Main: shaderir.Vec2}}
case shaderir.Vec3F:
return []shaderir.Type{{Main: shaderir.Vec3}}
case shaderir.Vec4F:
return []shaderir.Type{{Main: shaderir.Vec4}}
case shaderir.Mat2F:
return []shaderir.Type{{Main: shaderir.Mat2}}
case shaderir.Mat3F:
return []shaderir.Type{{Main: shaderir.Mat3}}
case shaderir.Mat4F:
return []shaderir.Type{{Main: shaderir.Mat4}}
default:
// TODO: Add more functions
cs.addError(expr.Pos(), fmt.Sprintf("detecting types is not implemented for: %s", n))
}
}
for _, f := range cs.funcs {
if f.name == n {
// TODO: Is it correct to combine out-params and return param?
ts := f.ir.OutParams
if f.ir.Return.Main != shaderir.None {
ts = append(ts, f.ir.Return)
}
return ts
}
}
cs.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n))
return nil
case *ast.CompositeLit:
return []shaderir.Type{cs.parseType(e.Type)}
case *ast.Ident:
n := e.Name
for _, v := range b.vars {
if v.name == n {
return []shaderir.Type{v.typ}
}
}
if b == &cs.global {
for i, v := range cs.uniforms {
if v == n {
return []shaderir.Type{cs.ir.Uniforms[i]}
}
}
}
if b.outer != nil {
return cs.detectType(b.outer, e)
}
cs.addError(expr.Pos(), fmt.Sprintf("unexpected identifier: %s", n))
return nil
case *ast.SelectorExpr:
t := cs.detectType(b, e.X)
if len(t) != 1 {
cs.addError(expr.Pos(), fmt.Sprintf("selector is not available in multiple-value context: %v", e.X))
return nil
}
switch t[0].Main {
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
switch len(e.Sel.Name) {
case 1:
return []shaderir.Type{{Main: shaderir.Float}}
case 2:
return []shaderir.Type{{Main: shaderir.Vec2}}
case 3:
return []shaderir.Type{{Main: shaderir.Float}}
case 4:
return []shaderir.Type{{Main: shaderir.Float}}
default:
cs.addError(expr.Pos(), fmt.Sprintf("invalid selector: %s", e.Sel.Name))
}
return nil
case shaderir.Struct:
cs.addError(expr.Pos(), fmt.Sprintf("selector for a struct is not implemented yet"))
return nil
default:
cs.addError(expr.Pos(), fmt.Sprintf("selector is not available for: %v", expr))
return nil
}
default:
cs.addError(expr.Pos(), fmt.Sprintf("detecting type not implemented: %#v", expr))
return nil
}
}