internal/shader: refactoring: move type deduction to shaderir package

Updates #2754
This commit is contained in:
Hajime Hoshi 2023-09-13 00:15:17 +09:00
parent 5e30e1ee1d
commit 19413c2805
2 changed files with 84 additions and 91 deletions

View File

@ -134,7 +134,8 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
}
if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) {
t, ok := shaderir.TypeFromBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst)
if !ok {
// TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
return nil, nil, nil, false
@ -142,32 +143,14 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
if lhs[0].Const != nil && rhs[0].Const != nil {
var v gconstant.Value
var t shaderir.Type
switch op {
case token.LAND, token.LOR:
b := gconstant.BoolVal(gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const))
v = gconstant.MakeBool(b)
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
t = shaderir.Type{Main: shaderir.Bool}
}
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ:
v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const))
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
t = shaderir.Type{Main: shaderir.Bool}
}
default:
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
switch {
case lhst.Main == shaderir.Float || rhst.Main == shaderir.Float:
t = shaderir.Type{Main: shaderir.Float}
case lhst.Main == shaderir.Int || rhst.Main == shaderir.Int:
t = shaderir.Type{Main: shaderir.Int}
case lhst.Main == shaderir.Bool || rhst.Main == shaderir.Bool:
t = shaderir.Type{Main: shaderir.Bool}
default:
// If both operands are untyped, keep untyped.
t = shaderir.Type{}
}
}
return []shaderir.Expr{
@ -178,38 +161,6 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}, []shaderir.Type{t}, stmts, true
}
var t shaderir.Type
switch {
case op2 == shaderir.LessThanOp || op2 == shaderir.LessThanEqualOp || op2 == shaderir.GreaterThanOp || op2 == shaderir.GreaterThanEqualOp || op2 == shaderir.EqualOp || op2 == shaderir.NotEqualOp || op2 == shaderir.VectorEqualOp || op2 == shaderir.VectorNotEqualOp || op2 == shaderir.AndAnd || op2 == shaderir.OrOr:
t = shaderir.Type{Main: shaderir.Bool}
case lhs[0].Const != nil && rhs[0].Const == nil:
t = rhst
case lhs[0].Const == nil && rhs[0].Const != nil:
t = lhst
case lhst.Equal(&rhst):
t = lhst
case op2 == shaderir.MatrixMul && lhst.Main == shaderir.Float && rhst.IsMatrix():
t = rhst
case op2 == shaderir.MatrixMul && lhst.IsFloatVector() && rhst.IsMatrix():
t = lhst
case op2 == shaderir.MatrixMul && lhst.IsMatrix() && rhst.Main == shaderir.Float:
t = lhst
case op2 == shaderir.MatrixMul && lhst.IsMatrix() && rhst.IsFloatVector():
t = rhst
case op2 == shaderir.Div && lhst.IsMatrix() && rhst.Main == shaderir.Float:
t = lhst
case lhst.Main == shaderir.Float && rhst.IsFloatVector():
t = rhst
case lhst.Main == shaderir.Int && rhst.IsIntVector():
t = rhst
case lhst.IsFloatVector() && rhst.Main == shaderir.Float:
t = lhst
case lhst.IsIntVector() && rhst.Main == shaderir.Int:
t = lhst
default:
panic(fmt.Sprintf("shaderir: invalid expression: %s %s %s", lhst.String(), e.Op, rhst.String()))
}
return []shaderir.Expr{
{
Type: shaderir.Binary,

View File

@ -68,7 +68,7 @@ func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (
return lhs, rhs, true
}
func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
func TypeFromBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) (Type, bool) {
// If both are untyped consts, compare the constants and try to truncate them if necessary.
if lhst.Main == None && rhst.Main == None {
// Assume that the constant types are already adjusted.
@ -77,19 +77,43 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
}
if op == AndAnd || op == OrOr {
return lhs.Const.Kind() == constant.Bool && rhs.Const.Kind() == constant.Bool
if lhs.Const.Kind() == constant.Bool && rhs.Const.Kind() == constant.Bool {
return Type{Main: Bool}, true
}
return Type{}, false
}
// For %, both operands must be integers if both are constants. Truncatable to an integer is not enough.
if op == ModOp {
return lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int
if lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int {
return Type{Main: Int}, true
}
return Type{}, false
}
if op == And || op == Or || op == Xor {
return lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int
if lhs.Const.Kind() == constant.Int && rhs.Const.Kind() == constant.Int {
return Type{Main: Int}, true
}
return Type{}, false
}
return true
if op == EqualOp || op == NotEqualOp || op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return Type{Main: Bool}, true
}
if lhst.Main == Float || rhst.Main == Float {
return Type{Main: Float}, true
}
if lhst.Main == Int || rhst.Main == Int {
return Type{Main: Int}, true
}
if lhst.Main == Bool || rhst.Main == Bool {
return Type{Main: Bool}, true
}
// If both operands are untyped, keep untyped.
return Type{}, true
}
// Both types must not be untyped.
@ -98,115 +122,133 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
}
if op == AndAnd || op == OrOr {
return lhst.Main == Bool && rhst.Main == Bool
if lhst.Main == Bool && rhst.Main == Bool {
return Type{Main: Bool}, true
}
return Type{}, false
}
if op == VectorEqualOp || op == VectorNotEqualOp {
return (lhst.IsFloatVector() || lhst.IsIntVector()) && (rhst.IsFloatVector() || lhst.IsIntVector()) && lhst.Equal(&rhst)
if (lhst.IsFloatVector() || lhst.IsIntVector()) && (rhst.IsFloatVector() || lhst.IsIntVector()) && lhst.Equal(&rhst) {
return Type{Main: Bool}, true
}
return Type{}, false
}
if op == LessThanOp || op == LessThanEqualOp || op == GreaterThanOp || op == GreaterThanEqualOp {
return (lhst.Main == Int && rhst.Main == Int) || (lhst.Main == Float && rhst.Main == Float)
if (lhst.Main == Int && rhst.Main == Int) || (lhst.Main == Float && rhst.Main == Float) {
return Type{Main: Bool}, true
}
return Type{}, false
}
// Comparing matrices are forbidden (#2187).
if op == EqualOp || op == NotEqualOp {
if lhst.IsMatrix() || rhst.IsMatrix() {
return false
return Type{}, false
}
return lhst.Equal(&rhst)
if lhst.Equal(&rhst) {
return Type{Main: Bool}, true
}
return Type{}, false
}
if op == Div && rhst.IsMatrix() {
return false
return Type{}, false
}
if op == ModOp {
if lhst.Main == IVec2 && rhst.Main == IVec2 {
return true
return Type{Main: IVec2}, true
}
if lhst.Main == IVec3 && rhst.Main == IVec3 {
return true
return Type{Main: IVec3}, true
}
if lhst.Main == IVec4 && rhst.Main == IVec4 {
return true
return Type{Main: IVec4}, true
}
return (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int
if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
return Type{}, false
}
if op == And || op == Or || op == Xor {
if lhst.Main == Int && rhst.Main == Int {
return true
return Type{Main: Int}, true
}
if lhst.Main == IVec2 && rhst.Main == IVec2 {
return true
return Type{Main: IVec2}, true
}
if lhst.Main == IVec3 && rhst.Main == IVec3 {
return true
return Type{Main: IVec3}, true
}
if lhst.Main == IVec4 && rhst.Main == IVec4 {
return true
return Type{Main: IVec4}, true
}
if lhst.IsIntVector() && rhst.Main == Int {
return true
return lhst, true
}
if lhst.Main == Int && rhst.IsIntVector() {
return true
return rhst, true
}
return false
return Type{}, false
}
if lhst.Equal(&rhst) {
return true
if lhst.Main == None {
return rhst, true
}
return lhst, true
}
if op == MatrixMul {
if lhst.IsMatrix() && rhst.Main == Float {
return true
return lhst, true
}
if lhst.Main == Mat2 && rhst.Main == Vec2 {
return true
return rhst, true
}
if lhst.Main == Mat3 && rhst.Main == Vec3 {
return true
return rhst, true
}
if lhst.Main == Mat4 && rhst.Main == Vec4 {
return true
return rhst, true
}
if lhst.Main == Float && rhst.IsMatrix() {
return true
return rhst, true
}
if lhst.Main == Vec2 && rhst.Main == Mat2 {
return true
return lhst, true
}
if lhst.Main == Vec3 && rhst.Main == Mat3 {
return true
return lhst, true
}
if lhst.Main == Vec4 && rhst.Main == Mat4 {
return true
return lhst, true
}
return false
return Type{}, false
}
if op == Div {
if lhst.IsMatrix() && rhst.Main == Float {
return true
return lhst, true
}
// fallback
}
if lhst.IsFloatVector() && rhst.Main == Float {
return true
return lhst, true
}
if rhst.IsFloatVector() && lhst.Main == Float {
return true
if lhst.Main == Float && rhst.IsFloatVector() {
return rhst, true
}
if lhst.IsIntVector() && rhst.Main == Int {
return true
return lhst, true
}
if rhst.IsIntVector() && lhst.Main == Int {
return true
if lhst.Main == Int && rhst.IsIntVector() {
return rhst, true
}
return false
return Type{}, false
}