From a61757687976a0980d6d2aadffcfa81de1ec5e7b Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Sun, 13 Mar 2022 19:17:44 +0900 Subject: [PATCH] internal/shaderir: replace Mul with ComponentWiseMul and MatrixMul This is a preparation for DirectX / HLSL. Updates #1007 --- internal/shader/expr.go | 16 +++++++-------- internal/shader/stmt.go | 36 +++++++++++++++++++--------------- internal/shaderir/glsl/type.go | 2 +- internal/shaderir/msl/type.go | 2 +- internal/shaderir/program.go | 10 +++++++--- internal/shaderir/type.go | 8 ++++++++ internal/testing/shader.go | 2 +- 7 files changed, 46 insertions(+), 30 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index 1e0167edb..7ba1fd4f9 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -164,7 +164,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable }, []shaderir.Type{t}, stmts, true } - op, ok := shaderir.OpFromToken(e.Op) + op, ok := shaderir.OpFromToken(e.Op, lhst, rhst) if !ok { cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) return nil, nil, nil, false @@ -178,7 +178,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr: switch rhst.Main { case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op != shaderir.Mul { + if op != shaderir.MatrixMul { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } @@ -199,7 +199,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable case lhs[0].Type != shaderir.NumberExpr && rhs[0].Type == shaderir.NumberExpr: switch lhst.Main { case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op != shaderir.Mul && op != shaderir.Div { + if op != shaderir.MatrixMul && op != shaderir.Div { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } @@ -228,7 +228,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: t = rhst case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op != shaderir.Mul { + if op != shaderir.MatrixMul { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } @@ -242,7 +242,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: t = lhst case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op != shaderir.Mul && op != shaderir.Div { + if op != shaderir.MatrixMul && op != shaderir.Div { cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } @@ -251,13 +251,13 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) return nil, nil, nil, false } - case op == shaderir.Mul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 || + case op == shaderir.MatrixMul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 || lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2): t = shaderir.Type{Main: shaderir.Vec2} - case op == shaderir.Mul && (lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 || + case op == shaderir.MatrixMul && (lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 || lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3): t = shaderir.Type{Main: shaderir.Vec3} - case op == shaderir.Mul && (lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 || + case op == shaderir.MatrixMul && (lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 || lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4): t = shaderir.Type{Main: shaderir.Vec4} default: diff --git a/internal/shader/stmt.go b/internal/shader/stmt.go index 8e69d1d9a..2aefdd235 100644 --- a/internal/shader/stmt.go +++ b/internal/shader/stmt.go @@ -62,20 +62,6 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } stmts = append(stmts, ss...) case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN: - var op shaderir.Op - switch stmt.Tok { - case token.ADD_ASSIGN: - op = shaderir.Add - case token.SUB_ASSIGN: - op = shaderir.Sub - case token.MUL_ASSIGN: - op = shaderir.Mul - case token.QUO_ASSIGN: - op = shaderir.Div - case token.REM_ASSIGN: - op = shaderir.ModOp - } - rhs, rts, ss, ok := cs.parseExpr(block, stmt.Rhs[0], true) if !ok { return nil, false @@ -88,6 +74,24 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP } stmts = append(stmts, ss...) + var op shaderir.Op + switch stmt.Tok { + case token.ADD_ASSIGN: + op = shaderir.Add + case token.SUB_ASSIGN: + op = shaderir.Sub + case token.MUL_ASSIGN: + if lts[0].IsMatrix() || rts[0].IsMatrix() { + op = shaderir.MatrixMul + } else { + op = shaderir.ComponentWiseMul + } + case token.QUO_ASSIGN: + op = shaderir.Div + case token.REM_ASSIGN: + op = shaderir.ModOp + } + // Treat an integer literal as an integer constant value. wasTypedConstInt := rhs[0].ConstType == shaderir.ConstTypeInt if rhs[0].Type == shaderir.NumberExpr && rts[0].Main == shaderir.Int { @@ -116,14 +120,14 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP return nil, false } case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if (op == shaderir.Mul || op == shaderir.Div) && rts[0].Main == shaderir.Float { + if (op == shaderir.MatrixMul || op == shaderir.Div) && rts[0].Main == shaderir.Float { // OK } else if (lts[0].Main == shaderir.Vec2 || lts[0].Main == shaderir.Vec3 || lts[0].Main == shaderir.Vec4) && rts[0].Main == shaderir.Float { // OK - } else if op == shaderir.Mul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || + } else if op == shaderir.MatrixMul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || (lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) || (lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4)) { // OK diff --git a/internal/shaderir/glsl/type.go b/internal/shaderir/glsl/type.go index bd1f71c4e..2888a9768 100644 --- a/internal/shaderir/glsl/type.go +++ b/internal/shaderir/glsl/type.go @@ -28,7 +28,7 @@ func opString(op shaderir.Op) string { return "-" case shaderir.NotOp: return "!" - case shaderir.Mul: + case shaderir.ComponentWiseMul, shaderir.MatrixMul: return "*" case shaderir.Div: return "/" diff --git a/internal/shaderir/msl/type.go b/internal/shaderir/msl/type.go index 6484cc31d..319b72b96 100644 --- a/internal/shaderir/msl/type.go +++ b/internal/shaderir/msl/type.go @@ -28,7 +28,7 @@ func opString(op shaderir.Op) string { return "-" case shaderir.NotOp: return "!" - case shaderir.Mul: + case shaderir.ComponentWiseMul, shaderir.MatrixMul: return "*" case shaderir.Div: return "/" diff --git a/internal/shaderir/program.go b/internal/shaderir/program.go index 2ae6d59f3..afb50af3d 100644 --- a/internal/shaderir/program.go +++ b/internal/shaderir/program.go @@ -137,7 +137,8 @@ const ( Add Op = iota Sub NotOp - Mul // TODO: Separate Hadamard-product and Matrix-product + ComponentWiseMul + MatrixMul Div ModOp LeftShift @@ -155,7 +156,7 @@ const ( OrOr ) -func OpFromToken(t token.Token) (Op, bool) { +func OpFromToken(t token.Token, lhs, rhs Type) (Op, bool) { switch t { case token.ADD: return Add, true @@ -164,7 +165,10 @@ func OpFromToken(t token.Token) (Op, bool) { case token.NOT: return NotOp, true case token.MUL: - return Mul, true + if lhs.IsMatrix() || rhs.IsMatrix() { + return MatrixMul, true + } + return ComponentWiseMul, true case token.QUO: return Div, true case token.REM: diff --git a/internal/shaderir/type.go b/internal/shaderir/type.go index 50de40480..1f4b19467 100644 --- a/internal/shaderir/type.go +++ b/internal/shaderir/type.go @@ -104,6 +104,14 @@ func (t *Type) FloatNum() int { } } +func (t *Type) IsMatrix() bool { + switch t.Main { + case Mat2, Mat3, Mat4: + return true + } + return false +} + type BasicType int const ( diff --git a/internal/testing/shader.go b/internal/testing/shader.go index 22080cd90..a46a9ac87 100644 --- a/internal/testing/shader.go +++ b/internal/testing/shader.go @@ -213,7 +213,7 @@ func defaultVertexFunc(invertY bool) shaderir.VertexFunc { }, { Type: shaderir.Binary, - Op: shaderir.Mul, + Op: shaderir.MatrixMul, Exprs: []shaderir.Expr{ projectionMatrix(invertY), vertexPosition(),