shader: Add define (:=)

This commit is contained in:
Hajime Hoshi 2020-06-01 02:23:27 +09:00
parent 6fa7b4bb5a
commit 37cc30bc38
2 changed files with 72 additions and 29 deletions

View File

@ -304,20 +304,21 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
case *ast.AssignStmt: case *ast.AssignStmt:
switch l.Tok { switch l.Tok {
case token.DEFINE: case token.DEFINE:
for i, s := range l.Lhs { for i, e := range l.Lhs {
v := variable{ v := variable{
name: s.(*ast.Ident).Name, name: e.(*ast.Ident).Name,
}
if len(l.Rhs) > 0 {
v.typ = cs.detectType(block, l.Rhs[i])
} }
v.typ = cs.detectType(block, l.Rhs[i])
v.init = cs.parseExpr(block, l.Rhs[i])
block.vars = append(block.vars, v) block.vars = append(block.vars, v)
} block.ir.LocalVars = append(block.ir.LocalVars, v.typ.ir)
for range l.Rhs { block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{
/*block.stmts = append(block.stmts, stmt{ Type: shaderir.Assign,
stmtType: stmtAssign, Exprs: []shaderir.Expr{
exprs: []ast.Expr{l.Lhs[i], l.Rhs[i]}, cs.parseExpr(block, l.Lhs[i]),
})*/ v.init,
},
})
} }
case token.ASSIGN: case token.ASSIGN:
// TODO: What about the statement `a,b = b,a?` // TODO: What about the statement `a,b = b,a?`
@ -367,18 +368,40 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out
func (s *compileState) detectType(b *block, expr ast.Expr) typ { func (s *compileState) detectType(b *block, expr ast.Expr) typ {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
if e.Kind == token.FLOAT { switch e.Kind {
case token.FLOAT:
return typ{ return typ{
ir: shaderir.Type{Main: shaderir.Float}, ir: shaderir.Type{Main: shaderir.Float},
} }
} case token.INT:
if e.Kind == token.INT {
return typ{ return typ{
ir: shaderir.Type{Main: shaderir.Int}, ir: shaderir.Type{Main: shaderir.Int},
} }
} }
s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value))
return typ{} return typ{}
case *ast.CallExpr:
n := e.Fun.(*ast.Ident).Name
f, ok := shaderir.ParseBuiltinFunc(n)
if ok {
switch f {
case shaderir.Vec2F:
return typ{ir: shaderir.Type{Main: shaderir.Vec2}}
case shaderir.Vec3F:
return typ{ir: shaderir.Type{Main: shaderir.Vec3}}
case shaderir.Vec4F:
return typ{ir: shaderir.Type{Main: shaderir.Vec4}}
case shaderir.Mat2F:
return typ{ir: shaderir.Type{Main: shaderir.Mat2}}
case shaderir.Mat3F:
return typ{ir: shaderir.Type{Main: shaderir.Mat3}}
case shaderir.Mat4F:
return typ{ir: shaderir.Type{Main: shaderir.Mat4}}
// TODO: Add more functions
}
}
s.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n))
return typ{}
case *ast.CompositeLit: case *ast.CompositeLit:
return s.parseType(e.Type) return s.parseType(e.Type)
case *ast.Ident: case *ast.Ident:
@ -440,7 +463,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) shaderir.Expr {
cs.parseExpr(block, e.Fun), cs.parseExpr(block, e.Fun),
} }
for _, a := range e.Args { for _, a := range e.Args {
exprs = append(exprs, cs.parseExpr(block, a)) e := cs.parseExpr(block, a)
// TODO: Convert integer literals to float literals if necessary.
exprs = append(exprs, e)
} }
return shaderir.Expr{ return shaderir.Expr{
Type: shaderir.Call, Type: shaderir.Call,

View File

@ -60,7 +60,7 @@ func Foo(foo vec2) vec4 {
}`, }`,
}, },
{ {
Name: "func multiple out params", Name: "multiple out params",
Src: `package main Src: `package main
func Foo(foo vec4) (float, float, float, float) { func Foo(foo vec4) (float, float, float, float) {
@ -104,23 +104,41 @@ func Foo(foo vec2) vec4 {
} }
l1 = l2; l1 = l2;
return; return;
}`,
},
{
Name: "define",
Src: `package main
func Foo(foo vec2) vec4 {
r := vec4(foo, 0, 1)
return r
}`,
// TODO: number literals must be floats.
VS: `void F0(in vec2 l0, out vec4 l1) {
vec4 l2 = vec4(0.0);
l2 = vec4(l0, 0, 1);
l1 = l2;
return;
}`, }`,
}, },
} }
for _, tc := range tests { for _, tc := range tests {
s, err := Compile([]byte(tc.Src)) t.Run(tc.Name, func(t *testing.T) {
if err != nil { s, err := Compile([]byte(tc.Src))
t.Error(err) if err != nil {
continue t.Error(err)
} return
vs, fs := s.Glsl()
if got, want := vs, tc.VS+"\n"; got != want {
t.Errorf("%s: got: %v, want: %v", tc.Name, got, want)
}
if tc.FS != "" {
if got, want := fs, tc.FS+"\n"; got != want {
t.Errorf("%s: got: %v, want: %v", tc.Name, got, want)
} }
} vs, fs := s.Glsl()
if got, want := vs, tc.VS+"\n"; got != want {
t.Errorf("got: %v, want: %v", got, want)
}
if tc.FS != "" {
if got, want := fs, tc.FS+"\n"; got != want {
t.Errorf("got: %v, want: %v", got, want)
}
}
})
} }
} }