| from sympy.core.relational import Eq | |
| from sympy.core.expr import Expr | |
| from sympy.core.numbers import Integer | |
| from sympy.logic.boolalg import Boolean, And | |
| from sympy.matrices.expressions.matexpr import MatrixExpr | |
| from sympy.matrices.exceptions import ShapeError | |
| from typing import Union | |
| def is_matadd_valid(*args: MatrixExpr) -> Boolean: | |
| """Return the symbolic condition how ``MatAdd``, ``HadamardProduct`` | |
| makes sense. | |
| Parameters | |
| ========== | |
| args | |
| The list of arguments of matrices to be tested for. | |
| Examples | |
| ======== | |
| >>> from sympy import MatrixSymbol, symbols | |
| >>> from sympy.matrices.expressions._shape import is_matadd_valid | |
| >>> m, n, p, q = symbols('m n p q') | |
| >>> A = MatrixSymbol('A', m, n) | |
| >>> B = MatrixSymbol('B', p, q) | |
| >>> is_matadd_valid(A, B) | |
| Eq(m, p) & Eq(n, q) | |
| """ | |
| rows, cols = zip(*(arg.shape for arg in args)) | |
| return And( | |
| *(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])), | |
| *(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])), | |
| ) | |
| def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean: | |
| """Return the symbolic condition how ``MatMul`` makes sense | |
| Parameters | |
| ========== | |
| args | |
| The list of arguments of matrices and scalar expressions to be tested | |
| for. | |
| Examples | |
| ======== | |
| >>> from sympy import MatrixSymbol, symbols | |
| >>> from sympy.matrices.expressions._shape import is_matmul_valid | |
| >>> m, n, p, q = symbols('m n p q') | |
| >>> A = MatrixSymbol('A', m, n) | |
| >>> B = MatrixSymbol('B', p, q) | |
| >>> is_matmul_valid(A, B) | |
| Eq(n, p) | |
| """ | |
| rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr))) | |
| return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:]))) | |
| def is_square(arg: MatrixExpr, /) -> Boolean: | |
| """Return the symbolic condition how the matrix is assumed to be square | |
| Parameters | |
| ========== | |
| arg | |
| The matrix to be tested for. | |
| Examples | |
| ======== | |
| >>> from sympy import MatrixSymbol, symbols | |
| >>> from sympy.matrices.expressions._shape import is_square | |
| >>> m, n = symbols('m n') | |
| >>> A = MatrixSymbol('A', m, n) | |
| >>> is_square(A) | |
| Eq(m, n) | |
| """ | |
| return Eq(arg.rows, arg.cols) | |
| def validate_matadd_integer(*args: MatrixExpr) -> None: | |
| """Validate matrix shape for addition only for integer values""" | |
| rows, cols = zip(*(x.shape for x in args)) | |
| if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1: | |
| raise ShapeError(f"Matrices have mismatching shape: {rows}") | |
| if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1: | |
| raise ShapeError(f"Matrices have mismatching shape: {cols}") | |
| def validate_matmul_integer(*args: MatrixExpr) -> None: | |
| """Validate matrix shape for multiplication only for integer values""" | |
| for A, B in zip(args[:-1], args[1:]): | |
| i, j = A.cols, B.rows | |
| if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j: | |
| raise ShapeError("Matrices are not aligned", i, j) | |