From 6b94de4ef69b36d4d3cbf2e2b00d5122785961e6 Mon Sep 17 00:00:00 2001 From: Hajime Hoshi Date: Tue, 25 Jul 2023 01:53:08 +0900 Subject: [PATCH] internal/shader: refactoring: integrate type checks to shaderir.AreValidTypesForBinaryOp --- internal/shader/expr.go | 124 +++++-------------------------------- internal/shaderir/check.go | 43 ++++++++++++- 2 files changed, 58 insertions(+), 109 deletions(-) diff --git a/internal/shader/expr.go b/internal/shader/expr.go index a72b63520..c55a52810 100644 --- a/internal/shader/expr.go +++ b/internal/shader/expr.go @@ -169,38 +169,29 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar } } - // TODO: Integrate AreValidTypesForBinaryOp calls to here. + if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) { + // TODO: Show a better type name for untyped constants. + cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) + return nil, nil, nil, false + } if lhs[0].Const != nil && rhs[0].Const != nil { - if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) { - // TODO: Show a better type name for untyped constants. - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) - return nil, nil, nil, false - } - var v gconstant.Value var t shaderir.Type switch op { - case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ, token.LAND, token.LOR: - switch op { - case token.LAND, token.LOR: - b := gconstant.BoolVal(gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)) - v = gconstant.MakeBool(b) - default: - v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) + case token.LAND, token.LOR: + b := gconstant.BoolVal(gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)) + v = gconstant.MakeBool(b) + if lhst.Main != shaderir.None || rhst.Main != shaderir.None { + t = shaderir.Type{Main: shaderir.Bool} } - t = shaderir.Type{Main: shaderir.Bool} - case token.REM: - if !cs.forceToInt(e, &lhs[0]) { - return nil, nil, nil, false + case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ: + v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) + if lhst.Main != shaderir.None || rhst.Main != shaderir.None { + t = shaderir.Type{Main: shaderir.Bool} } - if !cs.forceToInt(e, &rhs[0]) { - return nil, nil, nil, false - } - fallthrough default: v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) - switch { case lhst.Main == shaderir.Float || rhst.Main == shaderir.Float: t = shaderir.Type{Main: shaderir.Float} @@ -225,100 +216,17 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar var t shaderir.Type switch { case op2 == shaderir.LessThanOp || op2 == shaderir.LessThanEqualOp || op2 == shaderir.GreaterThanOp || op2 == shaderir.GreaterThanEqualOp || op2 == shaderir.EqualOp || op2 == shaderir.NotEqualOp || op2 == shaderir.VectorEqualOp || op2 == shaderir.VectorNotEqualOp || op2 == shaderir.AndAnd || op2 == shaderir.OrOr: - if !shaderir.AreValidTypesForBinaryOp(op2, &lhs[0], &rhs[0], lhst, rhst) { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } t = shaderir.Type{Main: shaderir.Bool} case lhs[0].Const != nil && rhs[0].Const == nil: - switch rhst.Main { - case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op2 != shaderir.MatrixMul { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - fallthrough - case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: - if lhst.Main != shaderir.Float { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - case shaderir.IVec2, shaderir.IVec3, shaderir.IVec4: - if lhst.Main != shaderir.Int { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - case shaderir.Int: - if lhst.Main != shaderir.Int { - cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", lhs[0].Const.String())) - return nil, nil, nil, false - } - lhs[0].ConstType = shaderir.ConstTypeInt - } t = rhst case lhs[0].Const == nil && rhs[0].Const != nil: - switch lhst.Main { - case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op2 != shaderir.MatrixMul && op2 != shaderir.Div { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - fallthrough - case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: - if rhst.Main != shaderir.Float { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - case shaderir.IVec2, shaderir.IVec3, shaderir.IVec4: - if rhst.Main != shaderir.Int { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - if !cs.forceToInt(e, &rhs[0]) { - return nil, nil, nil, false - } - case shaderir.Int: - if rhst.Main != shaderir.Int { - cs.addError(e.Pos(), fmt.Sprintf("constant %s truncated to integer", rhs[0].Const.String())) - return nil, nil, nil, false - } - rhs[0].ConstType = shaderir.ConstTypeInt - } t = lhst case lhst.Equal(&rhst): - if op2 == shaderir.Div && (rhst.Main == shaderir.Mat2 || rhst.Main == shaderir.Mat3 || rhst.Main == shaderir.Mat4) { - cs.addError(e.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", e.Op, rhst.String())) - return nil, nil, nil, false - } t = lhst case lhst.Main == shaderir.Float: - switch rhst.Main { - case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: - t = rhst - case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op2 != shaderir.MatrixMul { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - t = rhst - default: - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } + t = rhst case rhst.Main == shaderir.Float: - switch lhst.Main { - case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4: - t = lhst - case shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: - if op2 != shaderir.MatrixMul && op2 != shaderir.Div { - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } - t = lhst - default: - cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), e.Op, rhst.String())) - return nil, nil, nil, false - } + t = lhst case op2 == shaderir.MatrixMul && (lhst.Main == shaderir.Vec2 && rhst.Main == shaderir.Mat2 || lhst.Main == shaderir.Mat2 && rhst.Main == shaderir.Vec2): t = shaderir.Type{Main: shaderir.Vec2} diff --git a/internal/shaderir/check.go b/internal/shaderir/check.go index 1772749f5..66c4d4f60 100644 --- a/internal/shaderir/check.go +++ b/internal/shaderir/check.go @@ -84,6 +84,10 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { } } + if op == Div && rhst.IsMatrix() { + return false + } + // If both are untyped consts, compare the constants and try to truncate them if necessary. if lhst.Main == None && rhst.Main == None { // Assume that the constant types are already adjusted. @@ -103,5 +107,42 @@ func AreValidTypesForBinaryOp(op Op, lhs, rhs *Expr, lhst, rhst Type) bool { panic("shaderir: cannot resolve untyped values") } - return lhst.Equal(&rhst) + if lhst.Equal(&rhst) { + return true + } + + if op == MatrixMul { + if lhst.IsMatrix() && (rhst.isFloatVector() || rhst.Main == Float) { + // TODO: Check dimensions + return true + } + if rhst.IsMatrix() && (lhst.isFloatVector() || lhst.Main == Float) { + // TODO: Check dimensions + return true + } + return false + } + + if op == Div { + if lhst.IsMatrix() && (rhst.isFloatVector() || rhst.Main == Float) { + // TODO: Check dimensions + return true + } + // fallback + } + + if lhst.isFloatVector() && rhst.Main == Float { + return true + } + if rhst.isFloatVector() && lhst.Main == Float { + return true + } + if lhst.isIntVector() && rhst.Main == Int { + return true + } + if rhst.isIntVector() && lhst.Main == Int { + return true + } + + return false }