internal/shader: refactoring: integrate type checks to shaderir.AreValidTypesForBinaryOp

This commit is contained in:
Hajime Hoshi 2023-07-25 01:53:08 +09:00
parent b743b7ab50
commit 6b94de4ef6
2 changed files with 58 additions and 109 deletions

View File

@ -169,38 +169,29 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
}
}
// TODO: Integrate AreValidTypesForBinaryOp calls to here.
if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) {
// TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
return nil, nil, nil, false
}
if lhs[0].Const != nil && rhs[0].Const != nil {
if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) {
// TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
return nil, nil, nil, false
}
var v gconstant.Value
var t shaderir.Type
switch op {
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ, token.LAND, token.LOR:
switch op {
case token.LAND, token.LOR:
b := gconstant.BoolVal(gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const))
v = gconstant.MakeBool(b)
default:
v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const))
case token.LAND, token.LOR:
b := gconstant.BoolVal(gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const))
v = gconstant.MakeBool(b)
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
t = shaderir.Type{Main: shaderir.Bool}
}
t = shaderir.Type{Main: shaderir.Bool}
case token.REM:
if !cs.forceToInt(e, &lhs[0]) {
return nil, nil, nil, false
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ:
v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const))
if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
t = shaderir.Type{Main: shaderir.Bool}
}
if !cs.forceToInt(e, &rhs[0]) {
return nil, nil, nil, false
}
fallthrough
default:
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
switch {
case lhst.Main == shaderir.Float || rhst.Main == shaderir.Float:
t = shaderir.Type{Main: shaderir.Float}
@ -225,100 +216,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
var t shaderir.Type
switch {
case op2 == shaderir.LessThanOp || op2 == shaderir.LessThanEqualOp || op2 == shaderir.GreaterThanOp || op2 == shaderir.GreaterThanEqualOp || op2 == shaderir.EqualOp || op2 == shaderir.NotEqualOp || op2 == shaderir.VectorEqualOp || op2 == shaderir.VectorNotEqualOp || op2 == shaderir.AndAnd || op2 == shaderir.OrOr:
if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
t = shaderir.Type{Main: shaderir.Bool}
case lhs[0].Const != nil && rhs[0].Const == nil:
switch rhst.Main {
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op2 != shaderir.MatrixMul {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
fallthrough
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
if lhst.Main != shaderir.Float {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
case shaderir.IVec2, shaderir.IVec3, shaderir.IVec4:
if lhst.Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
case shaderir.Int:
if lhst.Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", lhs[0].Const.String()))
return nil, nil, nil, false
}
lhs[0].ConstType = shaderir.ConstTypeInt
}
t = rhst
case lhs[0].Const == nil && rhs[0].Const != nil:
switch lhst.Main {
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op2 != shaderir.MatrixMul && op2 != shaderir.Div {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
fallthrough
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
if rhst.Main != shaderir.Float {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
case shaderir.IVec2, shaderir.IVec3, shaderir.IVec4:
if rhst.Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
if !cs.forceToInt(e, &rhs[0]) {
return nil, nil, nil, false
}
case shaderir.Int:
if rhst.Main != shaderir.Int {
cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", rhs[0].Const.String()))
return nil, nil, nil, false
}
rhs[0].ConstType = shaderir.ConstTypeInt
}
t = lhst
case lhst.Equal(&rhst):
if op2 == shaderir.Div && (rhst.Main == shaderir.Mat2 || rhst.Main == shaderir.Mat3 || rhst.Main == shaderir.Mat4) {
cs.addError(e.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", e.Op, rhst.String()))
return nil, nil, nil, false
}
t = lhst
case lhst.Main == shaderir.Float:
switch rhst.Main {
case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
t = rhst
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op2 != shaderir.MatrixMul {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
t = rhst
default:
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
t = rhst
case rhst.Main == shaderir.Float:
switch lhst.Main {
case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4:
t = lhst
case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op2 != shaderir.MatrixMul && op2 != shaderir.Div {
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
t = lhst
default:
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String()))
return nil, nil, nil, false
}
t = lhst
case op2 == shaderir.MatrixMul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 ||
lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2):
t = shaderir.Type{Main: shaderir.Vec2}

View File

@ -84,6 +84,10 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
}
}
if op == Div && rhst.IsMatrix() {
return false
}
// If both are untyped consts, compare the constants and try to truncate them if necessary.
if lhst.Main == None && rhst.Main == None {
// Assume that the constant types are already adjusted.
@ -103,5 +107,42 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool {
panic("shaderir: cannot resolve untyped values")
}
return lhst.Equal(&rhst)
if lhst.Equal(&rhst) {
return true
}
if op == MatrixMul {
if lhst.IsMatrix() && (rhst.isFloatVector() || rhst.Main == Float) {
// TODO: Check dimensions
return true
}
if rhst.IsMatrix() && (lhst.isFloatVector() || lhst.Main == Float) {
// TODO: Check dimensions
return true
}
return false
}
if op == Div {
if lhst.IsMatrix() && (rhst.isFloatVector() || rhst.Main == Float) {
// TODO: Check dimensions
return true
}
// fallback
}
if lhst.isFloatVector() && rhst.Main == Float {
return true
}
if rhst.isFloatVector() && lhst.Main == Float {
return true
}
if lhst.isIntVector() && rhst.Main == Int {
return true
}
if rhst.isIntVector() && lhst.Main == Int {
return true
}
return false
}