From 37cc30bc38fd5902449f8800b7d6611743a40e22 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Mon, 1 Jun 2020 02:23:27 +0900 Subject: [PATCH] shader: Add define (:=) --- internal/shader/shader.go | 55 ++++++++++++++++++++++++---------- internal/shader/shader_test.go | 46 +++++++++++++++++++--------- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/internal/shader/shader.go b/internal/shader/shader.go index fcd3ff4f7..5508df8f6 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -304,20 +304,21 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out case *ast.AssignStmt: switch l.Tok { case token.DEFINE: - for i, s := range l.Lhs { + for i, e := range l.Lhs { v := variable{ - name: s.(*ast.Ident).Name, - } - if len(l.Rhs) > 0 { - v.typ = cs.detectType(block, l.Rhs[i]) + name: e.(*ast.Ident).Name, } + v.typ = cs.detectType(block, l.Rhs[i]) + v.init = cs.parseExpr(block, l.Rhs[i]) block.vars = append(block.vars, v) - } - for range l.Rhs { - /*block.stmts = append(block.stmts, stmt{ - stmtType: stmtAssign, - exprs: []ast.Expr{l.Lhs[i], l.Rhs[i]}, - })*/ + block.ir.LocalVars = append(block.ir.LocalVars, v.typ.ir) + block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ + Type: shaderir.Assign, + Exprs: []shaderir.Expr{ + cs.parseExpr(block, l.Lhs[i]), + v.init, + }, + }) } case token.ASSIGN: // TODO: What about the statement `a,b = b,a?` @@ -367,18 +368,40 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out func (s *compileState) detectType(b *block, expr ast.Expr) typ { switch e := expr.(type) { case *ast.BasicLit: - if e.Kind == token.FLOAT { + switch e.Kind { + case token.FLOAT: return typ{ ir: shaderir.Type{Main: shaderir.Float}, } - } - if e.Kind == token.INT { + case token.INT: return typ{ ir: shaderir.Type{Main: shaderir.Int}, } } s.addError(expr.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) return typ{} + case *ast.CallExpr: + n := e.Fun.(*ast.Ident).Name + f, ok := shaderir.ParseBuiltinFunc(n) + if ok { + switch f { + case shaderir.Vec2F: + return typ{ir: shaderir.Type{Main: shaderir.Vec2}} + case shaderir.Vec3F: + return typ{ir: shaderir.Type{Main: shaderir.Vec3}} + case shaderir.Vec4F: + return typ{ir: shaderir.Type{Main: shaderir.Vec4}} + case shaderir.Mat2F: + return typ{ir: shaderir.Type{Main: shaderir.Mat2}} + case shaderir.Mat3F: + return typ{ir: shaderir.Type{Main: shaderir.Mat3}} + case shaderir.Mat4F: + return typ{ir: shaderir.Type{Main: shaderir.Mat4}} + // TODO: Add more functions + } + } + s.addError(expr.Pos(), fmt.Sprintf("unexpected call: %s", n)) + return typ{} case *ast.CompositeLit: return s.parseType(e.Type) case *ast.Ident: @@ -440,7 +463,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) shaderir.Expr { cs.parseExpr(block, e.Fun), } for _, a := range e.Args { - exprs = append(exprs, cs.parseExpr(block, a)) + e := cs.parseExpr(block, a) + // TODO: Convert integer literals to float literals if necessary. + exprs = append(exprs, e) } return shaderir.Expr{ Type: shaderir.Call, diff --git a/internal/shader/shader_test.go b/internal/shader/shader_test.go index 3232c21e1..d91d959e8 100644 --- a/internal/shader/shader_test.go +++ b/internal/shader/shader_test.go @@ -60,7 +60,7 @@ func Foo(foo vec2) vec4 { }`, }, { - Name: "func multiple out params", + Name: "multiple out params", Src: `package main func Foo(foo vec4) (float, float, float, float) { @@ -104,23 +104,41 @@ func Foo(foo vec2) vec4 { } l1 = l2; return; +}`, + }, + { + Name: "define", + Src: `package main + +func Foo(foo vec2) vec4 { + r := vec4(foo, 0, 1) + return r +}`, + // TODO: number literals must be floats. + VS: `void F0(in vec2 l0, out vec4 l1) { + vec4 l2 = vec4(0.0); + l2 = vec4(l0, 0, 1); + l1 = l2; + return; }`, }, } for _, tc := range tests { - s, err := Compile([]byte(tc.Src)) - if err != nil { - t.Error(err) - continue - } - vs, fs := s.Glsl() - if got, want := vs, tc.VS+"\n"; got != want { - t.Errorf("%s: got: %v, want: %v", tc.Name, got, want) - } - if tc.FS != "" { - if got, want := fs, tc.FS+"\n"; got != want { - t.Errorf("%s: got: %v, want: %v", tc.Name, got, want) + t.Run(tc.Name, func(t *testing.T) { + s, err := Compile([]byte(tc.Src)) + if err != nil { + t.Error(err) + return } - } + vs, fs := s.Glsl() + if got, want := vs, tc.VS+"\n"; got != want { + t.Errorf("got: %v, want: %v", got, want) + } + if tc.FS != "" { + if got, want := fs, tc.FS+"\n"; got != want { + t.Errorf("got: %v, want: %v", got, want) + } + } + }) } }