internal/shaderir: replace Mul with ComponentWiseMul and MatrixMul

This is a preparation for DirectX / HLSL.

Updates #1007
This commit is contained in:
Hajime Hoshi 2022-03-13 19:17:44 +09:00
parent 044d41dd2d
commit a617576879
7 changed files with 46 additions and 30 deletions

View File

@ -164,7 +164,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{t}, stmts, true
} }
op, ok := shaderir.OpFromToken(e.Op) op, ok := shaderir.OpFromToken(e.Op, lhst, rhst)
if !ok { if !ok {
cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op)) cs.addError(e.Pos(), fmt.Sprintf("unexpected operator: %s", e.Op))
return nil, nil, nil, false return nil, nil, nil, false
@ -178,7 +178,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr: case lhs[0].Type == shaderir.NumberExpr && rhs[0].Type != shaderir.NumberExpr:
switch rhst.Main { switch rhst.Main {
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op != shaderir.Mul { if op != shaderir.MatrixMul {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -199,7 +199,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
case lhs[0].Type != shaderir.NumberExpr && rhs[0].Type == shaderir.NumberExpr: case lhs[0].Type != shaderir.NumberExpr && rhs[0].Type == shaderir.NumberExpr:
switch lhst.Main { switch lhst.Main {
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op != shaderir.Mul && op != shaderir.Div { if op != shaderir.MatrixMul && op != shaderir.Div {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -228,7 +228,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
t = rhst t = rhst
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op != shaderir.Mul { if op != shaderir.MatrixMul {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -242,7 +242,7 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
t = lhst t = lhst
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op != shaderir.Mul && op != shaderir.Div { if op != shaderir.MatrixMul && op != shaderir.Div {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
} }
@ -251,13 +251,13 @@ func (cs *compileState) parseExpr(block *block, expr ast.Expr, markLocalVariable
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false return nil, nil, nil, false
} }
case op == shaderir.Mul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 || case op == shaderir.MatrixMul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 ||
lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2): lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2):
t = shaderir.Type{Main: shaderir.Vec2} t = shaderir.Type{Main: shaderir.Vec2}
case op == shaderir.Mul && (lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 || case op == shaderir.MatrixMul && (lhst.Main == shaderir.Vec3 && rhst.Main == shaderir.Mat3 ||
lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3): lhst.Main == shaderir.Mat3 && rhst.Main == shaderir.Vec3):
t = shaderir.Type{Main: shaderir.Vec3} t = shaderir.Type{Main: shaderir.Vec3}
case op == shaderir.Mul && (lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 || case op == shaderir.MatrixMul && (lhst.Main == shaderir.Vec4 && rhst.Main == shaderir.Mat4 ||
lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4): lhst.Main == shaderir.Mat4 && rhst.Main == shaderir.Vec4):
t = shaderir.Type{Main: shaderir.Vec4} t = shaderir.Type{Main: shaderir.Vec4}
default: default:

View File

@ -62,20 +62,6 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN: case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN:
var op shaderir.Op
switch stmt.Tok {
case token.ADD_ASSIGN:
op = shaderir.Add
case token.SUB_ASSIGN:
op = shaderir.Sub
case token.MUL_ASSIGN:
op = shaderir.Mul
case token.QUO_ASSIGN:
op = shaderir.Div
case token.REM_ASSIGN:
op = shaderir.ModOp
}
rhs, rts, ss, ok := cs.parseExpr(block, stmt.Rhs[0], true) rhs, rts, ss, ok := cs.parseExpr(block, stmt.Rhs[0], true)
if !ok { if !ok {
return nil, false return nil, false
@ -88,6 +74,24 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
var op shaderir.Op
switch stmt.Tok {
case token.ADD_ASSIGN:
op = shaderir.Add
case token.SUB_ASSIGN:
op = shaderir.Sub
case token.MUL_ASSIGN:
if lts[0].IsMatrix() || rts[0].IsMatrix() {
op = shaderir.MatrixMul
} else {
op = shaderir.ComponentWiseMul
}
case token.QUO_ASSIGN:
op = shaderir.Div
case token.REM_ASSIGN:
op = shaderir.ModOp
}
// Treat an integer literal as an integer constant value. // Treat an integer literal as an integer constant value.
wasTypedConstInt := rhs[0].ConstType == shaderir.ConstTypeInt wasTypedConstInt := rhs[0].ConstType == shaderir.ConstTypeInt
if rhs[0].Type == shaderir.NumberExpr && rts[0].Main == shaderir.Int { if rhs[0].Type == shaderir.NumberExpr && rts[0].Main == shaderir.Int {
@ -116,14 +120,14 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if (op == shaderir.Mul || op == shaderir.Div) && rts[0].Main == shaderir.Float { if (op == shaderir.MatrixMul || op == shaderir.Div) && rts[0].Main == shaderir.Float {
// OK // OK
} else if (lts[0].Main == shaderir.Vec2 || } else if (lts[0].Main == shaderir.Vec2 ||
lts[0].Main == shaderir.Vec3 || lts[0].Main == shaderir.Vec3 ||
lts[0].Main == shaderir.Vec4) && lts[0].Main == shaderir.Vec4) &&
rts[0].Main == shaderir.Float { rts[0].Main == shaderir.Float {
// OK // OK
} else if op == shaderir.Mul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) || } else if op == shaderir.MatrixMul && ((lts[0].Main == shaderir.Vec2 && rts[0].Main == shaderir.Mat2) ||
(lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) || (lts[0].Main == shaderir.Vec3 && rts[0].Main == shaderir.Mat3) ||
(lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4)) { (lts[0].Main == shaderir.Vec4 && rts[0].Main == shaderir.Mat4)) {
// OK // OK

View File

@ -28,7 +28,7 @@ func opString(op shaderir.Op) string {
return "-" return "-"
case shaderir.NotOp: case shaderir.NotOp:
return "!" return "!"
case shaderir.Mul: case shaderir.ComponentWiseMul, shaderir.MatrixMul:
return "*" return "*"
case shaderir.Div: case shaderir.Div:
return "/" return "/"

View File

@ -28,7 +28,7 @@ func opString(op shaderir.Op) string {
return "-" return "-"
case shaderir.NotOp: case shaderir.NotOp:
return "!" return "!"
case shaderir.Mul: case shaderir.ComponentWiseMul, shaderir.MatrixMul:
return "*" return "*"
case shaderir.Div: case shaderir.Div:
return "/" return "/"

View File

@ -137,7 +137,8 @@ const (
Add Op = iota Add Op = iota
Sub Sub
NotOp NotOp
Mul // TODO: Separate Hadamard-product and Matrix-product ComponentWiseMul
MatrixMul
Div Div
ModOp ModOp
LeftShift LeftShift
@ -155,7 +156,7 @@ const (
OrOr OrOr
) )
func OpFromToken(t token.Token) (Op, bool) { func OpFromToken(t token.Token, lhs, rhs Type) (Op, bool) {
switch t { switch t {
case token.ADD: case token.ADD:
return Add, true return Add, true
@ -164,7 +165,10 @@ func OpFromToken(t token.Token) (Op, bool) {
case token.NOT: case token.NOT:
return NotOp, true return NotOp, true
case token.MUL: case token.MUL:
return Mul, true if lhs.IsMatrix() || rhs.IsMatrix() {
return MatrixMul, true
}
return ComponentWiseMul, true
case token.QUO: case token.QUO:
return Div, true return Div, true
case token.REM: case token.REM:

View File

@ -104,6 +104,14 @@ func (t *Type) FloatNum() int {
} }
} }
func (t *Type) IsMatrix() bool {
switch t.Main {
case Mat2, Mat3, Mat4:
return true
}
return false
}
type BasicType int type BasicType int
const ( const (

View File

@ -213,7 +213,7 @@ func defaultVertexFunc(invertY bool) shaderir.VertexFunc {
}, },
{ {
Type: shaderir.Binary, Type: shaderir.Binary,
Op: shaderir.Mul, Op: shaderir.MatrixMul,
Exprs: []shaderir.Expr{ Exprs: []shaderir.Expr{
projectionMatrix(invertY), projectionMatrix(invertY),
vertexPosition(), vertexPosition(),