diff --git a/internal/shader/shader.go b/internal/shader/shader.go index df76390e8..5ef7ab0df 100644 --- a/internal/shader/shader.go +++ b/internal/shader/shader.go @@ -95,20 +95,20 @@ type block struct { ir shaderir.Block } -func (b *block) findLocalVariable(name string) (int, bool) { +func (b *block) findLocalVariable(name string) (int, shaderir.Type, bool) { idx := 0 for outer := b.outer; outer != nil; outer = outer.outer { idx += len(outer.vars) } for i, v := range b.vars { if v.name == name { - return idx + i, true + return idx + i, v.typ, true } } if b.outer != nil { return b.outer.findLocalVariable(name) } - return 0, false + return 0, shaderir.Type{}, false } type ParseError struct { @@ -357,7 +357,7 @@ func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variabl }) if len(vs.Values) > 1 || (len(vs.Values) == 1 && len(inits) == 0) { - es, ss := s.parseExpr(block, init) + es, _, ss := s.parseExpr(block, init) inits = append(inits, es...) stmts = append(stmts, ss...) } @@ -594,9 +594,9 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out case token.REM_ASSIGN: op = shaderir.ModOp } - rhs, stmts := cs.parseExpr(block, l.Rhs[0]) + rhs, _, stmts := cs.parseExpr(block, l.Rhs[0]) block.ir.Stmts = append(block.ir.Stmts, stmts...) - lhs, stmts := cs.parseExpr(block, l.Lhs[0]) + lhs, _, stmts := cs.parseExpr(block, l.Lhs[0]) block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ Type: shaderir.Assign, @@ -627,7 +627,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out cs.parseDecl(block, l.Decl) case *ast.ReturnStmt: for i, r := range l.Results { - exprs, stmts := cs.parseExpr(block, r) + exprs, _, stmts := cs.parseExpr(block, r) block.ir.Stmts = append(block.ir.Stmts, stmts...) if len(exprs) == 0 { continue @@ -651,7 +651,7 @@ func (cs *compileState) parseBlock(outer *block, b *ast.BlockStmt, inParams, out Type: shaderir.Return, }) case *ast.ExprStmt: - exprs, stmts := cs.parseExpr(block, l.X) + exprs, _, stmts := cs.parseExpr(block, l.X) block.ir.Stmts = append(block.ir.Stmts, stmts...) for _, expr := range exprs { if expr.Type != shaderir.Call { @@ -675,13 +675,13 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr) for i := range lhs { // Prase RHS first for the order of the statements. if len(lhs) == len(rhs) { - rhs, stmts := cs.parseExpr(block, rhs[i]) + rhs, _, stmts := cs.parseExpr(block, rhs[i]) if len(rhs) > 1 { cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) } block.ir.Stmts = append(block.ir.Stmts, stmts...) - lhs, stmts := cs.parseExpr(block, lhs[i]) + lhs, _, stmts := cs.parseExpr(block, lhs[i]) block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ @@ -691,14 +691,14 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr) } else { if i == 0 { var stmts []shaderir.Stmt - rhsExprs, stmts = cs.parseExpr(block, rhs[0]) + rhsExprs, _, stmts = cs.parseExpr(block, rhs[0]) if len(rhsExprs) != len(lhs) { cs.addError(pos, fmt.Sprintf("single-value context and multiple-value context cannot be mixed")) } block.ir.Stmts = append(block.ir.Stmts, stmts...) } - lhs, stmts := cs.parseExpr(block, lhs[i]) + lhs, _, stmts := cs.parseExpr(block, lhs[i]) block.ir.Stmts = append(block.ir.Stmts, stmts...) block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ @@ -709,7 +709,7 @@ func (cs *compileState) assign(block *block, pos token.Pos, lhs, rhs []ast.Expr) } } -func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, []shaderir.Stmt) { +func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt) { switch e := expr.(type) { case *ast.BasicLit: switch e.Kind { @@ -717,26 +717,26 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, v, err := strconv.ParseInt(e.Value, 10, 32) if err != nil { cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) - return nil, nil + return nil, nil, nil } return []shaderir.Expr{ { Type: shaderir.IntExpr, Int: int32(v), }, - }, nil + }, []shaderir.Type{{Main: shaderir.Int}}, nil case token.FLOAT: v, err := strconv.ParseFloat(e.Value, 32) if err != nil { cs.addError(e.Pos(), fmt.Sprintf("unexpected literal: %s", e.Value)) - return nil, nil + return nil, nil, nil } return []shaderir.Expr{ { Type: shaderir.FloatExpr, Float: float32(v), }, - }, nil + }, []shaderir.Type{{Main: shaderir.Float}}, nil default: cs.addError(e.Pos(), fmt.Sprintf("literal not implemented: %#v", e)) } @@ -783,33 +783,36 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, op = shaderir.OrOr default: cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) - return nil, nil + return nil, nil, nil } var stmts []shaderir.Stmt // Prase RHS first for the order of the statements. - rhs, ss := cs.parseExpr(block, e.Y) + rhs, t0, ss := cs.parseExpr(block, e.Y) if len(rhs) != 1 { cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a binary operator: %s", e.Y)) - return nil, nil + return nil, nil, nil } stmts = append(stmts, ss...) - lhs, ss := cs.parseExpr(block, e.X) + lhs, t1, ss := cs.parseExpr(block, e.X) if len(lhs) != 1 { cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a binary operator: %s", e.X)) - return nil, nil + return nil, nil, nil } stmts = append(stmts, ss...) + // TODO: Check the compatibility of t0 and t1 + _ = t1 + return []shaderir.Expr{ { Type: shaderir.Binary, Op: op, Exprs: []shaderir.Expr{lhs[0], rhs[0]}, }, - }, stmts + }, t0, stmts case *ast.CallExpr: var ( callee shaderir.Expr @@ -819,10 +822,10 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, // Parse the argument first for the order of the statements. for _, a := range e.Args { - es, ss := cs.parseExpr(block, a) + es, _, ss := cs.parseExpr(block, a) if len(es) > 1 && len(e.Args) > 1 { cs.addError(e.Pos(), fmt.Sprintf("single-value context and multiple-value context cannot be mixed: %s", e.Fun)) - return nil, nil + return nil, nil, nil } // TODO: Convert integer literals to float literals if necessary. args = append(args, es...) @@ -830,10 +833,10 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, } // TODO: When len(ss) is not 0? - es, ss := cs.parseExpr(block, e.Fun) + es, _, ss := cs.parseExpr(block, e.Fun) if len(es) != 1 { cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a callee: %s", e.Fun)) - return nil, nil + return nil, nil, nil } callee = es[0] stmts = append(stmts, ss...) @@ -841,17 +844,19 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, // For built-in functions, we can call this in this position. Return an expression for the function // call. if callee.Type == shaderir.BuiltinFuncExpr { + var t shaderir.Type + // TODO: Decude the type based on the arguments. return []shaderir.Expr{ { Type: shaderir.Call, Exprs: append([]shaderir.Expr{callee}, args...), }, - }, stmts + }, []shaderir.Type{t}, stmts } if callee.Type != shaderir.FunctionExpr { cs.addError(e.Pos(), fmt.Sprintf("function callee must be a funciton name but %s", e.Fun)) - return nil, nil + return nil, nil, nil } f := cs.funcs[callee.Index] @@ -872,7 +877,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, if t := f.ir.Return; t.Main != shaderir.None { if len(outParams) == 0 { cs.addError(e.Pos(), fmt.Sprintf("a function returning value cannot have out-params so far: %s", e.Fun)) - return nil, nil + return nil, nil, nil } idx := len(block.vars) @@ -902,7 +907,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, Type: shaderir.LocalVariable, Index: idx, }, - }, stmts + }, []shaderir.Type{t}, stmts } // Even if the function doesn't return anything, calling the function should be done eariler to keep @@ -928,15 +933,15 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, Index: p, }) } - return exprs, stmts + return exprs, nil, stmts case *ast.Ident: - if i, ok := block.findLocalVariable(e.Name); ok { + if i, t, ok := block.findLocalVariable(e.Name); ok { return []shaderir.Expr{ { Type: shaderir.LocalVariable, Index: i, }, - }, nil + }, []shaderir.Type{t}, nil } if i, ok := cs.findFunction(e.Name); ok { return []shaderir.Expr{ @@ -944,7 +949,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, Type: shaderir.FunctionExpr, Index: i, }, - }, nil + }, nil, nil } if i, ok := cs.findUniformVariable(e.Name); ok { return []shaderir.Expr{ @@ -952,7 +957,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, Type: shaderir.UniformVariable, Index: i, }, - }, nil + }, []shaderir.Type{cs.ir.Uniforms[i]}, nil } if f, ok := shaderir.ParseBuiltinFunc(e.Name); ok { return []shaderir.Expr{ @@ -960,14 +965,28 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, Type: shaderir.BuiltinFuncExpr, BuiltinFunc: f, }, - }, nil + }, nil, nil } cs.addError(e.Pos(), fmt.Sprintf("unexpected identifier: %s", e.Name)) case *ast.SelectorExpr: - exprs, stmts := cs.parseExpr(block, e.X) + exprs, _, stmts := cs.parseExpr(block, e.X) if len(exprs) != 1 { cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a selector: %s", e.X)) - return nil, nil + return nil, nil, nil + } + var t shaderir.Type + switch len(e.Sel.Name) { + case 1: + t.Main = shaderir.Float + case 2: + t.Main = shaderir.Vec2 + case 3: + t.Main = shaderir.Vec3 + case 4: + t.Main = shaderir.Vec4 + default: + cs.addError(e.Pos(), fmt.Sprintf("unexpected swizzling: %s", e.Sel.Name)) + return nil, nil, nil } return []shaderir.Expr{ { @@ -980,7 +999,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, }, }, }, - }, stmts + }, []shaderir.Type{t}, stmts case *ast.UnaryExpr: var op shaderir.Op switch e.Op { @@ -992,12 +1011,12 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, op = shaderir.NotOp default: cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) - return nil, nil + return nil, nil, nil } - exprs, stmts := cs.parseExpr(block, e.X) + exprs, t, stmts := cs.parseExpr(block, e.X) if len(exprs) != 1 { cs.addError(e.Pos(), fmt.Sprintf("multiple-value context is not available at a unary operator: %s", e.X)) - return nil, nil + return nil, nil, nil } return []shaderir.Expr{ { @@ -1005,9 +1024,9 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr) ([]shaderir.Expr, Op: op, Exprs: exprs, }, - }, stmts + }, t, stmts default: cs.addError(e.Pos(), fmt.Sprintf("expression not implemented: %#v", e)) } - return nil, nil + return nil, nil, nil }