koichi12 commited on
Commit
ad4135d
·
verified ·
1 Parent(s): d1d6563

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/sympy/tensor/__init__.py +23 -0
  5. .venv/lib/python3.11/site-packages/sympy/tensor/array/__init__.py +271 -0
  6. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/sympy/tensor/array/array_comprehension.py +399 -0
  15. .venv/lib/python3.11/site-packages/sympy/tensor/array/array_derivatives.py +129 -0
  16. .venv/lib/python3.11/site-packages/sympy/tensor/array/arrayop.py +528 -0
  17. .venv/lib/python3.11/site-packages/sympy/tensor/array/dense_ndim_array.py +206 -0
  18. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__init__.py +178 -0
  19. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/array_expressions.py +1967 -0
  31. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py +194 -0
  32. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py +12 -0
  33. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py +6 -0
  34. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py +4 -0
  35. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py +4 -0
  36. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py +84 -0
  37. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py +1003 -0
  38. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py +257 -0
  39. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py +87 -0
  40. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__init__.py +0 -0
  41. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py +808 -0
.gitattributes CHANGED
@@ -437,3 +437,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
437
  .venv/lib/python3.11/site-packages/transformers/models/wav2vec2/__pycache__/modeling_tf_wav2vec2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
438
  .venv/lib/python3.11/site-packages/transformers/models/wav2vec2_conformer/__pycache__/modeling_wav2vec2_conformer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
439
  .venv/lib/python3.11/site-packages/transformers/models/whisper/__pycache__/modeling_whisper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
437
  .venv/lib/python3.11/site-packages/transformers/models/wav2vec2/__pycache__/modeling_tf_wav2vec2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
438
  .venv/lib/python3.11/site-packages/transformers/models/wav2vec2_conformer/__pycache__/modeling_wav2vec2_conformer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
439
  .venv/lib/python3.11/site-packages/transformers/models/whisper/__pycache__/modeling_whisper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
440
+ .venv/lib/python3.11/site-packages/transformers/utils/__pycache__/dummy_tf_objects.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
441
+ .venv/lib/python3.11/site-packages/transformers/utils/__pycache__/dummy_pt_objects.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
.venv/lib/python3.11/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-311.pyc ADDED
Binary file (11.3 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A module to manipulate symbolic objects with indices including tensors
2
+
3
+ """
4
+ from .indexed import IndexedBase, Idx, Indexed
5
+ from .index_methods import get_contraction_structure, get_indices
6
+ from .functions import shape
7
+ from .array import (MutableDenseNDimArray, ImmutableDenseNDimArray,
8
+ MutableSparseNDimArray, ImmutableSparseNDimArray, NDimArray, tensorproduct,
9
+ tensorcontraction, tensordiagonal, derive_by_array, permutedims, Array,
10
+ DenseNDimArray, SparseNDimArray,)
11
+
12
+ __all__ = [
13
+ 'IndexedBase', 'Idx', 'Indexed',
14
+
15
+ 'get_contraction_structure', 'get_indices',
16
+
17
+ 'shape',
18
+
19
+ 'MutableDenseNDimArray', 'ImmutableDenseNDimArray',
20
+ 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'NDimArray',
21
+ 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array', 'permutedims',
22
+ 'Array', 'DenseNDimArray', 'SparseNDimArray',
23
+ ]
.venv/lib/python3.11/site-packages/sympy/tensor/array/__init__.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ N-dim array module for SymPy.
3
+
4
+ Four classes are provided to handle N-dim arrays, given by the combinations
5
+ dense/sparse (i.e. whether to store all elements or only the non-zero ones in
6
+ memory) and mutable/immutable (immutable classes are SymPy objects, but cannot
7
+ change after they have been created).
8
+
9
+ Examples
10
+ ========
11
+
12
+ The following examples show the usage of ``Array``. This is an abbreviation for
13
+ ``ImmutableDenseNDimArray``, that is an immutable and dense N-dim array, the
14
+ other classes are analogous. For mutable classes it is also possible to change
15
+ element values after the object has been constructed.
16
+
17
+ Array construction can detect the shape of nested lists and tuples:
18
+
19
+ >>> from sympy import Array
20
+ >>> a1 = Array([[1, 2], [3, 4], [5, 6]])
21
+ >>> a1
22
+ [[1, 2], [3, 4], [5, 6]]
23
+ >>> a1.shape
24
+ (3, 2)
25
+ >>> a1.rank()
26
+ 2
27
+ >>> from sympy.abc import x, y, z
28
+ >>> a2 = Array([[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]])
29
+ >>> a2
30
+ [[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]]
31
+ >>> a2.shape
32
+ (2, 2, 2)
33
+ >>> a2.rank()
34
+ 3
35
+
36
+ Otherwise one could pass a 1-dim array followed by a shape tuple:
37
+
38
+ >>> m1 = Array(range(12), (3, 4))
39
+ >>> m1
40
+ [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
41
+ >>> m2 = Array(range(12), (3, 2, 2))
42
+ >>> m2
43
+ [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]]
44
+ >>> m2[1,1,1]
45
+ 7
46
+ >>> m2.reshape(4, 3)
47
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
48
+
49
+ Slice support:
50
+
51
+ >>> m2[:, 1, 1]
52
+ [3, 7, 11]
53
+
54
+ Elementwise derivative:
55
+
56
+ >>> from sympy.abc import x, y, z
57
+ >>> m3 = Array([x**3, x*y, z])
58
+ >>> m3.diff(x)
59
+ [3*x**2, y, 0]
60
+ >>> m3.diff(z)
61
+ [0, 0, 1]
62
+
63
+ Multiplication with other SymPy expressions is applied elementwisely:
64
+
65
+ >>> (1+x)*m3
66
+ [x**3*(x + 1), x*y*(x + 1), z*(x + 1)]
67
+
68
+ To apply a function to each element of the N-dim array, use ``applyfunc``:
69
+
70
+ >>> m3.applyfunc(lambda x: x/2)
71
+ [x**3/2, x*y/2, z/2]
72
+
73
+ N-dim arrays can be converted to nested lists by the ``tolist()`` method:
74
+
75
+ >>> m2.tolist()
76
+ [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]]
77
+ >>> isinstance(m2.tolist(), list)
78
+ True
79
+
80
+ If the rank is 2, it is possible to convert them to matrices with ``tomatrix()``:
81
+
82
+ >>> m1.tomatrix()
83
+ Matrix([
84
+ [0, 1, 2, 3],
85
+ [4, 5, 6, 7],
86
+ [8, 9, 10, 11]])
87
+
88
+ Products and contractions
89
+ -------------------------
90
+
91
+ Tensor product between arrays `A_{i_1,\ldots,i_n}` and `B_{j_1,\ldots,j_m}`
92
+ creates the combined array `P = A \otimes B` defined as
93
+
94
+ `P_{i_1,\ldots,i_n,j_1,\ldots,j_m} := A_{i_1,\ldots,i_n}\cdot B_{j_1,\ldots,j_m}.`
95
+
96
+ It is available through ``tensorproduct(...)``:
97
+
98
+ >>> from sympy import Array, tensorproduct
99
+ >>> from sympy.abc import x,y,z,t
100
+ >>> A = Array([x, y, z, t])
101
+ >>> B = Array([1, 2, 3, 4])
102
+ >>> tensorproduct(A, B)
103
+ [[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]]
104
+
105
+ In case you don't want to evaluate the tensor product immediately, you can use
106
+ ``ArrayTensorProduct``, which creates an unevaluated tensor product expression:
107
+
108
+ >>> from sympy.tensor.array.expressions import ArrayTensorProduct
109
+ >>> ArrayTensorProduct(A, B)
110
+ ArrayTensorProduct([x, y, z, t], [1, 2, 3, 4])
111
+
112
+ Calling ``.as_explicit()`` on ``ArrayTensorProduct`` is equivalent to just calling
113
+ ``tensorproduct(...)``:
114
+
115
+ >>> ArrayTensorProduct(A, B).as_explicit()
116
+ [[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]]
117
+
118
+ Tensor product between a rank-1 array and a matrix creates a rank-3 array:
119
+
120
+ >>> from sympy import eye
121
+ >>> p1 = tensorproduct(A, eye(4))
122
+ >>> p1
123
+ [[[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]], [[y, 0, 0, 0], [0, y, 0, 0], [0, 0, y, 0], [0, 0, 0, y]], [[z, 0, 0, 0], [0, z, 0, 0], [0, 0, z, 0], [0, 0, 0, z]], [[t, 0, 0, 0], [0, t, 0, 0], [0, 0, t, 0], [0, 0, 0, t]]]
124
+
125
+ Now, to get back `A_0 \otimes \mathbf{1}` one can access `p_{0,m,n}` by slicing:
126
+
127
+ >>> p1[0,:,:]
128
+ [[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]]
129
+
130
+ Tensor contraction sums over the specified axes, for example contracting
131
+ positions `a` and `b` means
132
+
133
+ `A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies \sum_k A_{i_1,\ldots,k,\ldots,k,\ldots,i_n}`
134
+
135
+ Remember that Python indexing is zero starting, to contract the a-th and b-th
136
+ axes it is therefore necessary to specify `a-1` and `b-1`
137
+
138
+ >>> from sympy import tensorcontraction
139
+ >>> C = Array([[x, y], [z, t]])
140
+
141
+ The matrix trace is equivalent to the contraction of a rank-2 array:
142
+
143
+ `A_{m,n} \implies \sum_k A_{k,k}`
144
+
145
+ >>> tensorcontraction(C, (0, 1))
146
+ t + x
147
+
148
+ To create an expression representing a tensor contraction that does not get
149
+ evaluated immediately, use ``ArrayContraction``, which is equivalent to
150
+ ``tensorcontraction(...)`` if it is followed by ``.as_explicit()``:
151
+
152
+ >>> from sympy.tensor.array.expressions import ArrayContraction
153
+ >>> ArrayContraction(C, (0, 1))
154
+ ArrayContraction([[x, y], [z, t]], (0, 1))
155
+ >>> ArrayContraction(C, (0, 1)).as_explicit()
156
+ t + x
157
+
158
+ Matrix product is equivalent to a tensor product of two rank-2 arrays, followed
159
+ by a contraction of the 2nd and 3rd axes (in Python indexing axes number 1, 2).
160
+
161
+ `A_{m,n}\cdot B_{i,j} \implies \sum_k A_{m, k}\cdot B_{k, j}`
162
+
163
+ >>> D = Array([[2, 1], [0, -1]])
164
+ >>> tensorcontraction(tensorproduct(C, D), (1, 2))
165
+ [[2*x, x - y], [2*z, -t + z]]
166
+
167
+ One may verify that the matrix product is equivalent:
168
+
169
+ >>> from sympy import Matrix
170
+ >>> Matrix([[x, y], [z, t]])*Matrix([[2, 1], [0, -1]])
171
+ Matrix([
172
+ [2*x, x - y],
173
+ [2*z, -t + z]])
174
+
175
+ or equivalently
176
+
177
+ >>> C.tomatrix()*D.tomatrix()
178
+ Matrix([
179
+ [2*x, x - y],
180
+ [2*z, -t + z]])
181
+
182
+ Diagonal operator
183
+ -----------------
184
+
185
+ The ``tensordiagonal`` function acts in a similar manner as ``tensorcontraction``,
186
+ but the joined indices are not summed over, for example diagonalizing
187
+ positions `a` and `b` means
188
+
189
+ `A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies A_{i_1,\ldots,k,\ldots,k,\ldots,i_n}
190
+ \implies \tilde{A}_{i_1,\ldots,i_{a-1},i_{a+1},\ldots,i_{b-1},i_{b+1},\ldots,i_n,k}`
191
+
192
+ where `\tilde{A}` is the array equivalent to the diagonal of `A` at positions
193
+ `a` and `b` moved to the last index slot.
194
+
195
+ Compare the difference between contraction and diagonal operators:
196
+
197
+ >>> from sympy import tensordiagonal
198
+ >>> from sympy.abc import a, b, c, d
199
+ >>> m = Matrix([[a, b], [c, d]])
200
+ >>> tensorcontraction(m, [0, 1])
201
+ a + d
202
+ >>> tensordiagonal(m, [0, 1])
203
+ [a, d]
204
+
205
+ In short, no summation occurs with ``tensordiagonal``.
206
+
207
+
208
+ Derivatives by array
209
+ --------------------
210
+
211
+ The usual derivative operation may be extended to support derivation with
212
+ respect to arrays, provided that all elements in the that array are symbols or
213
+ expressions suitable for derivations.
214
+
215
+ The definition of a derivative by an array is as follows: given the array
216
+ `A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}`
217
+ the derivative of arrays will return a new array `B` defined by
218
+
219
+ `B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}`
220
+
221
+ The function ``derive_by_array`` performs such an operation:
222
+
223
+ >>> from sympy import derive_by_array
224
+ >>> from sympy.abc import x, y, z, t
225
+ >>> from sympy import sin, exp
226
+
227
+ With scalars, it behaves exactly as the ordinary derivative:
228
+
229
+ >>> derive_by_array(sin(x*y), x)
230
+ y*cos(x*y)
231
+
232
+ Scalar derived by an array basis:
233
+
234
+ >>> derive_by_array(sin(x*y), [x, y, z])
235
+ [y*cos(x*y), x*cos(x*y), 0]
236
+
237
+ Deriving array by an array basis: `B^{nm} := \frac{\partial A^m}{\partial x^n}`
238
+
239
+ >>> basis = [x, y, z]
240
+ >>> ax = derive_by_array([exp(x), sin(y*z), t], basis)
241
+ >>> ax
242
+ [[exp(x), 0, 0], [0, z*cos(y*z), 0], [0, y*cos(y*z), 0]]
243
+
244
+ Contraction of the resulting array: `\sum_m \frac{\partial A^m}{\partial x^m}`
245
+
246
+ >>> tensorcontraction(ax, (0, 1))
247
+ z*cos(y*z) + exp(x)
248
+
249
+ """
250
+
251
+ from .dense_ndim_array import MutableDenseNDimArray, ImmutableDenseNDimArray, DenseNDimArray
252
+ from .sparse_ndim_array import MutableSparseNDimArray, ImmutableSparseNDimArray, SparseNDimArray
253
+ from .ndim_array import NDimArray, ArrayKind
254
+ from .arrayop import tensorproduct, tensorcontraction, tensordiagonal, derive_by_array, permutedims
255
+ from .array_comprehension import ArrayComprehension, ArrayComprehensionMap
256
+
257
+ Array = ImmutableDenseNDimArray
258
+
259
+ __all__ = [
260
+ 'MutableDenseNDimArray', 'ImmutableDenseNDimArray', 'DenseNDimArray',
261
+
262
+ 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'SparseNDimArray',
263
+
264
+ 'NDimArray', 'ArrayKind',
265
+
266
+ 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array',
267
+
268
+ 'permutedims', 'ArrayComprehension', 'ArrayComprehensionMap',
269
+
270
+ 'Array',
271
+ ]
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (8.48 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-311.pyc ADDED
Binary file (8.06 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-311.pyc ADDED
Binary file (1.08 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-311.pyc ADDED
Binary file (34 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/array_comprehension.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools, itertools
2
+ from sympy.core.sympify import _sympify, sympify
3
+ from sympy.core.expr import Expr
4
+ from sympy.core import Basic, Tuple
5
+ from sympy.tensor.array import ImmutableDenseNDimArray
6
+ from sympy.core.symbol import Symbol
7
+ from sympy.core.numbers import Integer
8
+
9
+
10
+ class ArrayComprehension(Basic):
11
+ """
12
+ Generate a list comprehension.
13
+
14
+ Explanation
15
+ ===========
16
+
17
+ If there is a symbolic dimension, for example, say [i for i in range(1, N)] where
18
+ N is a Symbol, then the expression will not be expanded to an array. Otherwise,
19
+ calling the doit() function will launch the expansion.
20
+
21
+ Examples
22
+ ========
23
+
24
+ >>> from sympy.tensor.array import ArrayComprehension
25
+ >>> from sympy import symbols
26
+ >>> i, j, k = symbols('i j k')
27
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
28
+ >>> a
29
+ ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
30
+ >>> a.doit()
31
+ [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
32
+ >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
33
+ >>> b.doit()
34
+ ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
35
+ """
36
+ def __new__(cls, function, *symbols, **assumptions):
37
+ if any(len(l) != 3 or None for l in symbols):
38
+ raise ValueError('ArrayComprehension requires values lower and upper bound'
39
+ ' for the expression')
40
+ arglist = [sympify(function)]
41
+ arglist.extend(cls._check_limits_validity(function, symbols))
42
+ obj = Basic.__new__(cls, *arglist, **assumptions)
43
+ obj._limits = obj._args[1:]
44
+ obj._shape = cls._calculate_shape_from_limits(obj._limits)
45
+ obj._rank = len(obj._shape)
46
+ obj._loop_size = cls._calculate_loop_size(obj._shape)
47
+ return obj
48
+
49
+ @property
50
+ def function(self):
51
+ """The function applied across limits.
52
+
53
+ Examples
54
+ ========
55
+
56
+ >>> from sympy.tensor.array import ArrayComprehension
57
+ >>> from sympy import symbols
58
+ >>> i, j = symbols('i j')
59
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
60
+ >>> a.function
61
+ 10*i + j
62
+ """
63
+ return self._args[0]
64
+
65
+ @property
66
+ def limits(self):
67
+ """
68
+ The list of limits that will be applied while expanding the array.
69
+
70
+ Examples
71
+ ========
72
+
73
+ >>> from sympy.tensor.array import ArrayComprehension
74
+ >>> from sympy import symbols
75
+ >>> i, j = symbols('i j')
76
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
77
+ >>> a.limits
78
+ ((i, 1, 4), (j, 1, 3))
79
+ """
80
+ return self._limits
81
+
82
+ @property
83
+ def free_symbols(self):
84
+ """
85
+ The set of the free_symbols in the array.
86
+ Variables appeared in the bounds are supposed to be excluded
87
+ from the free symbol set.
88
+
89
+ Examples
90
+ ========
91
+
92
+ >>> from sympy.tensor.array import ArrayComprehension
93
+ >>> from sympy import symbols
94
+ >>> i, j, k = symbols('i j k')
95
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
96
+ >>> a.free_symbols
97
+ set()
98
+ >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
99
+ >>> b.free_symbols
100
+ {k}
101
+ """
102
+ expr_free_sym = self.function.free_symbols
103
+ for var, inf, sup in self._limits:
104
+ expr_free_sym.discard(var)
105
+ curr_free_syms = inf.free_symbols.union(sup.free_symbols)
106
+ expr_free_sym = expr_free_sym.union(curr_free_syms)
107
+ return expr_free_sym
108
+
109
+ @property
110
+ def variables(self):
111
+ """The tuples of the variables in the limits.
112
+
113
+ Examples
114
+ ========
115
+
116
+ >>> from sympy.tensor.array import ArrayComprehension
117
+ >>> from sympy import symbols
118
+ >>> i, j, k = symbols('i j k')
119
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
120
+ >>> a.variables
121
+ [i, j]
122
+ """
123
+ return [l[0] for l in self._limits]
124
+
125
+ @property
126
+ def bound_symbols(self):
127
+ """The list of dummy variables.
128
+
129
+ Note
130
+ ====
131
+
132
+ Note that all variables are dummy variables since a limit without
133
+ lower bound or upper bound is not accepted.
134
+ """
135
+ return [l[0] for l in self._limits if len(l) != 1]
136
+
137
+ @property
138
+ def shape(self):
139
+ """
140
+ The shape of the expanded array, which may have symbols.
141
+
142
+ Note
143
+ ====
144
+
145
+ Both the lower and the upper bounds are included while
146
+ calculating the shape.
147
+
148
+ Examples
149
+ ========
150
+
151
+ >>> from sympy.tensor.array import ArrayComprehension
152
+ >>> from sympy import symbols
153
+ >>> i, j, k = symbols('i j k')
154
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
155
+ >>> a.shape
156
+ (4, 3)
157
+ >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
158
+ >>> b.shape
159
+ (4, k + 3)
160
+ """
161
+ return self._shape
162
+
163
+ @property
164
+ def is_shape_numeric(self):
165
+ """
166
+ Test if the array is shape-numeric which means there is no symbolic
167
+ dimension.
168
+
169
+ Examples
170
+ ========
171
+
172
+ >>> from sympy.tensor.array import ArrayComprehension
173
+ >>> from sympy import symbols
174
+ >>> i, j, k = symbols('i j k')
175
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
176
+ >>> a.is_shape_numeric
177
+ True
178
+ >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
179
+ >>> b.is_shape_numeric
180
+ False
181
+ """
182
+ for _, inf, sup in self._limits:
183
+ if Basic(inf, sup).atoms(Symbol):
184
+ return False
185
+ return True
186
+
187
+ def rank(self):
188
+ """The rank of the expanded array.
189
+
190
+ Examples
191
+ ========
192
+
193
+ >>> from sympy.tensor.array import ArrayComprehension
194
+ >>> from sympy import symbols
195
+ >>> i, j, k = symbols('i j k')
196
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
197
+ >>> a.rank()
198
+ 2
199
+ """
200
+ return self._rank
201
+
202
+ def __len__(self):
203
+ """
204
+ The length of the expanded array which means the number
205
+ of elements in the array.
206
+
207
+ Raises
208
+ ======
209
+
210
+ ValueError : When the length of the array is symbolic
211
+
212
+ Examples
213
+ ========
214
+
215
+ >>> from sympy.tensor.array import ArrayComprehension
216
+ >>> from sympy import symbols
217
+ >>> i, j = symbols('i j')
218
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
219
+ >>> len(a)
220
+ 12
221
+ """
222
+ if self._loop_size.free_symbols:
223
+ raise ValueError('Symbolic length is not supported')
224
+ return self._loop_size
225
+
226
+ @classmethod
227
+ def _check_limits_validity(cls, function, limits):
228
+ #limits = sympify(limits)
229
+ new_limits = []
230
+ for var, inf, sup in limits:
231
+ var = _sympify(var)
232
+ inf = _sympify(inf)
233
+ #since this is stored as an argument, it should be
234
+ #a Tuple
235
+ if isinstance(sup, list):
236
+ sup = Tuple(*sup)
237
+ else:
238
+ sup = _sympify(sup)
239
+ new_limits.append(Tuple(var, inf, sup))
240
+ if any((not isinstance(i, Expr)) or i.atoms(Symbol, Integer) != i.atoms()
241
+ for i in [inf, sup]):
242
+ raise TypeError('Bounds should be an Expression(combination of Integer and Symbol)')
243
+ if (inf > sup) == True:
244
+ raise ValueError('Lower bound should be inferior to upper bound')
245
+ if var in inf.free_symbols or var in sup.free_symbols:
246
+ raise ValueError('Variable should not be part of its bounds')
247
+ return new_limits
248
+
249
+ @classmethod
250
+ def _calculate_shape_from_limits(cls, limits):
251
+ return tuple([sup - inf + 1 for _, inf, sup in limits])
252
+
253
+ @classmethod
254
+ def _calculate_loop_size(cls, shape):
255
+ if not shape:
256
+ return 0
257
+ loop_size = 1
258
+ for l in shape:
259
+ loop_size = loop_size * l
260
+
261
+ return loop_size
262
+
263
+ def doit(self, **hints):
264
+ if not self.is_shape_numeric:
265
+ return self
266
+
267
+ return self._expand_array()
268
+
269
+ def _expand_array(self):
270
+ res = []
271
+ for values in itertools.product(*[range(inf, sup+1)
272
+ for var, inf, sup
273
+ in self._limits]):
274
+ res.append(self._get_element(values))
275
+
276
+ return ImmutableDenseNDimArray(res, self.shape)
277
+
278
+ def _get_element(self, values):
279
+ temp = self.function
280
+ for var, val in zip(self.variables, values):
281
+ temp = temp.subs(var, val)
282
+ return temp
283
+
284
+ def tolist(self):
285
+ """Transform the expanded array to a list.
286
+
287
+ Raises
288
+ ======
289
+
290
+ ValueError : When there is a symbolic dimension
291
+
292
+ Examples
293
+ ========
294
+
295
+ >>> from sympy.tensor.array import ArrayComprehension
296
+ >>> from sympy import symbols
297
+ >>> i, j = symbols('i j')
298
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
299
+ >>> a.tolist()
300
+ [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
301
+ """
302
+ if self.is_shape_numeric:
303
+ return self._expand_array().tolist()
304
+
305
+ raise ValueError("A symbolic array cannot be expanded to a list")
306
+
307
+ def tomatrix(self):
308
+ """Transform the expanded array to a matrix.
309
+
310
+ Raises
311
+ ======
312
+
313
+ ValueError : When there is a symbolic dimension
314
+ ValueError : When the rank of the expanded array is not equal to 2
315
+
316
+ Examples
317
+ ========
318
+
319
+ >>> from sympy.tensor.array import ArrayComprehension
320
+ >>> from sympy import symbols
321
+ >>> i, j = symbols('i j')
322
+ >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
323
+ >>> a.tomatrix()
324
+ Matrix([
325
+ [11, 12, 13],
326
+ [21, 22, 23],
327
+ [31, 32, 33],
328
+ [41, 42, 43]])
329
+ """
330
+ from sympy.matrices import Matrix
331
+
332
+ if not self.is_shape_numeric:
333
+ raise ValueError("A symbolic array cannot be expanded to a matrix")
334
+ if self._rank != 2:
335
+ raise ValueError('Dimensions must be of size of 2')
336
+
337
+ return Matrix(self._expand_array().tomatrix())
338
+
339
+
340
+ def isLambda(v):
341
+ LAMBDA = lambda: 0
342
+ return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__
343
+
344
+ class ArrayComprehensionMap(ArrayComprehension):
345
+ '''
346
+ A subclass of ArrayComprehension dedicated to map external function lambda.
347
+
348
+ Notes
349
+ =====
350
+
351
+ Only the lambda function is considered.
352
+ At most one argument in lambda function is accepted in order to avoid ambiguity
353
+ in value assignment.
354
+
355
+ Examples
356
+ ========
357
+
358
+ >>> from sympy.tensor.array import ArrayComprehensionMap
359
+ >>> from sympy import symbols
360
+ >>> i, j, k = symbols('i j k')
361
+ >>> a = ArrayComprehensionMap(lambda: 1, (i, 1, 4))
362
+ >>> a.doit()
363
+ [1, 1, 1, 1]
364
+ >>> b = ArrayComprehensionMap(lambda a: a+1, (j, 1, 4))
365
+ >>> b.doit()
366
+ [2, 3, 4, 5]
367
+
368
+ '''
369
+ def __new__(cls, function, *symbols, **assumptions):
370
+ if any(len(l) != 3 or None for l in symbols):
371
+ raise ValueError('ArrayComprehension requires values lower and upper bound'
372
+ ' for the expression')
373
+
374
+ if not isLambda(function):
375
+ raise ValueError('Data type not supported')
376
+
377
+ arglist = cls._check_limits_validity(function, symbols)
378
+ obj = Basic.__new__(cls, *arglist, **assumptions)
379
+ obj._limits = obj._args
380
+ obj._shape = cls._calculate_shape_from_limits(obj._limits)
381
+ obj._rank = len(obj._shape)
382
+ obj._loop_size = cls._calculate_loop_size(obj._shape)
383
+ obj._lambda = function
384
+ return obj
385
+
386
+ @property
387
+ def func(self):
388
+ class _(ArrayComprehensionMap):
389
+ def __new__(cls, *args, **kwargs):
390
+ return ArrayComprehensionMap(self._lambda, *args, **kwargs)
391
+ return _
392
+
393
+ def _get_element(self, values):
394
+ temp = self._lambda
395
+ if self._lambda.__code__.co_argcount == 0:
396
+ temp = temp()
397
+ elif self._lambda.__code__.co_argcount == 1:
398
+ temp = temp(functools.reduce(lambda a, b: a*b, values))
399
+ return temp
.venv/lib/python3.11/site-packages/sympy/tensor/array/array_derivatives.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from sympy.core.expr import Expr
4
+ from sympy.core.function import Derivative
5
+ from sympy.core.numbers import Integer
6
+ from sympy.matrices.matrixbase import MatrixBase
7
+ from .ndim_array import NDimArray
8
+ from .arrayop import derive_by_array
9
+ from sympy.matrices.expressions.matexpr import MatrixExpr
10
+ from sympy.matrices.expressions.special import ZeroMatrix
11
+ from sympy.matrices.expressions.matexpr import _matrix_derivative
12
+
13
+
14
+ class ArrayDerivative(Derivative):
15
+
16
+ is_scalar = False
17
+
18
+ def __new__(cls, expr, *variables, **kwargs):
19
+ obj = super().__new__(cls, expr, *variables, **kwargs)
20
+ if isinstance(obj, ArrayDerivative):
21
+ obj._shape = obj._get_shape()
22
+ return obj
23
+
24
+ def _get_shape(self):
25
+ shape = ()
26
+ for v, count in self.variable_count:
27
+ if hasattr(v, "shape"):
28
+ for i in range(count):
29
+ shape += v.shape
30
+ if hasattr(self.expr, "shape"):
31
+ shape += self.expr.shape
32
+ return shape
33
+
34
+ @property
35
+ def shape(self):
36
+ return self._shape
37
+
38
+ @classmethod
39
+ def _get_zero_with_shape_like(cls, expr):
40
+ if isinstance(expr, (MatrixBase, NDimArray)):
41
+ return expr.zeros(*expr.shape)
42
+ elif isinstance(expr, MatrixExpr):
43
+ return ZeroMatrix(*expr.shape)
44
+ else:
45
+ raise RuntimeError("Unable to determine shape of array-derivative.")
46
+
47
+ @staticmethod
48
+ def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixBase) -> Expr:
49
+ return v.applyfunc(lambda x: expr.diff(x))
50
+
51
+ @staticmethod
52
+ def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr:
53
+ if expr.has(v):
54
+ return _matrix_derivative(expr, v)
55
+ else:
56
+ return ZeroMatrix(*v.shape)
57
+
58
+ @staticmethod
59
+ def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr:
60
+ return v.applyfunc(lambda x: expr.diff(x))
61
+
62
+ @staticmethod
63
+ def _call_derive_matrix_by_scalar(expr: MatrixBase, v: Expr) -> Expr:
64
+ return _matrix_derivative(expr, v)
65
+
66
+ @staticmethod
67
+ def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr:
68
+ return expr._eval_derivative(v)
69
+
70
+ @staticmethod
71
+ def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr:
72
+ return expr.applyfunc(lambda x: x.diff(v))
73
+
74
+ @staticmethod
75
+ def _call_derive_default(expr: Expr, v: Expr) -> Expr | None:
76
+ if expr.has(v):
77
+ return _matrix_derivative(expr, v)
78
+ else:
79
+ return None
80
+
81
+ @classmethod
82
+ def _dispatch_eval_derivative_n_times(cls, expr, v, count):
83
+ # Evaluate the derivative `n` times. If
84
+ # `_eval_derivative_n_times` is not overridden by the current
85
+ # object, the default in `Basic` will call a loop over
86
+ # `_eval_derivative`:
87
+
88
+ if not isinstance(count, (int, Integer)) or ((count <= 0) == True):
89
+ return None
90
+
91
+ # TODO: this could be done with multiple-dispatching:
92
+ if expr.is_scalar:
93
+ if isinstance(v, MatrixBase):
94
+ result = cls._call_derive_scalar_by_matrix(expr, v)
95
+ elif isinstance(v, MatrixExpr):
96
+ result = cls._call_derive_scalar_by_matexpr(expr, v)
97
+ elif isinstance(v, NDimArray):
98
+ result = cls._call_derive_scalar_by_array(expr, v)
99
+ elif v.is_scalar:
100
+ # scalar by scalar has a special
101
+ return super()._dispatch_eval_derivative_n_times(expr, v, count)
102
+ else:
103
+ return None
104
+ elif v.is_scalar:
105
+ if isinstance(expr, MatrixBase):
106
+ result = cls._call_derive_matrix_by_scalar(expr, v)
107
+ elif isinstance(expr, MatrixExpr):
108
+ result = cls._call_derive_matexpr_by_scalar(expr, v)
109
+ elif isinstance(expr, NDimArray):
110
+ result = cls._call_derive_array_by_scalar(expr, v)
111
+ else:
112
+ return None
113
+ else:
114
+ # Both `expr` and `v` are some array/matrix type:
115
+ if isinstance(expr, MatrixBase) or isinstance(v, MatrixBase):
116
+ result = derive_by_array(expr, v)
117
+ elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr):
118
+ result = cls._call_derive_default(expr, v)
119
+ elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr):
120
+ # if one expression is a symbolic matrix expression while the other isn't, don't evaluate:
121
+ return None
122
+ else:
123
+ result = derive_by_array(expr, v)
124
+ if result is None:
125
+ return None
126
+ if count == 1:
127
+ return result
128
+ else:
129
+ return cls._dispatch_eval_derivative_n_times(result, v, count - 1)
.venv/lib/python3.11/site-packages/sympy/tensor/array/arrayop.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from collections.abc import Iterable
3
+
4
+ from sympy.core._print_helpers import Printable
5
+ from sympy.core.containers import Tuple
6
+ from sympy.core.function import diff
7
+ from sympy.core.singleton import S
8
+ from sympy.core.sympify import _sympify
9
+
10
+ from sympy.tensor.array.ndim_array import NDimArray
11
+ from sympy.tensor.array.dense_ndim_array import DenseNDimArray, ImmutableDenseNDimArray
12
+ from sympy.tensor.array.sparse_ndim_array import SparseNDimArray
13
+
14
+
15
+ def _arrayfy(a):
16
+ from sympy.matrices import MatrixBase
17
+
18
+ if isinstance(a, NDimArray):
19
+ return a
20
+ if isinstance(a, (MatrixBase, list, tuple, Tuple)):
21
+ return ImmutableDenseNDimArray(a)
22
+ return a
23
+
24
+
25
+ def tensorproduct(*args):
26
+ """
27
+ Tensor product among scalars or array-like objects.
28
+
29
+ The equivalent operator for array expressions is ``ArrayTensorProduct``,
30
+ which can be used to keep the expression unevaluated.
31
+
32
+ Examples
33
+ ========
34
+
35
+ >>> from sympy.tensor.array import tensorproduct, Array
36
+ >>> from sympy.abc import x, y, z, t
37
+ >>> A = Array([[1, 2], [3, 4]])
38
+ >>> B = Array([x, y])
39
+ >>> tensorproduct(A, B)
40
+ [[[x, y], [2*x, 2*y]], [[3*x, 3*y], [4*x, 4*y]]]
41
+ >>> tensorproduct(A, x)
42
+ [[x, 2*x], [3*x, 4*x]]
43
+ >>> tensorproduct(A, B, B)
44
+ [[[[x**2, x*y], [x*y, y**2]], [[2*x**2, 2*x*y], [2*x*y, 2*y**2]]], [[[3*x**2, 3*x*y], [3*x*y, 3*y**2]], [[4*x**2, 4*x*y], [4*x*y, 4*y**2]]]]
45
+
46
+ Applying this function on two matrices will result in a rank 4 array.
47
+
48
+ >>> from sympy import Matrix, eye
49
+ >>> m = Matrix([[x, y], [z, t]])
50
+ >>> p = tensorproduct(eye(3), m)
51
+ >>> p
52
+ [[[[x, y], [z, t]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[x, y], [z, t]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[x, y], [z, t]]]]
53
+
54
+ See Also
55
+ ========
56
+
57
+ sympy.tensor.array.expressions.array_expressions.ArrayTensorProduct
58
+
59
+ """
60
+ from sympy.tensor.array import SparseNDimArray, ImmutableSparseNDimArray
61
+
62
+ if len(args) == 0:
63
+ return S.One
64
+ if len(args) == 1:
65
+ return _arrayfy(args[0])
66
+ from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
67
+ from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
68
+ from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
69
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
70
+ if any(isinstance(arg, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)) for arg in args):
71
+ return ArrayTensorProduct(*args)
72
+ if len(args) > 2:
73
+ return tensorproduct(tensorproduct(args[0], args[1]), *args[2:])
74
+
75
+ # length of args is 2:
76
+ a, b = map(_arrayfy, args)
77
+
78
+ if not isinstance(a, NDimArray) or not isinstance(b, NDimArray):
79
+ return a*b
80
+
81
+ if isinstance(a, SparseNDimArray) and isinstance(b, SparseNDimArray):
82
+ lp = len(b)
83
+ new_array = {k1*lp + k2: v1*v2 for k1, v1 in a._sparse_array.items() for k2, v2 in b._sparse_array.items()}
84
+ return ImmutableSparseNDimArray(new_array, a.shape + b.shape)
85
+
86
+ product_list = [i*j for i in Flatten(a) for j in Flatten(b)]
87
+ return ImmutableDenseNDimArray(product_list, a.shape + b.shape)
88
+
89
+
90
+ def _util_contraction_diagonal(array, *contraction_or_diagonal_axes):
91
+ array = _arrayfy(array)
92
+
93
+ # Verify contraction_axes:
94
+ taken_dims = set()
95
+ for axes_group in contraction_or_diagonal_axes:
96
+ if not isinstance(axes_group, Iterable):
97
+ raise ValueError("collections of contraction/diagonal axes expected")
98
+
99
+ dim = array.shape[axes_group[0]]
100
+
101
+ for d in axes_group:
102
+ if d in taken_dims:
103
+ raise ValueError("dimension specified more than once")
104
+ if dim != array.shape[d]:
105
+ raise ValueError("cannot contract or diagonalize between axes of different dimension")
106
+ taken_dims.add(d)
107
+
108
+ rank = array.rank()
109
+
110
+ remaining_shape = [dim for i, dim in enumerate(array.shape) if i not in taken_dims]
111
+ cum_shape = [0]*rank
112
+ _cumul = 1
113
+ for i in range(rank):
114
+ cum_shape[rank - i - 1] = _cumul
115
+ _cumul *= int(array.shape[rank - i - 1])
116
+
117
+ # DEFINITION: by absolute position it is meant the position along the one
118
+ # dimensional array containing all the tensor components.
119
+
120
+ # Possible future work on this module: move computation of absolute
121
+ # positions to a class method.
122
+
123
+ # Determine absolute positions of the uncontracted indices:
124
+ remaining_indices = [[cum_shape[i]*j for j in range(array.shape[i])]
125
+ for i in range(rank) if i not in taken_dims]
126
+
127
+ # Determine absolute positions of the contracted indices:
128
+ summed_deltas = []
129
+ for axes_group in contraction_or_diagonal_axes:
130
+ lidx = []
131
+ for js in range(array.shape[axes_group[0]]):
132
+ lidx.append(sum(cum_shape[ig] * js for ig in axes_group))
133
+ summed_deltas.append(lidx)
134
+
135
+ return array, remaining_indices, remaining_shape, summed_deltas
136
+
137
+
138
+ def tensorcontraction(array, *contraction_axes):
139
+ """
140
+ Contraction of an array-like object on the specified axes.
141
+
142
+ The equivalent operator for array expressions is ``ArrayContraction``,
143
+ which can be used to keep the expression unevaluated.
144
+
145
+ Examples
146
+ ========
147
+
148
+ >>> from sympy import Array, tensorcontraction
149
+ >>> from sympy import Matrix, eye
150
+ >>> tensorcontraction(eye(3), (0, 1))
151
+ 3
152
+ >>> A = Array(range(18), (3, 2, 3))
153
+ >>> A
154
+ [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]]
155
+ >>> tensorcontraction(A, (0, 2))
156
+ [21, 30]
157
+
158
+ Matrix multiplication may be emulated with a proper combination of
159
+ ``tensorcontraction`` and ``tensorproduct``
160
+
161
+ >>> from sympy import tensorproduct
162
+ >>> from sympy.abc import a,b,c,d,e,f,g,h
163
+ >>> m1 = Matrix([[a, b], [c, d]])
164
+ >>> m2 = Matrix([[e, f], [g, h]])
165
+ >>> p = tensorproduct(m1, m2)
166
+ >>> p
167
+ [[[[a*e, a*f], [a*g, a*h]], [[b*e, b*f], [b*g, b*h]]], [[[c*e, c*f], [c*g, c*h]], [[d*e, d*f], [d*g, d*h]]]]
168
+ >>> tensorcontraction(p, (1, 2))
169
+ [[a*e + b*g, a*f + b*h], [c*e + d*g, c*f + d*h]]
170
+ >>> m1*m2
171
+ Matrix([
172
+ [a*e + b*g, a*f + b*h],
173
+ [c*e + d*g, c*f + d*h]])
174
+
175
+ See Also
176
+ ========
177
+
178
+ sympy.tensor.array.expressions.array_expressions.ArrayContraction
179
+
180
+ """
181
+ from sympy.tensor.array.expressions.array_expressions import _array_contraction
182
+ from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
183
+ from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
184
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
185
+ if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)):
186
+ return _array_contraction(array, *contraction_axes)
187
+
188
+ array, remaining_indices, remaining_shape, summed_deltas = _util_contraction_diagonal(array, *contraction_axes)
189
+
190
+ # Compute the contracted array:
191
+ #
192
+ # 1. external for loops on all uncontracted indices.
193
+ # Uncontracted indices are determined by the combinatorial product of
194
+ # the absolute positions of the remaining indices.
195
+ # 2. internal loop on all contracted indices.
196
+ # It sums the values of the absolute contracted index and the absolute
197
+ # uncontracted index for the external loop.
198
+ contracted_array = []
199
+ for icontrib in itertools.product(*remaining_indices):
200
+ index_base_position = sum(icontrib)
201
+ isum = S.Zero
202
+ for sum_to_index in itertools.product(*summed_deltas):
203
+ idx = array._get_tuple_index(index_base_position + sum(sum_to_index))
204
+ isum += array[idx]
205
+
206
+ contracted_array.append(isum)
207
+
208
+ if len(remaining_indices) == 0:
209
+ assert len(contracted_array) == 1
210
+ return contracted_array[0]
211
+
212
+ return type(array)(contracted_array, remaining_shape)
213
+
214
+
215
+ def tensordiagonal(array, *diagonal_axes):
216
+ """
217
+ Diagonalization of an array-like object on the specified axes.
218
+
219
+ This is equivalent to multiplying the expression by Kronecker deltas
220
+ uniting the axes.
221
+
222
+ The diagonal indices are put at the end of the axes.
223
+
224
+ The equivalent operator for array expressions is ``ArrayDiagonal``, which
225
+ can be used to keep the expression unevaluated.
226
+
227
+ Examples
228
+ ========
229
+
230
+ ``tensordiagonal`` acting on a 2-dimensional array by axes 0 and 1 is
231
+ equivalent to the diagonal of the matrix:
232
+
233
+ >>> from sympy import Array, tensordiagonal
234
+ >>> from sympy import Matrix, eye
235
+ >>> tensordiagonal(eye(3), (0, 1))
236
+ [1, 1, 1]
237
+
238
+ >>> from sympy.abc import a,b,c,d
239
+ >>> m1 = Matrix([[a, b], [c, d]])
240
+ >>> tensordiagonal(m1, [0, 1])
241
+ [a, d]
242
+
243
+ In case of higher dimensional arrays, the diagonalized out dimensions
244
+ are appended removed and appended as a single dimension at the end:
245
+
246
+ >>> A = Array(range(18), (3, 2, 3))
247
+ >>> A
248
+ [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]]
249
+ >>> tensordiagonal(A, (0, 2))
250
+ [[0, 7, 14], [3, 10, 17]]
251
+ >>> from sympy import permutedims
252
+ >>> tensordiagonal(A, (0, 2)) == permutedims(Array([A[0, :, 0], A[1, :, 1], A[2, :, 2]]), [1, 0])
253
+ True
254
+
255
+ See Also
256
+ ========
257
+
258
+ sympy.tensor.array.expressions.array_expressions.ArrayDiagonal
259
+
260
+ """
261
+ if any(len(i) <= 1 for i in diagonal_axes):
262
+ raise ValueError("need at least two axes to diagonalize")
263
+
264
+ from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
265
+ from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
266
+ from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, _array_diagonal
267
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
268
+ if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)):
269
+ return _array_diagonal(array, *diagonal_axes)
270
+
271
+ ArrayDiagonal._validate(array, *diagonal_axes)
272
+
273
+ array, remaining_indices, remaining_shape, diagonal_deltas = _util_contraction_diagonal(array, *diagonal_axes)
274
+
275
+ # Compute the diagonalized array:
276
+ #
277
+ # 1. external for loops on all undiagonalized indices.
278
+ # Undiagonalized indices are determined by the combinatorial product of
279
+ # the absolute positions of the remaining indices.
280
+ # 2. internal loop on all diagonal indices.
281
+ # It appends the values of the absolute diagonalized index and the absolute
282
+ # undiagonalized index for the external loop.
283
+ diagonalized_array = []
284
+ diagonal_shape = [len(i) for i in diagonal_deltas]
285
+ for icontrib in itertools.product(*remaining_indices):
286
+ index_base_position = sum(icontrib)
287
+ isum = []
288
+ for sum_to_index in itertools.product(*diagonal_deltas):
289
+ idx = array._get_tuple_index(index_base_position + sum(sum_to_index))
290
+ isum.append(array[idx])
291
+
292
+ isum = type(array)(isum).reshape(*diagonal_shape)
293
+ diagonalized_array.append(isum)
294
+
295
+ return type(array)(diagonalized_array, remaining_shape + diagonal_shape)
296
+
297
+
298
+ def derive_by_array(expr, dx):
299
+ r"""
300
+ Derivative by arrays. Supports both arrays and scalars.
301
+
302
+ The equivalent operator for array expressions is ``array_derive``.
303
+
304
+ Explanation
305
+ ===========
306
+
307
+ Given the array `A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}`
308
+ this function will return a new array `B` defined by
309
+
310
+ `B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}`
311
+
312
+ Examples
313
+ ========
314
+
315
+ >>> from sympy import derive_by_array
316
+ >>> from sympy.abc import x, y, z, t
317
+ >>> from sympy import cos
318
+ >>> derive_by_array(cos(x*t), x)
319
+ -t*sin(t*x)
320
+ >>> derive_by_array(cos(x*t), [x, y, z, t])
321
+ [-t*sin(t*x), 0, 0, -x*sin(t*x)]
322
+ >>> derive_by_array([x, y**2*z], [[x, y], [z, t]])
323
+ [[[1, 0], [0, 2*y*z]], [[0, y**2], [0, 0]]]
324
+
325
+ """
326
+ from sympy.matrices import MatrixBase
327
+ from sympy.tensor.array import SparseNDimArray
328
+ array_types = (Iterable, MatrixBase, NDimArray)
329
+
330
+ if isinstance(dx, array_types):
331
+ dx = ImmutableDenseNDimArray(dx)
332
+ for i in dx:
333
+ if not i._diff_wrt:
334
+ raise ValueError("cannot derive by this array")
335
+
336
+ if isinstance(expr, array_types):
337
+ if isinstance(expr, NDimArray):
338
+ expr = expr.as_immutable()
339
+ else:
340
+ expr = ImmutableDenseNDimArray(expr)
341
+
342
+ if isinstance(dx, array_types):
343
+ if isinstance(expr, SparseNDimArray):
344
+ lp = len(expr)
345
+ new_array = {k + i*lp: v
346
+ for i, x in enumerate(Flatten(dx))
347
+ for k, v in expr.diff(x)._sparse_array.items()}
348
+ else:
349
+ new_array = [[y.diff(x) for y in Flatten(expr)] for x in Flatten(dx)]
350
+ return type(expr)(new_array, dx.shape + expr.shape)
351
+ else:
352
+ return expr.diff(dx)
353
+ else:
354
+ expr = _sympify(expr)
355
+ if isinstance(dx, array_types):
356
+ return ImmutableDenseNDimArray([expr.diff(i) for i in Flatten(dx)], dx.shape)
357
+ else:
358
+ dx = _sympify(dx)
359
+ return diff(expr, dx)
360
+
361
+
362
+ def permutedims(expr, perm=None, index_order_old=None, index_order_new=None):
363
+ """
364
+ Permutes the indices of an array.
365
+
366
+ Parameter specifies the permutation of the indices.
367
+
368
+ The equivalent operator for array expressions is ``PermuteDims``, which can
369
+ be used to keep the expression unevaluated.
370
+
371
+ Examples
372
+ ========
373
+
374
+ >>> from sympy.abc import x, y, z, t
375
+ >>> from sympy import sin
376
+ >>> from sympy import Array, permutedims
377
+ >>> a = Array([[x, y, z], [t, sin(x), 0]])
378
+ >>> a
379
+ [[x, y, z], [t, sin(x), 0]]
380
+ >>> permutedims(a, (1, 0))
381
+ [[x, t], [y, sin(x)], [z, 0]]
382
+
383
+ If the array is of second order, ``transpose`` can be used:
384
+
385
+ >>> from sympy import transpose
386
+ >>> transpose(a)
387
+ [[x, t], [y, sin(x)], [z, 0]]
388
+
389
+ Examples on higher dimensions:
390
+
391
+ >>> b = Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
392
+ >>> permutedims(b, (2, 1, 0))
393
+ [[[1, 5], [3, 7]], [[2, 6], [4, 8]]]
394
+ >>> permutedims(b, (1, 2, 0))
395
+ [[[1, 5], [2, 6]], [[3, 7], [4, 8]]]
396
+
397
+ An alternative way to specify the same permutations as in the previous
398
+ lines involves passing the *old* and *new* indices, either as a list or as
399
+ a string:
400
+
401
+ >>> permutedims(b, index_order_old="cba", index_order_new="abc")
402
+ [[[1, 5], [3, 7]], [[2, 6], [4, 8]]]
403
+ >>> permutedims(b, index_order_old="cab", index_order_new="abc")
404
+ [[[1, 5], [2, 6]], [[3, 7], [4, 8]]]
405
+
406
+ ``Permutation`` objects are also allowed:
407
+
408
+ >>> from sympy.combinatorics import Permutation
409
+ >>> permutedims(b, Permutation([1, 2, 0]))
410
+ [[[1, 5], [2, 6]], [[3, 7], [4, 8]]]
411
+
412
+ See Also
413
+ ========
414
+
415
+ sympy.tensor.array.expressions.array_expressions.PermuteDims
416
+
417
+ """
418
+ from sympy.tensor.array import SparseNDimArray
419
+
420
+ from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
421
+ from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
422
+ from sympy.tensor.array.expressions.array_expressions import _permute_dims
423
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
424
+ from sympy.tensor.array.expressions import PermuteDims
425
+ from sympy.tensor.array.expressions.array_expressions import get_rank
426
+ perm = PermuteDims._get_permutation_from_arguments(perm, index_order_old, index_order_new, get_rank(expr))
427
+ if isinstance(expr, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)):
428
+ return _permute_dims(expr, perm)
429
+
430
+ if not isinstance(expr, NDimArray):
431
+ expr = ImmutableDenseNDimArray(expr)
432
+
433
+ from sympy.combinatorics import Permutation
434
+ if not isinstance(perm, Permutation):
435
+ perm = Permutation(list(perm))
436
+
437
+ if perm.size != expr.rank():
438
+ raise ValueError("wrong permutation size")
439
+
440
+ # Get the inverse permutation:
441
+ iperm = ~perm
442
+ new_shape = perm(expr.shape)
443
+
444
+ if isinstance(expr, SparseNDimArray):
445
+ return type(expr)({tuple(perm(expr._get_tuple_index(k))): v
446
+ for k, v in expr._sparse_array.items()}, new_shape)
447
+
448
+ indices_span = perm([range(i) for i in expr.shape])
449
+
450
+ new_array = [None]*len(expr)
451
+ for i, idx in enumerate(itertools.product(*indices_span)):
452
+ t = iperm(idx)
453
+ new_array[i] = expr[t]
454
+
455
+ return type(expr)(new_array, new_shape)
456
+
457
+
458
+ class Flatten(Printable):
459
+ """
460
+ Flatten an iterable object to a list in a lazy-evaluation way.
461
+
462
+ Notes
463
+ =====
464
+
465
+ This class is an iterator with which the memory cost can be economised.
466
+ Optimisation has been considered to ameliorate the performance for some
467
+ specific data types like DenseNDimArray and SparseNDimArray.
468
+
469
+ Examples
470
+ ========
471
+
472
+ >>> from sympy.tensor.array.arrayop import Flatten
473
+ >>> from sympy.tensor.array import Array
474
+ >>> A = Array(range(6)).reshape(2, 3)
475
+ >>> Flatten(A)
476
+ Flatten([[0, 1, 2], [3, 4, 5]])
477
+ >>> [i for i in Flatten(A)]
478
+ [0, 1, 2, 3, 4, 5]
479
+ """
480
+ def __init__(self, iterable):
481
+ from sympy.matrices.matrixbase import MatrixBase
482
+ from sympy.tensor.array import NDimArray
483
+
484
+ if not isinstance(iterable, (Iterable, MatrixBase)):
485
+ raise NotImplementedError("Data type not yet supported")
486
+
487
+ if isinstance(iterable, list):
488
+ iterable = NDimArray(iterable)
489
+
490
+ self._iter = iterable
491
+ self._idx = 0
492
+
493
+ def __iter__(self):
494
+ return self
495
+
496
+ def __next__(self):
497
+ from sympy.matrices.matrixbase import MatrixBase
498
+
499
+ if len(self._iter) > self._idx:
500
+ if isinstance(self._iter, DenseNDimArray):
501
+ result = self._iter._array[self._idx]
502
+
503
+ elif isinstance(self._iter, SparseNDimArray):
504
+ if self._idx in self._iter._sparse_array:
505
+ result = self._iter._sparse_array[self._idx]
506
+ else:
507
+ result = 0
508
+
509
+ elif isinstance(self._iter, MatrixBase):
510
+ result = self._iter[self._idx]
511
+
512
+ elif hasattr(self._iter, '__next__'):
513
+ result = next(self._iter)
514
+
515
+ else:
516
+ result = self._iter[self._idx]
517
+
518
+ else:
519
+ raise StopIteration
520
+
521
+ self._idx += 1
522
+ return result
523
+
524
+ def next(self):
525
+ return self.__next__()
526
+
527
+ def _sympystr(self, printer):
528
+ return type(self).__name__ + '(' + printer._print(self._iter) + ')'
.venv/lib/python3.11/site-packages/sympy/tensor/array/dense_ndim_array.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import List
3
+
4
+ from sympy.core.basic import Basic
5
+ from sympy.core.containers import Tuple
6
+ from sympy.core.singleton import S
7
+ from sympy.core.sympify import _sympify
8
+ from sympy.tensor.array.mutable_ndim_array import MutableNDimArray
9
+ from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray, ArrayKind
10
+ from sympy.utilities.iterables import flatten
11
+
12
+
13
+ class DenseNDimArray(NDimArray):
14
+
15
+ _array: List[Basic]
16
+
17
+ def __new__(self, *args, **kwargs):
18
+ return ImmutableDenseNDimArray(*args, **kwargs)
19
+
20
+ @property
21
+ def kind(self) -> ArrayKind:
22
+ return ArrayKind._union(self._array)
23
+
24
+ def __getitem__(self, index):
25
+ """
26
+ Allows to get items from N-dim array.
27
+
28
+ Examples
29
+ ========
30
+
31
+ >>> from sympy import MutableDenseNDimArray
32
+ >>> a = MutableDenseNDimArray([0, 1, 2, 3], (2, 2))
33
+ >>> a
34
+ [[0, 1], [2, 3]]
35
+ >>> a[0, 0]
36
+ 0
37
+ >>> a[1, 1]
38
+ 3
39
+ >>> a[0]
40
+ [0, 1]
41
+ >>> a[1]
42
+ [2, 3]
43
+
44
+
45
+ Symbolic index:
46
+
47
+ >>> from sympy.abc import i, j
48
+ >>> a[i, j]
49
+ [[0, 1], [2, 3]][i, j]
50
+
51
+ Replace `i` and `j` to get element `(1, 1)`:
52
+
53
+ >>> a[i, j].subs({i: 1, j: 1})
54
+ 3
55
+
56
+ """
57
+ syindex = self._check_symbolic_index(index)
58
+ if syindex is not None:
59
+ return syindex
60
+
61
+ index = self._check_index_for_getitem(index)
62
+
63
+ if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
64
+ sl_factors, eindices = self._get_slice_data_for_array_access(index)
65
+ array = [self._array[self._parse_index(i)] for i in eindices]
66
+ nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)]
67
+ return type(self)(array, nshape)
68
+ else:
69
+ index = self._parse_index(index)
70
+ return self._array[index]
71
+
72
+ @classmethod
73
+ def zeros(cls, *shape):
74
+ list_length = functools.reduce(lambda x, y: x*y, shape, S.One)
75
+ return cls._new(([0]*list_length,), shape)
76
+
77
+ def tomatrix(self):
78
+ """
79
+ Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error.
80
+
81
+ Examples
82
+ ========
83
+
84
+ >>> from sympy import MutableDenseNDimArray
85
+ >>> a = MutableDenseNDimArray([1 for i in range(9)], (3, 3))
86
+ >>> b = a.tomatrix()
87
+ >>> b
88
+ Matrix([
89
+ [1, 1, 1],
90
+ [1, 1, 1],
91
+ [1, 1, 1]])
92
+
93
+ """
94
+ from sympy.matrices import Matrix
95
+
96
+ if self.rank() != 2:
97
+ raise ValueError('Dimensions must be of size of 2')
98
+
99
+ return Matrix(self.shape[0], self.shape[1], self._array)
100
+
101
+ def reshape(self, *newshape):
102
+ """
103
+ Returns MutableDenseNDimArray instance with new shape. Elements number
104
+ must be suitable to new shape. The only argument of method sets
105
+ new shape.
106
+
107
+ Examples
108
+ ========
109
+
110
+ >>> from sympy import MutableDenseNDimArray
111
+ >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3))
112
+ >>> a.shape
113
+ (2, 3)
114
+ >>> a
115
+ [[1, 2, 3], [4, 5, 6]]
116
+ >>> b = a.reshape(3, 2)
117
+ >>> b.shape
118
+ (3, 2)
119
+ >>> b
120
+ [[1, 2], [3, 4], [5, 6]]
121
+
122
+ """
123
+ new_total_size = functools.reduce(lambda x,y: x*y, newshape)
124
+ if new_total_size != self._loop_size:
125
+ raise ValueError('Expecting reshape size to %d but got prod(%s) = %d' % (
126
+ self._loop_size, str(newshape), new_total_size))
127
+
128
+ # there is no `.func` as this class does not subtype `Basic`:
129
+ return type(self)(self._array, newshape)
130
+
131
+
132
+ class ImmutableDenseNDimArray(DenseNDimArray, ImmutableNDimArray): # type: ignore
133
+ def __new__(cls, iterable, shape=None, **kwargs):
134
+ return cls._new(iterable, shape, **kwargs)
135
+
136
+ @classmethod
137
+ def _new(cls, iterable, shape, **kwargs):
138
+ shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
139
+ shape = Tuple(*map(_sympify, shape))
140
+ cls._check_special_bounds(flat_list, shape)
141
+ flat_list = flatten(flat_list)
142
+ flat_list = Tuple(*flat_list)
143
+ self = Basic.__new__(cls, flat_list, shape, **kwargs)
144
+ self._shape = shape
145
+ self._array = list(flat_list)
146
+ self._rank = len(shape)
147
+ self._loop_size = functools.reduce(lambda x,y: x*y, shape, 1)
148
+ return self
149
+
150
+ def __setitem__(self, index, value):
151
+ raise TypeError('immutable N-dim array')
152
+
153
+ def as_mutable(self):
154
+ return MutableDenseNDimArray(self)
155
+
156
+ def _eval_simplify(self, **kwargs):
157
+ from sympy.simplify.simplify import simplify
158
+ return self.applyfunc(simplify)
159
+
160
+ class MutableDenseNDimArray(DenseNDimArray, MutableNDimArray):
161
+
162
+ def __new__(cls, iterable=None, shape=None, **kwargs):
163
+ return cls._new(iterable, shape, **kwargs)
164
+
165
+ @classmethod
166
+ def _new(cls, iterable, shape, **kwargs):
167
+ shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
168
+ flat_list = flatten(flat_list)
169
+ self = object.__new__(cls)
170
+ self._shape = shape
171
+ self._array = list(flat_list)
172
+ self._rank = len(shape)
173
+ self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list)
174
+ return self
175
+
176
+ def __setitem__(self, index, value):
177
+ """Allows to set items to MutableDenseNDimArray.
178
+
179
+ Examples
180
+ ========
181
+
182
+ >>> from sympy import MutableDenseNDimArray
183
+ >>> a = MutableDenseNDimArray.zeros(2, 2)
184
+ >>> a[0,0] = 1
185
+ >>> a[1,1] = 1
186
+ >>> a
187
+ [[1, 0], [0, 1]]
188
+
189
+ """
190
+ if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
191
+ value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value)
192
+ for i in eindices:
193
+ other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None]
194
+ self._array[self._parse_index(i)] = value[other_i]
195
+ else:
196
+ index = self._parse_index(index)
197
+ self._setter_iterable_check(value)
198
+ value = _sympify(value)
199
+ self._array[index] = value
200
+
201
+ def as_immutable(self):
202
+ return ImmutableDenseNDimArray(self)
203
+
204
+ @property
205
+ def free_symbols(self):
206
+ return {i for j in self._array for i in j.free_symbols}
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__init__.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ Array expressions are expressions representing N-dimensional arrays, without
3
+ evaluating them. These expressions represent in a certain way abstract syntax
4
+ trees of operations on N-dimensional arrays.
5
+
6
+ Every N-dimensional array operator has a corresponding array expression object.
7
+
8
+ Table of correspondences:
9
+
10
+ =============================== =============================
11
+ Array operator Array expression operator
12
+ =============================== =============================
13
+ tensorproduct ArrayTensorProduct
14
+ tensorcontraction ArrayContraction
15
+ tensordiagonal ArrayDiagonal
16
+ permutedims PermuteDims
17
+ =============================== =============================
18
+
19
+ Examples
20
+ ========
21
+
22
+ ``ArraySymbol`` objects are the N-dimensional equivalent of ``MatrixSymbol``
23
+ objects in the matrix module:
24
+
25
+ >>> from sympy.tensor.array.expressions import ArraySymbol
26
+ >>> from sympy.abc import i, j, k
27
+ >>> A = ArraySymbol("A", (3, 2, 4))
28
+ >>> A.shape
29
+ (3, 2, 4)
30
+ >>> A[i, j, k]
31
+ A[i, j, k]
32
+ >>> A.as_explicit()
33
+ [[[A[0, 0, 0], A[0, 0, 1], A[0, 0, 2], A[0, 0, 3]],
34
+ [A[0, 1, 0], A[0, 1, 1], A[0, 1, 2], A[0, 1, 3]]],
35
+ [[A[1, 0, 0], A[1, 0, 1], A[1, 0, 2], A[1, 0, 3]],
36
+ [A[1, 1, 0], A[1, 1, 1], A[1, 1, 2], A[1, 1, 3]]],
37
+ [[A[2, 0, 0], A[2, 0, 1], A[2, 0, 2], A[2, 0, 3]],
38
+ [A[2, 1, 0], A[2, 1, 1], A[2, 1, 2], A[2, 1, 3]]]]
39
+
40
+ Component-explicit arrays can be added inside array expressions:
41
+
42
+ >>> from sympy import Array
43
+ >>> from sympy import tensorproduct
44
+ >>> from sympy.tensor.array.expressions import ArrayTensorProduct
45
+ >>> a = Array([1, 2, 3])
46
+ >>> b = Array([i, j, k])
47
+ >>> expr = ArrayTensorProduct(a, b, b)
48
+ >>> expr
49
+ ArrayTensorProduct([1, 2, 3], [i, j, k], [i, j, k])
50
+ >>> expr.as_explicit() == tensorproduct(a, b, b)
51
+ True
52
+
53
+ Constructing array expressions from index-explicit forms
54
+ --------------------------------------------------------
55
+
56
+ Array expressions are index-implicit. This means they do not use any indices to
57
+ represent array operations. The function ``convert_indexed_to_array( ... )``
58
+ may be used to convert index-explicit expressions to array expressions.
59
+ It takes as input two parameters: the index-explicit expression and the order
60
+ of the indices:
61
+
62
+ >>> from sympy.tensor.array.expressions import convert_indexed_to_array
63
+ >>> from sympy import Sum
64
+ >>> A = ArraySymbol("A", (3, 3))
65
+ >>> B = ArraySymbol("B", (3, 3))
66
+ >>> convert_indexed_to_array(A[i, j], [i, j])
67
+ A
68
+ >>> convert_indexed_to_array(A[i, j], [j, i])
69
+ PermuteDims(A, (0 1))
70
+ >>> convert_indexed_to_array(A[i, j] + B[j, i], [i, j])
71
+ ArrayAdd(A, PermuteDims(B, (0 1)))
72
+ >>> convert_indexed_to_array(Sum(A[i, j]*B[j, k], (j, 0, 2)), [i, k])
73
+ ArrayContraction(ArrayTensorProduct(A, B), (1, 2))
74
+
75
+ The diagonal of a matrix in the array expression form:
76
+
77
+ >>> convert_indexed_to_array(A[i, i], [i])
78
+ ArrayDiagonal(A, (0, 1))
79
+
80
+ The trace of a matrix in the array expression form:
81
+
82
+ >>> convert_indexed_to_array(Sum(A[i, i], (i, 0, 2)), [i])
83
+ ArrayContraction(A, (0, 1))
84
+
85
+ Compatibility with matrices
86
+ ---------------------------
87
+
88
+ Array expressions can be mixed with objects from the matrix module:
89
+
90
+ >>> from sympy import MatrixSymbol
91
+ >>> from sympy.tensor.array.expressions import ArrayContraction
92
+ >>> M = MatrixSymbol("M", 3, 3)
93
+ >>> N = MatrixSymbol("N", 3, 3)
94
+
95
+ Express the matrix product in the array expression form:
96
+
97
+ >>> from sympy.tensor.array.expressions import convert_matrix_to_array
98
+ >>> expr = convert_matrix_to_array(M*N)
99
+ >>> expr
100
+ ArrayContraction(ArrayTensorProduct(M, N), (1, 2))
101
+
102
+ The expression can be converted back to matrix form:
103
+
104
+ >>> from sympy.tensor.array.expressions import convert_array_to_matrix
105
+ >>> convert_array_to_matrix(expr)
106
+ M*N
107
+
108
+ Add a second contraction on the remaining axes in order to get the trace of `M \cdot N`:
109
+
110
+ >>> expr_tr = ArrayContraction(expr, (0, 1))
111
+ >>> expr_tr
112
+ ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), (0, 1))
113
+
114
+ Flatten the expression by calling ``.doit()`` and remove the nested array contraction operations:
115
+
116
+ >>> expr_tr.doit()
117
+ ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2))
118
+
119
+ Get the explicit form of the array expression:
120
+
121
+ >>> expr.as_explicit()
122
+ [[M[0, 0]*N[0, 0] + M[0, 1]*N[1, 0] + M[0, 2]*N[2, 0], M[0, 0]*N[0, 1] + M[0, 1]*N[1, 1] + M[0, 2]*N[2, 1], M[0, 0]*N[0, 2] + M[0, 1]*N[1, 2] + M[0, 2]*N[2, 2]],
123
+ [M[1, 0]*N[0, 0] + M[1, 1]*N[1, 0] + M[1, 2]*N[2, 0], M[1, 0]*N[0, 1] + M[1, 1]*N[1, 1] + M[1, 2]*N[2, 1], M[1, 0]*N[0, 2] + M[1, 1]*N[1, 2] + M[1, 2]*N[2, 2]],
124
+ [M[2, 0]*N[0, 0] + M[2, 1]*N[1, 0] + M[2, 2]*N[2, 0], M[2, 0]*N[0, 1] + M[2, 1]*N[1, 1] + M[2, 2]*N[2, 1], M[2, 0]*N[0, 2] + M[2, 1]*N[1, 2] + M[2, 2]*N[2, 2]]]
125
+
126
+ Express the trace of a matrix:
127
+
128
+ >>> from sympy import Trace
129
+ >>> convert_matrix_to_array(Trace(M))
130
+ ArrayContraction(M, (0, 1))
131
+ >>> convert_matrix_to_array(Trace(M*N))
132
+ ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2))
133
+
134
+ Express the transposition of a matrix (will be expressed as a permutation of the axes:
135
+
136
+ >>> convert_matrix_to_array(M.T)
137
+ PermuteDims(M, (0 1))
138
+
139
+ Compute the derivative array expressions:
140
+
141
+ >>> from sympy.tensor.array.expressions import array_derive
142
+ >>> d = array_derive(M, M)
143
+ >>> d
144
+ PermuteDims(ArrayTensorProduct(I, I), (3)(1 2))
145
+
146
+ Verify that the derivative corresponds to the form computed with explicit matrices:
147
+
148
+ >>> d.as_explicit()
149
+ [[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]]
150
+ >>> Me = M.as_explicit()
151
+ >>> Me.diff(Me)
152
+ [[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]]
153
+
154
+ """
155
+
156
+ __all__ = [
157
+ "ArraySymbol", "ArrayElement", "ZeroArray", "OneArray",
158
+ "ArrayTensorProduct",
159
+ "ArrayContraction",
160
+ "ArrayDiagonal",
161
+ "PermuteDims",
162
+ "ArrayAdd",
163
+ "ArrayElementwiseApplyFunc",
164
+ "Reshape",
165
+ "convert_array_to_matrix",
166
+ "convert_matrix_to_array",
167
+ "convert_array_to_indexed",
168
+ "convert_indexed_to_array",
169
+ "array_derive",
170
+ ]
171
+
172
+ from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, PermuteDims, ArrayDiagonal, \
173
+ ArrayContraction, Reshape, ArraySymbol, ArrayElement, ZeroArray, OneArray, ArrayElementwiseApplyFunc
174
+ from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive
175
+ from sympy.tensor.array.expressions.from_array_to_indexed import convert_array_to_indexed
176
+ from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
177
+ from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
178
+ from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (7.39 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-311.pyc ADDED
Binary file (724 Bytes). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-311.pyc ADDED
Binary file (640 Bytes). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-311.pyc ADDED
Binary file (501 Bytes). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-311.pyc ADDED
Binary file (498 Bytes). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-311.pyc ADDED
Binary file (8.16 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-311.pyc ADDED
Binary file (66.1 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-311.pyc ADDED
Binary file (8.58 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-311.pyc ADDED
Binary file (9.02 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/array_expressions.py ADDED
@@ -0,0 +1,1967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import operator
3
+ from collections import defaultdict, Counter
4
+ from functools import reduce
5
+ import itertools
6
+ from itertools import accumulate
7
+ from typing import Optional, List, Tuple as tTuple
8
+
9
+ import typing
10
+
11
+ from sympy.core.numbers import Integer
12
+ from sympy.core.relational import Equality
13
+ from sympy.functions.special.tensor_functions import KroneckerDelta
14
+ from sympy.core.basic import Basic
15
+ from sympy.core.containers import Tuple
16
+ from sympy.core.expr import Expr
17
+ from sympy.core.function import (Function, Lambda)
18
+ from sympy.core.mul import Mul
19
+ from sympy.core.singleton import S
20
+ from sympy.core.sorting import default_sort_key
21
+ from sympy.core.symbol import (Dummy, Symbol)
22
+ from sympy.matrices.matrixbase import MatrixBase
23
+ from sympy.matrices.expressions.diagonal import diagonalize_vector
24
+ from sympy.matrices.expressions.matexpr import MatrixExpr
25
+ from sympy.matrices.expressions.special import ZeroMatrix
26
+ from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensordiagonal, tensorproduct)
27
+ from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray
28
+ from sympy.tensor.array.ndim_array import NDimArray
29
+ from sympy.tensor.indexed import (Indexed, IndexedBase)
30
+ from sympy.matrices.expressions.matexpr import MatrixElement
31
+ from sympy.tensor.array.expressions.utils import _apply_recursively_over_nested_lists, _sort_contraction_indices, \
32
+ _get_mapping_from_subranks, _build_push_indices_up_func_transformation, _get_contraction_links, \
33
+ _build_push_indices_down_func_transformation
34
+ from sympy.combinatorics import Permutation
35
+ from sympy.combinatorics.permutations import _af_invert
36
+ from sympy.core.sympify import _sympify
37
+
38
+
39
+ class _ArrayExpr(Expr):
40
+ shape: tTuple[Expr, ...]
41
+
42
+ def __getitem__(self, item):
43
+ if not isinstance(item, collections.abc.Iterable):
44
+ item = (item,)
45
+ ArrayElement._check_shape(self, item)
46
+ return self._get(item)
47
+
48
+ def _get(self, item):
49
+ return _get_array_element_or_slice(self, item)
50
+
51
+
52
+ class ArraySymbol(_ArrayExpr):
53
+ """
54
+ Symbol representing an array expression
55
+ """
56
+
57
+ def __new__(cls, symbol, shape: typing.Iterable) -> "ArraySymbol":
58
+ if isinstance(symbol, str):
59
+ symbol = Symbol(symbol)
60
+ # symbol = _sympify(symbol)
61
+ shape = Tuple(*map(_sympify, shape))
62
+ obj = Expr.__new__(cls, symbol, shape)
63
+ return obj
64
+
65
+ @property
66
+ def name(self):
67
+ return self._args[0]
68
+
69
+ @property
70
+ def shape(self):
71
+ return self._args[1]
72
+
73
+ def as_explicit(self):
74
+ if not all(i.is_Integer for i in self.shape):
75
+ raise ValueError("cannot express explicit array with symbolic shape")
76
+ data = [self[i] for i in itertools.product(*[range(j) for j in self.shape])]
77
+ return ImmutableDenseNDimArray(data).reshape(*self.shape)
78
+
79
+
80
+ class ArrayElement(Expr):
81
+ """
82
+ An element of an array.
83
+ """
84
+
85
+ _diff_wrt = True
86
+ is_symbol = True
87
+ is_commutative = True
88
+
89
+ def __new__(cls, name, indices):
90
+ if isinstance(name, str):
91
+ name = Symbol(name)
92
+ name = _sympify(name)
93
+ if not isinstance(indices, collections.abc.Iterable):
94
+ indices = (indices,)
95
+ indices = _sympify(tuple(indices))
96
+ cls._check_shape(name, indices)
97
+ obj = Expr.__new__(cls, name, indices)
98
+ return obj
99
+
100
+ @classmethod
101
+ def _check_shape(cls, name, indices):
102
+ indices = tuple(indices)
103
+ if hasattr(name, "shape"):
104
+ index_error = IndexError("number of indices does not match shape of the array")
105
+ if len(indices) != len(name.shape):
106
+ raise index_error
107
+ if any((i >= s) == True for i, s in zip(indices, name.shape)):
108
+ raise ValueError("shape is out of bounds")
109
+ if any((i < 0) == True for i in indices):
110
+ raise ValueError("shape contains negative values")
111
+
112
+ @property
113
+ def name(self):
114
+ return self._args[0]
115
+
116
+ @property
117
+ def indices(self):
118
+ return self._args[1]
119
+
120
+ def _eval_derivative(self, s):
121
+ if not isinstance(s, ArrayElement):
122
+ return S.Zero
123
+
124
+ if s == self:
125
+ return S.One
126
+
127
+ if s.name != self.name:
128
+ return S.Zero
129
+
130
+ return Mul.fromiter(KroneckerDelta(i, j) for i, j in zip(self.indices, s.indices))
131
+
132
+
133
+ class ZeroArray(_ArrayExpr):
134
+ """
135
+ Symbolic array of zeros. Equivalent to ``ZeroMatrix`` for matrices.
136
+ """
137
+
138
+ def __new__(cls, *shape):
139
+ if len(shape) == 0:
140
+ return S.Zero
141
+ shape = map(_sympify, shape)
142
+ obj = Expr.__new__(cls, *shape)
143
+ return obj
144
+
145
+ @property
146
+ def shape(self):
147
+ return self._args
148
+
149
+ def as_explicit(self):
150
+ if not all(i.is_Integer for i in self.shape):
151
+ raise ValueError("Cannot return explicit form for symbolic shape.")
152
+ return ImmutableDenseNDimArray.zeros(*self.shape)
153
+
154
+ def _get(self, item):
155
+ return S.Zero
156
+
157
+
158
+ class OneArray(_ArrayExpr):
159
+ """
160
+ Symbolic array of ones.
161
+ """
162
+
163
+ def __new__(cls, *shape):
164
+ if len(shape) == 0:
165
+ return S.One
166
+ shape = map(_sympify, shape)
167
+ obj = Expr.__new__(cls, *shape)
168
+ return obj
169
+
170
+ @property
171
+ def shape(self):
172
+ return self._args
173
+
174
+ def as_explicit(self):
175
+ if not all(i.is_Integer for i in self.shape):
176
+ raise ValueError("Cannot return explicit form for symbolic shape.")
177
+ return ImmutableDenseNDimArray([S.One for i in range(reduce(operator.mul, self.shape))]).reshape(*self.shape)
178
+
179
+ def _get(self, item):
180
+ return S.One
181
+
182
+
183
+ class _CodegenArrayAbstract(Basic):
184
+
185
+ @property
186
+ def subranks(self):
187
+ """
188
+ Returns the ranks of the objects in the uppermost tensor product inside
189
+ the current object. In case no tensor products are contained, return
190
+ the atomic ranks.
191
+
192
+ Examples
193
+ ========
194
+
195
+ >>> from sympy.tensor.array import tensorproduct, tensorcontraction
196
+ >>> from sympy import MatrixSymbol
197
+ >>> M = MatrixSymbol("M", 3, 3)
198
+ >>> N = MatrixSymbol("N", 3, 3)
199
+ >>> P = MatrixSymbol("P", 3, 3)
200
+
201
+ Important: do not confuse the rank of the matrix with the rank of an array.
202
+
203
+ >>> tp = tensorproduct(M, N, P)
204
+ >>> tp.subranks
205
+ [2, 2, 2]
206
+
207
+ >>> co = tensorcontraction(tp, (1, 2), (3, 4))
208
+ >>> co.subranks
209
+ [2, 2, 2]
210
+ """
211
+ return self._subranks[:]
212
+
213
+ def subrank(self):
214
+ """
215
+ The sum of ``subranks``.
216
+ """
217
+ return sum(self.subranks)
218
+
219
+ @property
220
+ def shape(self):
221
+ return self._shape
222
+
223
+ def doit(self, **hints):
224
+ deep = hints.get("deep", True)
225
+ if deep:
226
+ return self.func(*[arg.doit(**hints) for arg in self.args])._canonicalize()
227
+ else:
228
+ return self._canonicalize()
229
+
230
+ class ArrayTensorProduct(_CodegenArrayAbstract):
231
+ r"""
232
+ Class to represent the tensor product of array-like objects.
233
+ """
234
+
235
+ def __new__(cls, *args, **kwargs):
236
+ args = [_sympify(arg) for arg in args]
237
+
238
+ canonicalize = kwargs.pop("canonicalize", False)
239
+
240
+ ranks = [get_rank(arg) for arg in args]
241
+
242
+ obj = Basic.__new__(cls, *args)
243
+ obj._subranks = ranks
244
+ shapes = [get_shape(i) for i in args]
245
+
246
+ if any(i is None for i in shapes):
247
+ obj._shape = None
248
+ else:
249
+ obj._shape = tuple(j for i in shapes for j in i)
250
+ if canonicalize:
251
+ return obj._canonicalize()
252
+ return obj
253
+
254
+ def _canonicalize(self):
255
+ args = self.args
256
+ args = self._flatten(args)
257
+
258
+ ranks = [get_rank(arg) for arg in args]
259
+
260
+ # Check if there are nested permutation and lift them up:
261
+ permutation_cycles = []
262
+ for i, arg in enumerate(args):
263
+ if not isinstance(arg, PermuteDims):
264
+ continue
265
+ permutation_cycles.extend([[k + sum(ranks[:i]) for k in j] for j in arg.permutation.cyclic_form])
266
+ args[i] = arg.expr
267
+ if permutation_cycles:
268
+ return _permute_dims(_array_tensor_product(*args), Permutation(sum(ranks)-1)*Permutation(permutation_cycles))
269
+
270
+ if len(args) == 1:
271
+ return args[0]
272
+
273
+ # If any object is a ZeroArray, return a ZeroArray:
274
+ if any(isinstance(arg, (ZeroArray, ZeroMatrix)) for arg in args):
275
+ shapes = reduce(operator.add, [get_shape(i) for i in args], ())
276
+ return ZeroArray(*shapes)
277
+
278
+ # If there are contraction objects inside, transform the whole
279
+ # expression into `ArrayContraction`:
280
+ contractions = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayContraction)}
281
+ if contractions:
282
+ ranks = [_get_subrank(arg) if isinstance(arg, ArrayContraction) else get_rank(arg) for arg in args]
283
+ cumulative_ranks = list(accumulate([0] + ranks))[:-1]
284
+ tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayContraction) else arg for arg in args])
285
+ contraction_indices = [tuple(cumulative_ranks[i] + k for k in j) for i, arg in contractions.items() for j in arg.contraction_indices]
286
+ return _array_contraction(tp, *contraction_indices)
287
+
288
+ diagonals = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayDiagonal)}
289
+ if diagonals:
290
+ inverse_permutation = []
291
+ last_perm = []
292
+ ranks = [get_rank(arg) for arg in args]
293
+ cumulative_ranks = list(accumulate([0] + ranks))[:-1]
294
+ for i, arg in enumerate(args):
295
+ if isinstance(arg, ArrayDiagonal):
296
+ i1 = get_rank(arg) - len(arg.diagonal_indices)
297
+ i2 = len(arg.diagonal_indices)
298
+ inverse_permutation.extend([cumulative_ranks[i] + j for j in range(i1)])
299
+ last_perm.extend([cumulative_ranks[i] + j for j in range(i1, i1 + i2)])
300
+ else:
301
+ inverse_permutation.extend([cumulative_ranks[i] + j for j in range(get_rank(arg))])
302
+ inverse_permutation.extend(last_perm)
303
+ tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayDiagonal) else arg for arg in args])
304
+ ranks2 = [_get_subrank(arg) if isinstance(arg, ArrayDiagonal) else get_rank(arg) for arg in args]
305
+ cumulative_ranks2 = list(accumulate([0] + ranks2))[:-1]
306
+ diagonal_indices = [tuple(cumulative_ranks2[i] + k for k in j) for i, arg in diagonals.items() for j in arg.diagonal_indices]
307
+ return _permute_dims(_array_diagonal(tp, *diagonal_indices), _af_invert(inverse_permutation))
308
+
309
+ return self.func(*args, canonicalize=False)
310
+
311
+ @classmethod
312
+ def _flatten(cls, args):
313
+ args = [i for arg in args for i in (arg.args if isinstance(arg, cls) else [arg])]
314
+ return args
315
+
316
+ def as_explicit(self):
317
+ return tensorproduct(*[arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args])
318
+
319
+
320
+ class ArrayAdd(_CodegenArrayAbstract):
321
+ r"""
322
+ Class for elementwise array additions.
323
+ """
324
+
325
+ def __new__(cls, *args, **kwargs):
326
+ args = [_sympify(arg) for arg in args]
327
+ ranks = [get_rank(arg) for arg in args]
328
+ ranks = list(set(ranks))
329
+ if len(ranks) != 1:
330
+ raise ValueError("summing arrays of different ranks")
331
+ shapes = [arg.shape for arg in args]
332
+ if len({i for i in shapes if i is not None}) > 1:
333
+ raise ValueError("mismatching shapes in addition")
334
+
335
+ canonicalize = kwargs.pop("canonicalize", False)
336
+
337
+ obj = Basic.__new__(cls, *args)
338
+ obj._subranks = ranks
339
+ if any(i is None for i in shapes):
340
+ obj._shape = None
341
+ else:
342
+ obj._shape = shapes[0]
343
+ if canonicalize:
344
+ return obj._canonicalize()
345
+ return obj
346
+
347
+ def _canonicalize(self):
348
+ args = self.args
349
+
350
+ # Flatten:
351
+ args = self._flatten_args(args)
352
+
353
+ shapes = [get_shape(arg) for arg in args]
354
+ args = [arg for arg in args if not isinstance(arg, (ZeroArray, ZeroMatrix))]
355
+ if len(args) == 0:
356
+ if any(i for i in shapes if i is None):
357
+ raise NotImplementedError("cannot handle addition of ZeroMatrix/ZeroArray and undefined shape object")
358
+ return ZeroArray(*shapes[0])
359
+ elif len(args) == 1:
360
+ return args[0]
361
+ return self.func(*args, canonicalize=False)
362
+
363
+ @classmethod
364
+ def _flatten_args(cls, args):
365
+ new_args = []
366
+ for arg in args:
367
+ if isinstance(arg, ArrayAdd):
368
+ new_args.extend(arg.args)
369
+ else:
370
+ new_args.append(arg)
371
+ return new_args
372
+
373
+ def as_explicit(self):
374
+ return reduce(
375
+ operator.add,
376
+ [arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args])
377
+
378
+
379
+ class PermuteDims(_CodegenArrayAbstract):
380
+ r"""
381
+ Class to represent permutation of axes of arrays.
382
+
383
+ Examples
384
+ ========
385
+
386
+ >>> from sympy.tensor.array import permutedims
387
+ >>> from sympy import MatrixSymbol
388
+ >>> M = MatrixSymbol("M", 3, 3)
389
+ >>> cg = permutedims(M, [1, 0])
390
+
391
+ The object ``cg`` represents the transposition of ``M``, as the permutation
392
+ ``[1, 0]`` will act on its indices by switching them:
393
+
394
+ `M_{ij} \Rightarrow M_{ji}`
395
+
396
+ This is evident when transforming back to matrix form:
397
+
398
+ >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
399
+ >>> convert_array_to_matrix(cg)
400
+ M.T
401
+
402
+ >>> N = MatrixSymbol("N", 3, 2)
403
+ >>> cg = permutedims(N, [1, 0])
404
+ >>> cg.shape
405
+ (2, 3)
406
+
407
+ There are optional parameters that can be used as alternative to the permutation:
408
+
409
+ >>> from sympy.tensor.array.expressions import ArraySymbol, PermuteDims
410
+ >>> M = ArraySymbol("M", (1, 2, 3, 4, 5))
411
+ >>> expr = PermuteDims(M, index_order_old="ijklm", index_order_new="kijml")
412
+ >>> expr
413
+ PermuteDims(M, (0 2 1)(3 4))
414
+ >>> expr.shape
415
+ (3, 1, 2, 5, 4)
416
+
417
+ Permutations of tensor products are simplified in order to achieve a
418
+ standard form:
419
+
420
+ >>> from sympy.tensor.array import tensorproduct
421
+ >>> M = MatrixSymbol("M", 4, 5)
422
+ >>> tp = tensorproduct(M, N)
423
+ >>> tp.shape
424
+ (4, 5, 3, 2)
425
+ >>> perm1 = permutedims(tp, [2, 3, 1, 0])
426
+
427
+ The args ``(M, N)`` have been sorted and the permutation has been
428
+ simplified, the expression is equivalent:
429
+
430
+ >>> perm1.expr.args
431
+ (N, M)
432
+ >>> perm1.shape
433
+ (3, 2, 5, 4)
434
+ >>> perm1.permutation
435
+ (2 3)
436
+
437
+ The permutation in its array form has been simplified from
438
+ ``[2, 3, 1, 0]`` to ``[0, 1, 3, 2]``, as the arguments of the tensor
439
+ product `M` and `N` have been switched:
440
+
441
+ >>> perm1.permutation.array_form
442
+ [0, 1, 3, 2]
443
+
444
+ We can nest a second permutation:
445
+
446
+ >>> perm2 = permutedims(perm1, [1, 0, 2, 3])
447
+ >>> perm2.shape
448
+ (2, 3, 5, 4)
449
+ >>> perm2.permutation.array_form
450
+ [1, 0, 3, 2]
451
+ """
452
+
453
+ def __new__(cls, expr, permutation=None, index_order_old=None, index_order_new=None, **kwargs):
454
+ from sympy.combinatorics import Permutation
455
+ expr = _sympify(expr)
456
+ expr_rank = get_rank(expr)
457
+ permutation = cls._get_permutation_from_arguments(permutation, index_order_old, index_order_new, expr_rank)
458
+ permutation = Permutation(permutation)
459
+ permutation_size = permutation.size
460
+ if permutation_size != expr_rank:
461
+ raise ValueError("Permutation size must be the length of the shape of expr")
462
+
463
+ canonicalize = kwargs.pop("canonicalize", False)
464
+
465
+ obj = Basic.__new__(cls, expr, permutation)
466
+ obj._subranks = [get_rank(expr)]
467
+ shape = get_shape(expr)
468
+ if shape is None:
469
+ obj._shape = None
470
+ else:
471
+ obj._shape = tuple(shape[permutation(i)] for i in range(len(shape)))
472
+ if canonicalize:
473
+ return obj._canonicalize()
474
+ return obj
475
+
476
+ def _canonicalize(self):
477
+ expr = self.expr
478
+ permutation = self.permutation
479
+ if isinstance(expr, PermuteDims):
480
+ subexpr = expr.expr
481
+ subperm = expr.permutation
482
+ permutation = permutation * subperm
483
+ expr = subexpr
484
+ if isinstance(expr, ArrayContraction):
485
+ expr, permutation = self._PermuteDims_denestarg_ArrayContraction(expr, permutation)
486
+ if isinstance(expr, ArrayTensorProduct):
487
+ expr, permutation = self._PermuteDims_denestarg_ArrayTensorProduct(expr, permutation)
488
+ if isinstance(expr, (ZeroArray, ZeroMatrix)):
489
+ return ZeroArray(*[expr.shape[i] for i in permutation.array_form])
490
+ plist = permutation.array_form
491
+ if plist == sorted(plist):
492
+ return expr
493
+ return self.func(expr, permutation, canonicalize=False)
494
+
495
+ @property
496
+ def expr(self):
497
+ return self.args[0]
498
+
499
+ @property
500
+ def permutation(self):
501
+ return self.args[1]
502
+
503
+ @classmethod
504
+ def _PermuteDims_denestarg_ArrayTensorProduct(cls, expr, permutation):
505
+ # Get the permutation in its image-form:
506
+ perm_image_form = _af_invert(permutation.array_form)
507
+ args = list(expr.args)
508
+ # Starting index global position for every arg:
509
+ cumul = list(accumulate([0] + expr.subranks))
510
+ # Split `perm_image_form` into a list of list corresponding to the indices
511
+ # of every argument:
512
+ perm_image_form_in_components = [perm_image_form[cumul[i]:cumul[i+1]] for i in range(len(args))]
513
+ # Create an index, target-position-key array:
514
+ ps = [(i, sorted(comp)) for i, comp in enumerate(perm_image_form_in_components)]
515
+ # Sort the array according to the target-position-key:
516
+ # In this way, we define a canonical way to sort the arguments according
517
+ # to the permutation.
518
+ ps.sort(key=lambda x: x[1])
519
+ # Read the inverse-permutation (i.e. image-form) of the args:
520
+ perm_args_image_form = [i[0] for i in ps]
521
+ # Apply the args-permutation to the `args`:
522
+ args_sorted = [args[i] for i in perm_args_image_form]
523
+ # Apply the args-permutation to the array-form of the permutation of the axes (of `expr`):
524
+ perm_image_form_sorted_args = [perm_image_form_in_components[i] for i in perm_args_image_form]
525
+ new_permutation = Permutation(_af_invert([j for i in perm_image_form_sorted_args for j in i]))
526
+ return _array_tensor_product(*args_sorted), new_permutation
527
+
528
+ @classmethod
529
+ def _PermuteDims_denestarg_ArrayContraction(cls, expr, permutation):
530
+ if not isinstance(expr, ArrayContraction):
531
+ return expr, permutation
532
+ if not isinstance(expr.expr, ArrayTensorProduct):
533
+ return expr, permutation
534
+ args = expr.expr.args
535
+ subranks = [get_rank(arg) for arg in expr.expr.args]
536
+
537
+ contraction_indices = expr.contraction_indices
538
+ contraction_indices_flat = [j for i in contraction_indices for j in i]
539
+ cumul = list(accumulate([0] + subranks))
540
+
541
+ # Spread the permutation in its array form across the args in the corresponding
542
+ # tensor-product arguments with free indices:
543
+ permutation_array_blocks_up = []
544
+ image_form = _af_invert(permutation.array_form)
545
+ counter = 0
546
+ for i, e in enumerate(subranks):
547
+ current = []
548
+ for j in range(cumul[i], cumul[i+1]):
549
+ if j in contraction_indices_flat:
550
+ continue
551
+ current.append(image_form[counter])
552
+ counter += 1
553
+ permutation_array_blocks_up.append(current)
554
+
555
+ # Get the map of axis repositioning for every argument of tensor-product:
556
+ index_blocks = [list(range(cumul[i], cumul[i+1])) for i, e in enumerate(expr.subranks)]
557
+ index_blocks_up = expr._push_indices_up(expr.contraction_indices, index_blocks)
558
+ inverse_permutation = permutation**(-1)
559
+ index_blocks_up_permuted = [[inverse_permutation(j) for j in i if j is not None] for i in index_blocks_up]
560
+
561
+ # Sorting key is a list of tuple, first element is the index of `args`, second element of
562
+ # the tuple is the sorting key to sort `args` of the tensor product:
563
+ sorting_keys = list(enumerate(index_blocks_up_permuted))
564
+ sorting_keys.sort(key=lambda x: x[1])
565
+
566
+ # Now we can get the permutation acting on the args in its image-form:
567
+ new_perm_image_form = [i[0] for i in sorting_keys]
568
+ # Apply the args-level permutation to various elements:
569
+ new_index_blocks = [index_blocks[i] for i in new_perm_image_form]
570
+ new_index_perm_array_form = _af_invert([j for i in new_index_blocks for j in i])
571
+ new_args = [args[i] for i in new_perm_image_form]
572
+ new_contraction_indices = [tuple(new_index_perm_array_form[j] for j in i) for i in contraction_indices]
573
+ new_expr = _array_contraction(_array_tensor_product(*new_args), *new_contraction_indices)
574
+ new_permutation = Permutation(_af_invert([j for i in [permutation_array_blocks_up[k] for k in new_perm_image_form] for j in i]))
575
+ return new_expr, new_permutation
576
+
577
+ @classmethod
578
+ def _check_permutation_mapping(cls, expr, permutation):
579
+ subranks = expr.subranks
580
+ index2arg = [i for i, arg in enumerate(expr.args) for j in range(expr.subranks[i])]
581
+ permuted_indices = [permutation(i) for i in range(expr.subrank())]
582
+ new_args = list(expr.args)
583
+ arg_candidate_index = index2arg[permuted_indices[0]]
584
+ current_indices = []
585
+ new_permutation = []
586
+ inserted_arg_cand_indices = set()
587
+ for i, idx in enumerate(permuted_indices):
588
+ if index2arg[idx] != arg_candidate_index:
589
+ new_permutation.extend(current_indices)
590
+ current_indices = []
591
+ arg_candidate_index = index2arg[idx]
592
+ current_indices.append(idx)
593
+ arg_candidate_rank = subranks[arg_candidate_index]
594
+ if len(current_indices) == arg_candidate_rank:
595
+ new_permutation.extend(sorted(current_indices))
596
+ local_current_indices = [j - min(current_indices) for j in current_indices]
597
+ i1 = index2arg[i]
598
+ new_args[i1] = _permute_dims(new_args[i1], Permutation(local_current_indices))
599
+ inserted_arg_cand_indices.add(arg_candidate_index)
600
+ current_indices = []
601
+ new_permutation.extend(current_indices)
602
+
603
+ # TODO: swap args positions in order to simplify the expression:
604
+ # TODO: this should be in a function
605
+ args_positions = list(range(len(new_args)))
606
+ # Get possible shifts:
607
+ maps = {}
608
+ cumulative_subranks = [0] + list(accumulate(subranks))
609
+ for i in range(len(subranks)):
610
+ s = {index2arg[new_permutation[j]] for j in range(cumulative_subranks[i], cumulative_subranks[i+1])}
611
+ if len(s) != 1:
612
+ continue
613
+ elem = next(iter(s))
614
+ if i != elem:
615
+ maps[i] = elem
616
+
617
+ # Find cycles in the map:
618
+ lines = []
619
+ current_line = []
620
+ while maps:
621
+ if len(current_line) == 0:
622
+ k, v = maps.popitem()
623
+ current_line.append(k)
624
+ else:
625
+ k = current_line[-1]
626
+ if k not in maps:
627
+ current_line = []
628
+ continue
629
+ v = maps.pop(k)
630
+ if v in current_line:
631
+ lines.append(current_line)
632
+ current_line = []
633
+ continue
634
+ current_line.append(v)
635
+ for line in lines:
636
+ for i, e in enumerate(line):
637
+ args_positions[line[(i + 1) % len(line)]] = e
638
+
639
+ # TODO: function in order to permute the args:
640
+ permutation_blocks = [[new_permutation[cumulative_subranks[i] + j] for j in range(e)] for i, e in enumerate(subranks)]
641
+ new_args = [new_args[i] for i in args_positions]
642
+ new_permutation_blocks = [permutation_blocks[i] for i in args_positions]
643
+ new_permutation2 = [j for i in new_permutation_blocks for j in i]
644
+ return _array_tensor_product(*new_args), Permutation(new_permutation2) # **(-1)
645
+
646
+ @classmethod
647
+ def _check_if_there_are_closed_cycles(cls, expr, permutation):
648
+ args = list(expr.args)
649
+ subranks = expr.subranks
650
+ cyclic_form = permutation.cyclic_form
651
+ cumulative_subranks = [0] + list(accumulate(subranks))
652
+ cyclic_min = [min(i) for i in cyclic_form]
653
+ cyclic_max = [max(i) for i in cyclic_form]
654
+ cyclic_keep = []
655
+ for i, cycle in enumerate(cyclic_form):
656
+ flag = True
657
+ for j in range(len(cumulative_subranks) - 1):
658
+ if cyclic_min[i] >= cumulative_subranks[j] and cyclic_max[i] < cumulative_subranks[j+1]:
659
+ # Found a sinkable cycle.
660
+ args[j] = _permute_dims(args[j], Permutation([[k - cumulative_subranks[j] for k in cyclic_form[i]]]))
661
+ flag = False
662
+ break
663
+ if flag:
664
+ cyclic_keep.append(cyclic_form[i])
665
+ return _array_tensor_product(*args), Permutation(cyclic_keep, size=permutation.size)
666
+
667
+ def nest_permutation(self):
668
+ r"""
669
+ DEPRECATED.
670
+ """
671
+ ret = self._nest_permutation(self.expr, self.permutation)
672
+ if ret is None:
673
+ return self
674
+ return ret
675
+
676
+ @classmethod
677
+ def _nest_permutation(cls, expr, permutation):
678
+ if isinstance(expr, ArrayTensorProduct):
679
+ return _permute_dims(*cls._check_if_there_are_closed_cycles(expr, permutation))
680
+ elif isinstance(expr, ArrayContraction):
681
+ # Invert tree hierarchy: put the contraction above.
682
+ cycles = permutation.cyclic_form
683
+ newcycles = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *cycles)
684
+ newpermutation = Permutation(newcycles)
685
+ new_contr_indices = [tuple(newpermutation(j) for j in i) for i in expr.contraction_indices]
686
+ return _array_contraction(PermuteDims(expr.expr, newpermutation), *new_contr_indices)
687
+ elif isinstance(expr, ArrayAdd):
688
+ return _array_add(*[PermuteDims(arg, permutation) for arg in expr.args])
689
+ return None
690
+
691
+ def as_explicit(self):
692
+ expr = self.expr
693
+ if hasattr(expr, "as_explicit"):
694
+ expr = expr.as_explicit()
695
+ return permutedims(expr, self.permutation)
696
+
697
+ @classmethod
698
+ def _get_permutation_from_arguments(cls, permutation, index_order_old, index_order_new, dim):
699
+ if permutation is None:
700
+ if index_order_new is None or index_order_old is None:
701
+ raise ValueError("Permutation not defined")
702
+ return PermuteDims._get_permutation_from_index_orders(index_order_old, index_order_new, dim)
703
+ else:
704
+ if index_order_new is not None:
705
+ raise ValueError("index_order_new cannot be defined with permutation")
706
+ if index_order_old is not None:
707
+ raise ValueError("index_order_old cannot be defined with permutation")
708
+ return permutation
709
+
710
+ @classmethod
711
+ def _get_permutation_from_index_orders(cls, index_order_old, index_order_new, dim):
712
+ if len(set(index_order_new)) != dim:
713
+ raise ValueError("wrong number of indices in index_order_new")
714
+ if len(set(index_order_old)) != dim:
715
+ raise ValueError("wrong number of indices in index_order_old")
716
+ if len(set.symmetric_difference(set(index_order_new), set(index_order_old))) > 0:
717
+ raise ValueError("index_order_new and index_order_old must have the same indices")
718
+ permutation = [index_order_old.index(i) for i in index_order_new]
719
+ return permutation
720
+
721
+
722
+ class ArrayDiagonal(_CodegenArrayAbstract):
723
+ r"""
724
+ Class to represent the diagonal operator.
725
+
726
+ Explanation
727
+ ===========
728
+
729
+ In a 2-dimensional array it returns the diagonal, this looks like the
730
+ operation:
731
+
732
+ `A_{ij} \rightarrow A_{ii}`
733
+
734
+ The diagonal over axes 1 and 2 (the second and third) of the tensor product
735
+ of two 2-dimensional arrays `A \otimes B` is
736
+
737
+ `\Big[ A_{ab} B_{cd} \Big]_{abcd} \rightarrow \Big[ A_{ai} B_{id} \Big]_{adi}`
738
+
739
+ In this last example the array expression has been reduced from
740
+ 4-dimensional to 3-dimensional. Notice that no contraction has occurred,
741
+ rather there is a new index `i` for the diagonal, contraction would have
742
+ reduced the array to 2 dimensions.
743
+
744
+ Notice that the diagonalized out dimensions are added as new dimensions at
745
+ the end of the indices.
746
+ """
747
+
748
+ def __new__(cls, expr, *diagonal_indices, **kwargs):
749
+ expr = _sympify(expr)
750
+ diagonal_indices = [Tuple(*sorted(i)) for i in diagonal_indices]
751
+ canonicalize = kwargs.get("canonicalize", False)
752
+
753
+ shape = get_shape(expr)
754
+ if shape is not None:
755
+ cls._validate(expr, *diagonal_indices, **kwargs)
756
+ # Get new shape:
757
+ positions, shape = cls._get_positions_shape(shape, diagonal_indices)
758
+ else:
759
+ positions = None
760
+ if len(diagonal_indices) == 0:
761
+ return expr
762
+ obj = Basic.__new__(cls, expr, *diagonal_indices)
763
+ obj._positions = positions
764
+ obj._subranks = _get_subranks(expr)
765
+ obj._shape = shape
766
+ if canonicalize:
767
+ return obj._canonicalize()
768
+ return obj
769
+
770
+ def _canonicalize(self):
771
+ expr = self.expr
772
+ diagonal_indices = self.diagonal_indices
773
+ trivial_diags = [i for i in diagonal_indices if len(i) == 1]
774
+ if len(trivial_diags) > 0:
775
+ trivial_pos = {e[0]: i for i, e in enumerate(diagonal_indices) if len(e) == 1}
776
+ diag_pos = {e: i for i, e in enumerate(diagonal_indices) if len(e) > 1}
777
+ diagonal_indices_short = [i for i in diagonal_indices if len(i) > 1]
778
+ rank1 = get_rank(self)
779
+ rank2 = len(diagonal_indices)
780
+ rank3 = rank1 - rank2
781
+ inv_permutation = []
782
+ counter1 = 0
783
+ indices_down = ArrayDiagonal._push_indices_down(diagonal_indices_short, list(range(rank1)), get_rank(expr))
784
+ for i in indices_down:
785
+ if i in trivial_pos:
786
+ inv_permutation.append(rank3 + trivial_pos[i])
787
+ elif isinstance(i, (Integer, int)):
788
+ inv_permutation.append(counter1)
789
+ counter1 += 1
790
+ else:
791
+ inv_permutation.append(rank3 + diag_pos[i])
792
+ permutation = _af_invert(inv_permutation)
793
+ if len(diagonal_indices_short) > 0:
794
+ return _permute_dims(_array_diagonal(expr, *diagonal_indices_short), permutation)
795
+ else:
796
+ return _permute_dims(expr, permutation)
797
+ if isinstance(expr, ArrayAdd):
798
+ return self._ArrayDiagonal_denest_ArrayAdd(expr, *diagonal_indices)
799
+ if isinstance(expr, ArrayDiagonal):
800
+ return self._ArrayDiagonal_denest_ArrayDiagonal(expr, *diagonal_indices)
801
+ if isinstance(expr, PermuteDims):
802
+ return self._ArrayDiagonal_denest_PermuteDims(expr, *diagonal_indices)
803
+ if isinstance(expr, (ZeroArray, ZeroMatrix)):
804
+ positions, shape = self._get_positions_shape(expr.shape, diagonal_indices)
805
+ return ZeroArray(*shape)
806
+ return self.func(expr, *diagonal_indices, canonicalize=False)
807
+
808
+ @staticmethod
809
+ def _validate(expr, *diagonal_indices, **kwargs):
810
+ # Check that no diagonalization happens on indices with mismatched
811
+ # dimensions:
812
+ shape = get_shape(expr)
813
+ for i in diagonal_indices:
814
+ if any(j >= len(shape) for j in i):
815
+ raise ValueError("index is larger than expression shape")
816
+ if len({shape[j] for j in i}) != 1:
817
+ raise ValueError("diagonalizing indices of different dimensions")
818
+ if not kwargs.get("allow_trivial_diags", False) and len(i) <= 1:
819
+ raise ValueError("need at least two axes to diagonalize")
820
+ if len(set(i)) != len(i):
821
+ raise ValueError("axis index cannot be repeated")
822
+
823
+ @staticmethod
824
+ def _remove_trivial_dimensions(shape, *diagonal_indices):
825
+ return [tuple(j for j in i) for i in diagonal_indices if shape[i[0]] != 1]
826
+
827
+ @property
828
+ def expr(self):
829
+ return self.args[0]
830
+
831
+ @property
832
+ def diagonal_indices(self):
833
+ return self.args[1:]
834
+
835
+ @staticmethod
836
+ def _flatten(expr, *outer_diagonal_indices):
837
+ inner_diagonal_indices = expr.diagonal_indices
838
+ all_inner = [j for i in inner_diagonal_indices for j in i]
839
+ all_inner.sort()
840
+ # TODO: add API for total rank and cumulative rank:
841
+ total_rank = _get_subrank(expr)
842
+ inner_rank = len(all_inner)
843
+ outer_rank = total_rank - inner_rank
844
+ shifts = [0 for i in range(outer_rank)]
845
+ counter = 0
846
+ pointer = 0
847
+ for i in range(outer_rank):
848
+ while pointer < inner_rank and counter >= all_inner[pointer]:
849
+ counter += 1
850
+ pointer += 1
851
+ shifts[i] += pointer
852
+ counter += 1
853
+ outer_diagonal_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_diagonal_indices)
854
+ diagonal_indices = inner_diagonal_indices + outer_diagonal_indices
855
+ return _array_diagonal(expr.expr, *diagonal_indices)
856
+
857
+ @classmethod
858
+ def _ArrayDiagonal_denest_ArrayAdd(cls, expr, *diagonal_indices):
859
+ return _array_add(*[_array_diagonal(arg, *diagonal_indices) for arg in expr.args])
860
+
861
+ @classmethod
862
+ def _ArrayDiagonal_denest_ArrayDiagonal(cls, expr, *diagonal_indices):
863
+ return cls._flatten(expr, *diagonal_indices)
864
+
865
+ @classmethod
866
+ def _ArrayDiagonal_denest_PermuteDims(cls, expr: PermuteDims, *diagonal_indices):
867
+ back_diagonal_indices = [[expr.permutation(j) for j in i] for i in diagonal_indices]
868
+ nondiag = [i for i in range(get_rank(expr)) if not any(i in j for j in diagonal_indices)]
869
+ back_nondiag = [expr.permutation(i) for i in nondiag]
870
+ remap = {e: i for i, e in enumerate(sorted(back_nondiag))}
871
+ new_permutation1 = [remap[i] for i in back_nondiag]
872
+ shift = len(new_permutation1)
873
+ diag_block_perm = [i + shift for i in range(len(back_diagonal_indices))]
874
+ new_permutation = new_permutation1 + diag_block_perm
875
+ return _permute_dims(
876
+ _array_diagonal(
877
+ expr.expr,
878
+ *back_diagonal_indices
879
+ ),
880
+ new_permutation
881
+ )
882
+
883
+ def _push_indices_down_nonstatic(self, indices):
884
+ transform = lambda x: self._positions[x] if x < len(self._positions) else None
885
+ return _apply_recursively_over_nested_lists(transform, indices)
886
+
887
+ def _push_indices_up_nonstatic(self, indices):
888
+
889
+ def transform(x):
890
+ for i, e in enumerate(self._positions):
891
+ if (isinstance(e, int) and x == e) or (isinstance(e, tuple) and x in e):
892
+ return i
893
+
894
+ return _apply_recursively_over_nested_lists(transform, indices)
895
+
896
+ @classmethod
897
+ def _push_indices_down(cls, diagonal_indices, indices, rank):
898
+ positions, shape = cls._get_positions_shape(range(rank), diagonal_indices)
899
+ transform = lambda x: positions[x] if x < len(positions) else None
900
+ return _apply_recursively_over_nested_lists(transform, indices)
901
+
902
+ @classmethod
903
+ def _push_indices_up(cls, diagonal_indices, indices, rank):
904
+ positions, shape = cls._get_positions_shape(range(rank), diagonal_indices)
905
+
906
+ def transform(x):
907
+ for i, e in enumerate(positions):
908
+ if (isinstance(e, int) and x == e) or (isinstance(e, (tuple, Tuple)) and (x in e)):
909
+ return i
910
+
911
+ return _apply_recursively_over_nested_lists(transform, indices)
912
+
913
+ @classmethod
914
+ def _get_positions_shape(cls, shape, diagonal_indices):
915
+ data1 = tuple((i, shp) for i, shp in enumerate(shape) if not any(i in j for j in diagonal_indices))
916
+ pos1, shp1 = zip(*data1) if data1 else ((), ())
917
+ data2 = tuple((i, shape[i[0]]) for i in diagonal_indices)
918
+ pos2, shp2 = zip(*data2) if data2 else ((), ())
919
+ positions = pos1 + pos2
920
+ shape = shp1 + shp2
921
+ return positions, shape
922
+
923
+ def as_explicit(self):
924
+ expr = self.expr
925
+ if hasattr(expr, "as_explicit"):
926
+ expr = expr.as_explicit()
927
+ return tensordiagonal(expr, *self.diagonal_indices)
928
+
929
+
930
+ class ArrayElementwiseApplyFunc(_CodegenArrayAbstract):
931
+
932
+ def __new__(cls, function, element):
933
+
934
+ if not isinstance(function, Lambda):
935
+ d = Dummy('d')
936
+ function = Lambda(d, function(d))
937
+
938
+ obj = _CodegenArrayAbstract.__new__(cls, function, element)
939
+ obj._subranks = _get_subranks(element)
940
+ return obj
941
+
942
+ @property
943
+ def function(self):
944
+ return self.args[0]
945
+
946
+ @property
947
+ def expr(self):
948
+ return self.args[1]
949
+
950
+ @property
951
+ def shape(self):
952
+ return self.expr.shape
953
+
954
+ def _get_function_fdiff(self):
955
+ d = Dummy("d")
956
+ function = self.function(d)
957
+ fdiff = function.diff(d)
958
+ if isinstance(fdiff, Function):
959
+ fdiff = type(fdiff)
960
+ else:
961
+ fdiff = Lambda(d, fdiff)
962
+ return fdiff
963
+
964
+ def as_explicit(self):
965
+ expr = self.expr
966
+ if hasattr(expr, "as_explicit"):
967
+ expr = expr.as_explicit()
968
+ return expr.applyfunc(self.function)
969
+
970
+
971
+ class ArrayContraction(_CodegenArrayAbstract):
972
+ r"""
973
+ This class is meant to represent contractions of arrays in a form easily
974
+ processable by the code printers.
975
+ """
976
+
977
+ def __new__(cls, expr, *contraction_indices, **kwargs):
978
+ contraction_indices = _sort_contraction_indices(contraction_indices)
979
+ expr = _sympify(expr)
980
+
981
+ canonicalize = kwargs.get("canonicalize", False)
982
+
983
+ obj = Basic.__new__(cls, expr, *contraction_indices)
984
+ obj._subranks = _get_subranks(expr)
985
+ obj._mapping = _get_mapping_from_subranks(obj._subranks)
986
+
987
+ free_indices_to_position = {i: i for i in range(sum(obj._subranks)) if all(i not in cind for cind in contraction_indices)}
988
+ obj._free_indices_to_position = free_indices_to_position
989
+
990
+ shape = get_shape(expr)
991
+ cls._validate(expr, *contraction_indices)
992
+ if shape:
993
+ shape = tuple(shp for i, shp in enumerate(shape) if not any(i in j for j in contraction_indices))
994
+ obj._shape = shape
995
+ if canonicalize:
996
+ return obj._canonicalize()
997
+ return obj
998
+
999
+ def _canonicalize(self):
1000
+ expr = self.expr
1001
+ contraction_indices = self.contraction_indices
1002
+
1003
+ if len(contraction_indices) == 0:
1004
+ return expr
1005
+
1006
+ if isinstance(expr, ArrayContraction):
1007
+ return self._ArrayContraction_denest_ArrayContraction(expr, *contraction_indices)
1008
+
1009
+ if isinstance(expr, (ZeroArray, ZeroMatrix)):
1010
+ return self._ArrayContraction_denest_ZeroArray(expr, *contraction_indices)
1011
+
1012
+ if isinstance(expr, PermuteDims):
1013
+ return self._ArrayContraction_denest_PermuteDims(expr, *contraction_indices)
1014
+
1015
+ if isinstance(expr, ArrayTensorProduct):
1016
+ expr, contraction_indices = self._sort_fully_contracted_args(expr, contraction_indices)
1017
+ expr, contraction_indices = self._lower_contraction_to_addends(expr, contraction_indices)
1018
+ if len(contraction_indices) == 0:
1019
+ return expr
1020
+
1021
+ if isinstance(expr, ArrayDiagonal):
1022
+ return self._ArrayContraction_denest_ArrayDiagonal(expr, *contraction_indices)
1023
+
1024
+ if isinstance(expr, ArrayAdd):
1025
+ return self._ArrayContraction_denest_ArrayAdd(expr, *contraction_indices)
1026
+
1027
+ # Check single index contractions on 1-dimensional axes:
1028
+ contraction_indices = [i for i in contraction_indices if len(i) > 1 or get_shape(expr)[i[0]] != 1]
1029
+ if len(contraction_indices) == 0:
1030
+ return expr
1031
+
1032
+ return self.func(expr, *contraction_indices, canonicalize=False)
1033
+
1034
+ def __mul__(self, other):
1035
+ if other == 1:
1036
+ return self
1037
+ else:
1038
+ raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.")
1039
+
1040
+ def __rmul__(self, other):
1041
+ if other == 1:
1042
+ return self
1043
+ else:
1044
+ raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.")
1045
+
1046
+ @staticmethod
1047
+ def _validate(expr, *contraction_indices):
1048
+ shape = get_shape(expr)
1049
+ if shape is None:
1050
+ return
1051
+
1052
+ # Check that no contraction happens when the shape is mismatched:
1053
+ for i in contraction_indices:
1054
+ if len({shape[j] for j in i if shape[j] != -1}) != 1:
1055
+ raise ValueError("contracting indices of different dimensions")
1056
+
1057
+ @classmethod
1058
+ def _push_indices_down(cls, contraction_indices, indices):
1059
+ flattened_contraction_indices = [j for i in contraction_indices for j in i]
1060
+ flattened_contraction_indices.sort()
1061
+ transform = _build_push_indices_down_func_transformation(flattened_contraction_indices)
1062
+ return _apply_recursively_over_nested_lists(transform, indices)
1063
+
1064
+ @classmethod
1065
+ def _push_indices_up(cls, contraction_indices, indices):
1066
+ flattened_contraction_indices = [j for i in contraction_indices for j in i]
1067
+ flattened_contraction_indices.sort()
1068
+ transform = _build_push_indices_up_func_transformation(flattened_contraction_indices)
1069
+ return _apply_recursively_over_nested_lists(transform, indices)
1070
+
1071
+ @classmethod
1072
+ def _lower_contraction_to_addends(cls, expr, contraction_indices):
1073
+ if isinstance(expr, ArrayAdd):
1074
+ raise NotImplementedError()
1075
+ if not isinstance(expr, ArrayTensorProduct):
1076
+ return expr, contraction_indices
1077
+ subranks = expr.subranks
1078
+ cumranks = list(accumulate([0] + subranks))
1079
+ contraction_indices_remaining = []
1080
+ contraction_indices_args = [[] for i in expr.args]
1081
+ backshift = set()
1082
+ for contraction_group in contraction_indices:
1083
+ for j in range(len(expr.args)):
1084
+ if not isinstance(expr.args[j], ArrayAdd):
1085
+ continue
1086
+ if all(cumranks[j] <= k < cumranks[j+1] for k in contraction_group):
1087
+ contraction_indices_args[j].append([k - cumranks[j] for k in contraction_group])
1088
+ backshift.update(contraction_group)
1089
+ break
1090
+ else:
1091
+ contraction_indices_remaining.append(contraction_group)
1092
+ if len(contraction_indices_remaining) == len(contraction_indices):
1093
+ return expr, contraction_indices
1094
+ total_rank = get_rank(expr)
1095
+ shifts = list(accumulate([1 if i in backshift else 0 for i in range(total_rank)]))
1096
+ contraction_indices_remaining = [Tuple.fromiter(j - shifts[j] for j in i) for i in contraction_indices_remaining]
1097
+ ret = _array_tensor_product(*[
1098
+ _array_contraction(arg, *contr) for arg, contr in zip(expr.args, contraction_indices_args)
1099
+ ])
1100
+ return ret, contraction_indices_remaining
1101
+
1102
+ def split_multiple_contractions(self):
1103
+ """
1104
+ Recognize multiple contractions and attempt at rewriting them as paired-contractions.
1105
+
1106
+ This allows some contractions involving more than two indices to be
1107
+ rewritten as multiple contractions involving two indices, thus allowing
1108
+ the expression to be rewritten as a matrix multiplication line.
1109
+
1110
+ Examples:
1111
+
1112
+ * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C`
1113
+
1114
+ Care for:
1115
+ - matrix being diagonalized (i.e. `A_ii`)
1116
+ - vectors being diagonalized (i.e. `a_i0`)
1117
+
1118
+ Multiple contractions can be split into matrix multiplications if
1119
+ not more than two arguments are non-diagonals or non-vectors.
1120
+ Vectors get diagonalized while diagonal matrices remain diagonal.
1121
+ The non-diagonal matrices can be at the beginning or at the end
1122
+ of the final matrix multiplication line.
1123
+ """
1124
+
1125
+ editor = _EditArrayContraction(self)
1126
+
1127
+ contraction_indices = self.contraction_indices
1128
+
1129
+ onearray_insert = []
1130
+
1131
+ for indl, links in enumerate(contraction_indices):
1132
+ if len(links) <= 2:
1133
+ continue
1134
+
1135
+ # Check multiple contractions:
1136
+ #
1137
+ # Examples:
1138
+ #
1139
+ # * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C \otimes OneArray(1)` with permutation (1 2)
1140
+ #
1141
+ # Care for:
1142
+ # - matrix being diagonalized (i.e. `A_ii`)
1143
+ # - vectors being diagonalized (i.e. `a_i0`)
1144
+
1145
+ # Multiple contractions can be split into matrix multiplications if
1146
+ # not more than three arguments are non-diagonals or non-vectors.
1147
+ #
1148
+ # Vectors get diagonalized while diagonal matrices remain diagonal.
1149
+ # The non-diagonal matrices can be at the beginning or at the end
1150
+ # of the final matrix multiplication line.
1151
+
1152
+ positions = editor.get_mapping_for_index(indl)
1153
+
1154
+ # Also consider the case of diagonal matrices being contracted:
1155
+ current_dimension = self.expr.shape[links[0]]
1156
+
1157
+ not_vectors = []
1158
+ vectors = []
1159
+ for arg_ind, rel_ind in positions:
1160
+ arg = editor.args_with_ind[arg_ind]
1161
+ mat = arg.element
1162
+ abs_arg_start, abs_arg_end = editor.get_absolute_range(arg)
1163
+ other_arg_pos = 1-rel_ind
1164
+ other_arg_abs = abs_arg_start + other_arg_pos
1165
+ if ((1 not in mat.shape) or
1166
+ ((current_dimension == 1) is True and mat.shape != (1, 1)) or
1167
+ any(other_arg_abs in l for li, l in enumerate(contraction_indices) if li != indl)
1168
+ ):
1169
+ not_vectors.append((arg, rel_ind))
1170
+ else:
1171
+ vectors.append((arg, rel_ind))
1172
+ if len(not_vectors) > 2:
1173
+ # If more than two arguments in the multiple contraction are
1174
+ # non-vectors and non-diagonal matrices, we cannot find a way
1175
+ # to split this contraction into a matrix multiplication line:
1176
+ continue
1177
+ # Three cases to handle:
1178
+ # - zero non-vectors
1179
+ # - one non-vector
1180
+ # - two non-vectors
1181
+ for v, rel_ind in vectors:
1182
+ v.element = diagonalize_vector(v.element)
1183
+ vectors_to_loop = not_vectors[:1] + vectors + not_vectors[1:]
1184
+ first_not_vector, rel_ind = vectors_to_loop[0]
1185
+ new_index = first_not_vector.indices[rel_ind]
1186
+
1187
+ for v, rel_ind in vectors_to_loop[1:-1]:
1188
+ v.indices[rel_ind] = new_index
1189
+ new_index = editor.get_new_contraction_index()
1190
+ assert v.indices.index(None) == 1 - rel_ind
1191
+ v.indices[v.indices.index(None)] = new_index
1192
+ onearray_insert.append(v)
1193
+
1194
+ last_vec, rel_ind = vectors_to_loop[-1]
1195
+ last_vec.indices[rel_ind] = new_index
1196
+
1197
+ for v in onearray_insert:
1198
+ editor.insert_after(v, _ArgE(OneArray(1), [None]))
1199
+
1200
+ return editor.to_array_contraction()
1201
+
1202
+ def flatten_contraction_of_diagonal(self):
1203
+ if not isinstance(self.expr, ArrayDiagonal):
1204
+ return self
1205
+ contraction_down = self.expr._push_indices_down(self.expr.diagonal_indices, self.contraction_indices)
1206
+ new_contraction_indices = []
1207
+ diagonal_indices = self.expr.diagonal_indices[:]
1208
+ for i in contraction_down:
1209
+ contraction_group = list(i)
1210
+ for j in i:
1211
+ diagonal_with = [k for k in diagonal_indices if j in k]
1212
+ contraction_group.extend([l for k in diagonal_with for l in k])
1213
+ diagonal_indices = [k for k in diagonal_indices if k not in diagonal_with]
1214
+ new_contraction_indices.append(sorted(set(contraction_group)))
1215
+
1216
+ new_contraction_indices = ArrayDiagonal._push_indices_up(diagonal_indices, new_contraction_indices)
1217
+ return _array_contraction(
1218
+ _array_diagonal(
1219
+ self.expr.expr,
1220
+ *diagonal_indices
1221
+ ),
1222
+ *new_contraction_indices
1223
+ )
1224
+
1225
+ @staticmethod
1226
+ def _get_free_indices_to_position_map(free_indices, contraction_indices):
1227
+ free_indices_to_position = {}
1228
+ flattened_contraction_indices = [j for i in contraction_indices for j in i]
1229
+ counter = 0
1230
+ for ind in free_indices:
1231
+ while counter in flattened_contraction_indices:
1232
+ counter += 1
1233
+ free_indices_to_position[ind] = counter
1234
+ counter += 1
1235
+ return free_indices_to_position
1236
+
1237
+ @staticmethod
1238
+ def _get_index_shifts(expr):
1239
+ """
1240
+ Get the mapping of indices at the positions before the contraction
1241
+ occurs.
1242
+
1243
+ Examples
1244
+ ========
1245
+
1246
+ >>> from sympy.tensor.array import tensorproduct, tensorcontraction
1247
+ >>> from sympy import MatrixSymbol
1248
+ >>> M = MatrixSymbol("M", 3, 3)
1249
+ >>> N = MatrixSymbol("N", 3, 3)
1250
+ >>> cg = tensorcontraction(tensorproduct(M, N), [1, 2])
1251
+ >>> cg._get_index_shifts(cg)
1252
+ [0, 2]
1253
+
1254
+ Indeed, ``cg`` after the contraction has two dimensions, 0 and 1. They
1255
+ need to be shifted by 0 and 2 to get the corresponding positions before
1256
+ the contraction (that is, 0 and 3).
1257
+ """
1258
+ inner_contraction_indices = expr.contraction_indices
1259
+ all_inner = [j for i in inner_contraction_indices for j in i]
1260
+ all_inner.sort()
1261
+ # TODO: add API for total rank and cumulative rank:
1262
+ total_rank = _get_subrank(expr)
1263
+ inner_rank = len(all_inner)
1264
+ outer_rank = total_rank - inner_rank
1265
+ shifts = [0 for i in range(outer_rank)]
1266
+ counter = 0
1267
+ pointer = 0
1268
+ for i in range(outer_rank):
1269
+ while pointer < inner_rank and counter >= all_inner[pointer]:
1270
+ counter += 1
1271
+ pointer += 1
1272
+ shifts[i] += pointer
1273
+ counter += 1
1274
+ return shifts
1275
+
1276
+ @staticmethod
1277
+ def _convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices):
1278
+ shifts = ArrayContraction._get_index_shifts(expr)
1279
+ outer_contraction_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_contraction_indices)
1280
+ return outer_contraction_indices
1281
+
1282
+ @staticmethod
1283
+ def _flatten(expr, *outer_contraction_indices):
1284
+ inner_contraction_indices = expr.contraction_indices
1285
+ outer_contraction_indices = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices)
1286
+ contraction_indices = inner_contraction_indices + outer_contraction_indices
1287
+ return _array_contraction(expr.expr, *contraction_indices)
1288
+
1289
+ @classmethod
1290
+ def _ArrayContraction_denest_ArrayContraction(cls, expr, *contraction_indices):
1291
+ return cls._flatten(expr, *contraction_indices)
1292
+
1293
+ @classmethod
1294
+ def _ArrayContraction_denest_ZeroArray(cls, expr, *contraction_indices):
1295
+ contraction_indices_flat = [j for i in contraction_indices for j in i]
1296
+ shape = [e for i, e in enumerate(expr.shape) if i not in contraction_indices_flat]
1297
+ return ZeroArray(*shape)
1298
+
1299
+ @classmethod
1300
+ def _ArrayContraction_denest_ArrayAdd(cls, expr, *contraction_indices):
1301
+ return _array_add(*[_array_contraction(i, *contraction_indices) for i in expr.args])
1302
+
1303
+ @classmethod
1304
+ def _ArrayContraction_denest_PermuteDims(cls, expr, *contraction_indices):
1305
+ permutation = expr.permutation
1306
+ plist = permutation.array_form
1307
+ new_contraction_indices = [tuple(permutation(j) for j in i) for i in contraction_indices]
1308
+ new_plist = [i for i in plist if not any(i in j for j in new_contraction_indices)]
1309
+ new_plist = cls._push_indices_up(new_contraction_indices, new_plist)
1310
+ return _permute_dims(
1311
+ _array_contraction(expr.expr, *new_contraction_indices),
1312
+ Permutation(new_plist)
1313
+ )
1314
+
1315
+ @classmethod
1316
+ def _ArrayContraction_denest_ArrayDiagonal(cls, expr: 'ArrayDiagonal', *contraction_indices):
1317
+ diagonal_indices = list(expr.diagonal_indices)
1318
+ down_contraction_indices = expr._push_indices_down(expr.diagonal_indices, contraction_indices, get_rank(expr.expr))
1319
+ # Flatten diagonally contracted indices:
1320
+ down_contraction_indices = [[k for j in i for k in (j if isinstance(j, (tuple, Tuple)) else [j])] for i in down_contraction_indices]
1321
+ new_contraction_indices = []
1322
+ for contr_indgrp in down_contraction_indices:
1323
+ ind = contr_indgrp[:]
1324
+ for j, diag_indgrp in enumerate(diagonal_indices):
1325
+ if diag_indgrp is None:
1326
+ continue
1327
+ if any(i in diag_indgrp for i in contr_indgrp):
1328
+ ind.extend(diag_indgrp)
1329
+ diagonal_indices[j] = None
1330
+ new_contraction_indices.append(sorted(set(ind)))
1331
+
1332
+ new_diagonal_indices_down = [i for i in diagonal_indices if i is not None]
1333
+ new_diagonal_indices = ArrayContraction._push_indices_up(new_contraction_indices, new_diagonal_indices_down)
1334
+ return _array_diagonal(
1335
+ _array_contraction(expr.expr, *new_contraction_indices),
1336
+ *new_diagonal_indices
1337
+ )
1338
+
1339
+ @classmethod
1340
+ def _sort_fully_contracted_args(cls, expr, contraction_indices):
1341
+ if expr.shape is None:
1342
+ return expr, contraction_indices
1343
+ cumul = list(accumulate([0] + expr.subranks))
1344
+ index_blocks = [list(range(cumul[i], cumul[i+1])) for i in range(len(expr.args))]
1345
+ contraction_indices_flat = {j for i in contraction_indices for j in i}
1346
+ fully_contracted = [all(j in contraction_indices_flat for j in range(cumul[i], cumul[i+1])) for i, arg in enumerate(expr.args)]
1347
+ new_pos = sorted(range(len(expr.args)), key=lambda x: (0, default_sort_key(expr.args[x])) if fully_contracted[x] else (1,))
1348
+ new_args = [expr.args[i] for i in new_pos]
1349
+ new_index_blocks_flat = [j for i in new_pos for j in index_blocks[i]]
1350
+ index_permutation_array_form = _af_invert(new_index_blocks_flat)
1351
+ new_contraction_indices = [tuple(index_permutation_array_form[j] for j in i) for i in contraction_indices]
1352
+ new_contraction_indices = _sort_contraction_indices(new_contraction_indices)
1353
+ return _array_tensor_product(*new_args), new_contraction_indices
1354
+
1355
+ def _get_contraction_tuples(self):
1356
+ r"""
1357
+ Return tuples containing the argument index and position within the
1358
+ argument of the index position.
1359
+
1360
+ Examples
1361
+ ========
1362
+
1363
+ >>> from sympy import MatrixSymbol
1364
+ >>> from sympy.abc import N
1365
+ >>> from sympy.tensor.array import tensorproduct, tensorcontraction
1366
+ >>> A = MatrixSymbol("A", N, N)
1367
+ >>> B = MatrixSymbol("B", N, N)
1368
+
1369
+ >>> cg = tensorcontraction(tensorproduct(A, B), (1, 2))
1370
+ >>> cg._get_contraction_tuples()
1371
+ [[(0, 1), (1, 0)]]
1372
+
1373
+ Notes
1374
+ =====
1375
+
1376
+ Here the contraction pair `(1, 2)` meaning that the 2nd and 3rd indices
1377
+ of the tensor product `A\otimes B` are contracted, has been transformed
1378
+ into `(0, 1)` and `(1, 0)`, identifying the same indices in a different
1379
+ notation. `(0, 1)` is the second index (1) of the first argument (i.e.
1380
+ 0 or `A`). `(1, 0)` is the first index (i.e. 0) of the second
1381
+ argument (i.e. 1 or `B`).
1382
+ """
1383
+ mapping = self._mapping
1384
+ return [[mapping[j] for j in i] for i in self.contraction_indices]
1385
+
1386
+ @staticmethod
1387
+ def _contraction_tuples_to_contraction_indices(expr, contraction_tuples):
1388
+ # TODO: check that `expr` has `.subranks`:
1389
+ ranks = expr.subranks
1390
+ cumulative_ranks = [0] + list(accumulate(ranks))
1391
+ return [tuple(cumulative_ranks[j]+k for j, k in i) for i in contraction_tuples]
1392
+
1393
+ @property
1394
+ def free_indices(self):
1395
+ return self._free_indices[:]
1396
+
1397
+ @property
1398
+ def free_indices_to_position(self):
1399
+ return dict(self._free_indices_to_position)
1400
+
1401
+ @property
1402
+ def expr(self):
1403
+ return self.args[0]
1404
+
1405
+ @property
1406
+ def contraction_indices(self):
1407
+ return self.args[1:]
1408
+
1409
+ def _contraction_indices_to_components(self):
1410
+ expr = self.expr
1411
+ if not isinstance(expr, ArrayTensorProduct):
1412
+ raise NotImplementedError("only for contractions of tensor products")
1413
+ ranks = expr.subranks
1414
+ mapping = {}
1415
+ counter = 0
1416
+ for i, rank in enumerate(ranks):
1417
+ for j in range(rank):
1418
+ mapping[counter] = (i, j)
1419
+ counter += 1
1420
+ return mapping
1421
+
1422
+ def sort_args_by_name(self):
1423
+ """
1424
+ Sort arguments in the tensor product so that their order is lexicographical.
1425
+
1426
+ Examples
1427
+ ========
1428
+
1429
+ >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
1430
+ >>> from sympy import MatrixSymbol
1431
+ >>> from sympy.abc import N
1432
+ >>> A = MatrixSymbol("A", N, N)
1433
+ >>> B = MatrixSymbol("B", N, N)
1434
+ >>> C = MatrixSymbol("C", N, N)
1435
+ >>> D = MatrixSymbol("D", N, N)
1436
+
1437
+ >>> cg = convert_matrix_to_array(C*D*A*B)
1438
+ >>> cg
1439
+ ArrayContraction(ArrayTensorProduct(A, D, C, B), (0, 3), (1, 6), (2, 5))
1440
+ >>> cg.sort_args_by_name()
1441
+ ArrayContraction(ArrayTensorProduct(A, D, B, C), (0, 3), (1, 4), (2, 7))
1442
+ """
1443
+ expr = self.expr
1444
+ if not isinstance(expr, ArrayTensorProduct):
1445
+ return self
1446
+ args = expr.args
1447
+ sorted_data = sorted(enumerate(args), key=lambda x: default_sort_key(x[1]))
1448
+ pos_sorted, args_sorted = zip(*sorted_data)
1449
+ reordering_map = {i: pos_sorted.index(i) for i, arg in enumerate(args)}
1450
+ contraction_tuples = self._get_contraction_tuples()
1451
+ contraction_tuples = [[(reordering_map[j], k) for j, k in i] for i in contraction_tuples]
1452
+ c_tp = _array_tensor_product(*args_sorted)
1453
+ new_contr_indices = self._contraction_tuples_to_contraction_indices(
1454
+ c_tp,
1455
+ contraction_tuples
1456
+ )
1457
+ return _array_contraction(c_tp, *new_contr_indices)
1458
+
1459
+ def _get_contraction_links(self):
1460
+ r"""
1461
+ Returns a dictionary of links between arguments in the tensor product
1462
+ being contracted.
1463
+
1464
+ See the example for an explanation of the values.
1465
+
1466
+ Examples
1467
+ ========
1468
+
1469
+ >>> from sympy import MatrixSymbol
1470
+ >>> from sympy.abc import N
1471
+ >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
1472
+ >>> A = MatrixSymbol("A", N, N)
1473
+ >>> B = MatrixSymbol("B", N, N)
1474
+ >>> C = MatrixSymbol("C", N, N)
1475
+ >>> D = MatrixSymbol("D", N, N)
1476
+
1477
+ Matrix multiplications are pairwise contractions between neighboring
1478
+ matrices:
1479
+
1480
+ `A_{ij} B_{jk} C_{kl} D_{lm}`
1481
+
1482
+ >>> cg = convert_matrix_to_array(A*B*C*D)
1483
+ >>> cg
1484
+ ArrayContraction(ArrayTensorProduct(B, C, A, D), (0, 5), (1, 2), (3, 6))
1485
+
1486
+ >>> cg._get_contraction_links()
1487
+ {0: {0: (2, 1), 1: (1, 0)}, 1: {0: (0, 1), 1: (3, 0)}, 2: {1: (0, 0)}, 3: {0: (1, 1)}}
1488
+
1489
+ This dictionary is interpreted as follows: argument in position 0 (i.e.
1490
+ matrix `A`) has its second index (i.e. 1) contracted to `(1, 0)`, that
1491
+ is argument in position 1 (matrix `B`) on the first index slot of `B`,
1492
+ this is the contraction provided by the index `j` from `A`.
1493
+
1494
+ The argument in position 1 (that is, matrix `B`) has two contractions,
1495
+ the ones provided by the indices `j` and `k`, respectively the first
1496
+ and second indices (0 and 1 in the sub-dict). The link `(0, 1)` and
1497
+ `(2, 0)` respectively. `(0, 1)` is the index slot 1 (the 2nd) of
1498
+ argument in position 0 (that is, `A_{\ldot j}`), and so on.
1499
+ """
1500
+ args, dlinks = _get_contraction_links([self], self.subranks, *self.contraction_indices)
1501
+ return dlinks
1502
+
1503
+ def as_explicit(self):
1504
+ expr = self.expr
1505
+ if hasattr(expr, "as_explicit"):
1506
+ expr = expr.as_explicit()
1507
+ return tensorcontraction(expr, *self.contraction_indices)
1508
+
1509
+
1510
+ class Reshape(_CodegenArrayAbstract):
1511
+ """
1512
+ Reshape the dimensions of an array expression.
1513
+
1514
+ Examples
1515
+ ========
1516
+
1517
+ >>> from sympy.tensor.array.expressions import ArraySymbol, Reshape
1518
+ >>> A = ArraySymbol("A", (6,))
1519
+ >>> A.shape
1520
+ (6,)
1521
+ >>> Reshape(A, (3, 2)).shape
1522
+ (3, 2)
1523
+
1524
+ Check the component-explicit forms:
1525
+
1526
+ >>> A.as_explicit()
1527
+ [A[0], A[1], A[2], A[3], A[4], A[5]]
1528
+ >>> Reshape(A, (3, 2)).as_explicit()
1529
+ [[A[0], A[1]], [A[2], A[3]], [A[4], A[5]]]
1530
+
1531
+ """
1532
+
1533
+ def __new__(cls, expr, shape):
1534
+ expr = _sympify(expr)
1535
+ if not isinstance(shape, Tuple):
1536
+ shape = Tuple(*shape)
1537
+ if Equality(Mul.fromiter(expr.shape), Mul.fromiter(shape)) == False:
1538
+ raise ValueError("shape mismatch")
1539
+ obj = Expr.__new__(cls, expr, shape)
1540
+ obj._shape = tuple(shape)
1541
+ obj._expr = expr
1542
+ return obj
1543
+
1544
+ @property
1545
+ def shape(self):
1546
+ return self._shape
1547
+
1548
+ @property
1549
+ def expr(self):
1550
+ return self._expr
1551
+
1552
+ def doit(self, *args, **kwargs):
1553
+ if kwargs.get("deep", True):
1554
+ expr = self.expr.doit(*args, **kwargs)
1555
+ else:
1556
+ expr = self.expr
1557
+ if isinstance(expr, (MatrixBase, NDimArray)):
1558
+ return expr.reshape(*self.shape)
1559
+ return Reshape(expr, self.shape)
1560
+
1561
+ def as_explicit(self):
1562
+ ee = self.expr
1563
+ if hasattr(ee, "as_explicit"):
1564
+ ee = ee.as_explicit()
1565
+ if isinstance(ee, MatrixBase):
1566
+ from sympy import Array
1567
+ ee = Array(ee)
1568
+ elif isinstance(ee, MatrixExpr):
1569
+ return self
1570
+ return ee.reshape(*self.shape)
1571
+
1572
+
1573
+ class _ArgE:
1574
+ """
1575
+ The ``_ArgE`` object contains references to the array expression
1576
+ (``.element``) and a list containing the information about index
1577
+ contractions (``.indices``).
1578
+
1579
+ Index contractions are numbered and contracted indices show the number of
1580
+ the contraction. Uncontracted indices have ``None`` value.
1581
+
1582
+ For example:
1583
+ ``_ArgE(M, [None, 3])``
1584
+ This object means that expression ``M`` is part of an array contraction
1585
+ and has two indices, the first is not contracted (value ``None``),
1586
+ the second index is contracted to the 4th (i.e. number ``3``) group of the
1587
+ array contraction object.
1588
+ """
1589
+ indices: List[Optional[int]]
1590
+
1591
+ def __init__(self, element, indices: Optional[List[Optional[int]]] = None):
1592
+ self.element = element
1593
+ if indices is None:
1594
+ self.indices = [None for i in range(get_rank(element))]
1595
+ else:
1596
+ self.indices = indices
1597
+
1598
+ def __str__(self):
1599
+ return "_ArgE(%s, %s)" % (self.element, self.indices)
1600
+
1601
+ __repr__ = __str__
1602
+
1603
+
1604
+ class _IndPos:
1605
+ """
1606
+ Index position, requiring two integers in the constructor:
1607
+
1608
+ - arg: the position of the argument in the tensor product,
1609
+ - rel: the relative position of the index inside the argument.
1610
+ """
1611
+ def __init__(self, arg: int, rel: int):
1612
+ self.arg = arg
1613
+ self.rel = rel
1614
+
1615
+ def __str__(self):
1616
+ return "_IndPos(%i, %i)" % (self.arg, self.rel)
1617
+
1618
+ __repr__ = __str__
1619
+
1620
+ def __iter__(self):
1621
+ yield from [self.arg, self.rel]
1622
+
1623
+
1624
+ class _EditArrayContraction:
1625
+ """
1626
+ Utility class to help manipulate array contraction objects.
1627
+
1628
+ This class takes as input an ``ArrayContraction`` object and turns it into
1629
+ an editable object.
1630
+
1631
+ The field ``args_with_ind`` of this class is a list of ``_ArgE`` objects
1632
+ which can be used to easily edit the contraction structure of the
1633
+ expression.
1634
+
1635
+ Once editing is finished, the ``ArrayContraction`` object may be recreated
1636
+ by calling the ``.to_array_contraction()`` method.
1637
+ """
1638
+
1639
+ def __init__(self, base_array: typing.Union[ArrayContraction, ArrayDiagonal, ArrayTensorProduct]):
1640
+
1641
+ expr: Basic
1642
+ diagonalized: tTuple[tTuple[int, ...], ...]
1643
+ contraction_indices: List[tTuple[int]]
1644
+ if isinstance(base_array, ArrayContraction):
1645
+ mapping = _get_mapping_from_subranks(base_array.subranks)
1646
+ expr = base_array.expr
1647
+ contraction_indices = base_array.contraction_indices
1648
+ diagonalized = ()
1649
+ elif isinstance(base_array, ArrayDiagonal):
1650
+
1651
+ if isinstance(base_array.expr, ArrayContraction):
1652
+ mapping = _get_mapping_from_subranks(base_array.expr.subranks)
1653
+ expr = base_array.expr.expr
1654
+ diagonalized = ArrayContraction._push_indices_down(base_array.expr.contraction_indices, base_array.diagonal_indices)
1655
+ contraction_indices = base_array.expr.contraction_indices
1656
+ elif isinstance(base_array.expr, ArrayTensorProduct):
1657
+ mapping = {}
1658
+ expr = base_array.expr
1659
+ diagonalized = base_array.diagonal_indices
1660
+ contraction_indices = []
1661
+ else:
1662
+ mapping = {}
1663
+ expr = base_array.expr
1664
+ diagonalized = base_array.diagonal_indices
1665
+ contraction_indices = []
1666
+
1667
+ elif isinstance(base_array, ArrayTensorProduct):
1668
+ expr = base_array
1669
+ contraction_indices = []
1670
+ diagonalized = ()
1671
+ else:
1672
+ raise NotImplementedError()
1673
+
1674
+ if isinstance(expr, ArrayTensorProduct):
1675
+ args = list(expr.args)
1676
+ else:
1677
+ args = [expr]
1678
+
1679
+ args_with_ind: List[_ArgE] = [_ArgE(arg) for arg in args]
1680
+ for i, contraction_tuple in enumerate(contraction_indices):
1681
+ for j in contraction_tuple:
1682
+ arg_pos, rel_pos = mapping[j]
1683
+ args_with_ind[arg_pos].indices[rel_pos] = i
1684
+ self.args_with_ind: List[_ArgE] = args_with_ind
1685
+ self.number_of_contraction_indices: int = len(contraction_indices)
1686
+ self._track_permutation: Optional[List[List[int]]] = None
1687
+
1688
+ mapping = _get_mapping_from_subranks(base_array.subranks)
1689
+
1690
+ # Trick: add diagonalized indices as negative indices into the editor object:
1691
+ for i, e in enumerate(diagonalized):
1692
+ for j in e:
1693
+ arg_pos, rel_pos = mapping[j]
1694
+ self.args_with_ind[arg_pos].indices[rel_pos] = -1 - i
1695
+
1696
+ def insert_after(self, arg: _ArgE, new_arg: _ArgE):
1697
+ pos = self.args_with_ind.index(arg)
1698
+ self.args_with_ind.insert(pos + 1, new_arg)
1699
+
1700
+ def get_new_contraction_index(self):
1701
+ self.number_of_contraction_indices += 1
1702
+ return self.number_of_contraction_indices - 1
1703
+
1704
+ def refresh_indices(self):
1705
+ updates = {}
1706
+ for arg_with_ind in self.args_with_ind:
1707
+ updates.update({i: -1 for i in arg_with_ind.indices if i is not None})
1708
+ for i, e in enumerate(sorted(updates)):
1709
+ updates[e] = i
1710
+ self.number_of_contraction_indices = len(updates)
1711
+ for arg_with_ind in self.args_with_ind:
1712
+ arg_with_ind.indices = [updates.get(i, None) for i in arg_with_ind.indices]
1713
+
1714
+ def merge_scalars(self):
1715
+ scalars = []
1716
+ for arg_with_ind in self.args_with_ind:
1717
+ if len(arg_with_ind.indices) == 0:
1718
+ scalars.append(arg_with_ind)
1719
+ for i in scalars:
1720
+ self.args_with_ind.remove(i)
1721
+ scalar = Mul.fromiter([i.element for i in scalars])
1722
+ if len(self.args_with_ind) == 0:
1723
+ self.args_with_ind.append(_ArgE(scalar))
1724
+ else:
1725
+ from sympy.tensor.array.expressions.from_array_to_matrix import _a2m_tensor_product
1726
+ self.args_with_ind[0].element = _a2m_tensor_product(scalar, self.args_with_ind[0].element)
1727
+
1728
+ def to_array_contraction(self):
1729
+
1730
+ # Count the ranks of the arguments:
1731
+ counter = 0
1732
+ # Create a collector for the new diagonal indices:
1733
+ diag_indices = defaultdict(list)
1734
+
1735
+ count_index_freq = Counter()
1736
+ for arg_with_ind in self.args_with_ind:
1737
+ count_index_freq.update(Counter(arg_with_ind.indices))
1738
+
1739
+ free_index_count = count_index_freq[None]
1740
+
1741
+ # Construct the inverse permutation:
1742
+ inv_perm1 = []
1743
+ inv_perm2 = []
1744
+ # Keep track of which diagonal indices have already been processed:
1745
+ done = set()
1746
+
1747
+ # Counter for the diagonal indices:
1748
+ counter4 = 0
1749
+
1750
+ for arg_with_ind in self.args_with_ind:
1751
+ # If some diagonalization axes have been removed, they should be
1752
+ # permuted in order to keep the permutation.
1753
+ # Add permutation here
1754
+ counter2 = 0 # counter for the indices
1755
+ for i in arg_with_ind.indices:
1756
+ if i is None:
1757
+ inv_perm1.append(counter4)
1758
+ counter2 += 1
1759
+ counter4 += 1
1760
+ continue
1761
+ if i >= 0:
1762
+ continue
1763
+ # Reconstruct the diagonal indices:
1764
+ diag_indices[-1 - i].append(counter + counter2)
1765
+ if count_index_freq[i] == 1 and i not in done:
1766
+ inv_perm1.append(free_index_count - 1 - i)
1767
+ done.add(i)
1768
+ elif i not in done:
1769
+ inv_perm2.append(free_index_count - 1 - i)
1770
+ done.add(i)
1771
+ counter2 += 1
1772
+ # Remove negative indices to restore a proper editor object:
1773
+ arg_with_ind.indices = [i if i is not None and i >= 0 else None for i in arg_with_ind.indices]
1774
+ counter += len([i for i in arg_with_ind.indices if i is None or i < 0])
1775
+
1776
+ inverse_permutation = inv_perm1 + inv_perm2
1777
+ permutation = _af_invert(inverse_permutation)
1778
+
1779
+ # Get the diagonal indices after the detection of HadamardProduct in the expression:
1780
+ diag_indices_filtered = [tuple(v) for v in diag_indices.values() if len(v) > 1]
1781
+
1782
+ self.merge_scalars()
1783
+ self.refresh_indices()
1784
+ args = [arg.element for arg in self.args_with_ind]
1785
+ contraction_indices = self.get_contraction_indices()
1786
+ expr = _array_contraction(_array_tensor_product(*args), *contraction_indices)
1787
+ expr2 = _array_diagonal(expr, *diag_indices_filtered)
1788
+ if self._track_permutation is not None:
1789
+ permutation2 = _af_invert([j for i in self._track_permutation for j in i])
1790
+ expr2 = _permute_dims(expr2, permutation2)
1791
+
1792
+ expr3 = _permute_dims(expr2, permutation)
1793
+ return expr3
1794
+
1795
+ def get_contraction_indices(self) -> List[List[int]]:
1796
+ contraction_indices: List[List[int]] = [[] for i in range(self.number_of_contraction_indices)]
1797
+ current_position: int = 0
1798
+ for arg_with_ind in self.args_with_ind:
1799
+ for j in arg_with_ind.indices:
1800
+ if j is not None:
1801
+ contraction_indices[j].append(current_position)
1802
+ current_position += 1
1803
+ return contraction_indices
1804
+
1805
+ def get_mapping_for_index(self, ind) -> List[_IndPos]:
1806
+ if ind >= self.number_of_contraction_indices:
1807
+ raise ValueError("index value exceeding the index range")
1808
+ positions: List[_IndPos] = []
1809
+ for i, arg_with_ind in enumerate(self.args_with_ind):
1810
+ for j, arg_ind in enumerate(arg_with_ind.indices):
1811
+ if ind == arg_ind:
1812
+ positions.append(_IndPos(i, j))
1813
+ return positions
1814
+
1815
+ def get_contraction_indices_to_ind_rel_pos(self) -> List[List[_IndPos]]:
1816
+ contraction_indices: List[List[_IndPos]] = [[] for i in range(self.number_of_contraction_indices)]
1817
+ for i, arg_with_ind in enumerate(self.args_with_ind):
1818
+ for j, ind in enumerate(arg_with_ind.indices):
1819
+ if ind is not None:
1820
+ contraction_indices[ind].append(_IndPos(i, j))
1821
+ return contraction_indices
1822
+
1823
+ def count_args_with_index(self, index: int) -> int:
1824
+ """
1825
+ Count the number of arguments that have the given index.
1826
+ """
1827
+ counter: int = 0
1828
+ for arg_with_ind in self.args_with_ind:
1829
+ if index in arg_with_ind.indices:
1830
+ counter += 1
1831
+ return counter
1832
+
1833
+ def get_args_with_index(self, index: int) -> List[_ArgE]:
1834
+ """
1835
+ Get a list of arguments having the given index.
1836
+ """
1837
+ ret: List[_ArgE] = [i for i in self.args_with_ind if index in i.indices]
1838
+ return ret
1839
+
1840
+ @property
1841
+ def number_of_diagonal_indices(self):
1842
+ data = set()
1843
+ for arg in self.args_with_ind:
1844
+ data.update({i for i in arg.indices if i is not None and i < 0})
1845
+ return len(data)
1846
+
1847
+ def track_permutation_start(self):
1848
+ permutation = []
1849
+ perm_diag = []
1850
+ counter = 0
1851
+ counter2 = -1
1852
+ for arg_with_ind in self.args_with_ind:
1853
+ perm = []
1854
+ for i in arg_with_ind.indices:
1855
+ if i is not None:
1856
+ if i < 0:
1857
+ perm_diag.append(counter2)
1858
+ counter2 -= 1
1859
+ continue
1860
+ perm.append(counter)
1861
+ counter += 1
1862
+ permutation.append(perm)
1863
+ max_ind = max(max(i) if i else -1 for i in permutation) if permutation else -1
1864
+ perm_diag = [max_ind - i for i in perm_diag]
1865
+ self._track_permutation = permutation + [perm_diag]
1866
+
1867
+ def track_permutation_merge(self, destination: _ArgE, from_element: _ArgE):
1868
+ index_destination = self.args_with_ind.index(destination)
1869
+ index_element = self.args_with_ind.index(from_element)
1870
+ self._track_permutation[index_destination].extend(self._track_permutation[index_element]) # type: ignore
1871
+ self._track_permutation.pop(index_element) # type: ignore
1872
+
1873
+ def get_absolute_free_range(self, arg: _ArgE) -> typing.Tuple[int, int]:
1874
+ """
1875
+ Return the range of the free indices of the arg as absolute positions
1876
+ among all free indices.
1877
+ """
1878
+ counter = 0
1879
+ for arg_with_ind in self.args_with_ind:
1880
+ number_free_indices = len([i for i in arg_with_ind.indices if i is None])
1881
+ if arg_with_ind == arg:
1882
+ return counter, counter + number_free_indices
1883
+ counter += number_free_indices
1884
+ raise IndexError("argument not found")
1885
+
1886
+ def get_absolute_range(self, arg: _ArgE) -> typing.Tuple[int, int]:
1887
+ """
1888
+ Return the absolute range of indices for arg, disregarding dummy
1889
+ indices.
1890
+ """
1891
+ counter = 0
1892
+ for arg_with_ind in self.args_with_ind:
1893
+ number_indices = len(arg_with_ind.indices)
1894
+ if arg_with_ind == arg:
1895
+ return counter, counter + number_indices
1896
+ counter += number_indices
1897
+ raise IndexError("argument not found")
1898
+
1899
+
1900
+ def get_rank(expr):
1901
+ if isinstance(expr, (MatrixExpr, MatrixElement)):
1902
+ return 2
1903
+ if isinstance(expr, _CodegenArrayAbstract):
1904
+ return len(expr.shape)
1905
+ if isinstance(expr, NDimArray):
1906
+ return expr.rank()
1907
+ if isinstance(expr, Indexed):
1908
+ return expr.rank
1909
+ if isinstance(expr, IndexedBase):
1910
+ shape = expr.shape
1911
+ if shape is None:
1912
+ return -1
1913
+ else:
1914
+ return len(shape)
1915
+ if hasattr(expr, "shape"):
1916
+ return len(expr.shape)
1917
+ return 0
1918
+
1919
+
1920
+ def _get_subrank(expr):
1921
+ if isinstance(expr, _CodegenArrayAbstract):
1922
+ return expr.subrank()
1923
+ return get_rank(expr)
1924
+
1925
+
1926
+ def _get_subranks(expr):
1927
+ if isinstance(expr, _CodegenArrayAbstract):
1928
+ return expr.subranks
1929
+ else:
1930
+ return [get_rank(expr)]
1931
+
1932
+
1933
+ def get_shape(expr):
1934
+ if hasattr(expr, "shape"):
1935
+ return expr.shape
1936
+ return ()
1937
+
1938
+
1939
+ def nest_permutation(expr):
1940
+ if isinstance(expr, PermuteDims):
1941
+ return expr.nest_permutation()
1942
+ else:
1943
+ return expr
1944
+
1945
+
1946
+ def _array_tensor_product(*args, **kwargs):
1947
+ return ArrayTensorProduct(*args, canonicalize=True, **kwargs)
1948
+
1949
+
1950
+ def _array_contraction(expr, *contraction_indices, **kwargs):
1951
+ return ArrayContraction(expr, *contraction_indices, canonicalize=True, **kwargs)
1952
+
1953
+
1954
+ def _array_diagonal(expr, *diagonal_indices, **kwargs):
1955
+ return ArrayDiagonal(expr, *diagonal_indices, canonicalize=True, **kwargs)
1956
+
1957
+
1958
+ def _permute_dims(expr, permutation, **kwargs):
1959
+ return PermuteDims(expr, permutation, canonicalize=True, **kwargs)
1960
+
1961
+
1962
+ def _array_add(*args, **kwargs):
1963
+ return ArrayAdd(*args, canonicalize=True, **kwargs)
1964
+
1965
+
1966
+ def _get_array_element_or_slice(expr, indices):
1967
+ return ArrayElement(expr, indices)
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ from functools import reduce, singledispatch
3
+
4
+ from sympy.core.expr import Expr
5
+ from sympy.core.singleton import S
6
+ from sympy.matrices.expressions.hadamard import HadamardProduct
7
+ from sympy.matrices.expressions.inverse import Inverse
8
+ from sympy.matrices.expressions.matexpr import (MatrixExpr, MatrixSymbol)
9
+ from sympy.matrices.expressions.special import Identity, OneMatrix
10
+ from sympy.matrices.expressions.transpose import Transpose
11
+ from sympy.combinatorics.permutations import _af_invert
12
+ from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
13
+ from sympy.tensor.array.expressions.array_expressions import (
14
+ _ArrayExpr, ZeroArray, ArraySymbol, ArrayTensorProduct, ArrayAdd,
15
+ PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, get_rank,
16
+ get_shape, ArrayContraction, _array_tensor_product, _array_contraction,
17
+ _array_diagonal, _array_add, _permute_dims, Reshape)
18
+ from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
19
+
20
+
21
+ @singledispatch
22
+ def array_derive(expr, x):
23
+ """
24
+ Derivatives (gradients) for array expressions.
25
+ """
26
+ raise NotImplementedError(f"not implemented for type {type(expr)}")
27
+
28
+
29
+ @array_derive.register(Expr)
30
+ def _(expr: Expr, x: _ArrayExpr):
31
+ return ZeroArray(*x.shape)
32
+
33
+
34
+ @array_derive.register(ArrayTensorProduct)
35
+ def _(expr: ArrayTensorProduct, x: Expr):
36
+ args = expr.args
37
+ addend_list = []
38
+ for i, arg in enumerate(expr.args):
39
+ darg = array_derive(arg, x)
40
+ if darg == 0:
41
+ continue
42
+ args_prev = args[:i]
43
+ args_succ = args[i+1:]
44
+ shape_prev = reduce(operator.add, map(get_shape, args_prev), ())
45
+ shape_succ = reduce(operator.add, map(get_shape, args_succ), ())
46
+ addend = _array_tensor_product(*args_prev, darg, *args_succ)
47
+ tot1 = len(get_shape(x))
48
+ tot2 = tot1 + len(shape_prev)
49
+ tot3 = tot2 + len(get_shape(arg))
50
+ tot4 = tot3 + len(shape_succ)
51
+ perm = list(range(tot1, tot2)) + \
52
+ list(range(tot1)) + list(range(tot2, tot3)) + \
53
+ list(range(tot3, tot4))
54
+ addend = _permute_dims(addend, _af_invert(perm))
55
+ addend_list.append(addend)
56
+ if len(addend_list) == 1:
57
+ return addend_list[0]
58
+ elif len(addend_list) == 0:
59
+ return S.Zero
60
+ else:
61
+ return _array_add(*addend_list)
62
+
63
+
64
+ @array_derive.register(ArraySymbol)
65
+ def _(expr: ArraySymbol, x: _ArrayExpr):
66
+ if expr == x:
67
+ return _permute_dims(
68
+ ArrayTensorProduct.fromiter(Identity(i) for i in expr.shape),
69
+ [2*i for i in range(len(expr.shape))] + [2*i+1 for i in range(len(expr.shape))]
70
+ )
71
+ return ZeroArray(*(x.shape + expr.shape))
72
+
73
+
74
+ @array_derive.register(MatrixSymbol)
75
+ def _(expr: MatrixSymbol, x: _ArrayExpr):
76
+ m, n = expr.shape
77
+ if expr == x:
78
+ return _permute_dims(
79
+ _array_tensor_product(Identity(m), Identity(n)),
80
+ [0, 2, 1, 3]
81
+ )
82
+ return ZeroArray(*(x.shape + expr.shape))
83
+
84
+
85
+ @array_derive.register(Identity)
86
+ def _(expr: Identity, x: _ArrayExpr):
87
+ return ZeroArray(*(x.shape + expr.shape))
88
+
89
+
90
+ @array_derive.register(OneMatrix)
91
+ def _(expr: OneMatrix, x: _ArrayExpr):
92
+ return ZeroArray(*(x.shape + expr.shape))
93
+
94
+
95
+ @array_derive.register(Transpose)
96
+ def _(expr: Transpose, x: Expr):
97
+ # D(A.T, A) ==> (m,n,i,j) ==> D(A_ji, A_mn) = d_mj d_ni
98
+ # D(B.T, A) ==> (m,n,i,j) ==> D(B_ji, A_mn)
99
+ fd = array_derive(expr.arg, x)
100
+ return _permute_dims(fd, [0, 1, 3, 2])
101
+
102
+
103
+ @array_derive.register(Inverse)
104
+ def _(expr: Inverse, x: Expr):
105
+ mat = expr.I
106
+ dexpr = array_derive(mat, x)
107
+ tp = _array_tensor_product(-expr, dexpr, expr)
108
+ mp = _array_contraction(tp, (1, 4), (5, 6))
109
+ pp = _permute_dims(mp, [1, 2, 0, 3])
110
+ return pp
111
+
112
+
113
+ @array_derive.register(ElementwiseApplyFunction)
114
+ def _(expr: ElementwiseApplyFunction, x: Expr):
115
+ assert get_rank(expr) == 2
116
+ assert get_rank(x) == 2
117
+ fdiff = expr._get_function_fdiff()
118
+ dexpr = array_derive(expr.expr, x)
119
+ tp = _array_tensor_product(
120
+ ElementwiseApplyFunction(fdiff, expr.expr),
121
+ dexpr
122
+ )
123
+ td = _array_diagonal(
124
+ tp, (0, 4), (1, 5)
125
+ )
126
+ return td
127
+
128
+
129
+ @array_derive.register(ArrayElementwiseApplyFunc)
130
+ def _(expr: ArrayElementwiseApplyFunc, x: Expr):
131
+ fdiff = expr._get_function_fdiff()
132
+ subexpr = expr.expr
133
+ dsubexpr = array_derive(subexpr, x)
134
+ tp = _array_tensor_product(
135
+ dsubexpr,
136
+ ArrayElementwiseApplyFunc(fdiff, subexpr)
137
+ )
138
+ b = get_rank(x)
139
+ c = get_rank(expr)
140
+ diag_indices = [(b + i, b + c + i) for i in range(c)]
141
+ return _array_diagonal(tp, *diag_indices)
142
+
143
+
144
+ @array_derive.register(MatrixExpr)
145
+ def _(expr: MatrixExpr, x: Expr):
146
+ cg = convert_matrix_to_array(expr)
147
+ return array_derive(cg, x)
148
+
149
+
150
+ @array_derive.register(HadamardProduct)
151
+ def _(expr: HadamardProduct, x: Expr):
152
+ raise NotImplementedError()
153
+
154
+
155
+ @array_derive.register(ArrayContraction)
156
+ def _(expr: ArrayContraction, x: Expr):
157
+ fd = array_derive(expr.expr, x)
158
+ rank_x = len(get_shape(x))
159
+ contraction_indices = expr.contraction_indices
160
+ new_contraction_indices = [tuple(j + rank_x for j in i) for i in contraction_indices]
161
+ return _array_contraction(fd, *new_contraction_indices)
162
+
163
+
164
+ @array_derive.register(ArrayDiagonal)
165
+ def _(expr: ArrayDiagonal, x: Expr):
166
+ dsubexpr = array_derive(expr.expr, x)
167
+ rank_x = len(get_shape(x))
168
+ diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices]
169
+ return _array_diagonal(dsubexpr, *diag_indices)
170
+
171
+
172
+ @array_derive.register(ArrayAdd)
173
+ def _(expr: ArrayAdd, x: Expr):
174
+ return _array_add(*[array_derive(arg, x) for arg in expr.args])
175
+
176
+
177
+ @array_derive.register(PermuteDims)
178
+ def _(expr: PermuteDims, x: Expr):
179
+ de = array_derive(expr.expr, x)
180
+ perm = [0, 1] + [i + 2 for i in expr.permutation.array_form]
181
+ return _permute_dims(de, perm)
182
+
183
+
184
+ @array_derive.register(Reshape)
185
+ def _(expr: Reshape, x: Expr):
186
+ de = array_derive(expr.expr, x)
187
+ return Reshape(de, get_shape(x) + expr.shape)
188
+
189
+
190
+ def matrix_derive(expr, x):
191
+ from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
192
+ ce = convert_matrix_to_array(expr)
193
+ dce = array_derive(ce, x)
194
+ return convert_array_to_matrix(dce).doit()
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.tensor.array.expressions import from_array_to_indexed
2
+ from sympy.utilities.decorator import deprecated
3
+
4
+
5
+ _conv_to_from_decorator = deprecated(
6
+ "module has been renamed by replacing 'conv_' with 'from_' in its name",
7
+ deprecated_since_version="1.11",
8
+ active_deprecations_target="deprecated-conv-array-expr-module-names",
9
+ )
10
+
11
+
12
+ convert_array_to_indexed = _conv_to_from_decorator(from_array_to_indexed.convert_array_to_indexed)
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from sympy.tensor.array.expressions import from_array_to_matrix
2
+ from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator
3
+
4
+ convert_array_to_matrix = _conv_to_from_decorator(from_array_to_matrix.convert_array_to_matrix)
5
+ _array2matrix = _conv_to_from_decorator(from_array_to_matrix._array2matrix)
6
+ _remove_trivial_dims = _conv_to_from_decorator(from_array_to_matrix._remove_trivial_dims)
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from sympy.tensor.array.expressions import from_indexed_to_array
2
+ from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator
3
+
4
+ convert_indexed_to_array = _conv_to_from_decorator(from_indexed_to_array.convert_indexed_to_array)
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from sympy.tensor.array.expressions import from_matrix_to_array
2
+ from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator
3
+
4
+ convert_matrix_to_array = _conv_to_from_decorator(from_matrix_to_array.convert_matrix_to_array)
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import operator
3
+ from itertools import accumulate
4
+
5
+ from sympy import Mul, Sum, Dummy, Add
6
+ from sympy.tensor.array.expressions import PermuteDims, ArrayAdd, ArrayElementwiseApplyFunc, Reshape
7
+ from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, get_rank, ArrayContraction, \
8
+ ArrayDiagonal, get_shape, _get_array_element_or_slice, _ArrayExpr
9
+ from sympy.tensor.array.expressions.utils import _apply_permutation_to_list
10
+
11
+
12
+ def convert_array_to_indexed(expr, indices):
13
+ return _ConvertArrayToIndexed().do_convert(expr, indices)
14
+
15
+
16
+ class _ConvertArrayToIndexed:
17
+
18
+ def __init__(self):
19
+ self.count_dummies = 0
20
+
21
+ def do_convert(self, expr, indices):
22
+ if isinstance(expr, ArrayTensorProduct):
23
+ cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args]))
24
+ indices_grp = [indices[cumul[i]:cumul[i+1]] for i in range(len(expr.args))]
25
+ return Mul.fromiter(self.do_convert(arg, ind) for arg, ind in zip(expr.args, indices_grp))
26
+ if isinstance(expr, ArrayContraction):
27
+ new_indices = [None for i in range(get_rank(expr.expr))]
28
+ limits = []
29
+ bottom_shape = get_shape(expr.expr)
30
+ for contraction_index_grp in expr.contraction_indices:
31
+ d = Dummy(f"d{self.count_dummies}")
32
+ self.count_dummies += 1
33
+ dim = bottom_shape[contraction_index_grp[0]]
34
+ limits.append((d, 0, dim-1))
35
+ for i in contraction_index_grp:
36
+ new_indices[i] = d
37
+ j = 0
38
+ for i in range(len(new_indices)):
39
+ if new_indices[i] is None:
40
+ new_indices[i] = indices[j]
41
+ j += 1
42
+ newexpr = self.do_convert(expr.expr, new_indices)
43
+ return Sum(newexpr, *limits)
44
+ if isinstance(expr, ArrayDiagonal):
45
+ new_indices = [None for i in range(get_rank(expr.expr))]
46
+ ind_pos = expr._push_indices_down(expr.diagonal_indices, list(range(len(indices))), get_rank(expr))
47
+ for i, index in zip(ind_pos, indices):
48
+ if isinstance(i, collections.abc.Iterable):
49
+ for j in i:
50
+ new_indices[j] = index
51
+ else:
52
+ new_indices[i] = index
53
+ newexpr = self.do_convert(expr.expr, new_indices)
54
+ return newexpr
55
+ if isinstance(expr, PermuteDims):
56
+ permuted_indices = _apply_permutation_to_list(expr.permutation, indices)
57
+ return self.do_convert(expr.expr, permuted_indices)
58
+ if isinstance(expr, ArrayAdd):
59
+ return Add.fromiter(self.do_convert(arg, indices) for arg in expr.args)
60
+ if isinstance(expr, _ArrayExpr):
61
+ return expr.__getitem__(tuple(indices))
62
+ if isinstance(expr, ArrayElementwiseApplyFunc):
63
+ return expr.function(self.do_convert(expr.expr, indices))
64
+ if isinstance(expr, Reshape):
65
+ shape_up = expr.shape
66
+ shape_down = get_shape(expr.expr)
67
+ cumul = list(accumulate([1] + list(reversed(shape_up)), operator.mul))
68
+ one_index = Add.fromiter(i*s for i, s in zip(reversed(indices), cumul))
69
+ dest_indices = [None for _ in shape_down]
70
+ c = 1
71
+ for i, e in enumerate(reversed(shape_down)):
72
+ if c == 1:
73
+ if i == len(shape_down) - 1:
74
+ dest_indices[i] = one_index
75
+ else:
76
+ dest_indices[i] = one_index % e
77
+ elif i == len(shape_down) - 1:
78
+ dest_indices[i] = one_index // c
79
+ else:
80
+ dest_indices[i] = one_index // c % e
81
+ c *= e
82
+ dest_indices.reverse()
83
+ return self.do_convert(expr.expr, dest_indices)
84
+ return _get_array_element_or_slice(expr, indices)
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from collections import defaultdict
3
+ from typing import Tuple as tTuple, Union as tUnion, FrozenSet, Dict as tDict, List, Optional
4
+ from functools import singledispatch
5
+ from itertools import accumulate
6
+
7
+ from sympy import MatMul, Basic, Wild, KroneckerProduct
8
+ from sympy.assumptions.ask import (Q, ask)
9
+ from sympy.core.mul import Mul
10
+ from sympy.core.singleton import S
11
+ from sympy.matrices.expressions.diagonal import DiagMatrix
12
+ from sympy.matrices.expressions.hadamard import hadamard_product, HadamardPower
13
+ from sympy.matrices.expressions.matexpr import MatrixExpr
14
+ from sympy.matrices.expressions.special import (Identity, ZeroMatrix, OneMatrix)
15
+ from sympy.matrices.expressions.trace import Trace
16
+ from sympy.matrices.expressions.transpose import Transpose
17
+ from sympy.combinatorics.permutations import _af_invert, Permutation
18
+ from sympy.matrices.matrixbase import MatrixBase
19
+ from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
20
+ from sympy.matrices.expressions.matexpr import MatrixElement
21
+ from sympy.tensor.array.expressions.array_expressions import PermuteDims, ArrayDiagonal, \
22
+ ArrayTensorProduct, OneArray, get_rank, _get_subrank, ZeroArray, ArrayContraction, \
23
+ ArrayAdd, _CodegenArrayAbstract, get_shape, ArrayElementwiseApplyFunc, _ArrayExpr, _EditArrayContraction, _ArgE, \
24
+ ArrayElement, _array_tensor_product, _array_contraction, _array_diagonal, _array_add, _permute_dims
25
+ from sympy.tensor.array.expressions.utils import _get_mapping_from_subranks
26
+
27
+
28
+ def _get_candidate_for_matmul_from_contraction(scan_indices: List[Optional[int]], remaining_args: List[_ArgE]) -> tTuple[Optional[_ArgE], bool, int]:
29
+
30
+ scan_indices_int: List[int] = [i for i in scan_indices if i is not None]
31
+ if len(scan_indices_int) == 0:
32
+ return None, False, -1
33
+
34
+ transpose: bool = False
35
+ candidate: Optional[_ArgE] = None
36
+ candidate_index: int = -1
37
+ for arg_with_ind2 in remaining_args:
38
+ if not isinstance(arg_with_ind2.element, MatrixExpr):
39
+ continue
40
+ for index in scan_indices_int:
41
+ if candidate_index != -1 and candidate_index != index:
42
+ # A candidate index has already been selected, check
43
+ # repetitions only for that index:
44
+ continue
45
+ if index in arg_with_ind2.indices:
46
+ if set(arg_with_ind2.indices) == {index}:
47
+ # Index repeated twice in arg_with_ind2
48
+ candidate = None
49
+ break
50
+ if candidate is None:
51
+ candidate = arg_with_ind2
52
+ candidate_index = index
53
+ transpose = (index == arg_with_ind2.indices[1])
54
+ else:
55
+ # Index repeated more than twice, break
56
+ candidate = None
57
+ break
58
+ return candidate, transpose, candidate_index
59
+
60
+
61
+ def _insert_candidate_into_editor(editor: _EditArrayContraction, arg_with_ind: _ArgE, candidate: _ArgE, transpose1: bool, transpose2: bool):
62
+ other = candidate.element
63
+ other_index: Optional[int]
64
+ if transpose2:
65
+ other = Transpose(other)
66
+ other_index = candidate.indices[0]
67
+ else:
68
+ other_index = candidate.indices[1]
69
+ new_element = (Transpose(arg_with_ind.element) if transpose1 else arg_with_ind.element) * other
70
+ editor.args_with_ind.remove(candidate)
71
+ new_arge = _ArgE(new_element)
72
+ return new_arge, other_index
73
+
74
+
75
+ def _support_function_tp1_recognize(contraction_indices, args):
76
+ if len(contraction_indices) == 0:
77
+ return _a2m_tensor_product(*args)
78
+
79
+ ac = _array_contraction(_array_tensor_product(*args), *contraction_indices)
80
+ editor = _EditArrayContraction(ac)
81
+ editor.track_permutation_start()
82
+
83
+ while True:
84
+ flag_stop = True
85
+ for i, arg_with_ind in enumerate(editor.args_with_ind):
86
+ if not isinstance(arg_with_ind.element, MatrixExpr):
87
+ continue
88
+
89
+ first_index = arg_with_ind.indices[0]
90
+ second_index = arg_with_ind.indices[1]
91
+
92
+ first_frequency = editor.count_args_with_index(first_index)
93
+ second_frequency = editor.count_args_with_index(second_index)
94
+
95
+ if first_index is not None and first_frequency == 1 and first_index == second_index:
96
+ flag_stop = False
97
+ arg_with_ind.element = Trace(arg_with_ind.element)._normalize()
98
+ arg_with_ind.indices = []
99
+ break
100
+
101
+ scan_indices = []
102
+ if first_frequency == 2:
103
+ scan_indices.append(first_index)
104
+ if second_frequency == 2:
105
+ scan_indices.append(second_index)
106
+
107
+ candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(scan_indices, editor.args_with_ind[i+1:])
108
+ if candidate is not None:
109
+ flag_stop = False
110
+ editor.track_permutation_merge(arg_with_ind, candidate)
111
+ transpose1 = found_index == first_index
112
+ new_arge, other_index = _insert_candidate_into_editor(editor, arg_with_ind, candidate, transpose1, transpose)
113
+ if found_index == first_index:
114
+ new_arge.indices = [second_index, other_index]
115
+ else:
116
+ new_arge.indices = [first_index, other_index]
117
+ set_indices = set(new_arge.indices)
118
+ if len(set_indices) == 1 and set_indices != {None}:
119
+ # This is a trace:
120
+ new_arge.element = Trace(new_arge.element)._normalize()
121
+ new_arge.indices = []
122
+ editor.args_with_ind[i] = new_arge
123
+ # TODO: is this break necessary?
124
+ break
125
+
126
+ if flag_stop:
127
+ break
128
+
129
+ editor.refresh_indices()
130
+ return editor.to_array_contraction()
131
+
132
+
133
+ def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct):
134
+ # If there are matrices of trivial shape in the tensor product (i.e. shape
135
+ # (1, 1)), try to check if there is a suitable non-trivial MatMul where the
136
+ # expression can be inserted.
137
+
138
+ # For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the
139
+ # expressions "_array_tensor_product(a, b*b.T)" can be rewritten as
140
+ # "b*a*b.T"
141
+
142
+ trivial_matrices = []
143
+ pos: Optional[int] = None
144
+ first: Optional[MatrixExpr] = None
145
+ second: Optional[MatrixExpr] = None
146
+ removed: List[int] = []
147
+ counter: int = 0
148
+ args: List[Optional[Basic]] = list(expr.args)
149
+ for i, arg in enumerate(expr.args):
150
+ if isinstance(arg, MatrixExpr):
151
+ if arg.shape == (1, 1):
152
+ trivial_matrices.append(arg)
153
+ args[i] = None
154
+ removed.extend([counter, counter+1])
155
+ elif pos is None and isinstance(arg, MatMul):
156
+ margs = arg.args
157
+ for j, e in enumerate(margs):
158
+ if isinstance(e, MatrixExpr) and e.shape[1] == 1:
159
+ pos = i
160
+ first = MatMul.fromiter(margs[:j+1])
161
+ second = MatMul.fromiter(margs[j+1:])
162
+ break
163
+ counter += get_rank(arg)
164
+ if pos is None:
165
+ return expr, []
166
+ args[pos] = (first*MatMul.fromiter(i for i in trivial_matrices)*second).doit()
167
+ return _array_tensor_product(*[i for i in args if i is not None]), removed
168
+
169
+
170
+ def _find_trivial_kronecker_products_broadcast(expr: ArrayTensorProduct):
171
+ newargs: List[Basic] = []
172
+ removed = []
173
+ count_dims = 0
174
+ for arg in expr.args:
175
+ count_dims += get_rank(arg)
176
+ shape = get_shape(arg)
177
+ current_range = [count_dims-i for i in range(len(shape), 0, -1)]
178
+ if (shape == (1, 1) and len(newargs) > 0 and 1 not in get_shape(newargs[-1]) and
179
+ isinstance(newargs[-1], MatrixExpr) and isinstance(arg, MatrixExpr)):
180
+ # KroneckerProduct object allows the trick of broadcasting:
181
+ newargs[-1] = KroneckerProduct(newargs[-1], arg)
182
+ removed.extend(current_range)
183
+ elif 1 not in shape and len(newargs) > 0 and get_shape(newargs[-1]) == (1, 1):
184
+ # Broadcast:
185
+ newargs[-1] = KroneckerProduct(newargs[-1], arg)
186
+ prev_range = [i for i in range(min(current_range)) if i not in removed]
187
+ removed.extend(prev_range[-2:])
188
+ else:
189
+ newargs.append(arg)
190
+ return _array_tensor_product(*newargs), removed
191
+
192
+
193
+ @singledispatch
194
+ def _array2matrix(expr):
195
+ return expr
196
+
197
+
198
+ @_array2matrix.register(ZeroArray)
199
+ def _(expr: ZeroArray):
200
+ if get_rank(expr) == 2:
201
+ return ZeroMatrix(*expr.shape)
202
+ else:
203
+ return expr
204
+
205
+
206
+ @_array2matrix.register(ArrayTensorProduct)
207
+ def _(expr: ArrayTensorProduct):
208
+ return _a2m_tensor_product(*[_array2matrix(arg) for arg in expr.args])
209
+
210
+
211
+ @_array2matrix.register(ArrayContraction)
212
+ def _(expr: ArrayContraction):
213
+ expr = expr.flatten_contraction_of_diagonal()
214
+ expr = identify_removable_identity_matrices(expr)
215
+ expr = expr.split_multiple_contractions()
216
+ expr = identify_hadamard_products(expr)
217
+ if not isinstance(expr, ArrayContraction):
218
+ return _array2matrix(expr)
219
+ subexpr = expr.expr
220
+ contraction_indices: tTuple[tTuple[int]] = expr.contraction_indices
221
+ if contraction_indices == ((0,), (1,)) or (
222
+ contraction_indices == ((0,),) and subexpr.shape[1] == 1
223
+ ) or (
224
+ contraction_indices == ((1,),) and subexpr.shape[0] == 1
225
+ ):
226
+ shape = subexpr.shape
227
+ subexpr = _array2matrix(subexpr)
228
+ if isinstance(subexpr, MatrixExpr):
229
+ return OneMatrix(1, shape[0])*subexpr*OneMatrix(shape[1], 1)
230
+ if isinstance(subexpr, ArrayTensorProduct):
231
+ newexpr = _array_contraction(_array2matrix(subexpr), *contraction_indices)
232
+ contraction_indices = newexpr.contraction_indices
233
+ if any(i > 2 for i in newexpr.subranks):
234
+ addends = _array_add(*[_a2m_tensor_product(*j) for j in itertools.product(*[i.args if isinstance(i,
235
+ ArrayAdd) else [i] for i in expr.expr.args])])
236
+ newexpr = _array_contraction(addends, *contraction_indices)
237
+ if isinstance(newexpr, ArrayAdd):
238
+ ret = _array2matrix(newexpr)
239
+ return ret
240
+ assert isinstance(newexpr, ArrayContraction)
241
+ ret = _support_function_tp1_recognize(contraction_indices, list(newexpr.expr.args))
242
+ return ret
243
+ elif not isinstance(subexpr, _CodegenArrayAbstract):
244
+ ret = _array2matrix(subexpr)
245
+ if isinstance(ret, MatrixExpr):
246
+ assert expr.contraction_indices == ((0, 1),)
247
+ return _a2m_trace(ret)
248
+ else:
249
+ return _array_contraction(ret, *expr.contraction_indices)
250
+
251
+
252
+ @_array2matrix.register(ArrayDiagonal)
253
+ def _(expr: ArrayDiagonal):
254
+ pexpr = _array_diagonal(_array2matrix(expr.expr), *expr.diagonal_indices)
255
+ pexpr = identify_hadamard_products(pexpr)
256
+ if isinstance(pexpr, ArrayDiagonal):
257
+ pexpr = _array_diag2contr_diagmatrix(pexpr)
258
+ if expr == pexpr:
259
+ return expr
260
+ return _array2matrix(pexpr)
261
+
262
+
263
+ @_array2matrix.register(PermuteDims)
264
+ def _(expr: PermuteDims):
265
+ if expr.permutation.array_form == [1, 0]:
266
+ return _a2m_transpose(_array2matrix(expr.expr))
267
+ elif isinstance(expr.expr, ArrayTensorProduct):
268
+ ranks = expr.expr.subranks
269
+ inv_permutation = expr.permutation**(-1)
270
+ newrange = [inv_permutation(i) for i in range(sum(ranks))]
271
+ newpos = []
272
+ counter = 0
273
+ for rank in ranks:
274
+ newpos.append(newrange[counter:counter+rank])
275
+ counter += rank
276
+ newargs = []
277
+ newperm = []
278
+ scalars = []
279
+ for pos, arg in zip(newpos, expr.expr.args):
280
+ if len(pos) == 0:
281
+ scalars.append(_array2matrix(arg))
282
+ elif pos == sorted(pos):
283
+ newargs.append((_array2matrix(arg), pos[0]))
284
+ newperm.extend(pos)
285
+ elif len(pos) == 2:
286
+ newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0]))
287
+ newperm.extend(reversed(pos))
288
+ else:
289
+ raise NotImplementedError()
290
+ newargs = [i[0] for i in newargs]
291
+ return _permute_dims(_a2m_tensor_product(*scalars, *newargs), _af_invert(newperm))
292
+ elif isinstance(expr.expr, ArrayContraction):
293
+ mat_mul_lines = _array2matrix(expr.expr)
294
+ if not isinstance(mat_mul_lines, ArrayTensorProduct):
295
+ return _permute_dims(mat_mul_lines, expr.permutation)
296
+ # TODO: this assumes that all arguments are matrices, it may not be the case:
297
+ permutation = Permutation(2*len(mat_mul_lines.args)-1)*expr.permutation
298
+ permuted = [permutation(i) for i in range(2*len(mat_mul_lines.args))]
299
+ args_array = [None for i in mat_mul_lines.args]
300
+ for i in range(len(mat_mul_lines.args)):
301
+ p1 = permuted[2*i]
302
+ p2 = permuted[2*i+1]
303
+ if p1 // 2 != p2 // 2:
304
+ return _permute_dims(mat_mul_lines, permutation)
305
+ if p1 > p2:
306
+ args_array[i] = _a2m_transpose(mat_mul_lines.args[p1 // 2])
307
+ else:
308
+ args_array[i] = mat_mul_lines.args[p1 // 2]
309
+ return _a2m_tensor_product(*args_array)
310
+ else:
311
+ return expr
312
+
313
+
314
+ @_array2matrix.register(ArrayAdd)
315
+ def _(expr: ArrayAdd):
316
+ addends = [_array2matrix(arg) for arg in expr.args]
317
+ return _a2m_add(*addends)
318
+
319
+
320
+ @_array2matrix.register(ArrayElementwiseApplyFunc)
321
+ def _(expr: ArrayElementwiseApplyFunc):
322
+ subexpr = _array2matrix(expr.expr)
323
+ if isinstance(subexpr, MatrixExpr):
324
+ if subexpr.shape != (1, 1):
325
+ d = expr.function.bound_symbols[0]
326
+ w = Wild("w", exclude=[d])
327
+ p = Wild("p", exclude=[d])
328
+ m = expr.function.expr.match(w*d**p)
329
+ if m is not None:
330
+ return m[w]*HadamardPower(subexpr, m[p])
331
+ return ElementwiseApplyFunction(expr.function, subexpr)
332
+ else:
333
+ return ArrayElementwiseApplyFunc(expr.function, subexpr)
334
+
335
+
336
+ @_array2matrix.register(ArrayElement)
337
+ def _(expr: ArrayElement):
338
+ ret = _array2matrix(expr.name)
339
+ if isinstance(ret, MatrixExpr):
340
+ return MatrixElement(ret, *expr.indices)
341
+ return ArrayElement(ret, expr.indices)
342
+
343
+
344
+ @singledispatch
345
+ def _remove_trivial_dims(expr):
346
+ return expr, []
347
+
348
+
349
+ @_remove_trivial_dims.register(ArrayTensorProduct)
350
+ def _(expr: ArrayTensorProduct):
351
+ # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`.
352
+ # The matrix expression has to be equivalent to the tensor product of the
353
+ # matrices, with trivial dimensions (i.e. dim=1) dropped.
354
+ # That is, add contractions over trivial dimensions:
355
+
356
+ removed = []
357
+ newargs = []
358
+ cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args]))
359
+ pending = None
360
+ prev_i = None
361
+ for i, arg in enumerate(expr.args):
362
+ current_range = list(range(cumul[i], cumul[i+1]))
363
+ if isinstance(arg, OneArray):
364
+ removed.extend(current_range)
365
+ continue
366
+ if not isinstance(arg, (MatrixExpr, MatrixBase)):
367
+ rarg, rem = _remove_trivial_dims(arg)
368
+ removed.extend(rem)
369
+ newargs.append(rarg)
370
+ continue
371
+ elif getattr(arg, "is_Identity", False) and arg.shape == (1, 1):
372
+ if arg.shape == (1, 1):
373
+ # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1.
374
+ removed.extend(current_range)
375
+ continue
376
+ elif arg.shape == (1, 1):
377
+ arg, _ = _remove_trivial_dims(arg)
378
+ # Matrix is equivalent to scalar:
379
+ if len(newargs) == 0:
380
+ newargs.append(arg)
381
+ elif 1 in get_shape(newargs[-1]):
382
+ if newargs[-1].shape[1] == 1:
383
+ newargs[-1] = newargs[-1]*arg
384
+ else:
385
+ newargs[-1] = arg*newargs[-1]
386
+ removed.extend(current_range)
387
+ else:
388
+ newargs.append(arg)
389
+ elif 1 in arg.shape:
390
+ k = [i for i in arg.shape if i != 1][0]
391
+ if pending is None:
392
+ pending = k
393
+ prev_i = i
394
+ newargs.append(arg)
395
+ elif pending == k:
396
+ prev = newargs[-1]
397
+ if prev.shape[0] == 1:
398
+ d1 = cumul[prev_i]
399
+ prev = _a2m_transpose(prev)
400
+ else:
401
+ d1 = cumul[prev_i] + 1
402
+ if arg.shape[1] == 1:
403
+ d2 = cumul[i] + 1
404
+ arg = _a2m_transpose(arg)
405
+ else:
406
+ d2 = cumul[i]
407
+ newargs[-1] = prev*arg
408
+ pending = None
409
+ removed.extend([d1, d2])
410
+ else:
411
+ newargs.append(arg)
412
+ pending = k
413
+ prev_i = i
414
+ else:
415
+ newargs.append(arg)
416
+ pending = None
417
+ newexpr, newremoved = _a2m_tensor_product(*newargs), sorted(removed)
418
+ if isinstance(newexpr, ArrayTensorProduct):
419
+ newexpr, newremoved2 = _find_trivial_matrices_rewrite(newexpr)
420
+ newremoved = _combine_removed(-1, newremoved, newremoved2)
421
+ if isinstance(newexpr, ArrayTensorProduct):
422
+ newexpr, newremoved2 = _find_trivial_kronecker_products_broadcast(newexpr)
423
+ newremoved = _combine_removed(-1, newremoved, newremoved2)
424
+ return newexpr, newremoved
425
+
426
+
427
+ @_remove_trivial_dims.register(ArrayAdd)
428
+ def _(expr: ArrayAdd):
429
+ rec = [_remove_trivial_dims(arg) for arg in expr.args]
430
+ newargs, removed = zip(*rec)
431
+ if len({get_shape(i) for i in newargs}) > 1:
432
+ return expr, []
433
+ if len(removed) == 0:
434
+ return expr, removed
435
+ removed1 = removed[0]
436
+ return _a2m_add(*newargs), removed1
437
+
438
+
439
+ @_remove_trivial_dims.register(PermuteDims)
440
+ def _(expr: PermuteDims):
441
+ subexpr, subremoved = _remove_trivial_dims(expr.expr)
442
+ p = expr.permutation.array_form
443
+ pinv = _af_invert(expr.permutation.array_form)
444
+ shift = list(accumulate([1 if i in subremoved else 0 for i in range(len(p))]))
445
+ premoved = [pinv[i] for i in subremoved]
446
+ p2 = [e - shift[e] for e in p if e not in subremoved]
447
+ # TODO: check if subremoved should be permuted as well...
448
+ newexpr = _permute_dims(subexpr, p2)
449
+ premoved = sorted(premoved)
450
+ if newexpr != expr:
451
+ newexpr, removed2 = _remove_trivial_dims(_array2matrix(newexpr))
452
+ premoved = _combine_removed(-1, premoved, removed2)
453
+ return newexpr, premoved
454
+
455
+
456
+ @_remove_trivial_dims.register(ArrayContraction)
457
+ def _(expr: ArrayContraction):
458
+ new_expr, removed0 = _array_contraction_to_diagonal_multiple_identity(expr)
459
+ if new_expr != expr:
460
+ new_expr2, removed1 = _remove_trivial_dims(_array2matrix(new_expr))
461
+ removed = _combine_removed(-1, removed0, removed1)
462
+ return new_expr2, removed
463
+ rank1 = get_rank(expr)
464
+ expr, removed1 = remove_identity_matrices(expr)
465
+ if not isinstance(expr, ArrayContraction):
466
+ expr2, removed2 = _remove_trivial_dims(expr)
467
+ return expr2, _combine_removed(rank1, removed1, removed2)
468
+ newexpr, removed2 = _remove_trivial_dims(expr.expr)
469
+ shifts = list(accumulate([1 if i in removed2 else 0 for i in range(get_rank(expr.expr))]))
470
+ new_contraction_indices = [tuple(j for j in i if j not in removed2) for i in expr.contraction_indices]
471
+ # Remove possible empty tuples "()":
472
+ new_contraction_indices = [i for i in new_contraction_indices if len(i) > 0]
473
+ contraction_indices_flat = [j for i in expr.contraction_indices for j in i]
474
+ removed2 = [i for i in removed2 if i not in contraction_indices_flat]
475
+ new_contraction_indices = [tuple(j - shifts[j] for j in i) for i in new_contraction_indices]
476
+ # Shift removed2:
477
+ removed2 = ArrayContraction._push_indices_up(expr.contraction_indices, removed2)
478
+ removed = _combine_removed(rank1, removed1, removed2)
479
+ return _array_contraction(newexpr, *new_contraction_indices), list(removed)
480
+
481
+
482
+ def _remove_diagonalized_identity_matrices(expr: ArrayDiagonal):
483
+ assert isinstance(expr, ArrayDiagonal)
484
+ editor = _EditArrayContraction(expr)
485
+ mapping = {i: {j for j in editor.args_with_ind if i in j.indices} for i in range(-1, -1-editor.number_of_diagonal_indices, -1)}
486
+ removed = []
487
+ counter: int = 0
488
+ for i, arg_with_ind in enumerate(editor.args_with_ind):
489
+ counter += len(arg_with_ind.indices)
490
+ if isinstance(arg_with_ind.element, Identity):
491
+ if None in arg_with_ind.indices and any(i is not None and (i < 0) == True for i in arg_with_ind.indices):
492
+ diag_ind = [j for j in arg_with_ind.indices if j is not None][0]
493
+ other = [j for j in mapping[diag_ind] if j != arg_with_ind][0]
494
+ if not isinstance(other.element, MatrixExpr):
495
+ continue
496
+ if 1 not in other.element.shape:
497
+ continue
498
+ if None not in other.indices:
499
+ continue
500
+ editor.args_with_ind[i].element = None
501
+ none_index = other.indices.index(None)
502
+ other.element = DiagMatrix(other.element)
503
+ other_range = editor.get_absolute_range(other)
504
+ removed.extend([other_range[0] + none_index])
505
+ editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None]
506
+ removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, get_rank(expr.expr))
507
+ return editor.to_array_contraction(), removed
508
+
509
+
510
+ @_remove_trivial_dims.register(ArrayDiagonal)
511
+ def _(expr: ArrayDiagonal):
512
+ newexpr, removed = _remove_trivial_dims(expr.expr)
513
+ shifts = list(accumulate([0] + [1 if i in removed else 0 for i in range(get_rank(expr.expr))]))
514
+ new_diag_indices_map = {i: tuple(j for j in i if j not in removed) for i in expr.diagonal_indices}
515
+ for old_diag_tuple, new_diag_tuple in new_diag_indices_map.items():
516
+ if len(new_diag_tuple) == 1:
517
+ removed = [i for i in removed if i not in old_diag_tuple]
518
+ new_diag_indices = [tuple(j - shifts[j] for j in i) for i in new_diag_indices_map.values()]
519
+ rank = get_rank(expr.expr)
520
+ removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, rank)
521
+ removed = sorted(set(removed))
522
+ # If there are single axes to diagonalize remaining, it means that their
523
+ # corresponding dimension has been removed, they no longer need diagonalization:
524
+ new_diag_indices = [i for i in new_diag_indices if len(i) > 0]
525
+ if len(new_diag_indices) > 0:
526
+ newexpr2 = _array_diagonal(newexpr, *new_diag_indices, allow_trivial_diags=True)
527
+ else:
528
+ newexpr2 = newexpr
529
+ if isinstance(newexpr2, ArrayDiagonal):
530
+ newexpr3, removed2 = _remove_diagonalized_identity_matrices(newexpr2)
531
+ removed = _combine_removed(-1, removed, removed2)
532
+ return newexpr3, removed
533
+ else:
534
+ return newexpr2, removed
535
+
536
+
537
+ @_remove_trivial_dims.register(ElementwiseApplyFunction)
538
+ def _(expr: ElementwiseApplyFunction):
539
+ subexpr, removed = _remove_trivial_dims(expr.expr)
540
+ if subexpr.shape == (1, 1):
541
+ # TODO: move this to ElementwiseApplyFunction
542
+ return expr.function(subexpr), removed + [0, 1]
543
+ return ElementwiseApplyFunction(expr.function, subexpr), []
544
+
545
+
546
+ @_remove_trivial_dims.register(ArrayElementwiseApplyFunc)
547
+ def _(expr: ArrayElementwiseApplyFunc):
548
+ subexpr, removed = _remove_trivial_dims(expr.expr)
549
+ return ArrayElementwiseApplyFunc(expr.function, subexpr), removed
550
+
551
+
552
+ def convert_array_to_matrix(expr):
553
+ r"""
554
+ Recognize matrix expressions in codegen objects.
555
+
556
+ If more than one matrix multiplication line have been detected, return a
557
+ list with the matrix expressions.
558
+
559
+ Examples
560
+ ========
561
+
562
+ >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
563
+ >>> from sympy.tensor.array import tensorcontraction, tensorproduct
564
+ >>> from sympy import MatrixSymbol, Sum
565
+ >>> from sympy.abc import i, j, k, l, N
566
+ >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
567
+ >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
568
+ >>> A = MatrixSymbol("A", N, N)
569
+ >>> B = MatrixSymbol("B", N, N)
570
+ >>> C = MatrixSymbol("C", N, N)
571
+ >>> D = MatrixSymbol("D", N, N)
572
+
573
+ >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1))
574
+ >>> cg = convert_indexed_to_array(expr)
575
+ >>> convert_array_to_matrix(cg)
576
+ A*B
577
+ >>> cg = convert_indexed_to_array(expr, first_indices=[k])
578
+ >>> convert_array_to_matrix(cg)
579
+ B.T*A.T
580
+
581
+ Transposition is detected:
582
+
583
+ >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1))
584
+ >>> cg = convert_indexed_to_array(expr)
585
+ >>> convert_array_to_matrix(cg)
586
+ A.T*B
587
+ >>> cg = convert_indexed_to_array(expr, first_indices=[k])
588
+ >>> convert_array_to_matrix(cg)
589
+ B.T*A
590
+
591
+ Detect the trace:
592
+
593
+ >>> expr = Sum(A[i, i], (i, 0, N-1))
594
+ >>> cg = convert_indexed_to_array(expr)
595
+ >>> convert_array_to_matrix(cg)
596
+ Trace(A)
597
+
598
+ Recognize some more complex traces:
599
+
600
+ >>> expr = Sum(A[i, j]*B[j, i], (i, 0, N-1), (j, 0, N-1))
601
+ >>> cg = convert_indexed_to_array(expr)
602
+ >>> convert_array_to_matrix(cg)
603
+ Trace(A*B)
604
+
605
+ More complicated expressions:
606
+
607
+ >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1))
608
+ >>> cg = convert_indexed_to_array(expr)
609
+ >>> convert_array_to_matrix(cg)
610
+ A*B.T*A.T
611
+
612
+ Expressions constructed from matrix expressions do not contain literal
613
+ indices, the positions of free indices are returned instead:
614
+
615
+ >>> expr = A*B
616
+ >>> cg = convert_matrix_to_array(expr)
617
+ >>> convert_array_to_matrix(cg)
618
+ A*B
619
+
620
+ If more than one line of matrix multiplications is detected, return
621
+ separate matrix multiplication factors embedded in a tensor product object:
622
+
623
+ >>> cg = tensorcontraction(tensorproduct(A, B, C, D), (1, 2), (5, 6))
624
+ >>> convert_array_to_matrix(cg)
625
+ ArrayTensorProduct(A*B, C*D)
626
+
627
+ The two lines have free indices at axes 0, 3 and 4, 7, respectively.
628
+ """
629
+ rec = _array2matrix(expr)
630
+ rec, removed = _remove_trivial_dims(rec)
631
+ return rec
632
+
633
+
634
+ def _array_diag2contr_diagmatrix(expr: ArrayDiagonal):
635
+ if isinstance(expr.expr, ArrayTensorProduct):
636
+ args = list(expr.expr.args)
637
+ diag_indices = list(expr.diagonal_indices)
638
+ mapping = _get_mapping_from_subranks([_get_subrank(arg) for arg in args])
639
+ tuple_links = [[mapping[j] for j in i] for i in diag_indices]
640
+ contr_indices = []
641
+ total_rank = get_rank(expr)
642
+ replaced = [False for arg in args]
643
+ for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)):
644
+ if len(abs_pos) != 2:
645
+ continue
646
+ (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos
647
+ arg1 = args[pos1_outer]
648
+ arg2 = args[pos2_outer]
649
+ if get_rank(arg1) != 2 or get_rank(arg2) != 2:
650
+ if replaced[pos1_outer]:
651
+ diag_indices[i] = None
652
+ if replaced[pos2_outer]:
653
+ diag_indices[i] = None
654
+ continue
655
+ pos1_in2 = 1 - pos1_inner
656
+ pos2_in2 = 1 - pos2_inner
657
+ if arg1.shape[pos1_in2] == 1:
658
+ if arg1.shape[pos1_inner] != 1:
659
+ darg1 = DiagMatrix(arg1)
660
+ else:
661
+ darg1 = arg1
662
+ args.append(darg1)
663
+ contr_indices.append(((pos2_outer, pos2_inner), (len(args)-1, pos1_inner)))
664
+ total_rank += 1
665
+ diag_indices[i] = None
666
+ args[pos1_outer] = OneArray(arg1.shape[pos1_in2])
667
+ replaced[pos1_outer] = True
668
+ elif arg2.shape[pos2_in2] == 1:
669
+ if arg2.shape[pos2_inner] != 1:
670
+ darg2 = DiagMatrix(arg2)
671
+ else:
672
+ darg2 = arg2
673
+ args.append(darg2)
674
+ contr_indices.append(((pos1_outer, pos1_inner), (len(args)-1, pos2_inner)))
675
+ total_rank += 1
676
+ diag_indices[i] = None
677
+ args[pos2_outer] = OneArray(arg2.shape[pos2_in2])
678
+ replaced[pos2_outer] = True
679
+ diag_indices_new = [i for i in diag_indices if i is not None]
680
+ cumul = list(accumulate([0] + [get_rank(arg) for arg in args]))
681
+ contr_indices2 = [tuple(cumul[a] + b for a, b in i) for i in contr_indices]
682
+ tc = _array_contraction(
683
+ _array_tensor_product(*args), *contr_indices2
684
+ )
685
+ td = _array_diagonal(tc, *diag_indices_new)
686
+ return td
687
+ return expr
688
+
689
+
690
+ def _a2m_mul(*args):
691
+ if not any(isinstance(i, _CodegenArrayAbstract) for i in args):
692
+ from sympy.matrices.expressions.matmul import MatMul
693
+ return MatMul(*args).doit()
694
+ else:
695
+ return _array_contraction(
696
+ _array_tensor_product(*args),
697
+ *[(2*i-1, 2*i) for i in range(1, len(args))]
698
+ )
699
+
700
+
701
+ def _a2m_tensor_product(*args):
702
+ scalars = []
703
+ arrays = []
704
+ for arg in args:
705
+ if isinstance(arg, (MatrixExpr, _ArrayExpr, _CodegenArrayAbstract)):
706
+ arrays.append(arg)
707
+ else:
708
+ scalars.append(arg)
709
+ scalar = Mul.fromiter(scalars)
710
+ if len(arrays) == 0:
711
+ return scalar
712
+ if scalar != 1:
713
+ if isinstance(arrays[0], _CodegenArrayAbstract):
714
+ arrays = [scalar] + arrays
715
+ else:
716
+ arrays[0] *= scalar
717
+ return _array_tensor_product(*arrays)
718
+
719
+
720
+ def _a2m_add(*args):
721
+ if not any(isinstance(i, _CodegenArrayAbstract) for i in args):
722
+ from sympy.matrices.expressions.matadd import MatAdd
723
+ return MatAdd(*args).doit()
724
+ else:
725
+ return _array_add(*args)
726
+
727
+
728
+ def _a2m_trace(arg):
729
+ if isinstance(arg, _CodegenArrayAbstract):
730
+ return _array_contraction(arg, (0, 1))
731
+ else:
732
+ from sympy.matrices.expressions.trace import Trace
733
+ return Trace(arg)
734
+
735
+
736
+ def _a2m_transpose(arg):
737
+ if isinstance(arg, _CodegenArrayAbstract):
738
+ return _permute_dims(arg, [1, 0])
739
+ else:
740
+ from sympy.matrices.expressions.transpose import Transpose
741
+ return Transpose(arg).doit()
742
+
743
+
744
+ def identify_hadamard_products(expr: tUnion[ArrayContraction, ArrayDiagonal]):
745
+
746
+ editor: _EditArrayContraction = _EditArrayContraction(expr)
747
+
748
+ map_contr_to_args: tDict[FrozenSet, List[_ArgE]] = defaultdict(list)
749
+ map_ind_to_inds: tDict[Optional[int], int] = defaultdict(int)
750
+ for arg_with_ind in editor.args_with_ind:
751
+ for ind in arg_with_ind.indices:
752
+ map_ind_to_inds[ind] += 1
753
+ if None in arg_with_ind.indices:
754
+ continue
755
+ map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind)
756
+
757
+ k: FrozenSet[int]
758
+ v: List[_ArgE]
759
+ for k, v in map_contr_to_args.items():
760
+ make_trace: bool = False
761
+ if len(k) == 1 and next(iter(k)) >= 0 and sum(next(iter(k)) in i for i in map_contr_to_args) == 1:
762
+ # This is a trace: the arguments are fully contracted with only one
763
+ # index, and the index isn't used anywhere else:
764
+ make_trace = True
765
+ first_element = S.One
766
+ elif len(k) != 2:
767
+ # Hadamard product only defined for matrices:
768
+ continue
769
+ if len(v) == 1:
770
+ # Hadamard product with a single argument makes no sense:
771
+ continue
772
+ for ind in k:
773
+ if map_ind_to_inds[ind] <= 2:
774
+ # There is no other contraction, skip:
775
+ continue
776
+
777
+ def check_transpose(x):
778
+ x = [i if i >= 0 else -1-i for i in x]
779
+ return x == sorted(x)
780
+
781
+ # Check if expression is a trace:
782
+ if all(map_ind_to_inds[j] == len(v) and j >= 0 for j in k) and all(j >= 0 for j in k):
783
+ # This is a trace
784
+ make_trace = True
785
+ first_element = v[0].element
786
+ if not check_transpose(v[0].indices):
787
+ first_element = first_element.T
788
+ hadamard_factors = v[1:]
789
+ else:
790
+ hadamard_factors = v
791
+
792
+ # This is a Hadamard product:
793
+
794
+ hp = hadamard_product(*[i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors])
795
+ hp_indices = v[0].indices
796
+ if not check_transpose(hadamard_factors[0].indices):
797
+ hp_indices = list(reversed(hp_indices))
798
+ if make_trace:
799
+ hp = Trace(first_element*hp.T)._normalize()
800
+ hp_indices = []
801
+ editor.insert_after(v[0], _ArgE(hp, hp_indices))
802
+ for i in v:
803
+ editor.args_with_ind.remove(i)
804
+
805
+ return editor.to_array_contraction()
806
+
807
+
808
+ def identify_removable_identity_matrices(expr):
809
+ editor = _EditArrayContraction(expr)
810
+
811
+ flag = True
812
+ while flag:
813
+ flag = False
814
+ for arg_with_ind in editor.args_with_ind:
815
+ if isinstance(arg_with_ind.element, Identity):
816
+ k = arg_with_ind.element.shape[0]
817
+ # Candidate for removal:
818
+ if arg_with_ind.indices == [None, None]:
819
+ # Free identity matrix, will be cleared by _remove_trivial_dims:
820
+ continue
821
+ elif None in arg_with_ind.indices:
822
+ ind = [j for j in arg_with_ind.indices if j is not None][0]
823
+ counted = editor.count_args_with_index(ind)
824
+ if counted == 1:
825
+ # Identity matrix contracted only on one index with itself,
826
+ # transform to a OneArray(k) element:
827
+ editor.insert_after(arg_with_ind, OneArray(k))
828
+ editor.args_with_ind.remove(arg_with_ind)
829
+ flag = True
830
+ break
831
+ elif counted > 2:
832
+ # Case counted = 2 is a matrix multiplication by identity matrix, skip it.
833
+ # Case counted > 2 is a multiple contraction,
834
+ # this is a case where the contraction becomes a diagonalization if the
835
+ # identity matrix is dropped.
836
+ continue
837
+ elif arg_with_ind.indices[0] == arg_with_ind.indices[1]:
838
+ ind = arg_with_ind.indices[0]
839
+ counted = editor.count_args_with_index(ind)
840
+ if counted > 1:
841
+ editor.args_with_ind.remove(arg_with_ind)
842
+ flag = True
843
+ break
844
+ else:
845
+ # This is a trace, skip it as it will be recognized somewhere else:
846
+ pass
847
+ elif ask(Q.diagonal(arg_with_ind.element)):
848
+ if arg_with_ind.indices == [None, None]:
849
+ continue
850
+ elif None in arg_with_ind.indices:
851
+ pass
852
+ elif arg_with_ind.indices[0] == arg_with_ind.indices[1]:
853
+ ind = arg_with_ind.indices[0]
854
+ counted = editor.count_args_with_index(ind)
855
+ if counted == 3:
856
+ # A_ai B_bi D_ii ==> A_ai D_ij B_bj
857
+ ind_new = editor.get_new_contraction_index()
858
+ other_args = [j for j in editor.args_with_ind if j != arg_with_ind]
859
+ other_args[1].indices = [ind_new if j == ind else j for j in other_args[1].indices]
860
+ arg_with_ind.indices = [ind, ind_new]
861
+ flag = True
862
+ break
863
+
864
+ return editor.to_array_contraction()
865
+
866
+
867
+ def remove_identity_matrices(expr: ArrayContraction):
868
+ editor = _EditArrayContraction(expr)
869
+ removed: List[int] = []
870
+
871
+ permutation_map = {}
872
+
873
+ free_indices = list(accumulate([0] + [sum(i is None for i in arg.indices) for arg in editor.args_with_ind]))
874
+ free_map = dict(zip(editor.args_with_ind, free_indices[:-1]))
875
+
876
+ update_pairs = {}
877
+
878
+ for ind in range(editor.number_of_contraction_indices):
879
+ args = editor.get_args_with_index(ind)
880
+ identity_matrices = [i for i in args if isinstance(i.element, Identity)]
881
+ number_identity_matrices = len(identity_matrices)
882
+ # If the contraction involves a non-identity matrix and multiple identity matrices:
883
+ if number_identity_matrices != len(args) - 1 or number_identity_matrices == 0:
884
+ continue
885
+ # Get the non-identity element:
886
+ non_identity = [i for i in args if not isinstance(i.element, Identity)][0]
887
+ # Check that all identity matrices have at least one free index
888
+ # (otherwise they would be contractions to some other elements)
889
+ if any(None not in i.indices for i in identity_matrices):
890
+ continue
891
+ # Mark the identity matrices for removal:
892
+ for i in identity_matrices:
893
+ i.element = None
894
+ removed.extend(range(free_map[i], free_map[i] + len([j for j in i.indices if j is None])))
895
+ last_removed = removed.pop(-1)
896
+ update_pairs[last_removed, ind] = non_identity.indices[:]
897
+ # Remove the indices from the non-identity matrix, as the contraction
898
+ # no longer exists:
899
+ non_identity.indices = [None if i == ind else i for i in non_identity.indices]
900
+
901
+ removed.sort()
902
+
903
+ shifts = list(accumulate([1 if i in removed else 0 for i in range(get_rank(expr))]))
904
+ for (last_removed, ind), non_identity_indices in update_pairs.items():
905
+ pos = [free_map[non_identity] + i for i, e in enumerate(non_identity_indices) if e == ind]
906
+ assert len(pos) == 1
907
+ for j in pos:
908
+ permutation_map[j] = last_removed
909
+
910
+ editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None]
911
+ ret_expr = editor.to_array_contraction()
912
+ permutation = []
913
+ counter = 0
914
+ counter2 = 0
915
+ for j in range(get_rank(expr)):
916
+ if j in removed:
917
+ continue
918
+ if counter2 in permutation_map:
919
+ target = permutation_map[counter2]
920
+ permutation.append(target - shifts[target])
921
+ counter2 += 1
922
+ else:
923
+ while counter in permutation_map.values():
924
+ counter += 1
925
+ permutation.append(counter)
926
+ counter += 1
927
+ counter2 += 1
928
+ ret_expr2 = _permute_dims(ret_expr, _af_invert(permutation))
929
+ return ret_expr2, removed
930
+
931
+
932
+ def _combine_removed(dim: int, removed1: List[int], removed2: List[int]) -> List[int]:
933
+ # Concatenate two axis removal operations as performed by
934
+ # _remove_trivial_dims,
935
+ removed1 = sorted(removed1)
936
+ removed2 = sorted(removed2)
937
+ i = 0
938
+ j = 0
939
+ removed = []
940
+ while True:
941
+ if j >= len(removed2):
942
+ while i < len(removed1):
943
+ removed.append(removed1[i])
944
+ i += 1
945
+ break
946
+ elif i < len(removed1) and removed1[i] <= i + removed2[j]:
947
+ removed.append(removed1[i])
948
+ i += 1
949
+ else:
950
+ removed.append(i + removed2[j])
951
+ j += 1
952
+ return removed
953
+
954
+
955
+ def _array_contraction_to_diagonal_multiple_identity(expr: ArrayContraction):
956
+ editor = _EditArrayContraction(expr)
957
+ editor.track_permutation_start()
958
+ removed: List[int] = []
959
+ diag_index_counter: int = 0
960
+ for i in range(editor.number_of_contraction_indices):
961
+ identities = []
962
+ args = []
963
+ for j, arg in enumerate(editor.args_with_ind):
964
+ if i not in arg.indices:
965
+ continue
966
+ if isinstance(arg.element, Identity):
967
+ identities.append(arg)
968
+ else:
969
+ args.append(arg)
970
+ if len(identities) == 0:
971
+ continue
972
+ if len(args) + len(identities) < 3:
973
+ continue
974
+ new_diag_ind = -1 - diag_index_counter
975
+ diag_index_counter += 1
976
+ # Variable "flag" to control whether to skip this contraction set:
977
+ flag: bool = True
978
+ for i1, id1 in enumerate(identities):
979
+ if None not in id1.indices:
980
+ flag = True
981
+ break
982
+ free_pos = list(range(*editor.get_absolute_free_range(id1)))[0]
983
+ editor._track_permutation[-1].append(free_pos) # type: ignore
984
+ id1.element = None
985
+ flag = False
986
+ break
987
+ if flag:
988
+ continue
989
+ for arg in identities[:i1] + identities[i1+1:]:
990
+ arg.element = None
991
+ removed.extend(range(*editor.get_absolute_free_range(arg)))
992
+ for arg in args:
993
+ arg.indices = [new_diag_ind if j == i else j for j in arg.indices]
994
+ for j, e in enumerate(editor.args_with_ind):
995
+ if e.element is None:
996
+ editor._track_permutation[j] = None # type: ignore
997
+ editor._track_permutation = [i for i in editor._track_permutation if i is not None] # type: ignore
998
+ # Renumber permutation array form in order to deal with deleted positions:
999
+ remap = {e: i for i, e in enumerate(sorted({k for j in editor._track_permutation for k in j}))}
1000
+ editor._track_permutation = [[remap[j] for j in i] for i in editor._track_permutation]
1001
+ editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None]
1002
+ new_expr = editor.to_array_contraction()
1003
+ return new_expr, removed
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ from sympy import Function
4
+ from sympy.combinatorics.permutations import _af_invert
5
+ from sympy.concrete.summations import Sum
6
+ from sympy.core.add import Add
7
+ from sympy.core.mul import Mul
8
+ from sympy.core.numbers import Integer
9
+ from sympy.core.power import Pow
10
+ from sympy.core.sorting import default_sort_key
11
+ from sympy.functions.special.tensor_functions import KroneckerDelta
12
+ from sympy.tensor.array.expressions import ArrayElementwiseApplyFunc
13
+ from sympy.tensor.indexed import (Indexed, IndexedBase)
14
+ from sympy.combinatorics import Permutation
15
+ from sympy.matrices.expressions.matexpr import MatrixElement
16
+ from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, \
17
+ get_shape, ArrayElement, _array_tensor_product, _array_diagonal, _array_contraction, _array_add, \
18
+ _permute_dims, OneArray, ArrayAdd
19
+ from sympy.tensor.array.expressions.utils import _get_argindex, _get_diagonal_indices
20
+
21
+
22
+ def convert_indexed_to_array(expr, first_indices=None):
23
+ r"""
24
+ Parse indexed expression into a form useful for code generation.
25
+
26
+ Examples
27
+ ========
28
+
29
+ >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
30
+ >>> from sympy import MatrixSymbol, Sum, symbols
31
+
32
+ >>> i, j, k, d = symbols("i j k d")
33
+ >>> M = MatrixSymbol("M", d, d)
34
+ >>> N = MatrixSymbol("N", d, d)
35
+
36
+ Recognize the trace in summation form:
37
+
38
+ >>> expr = Sum(M[i, i], (i, 0, d-1))
39
+ >>> convert_indexed_to_array(expr)
40
+ ArrayContraction(M, (0, 1))
41
+
42
+ Recognize the extraction of the diagonal by using the same index `i` on
43
+ both axes of the matrix:
44
+
45
+ >>> expr = M[i, i]
46
+ >>> convert_indexed_to_array(expr)
47
+ ArrayDiagonal(M, (0, 1))
48
+
49
+ This function can help perform the transformation expressed in two
50
+ different mathematical notations as:
51
+
52
+ `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}`
53
+
54
+ Recognize the matrix multiplication in summation form:
55
+
56
+ >>> expr = Sum(M[i, j]*N[j, k], (j, 0, d-1))
57
+ >>> convert_indexed_to_array(expr)
58
+ ArrayContraction(ArrayTensorProduct(M, N), (1, 2))
59
+
60
+ Specify that ``k`` has to be the starting index:
61
+
62
+ >>> convert_indexed_to_array(expr, first_indices=[k])
63
+ ArrayContraction(ArrayTensorProduct(N, M), (0, 3))
64
+ """
65
+
66
+ result, indices = _convert_indexed_to_array(expr)
67
+
68
+ if any(isinstance(i, (int, Integer)) for i in indices):
69
+ result = ArrayElement(result, indices)
70
+ indices = []
71
+
72
+ if not first_indices:
73
+ return result
74
+
75
+ def _check_is_in(elem, indices):
76
+ if elem in indices:
77
+ return True
78
+ if any(elem in i for i in indices if isinstance(i, frozenset)):
79
+ return True
80
+ return False
81
+
82
+ repl = {j: i for i in indices if isinstance(i, frozenset) for j in i}
83
+ first_indices = [repl.get(i, i) for i in first_indices]
84
+ for i in first_indices:
85
+ if not _check_is_in(i, indices):
86
+ first_indices.remove(i)
87
+ first_indices.extend([i for i in indices if not _check_is_in(i, first_indices)])
88
+
89
+ def _get_pos(elem, indices):
90
+ if elem in indices:
91
+ return indices.index(elem)
92
+ for i, e in enumerate(indices):
93
+ if not isinstance(e, frozenset):
94
+ continue
95
+ if elem in e:
96
+ return i
97
+ raise ValueError("not found")
98
+
99
+ permutation = _af_invert([_get_pos(i, first_indices) for i in indices])
100
+ if isinstance(result, ArrayAdd):
101
+ return _array_add(*[_permute_dims(arg, permutation) for arg in result.args])
102
+ else:
103
+ return _permute_dims(result, permutation)
104
+
105
+
106
+ def _convert_indexed_to_array(expr):
107
+ if isinstance(expr, Sum):
108
+ function = expr.function
109
+ summation_indices = expr.variables
110
+ subexpr, subindices = _convert_indexed_to_array(function)
111
+ subindicessets = {j: i for i in subindices if isinstance(i, frozenset) for j in i}
112
+ summation_indices = sorted({subindicessets.get(i, i) for i in summation_indices}, key=default_sort_key)
113
+ # TODO: check that Kronecker delta is only contracted to one other element:
114
+ kronecker_indices = set()
115
+ if isinstance(function, Mul):
116
+ for arg in function.args:
117
+ if not isinstance(arg, KroneckerDelta):
118
+ continue
119
+ arg_indices = sorted(set(arg.indices), key=default_sort_key)
120
+ if len(arg_indices) == 2:
121
+ kronecker_indices.update(arg_indices)
122
+ kronecker_indices = sorted(kronecker_indices, key=default_sort_key)
123
+ # Check dimensional consistency:
124
+ shape = get_shape(subexpr)
125
+ if shape:
126
+ for ind, istart, iend in expr.limits:
127
+ i = _get_argindex(subindices, ind)
128
+ if istart != 0 or iend+1 != shape[i]:
129
+ raise ValueError("summation index and array dimension mismatch: %s" % ind)
130
+ contraction_indices = []
131
+ subindices = list(subindices)
132
+ if isinstance(subexpr, ArrayDiagonal):
133
+ diagonal_indices = list(subexpr.diagonal_indices)
134
+ dindices = subindices[-len(diagonal_indices):]
135
+ subindices = subindices[:-len(diagonal_indices)]
136
+ for index in summation_indices:
137
+ if index in dindices:
138
+ position = dindices.index(index)
139
+ contraction_indices.append(diagonal_indices[position])
140
+ diagonal_indices[position] = None
141
+ diagonal_indices = [i for i in diagonal_indices if i is not None]
142
+ for i, ind in enumerate(subindices):
143
+ if ind in summation_indices:
144
+ pass
145
+ if diagonal_indices:
146
+ subexpr = _array_diagonal(subexpr.expr, *diagonal_indices)
147
+ else:
148
+ subexpr = subexpr.expr
149
+
150
+ axes_contraction = defaultdict(list)
151
+ for i, ind in enumerate(subindices):
152
+ include = all(j not in kronecker_indices for j in ind) if isinstance(ind, frozenset) else ind not in kronecker_indices
153
+ if ind in summation_indices and include:
154
+ axes_contraction[ind].append(i)
155
+ subindices[i] = None
156
+ for k, v in axes_contraction.items():
157
+ if any(i in kronecker_indices for i in k) if isinstance(k, frozenset) else k in kronecker_indices:
158
+ continue
159
+ contraction_indices.append(tuple(v))
160
+ free_indices = [i for i in subindices if i is not None]
161
+ indices_ret = list(free_indices)
162
+ indices_ret.sort(key=lambda x: free_indices.index(x))
163
+ return _array_contraction(
164
+ subexpr,
165
+ *contraction_indices,
166
+ free_indices=free_indices
167
+ ), tuple(indices_ret)
168
+ if isinstance(expr, Mul):
169
+ args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args])
170
+ # Check if there are KroneckerDelta objects:
171
+ kronecker_delta_repl = {}
172
+ for arg in args:
173
+ if not isinstance(arg, KroneckerDelta):
174
+ continue
175
+ # Diagonalize two indices:
176
+ i, j = arg.indices
177
+ kindices = set(arg.indices)
178
+ if i in kronecker_delta_repl:
179
+ kindices.update(kronecker_delta_repl[i])
180
+ if j in kronecker_delta_repl:
181
+ kindices.update(kronecker_delta_repl[j])
182
+ kindices = frozenset(kindices)
183
+ for index in kindices:
184
+ kronecker_delta_repl[index] = kindices
185
+ # Remove KroneckerDelta objects, their relations should be handled by
186
+ # ArrayDiagonal:
187
+ newargs = []
188
+ newindices = []
189
+ for arg, loc_indices in zip(args, indices):
190
+ if isinstance(arg, KroneckerDelta):
191
+ continue
192
+ newargs.append(arg)
193
+ newindices.append(loc_indices)
194
+ flattened_indices = [kronecker_delta_repl.get(j, j) for i in newindices for j in i]
195
+ diagonal_indices, ret_indices = _get_diagonal_indices(flattened_indices)
196
+ tp = _array_tensor_product(*newargs)
197
+ if diagonal_indices:
198
+ return _array_diagonal(tp, *diagonal_indices), ret_indices
199
+ else:
200
+ return tp, ret_indices
201
+ if isinstance(expr, MatrixElement):
202
+ indices = expr.args[1:]
203
+ diagonal_indices, ret_indices = _get_diagonal_indices(indices)
204
+ if diagonal_indices:
205
+ return _array_diagonal(expr.args[0], *diagonal_indices), ret_indices
206
+ else:
207
+ return expr.args[0], ret_indices
208
+ if isinstance(expr, ArrayElement):
209
+ indices = expr.indices
210
+ diagonal_indices, ret_indices = _get_diagonal_indices(indices)
211
+ if diagonal_indices:
212
+ return _array_diagonal(expr.name, *diagonal_indices), ret_indices
213
+ else:
214
+ return expr.name, ret_indices
215
+ if isinstance(expr, Indexed):
216
+ indices = expr.indices
217
+ diagonal_indices, ret_indices = _get_diagonal_indices(indices)
218
+ if diagonal_indices:
219
+ return _array_diagonal(expr.base, *diagonal_indices), ret_indices
220
+ else:
221
+ return expr.args[0], ret_indices
222
+ if isinstance(expr, IndexedBase):
223
+ raise NotImplementedError
224
+ if isinstance(expr, KroneckerDelta):
225
+ return expr, expr.indices
226
+ if isinstance(expr, Add):
227
+ args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args])
228
+ args = list(args)
229
+ # Check if all indices are compatible. Otherwise expand the dimensions:
230
+ index0 = []
231
+ shape0 = []
232
+ for arg, arg_indices in zip(args, indices):
233
+ arg_indices_set = set(arg_indices)
234
+ arg_indices_missing = arg_indices_set.difference(index0)
235
+ index0.extend([i for i in arg_indices if i in arg_indices_missing])
236
+ arg_shape = get_shape(arg)
237
+ shape0.extend([arg_shape[i] for i, e in enumerate(arg_indices) if e in arg_indices_missing])
238
+ for i, (arg, arg_indices) in enumerate(zip(args, indices)):
239
+ if len(arg_indices) < len(index0):
240
+ missing_indices_pos = [i for i, e in enumerate(index0) if e not in arg_indices]
241
+ missing_shape = [shape0[i] for i in missing_indices_pos]
242
+ arg_indices = tuple(index0[j] for j in missing_indices_pos) + arg_indices
243
+ args[i] = _array_tensor_product(OneArray(*missing_shape), args[i])
244
+ permutation = Permutation([arg_indices.index(j) for j in index0])
245
+ # Perform index permutations:
246
+ args[i] = _permute_dims(args[i], permutation)
247
+ return _array_add(*args), tuple(index0)
248
+ if isinstance(expr, Pow):
249
+ subexpr, subindices = _convert_indexed_to_array(expr.base)
250
+ if isinstance(expr.exp, (int, Integer)):
251
+ diags = zip(*[(2*i, 2*i + 1) for i in range(expr.exp)])
252
+ arr = _array_diagonal(_array_tensor_product(*[subexpr for i in range(expr.exp)]), *diags)
253
+ return arr, subindices
254
+ if isinstance(expr, Function):
255
+ subexpr, subindices = _convert_indexed_to_array(expr.args[0])
256
+ return ArrayElementwiseApplyFunc(type(expr), subexpr), subindices
257
+ return expr, ()
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy import KroneckerProduct
2
+ from sympy.core.basic import Basic
3
+ from sympy.core.function import Lambda
4
+ from sympy.core.mul import Mul
5
+ from sympy.core.numbers import Integer
6
+ from sympy.core.power import Pow
7
+ from sympy.core.singleton import S
8
+ from sympy.core.symbol import (Dummy, symbols)
9
+ from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct)
10
+ from sympy.matrices.expressions.matadd import MatAdd
11
+ from sympy.matrices.expressions.matmul import MatMul
12
+ from sympy.matrices.expressions.matpow import MatPow
13
+ from sympy.matrices.expressions.trace import Trace
14
+ from sympy.matrices.expressions.transpose import Transpose
15
+ from sympy.matrices.expressions.matexpr import MatrixExpr
16
+ from sympy.tensor.array.expressions.array_expressions import \
17
+ ArrayElementwiseApplyFunc, _array_tensor_product, _array_contraction, \
18
+ _array_diagonal, _array_add, _permute_dims, Reshape
19
+
20
+
21
+ def convert_matrix_to_array(expr: Basic) -> Basic:
22
+ if isinstance(expr, MatMul):
23
+ args_nonmat = []
24
+ args = []
25
+ for arg in expr.args:
26
+ if isinstance(arg, MatrixExpr):
27
+ args.append(arg)
28
+ else:
29
+ args_nonmat.append(convert_matrix_to_array(arg))
30
+ contractions = [(2*i+1, 2*i+2) for i in range(len(args)-1)]
31
+ scalar = _array_tensor_product(*args_nonmat) if args_nonmat else S.One
32
+ if scalar == 1:
33
+ tprod = _array_tensor_product(
34
+ *[convert_matrix_to_array(arg) for arg in args])
35
+ else:
36
+ tprod = _array_tensor_product(
37
+ scalar,
38
+ *[convert_matrix_to_array(arg) for arg in args])
39
+ return _array_contraction(
40
+ tprod,
41
+ *contractions
42
+ )
43
+ elif isinstance(expr, MatAdd):
44
+ return _array_add(
45
+ *[convert_matrix_to_array(arg) for arg in expr.args]
46
+ )
47
+ elif isinstance(expr, Transpose):
48
+ return _permute_dims(
49
+ convert_matrix_to_array(expr.args[0]), [1, 0]
50
+ )
51
+ elif isinstance(expr, Trace):
52
+ inner_expr: MatrixExpr = convert_matrix_to_array(expr.arg) # type: ignore
53
+ return _array_contraction(inner_expr, (0, len(inner_expr.shape) - 1))
54
+ elif isinstance(expr, Mul):
55
+ return _array_tensor_product(*[convert_matrix_to_array(i) for i in expr.args])
56
+ elif isinstance(expr, Pow):
57
+ base = convert_matrix_to_array(expr.base)
58
+ if (expr.exp > 0) == True:
59
+ return _array_tensor_product(*[base for i in range(expr.exp)])
60
+ else:
61
+ return expr
62
+ elif isinstance(expr, MatPow):
63
+ base = convert_matrix_to_array(expr.base)
64
+ if expr.exp.is_Integer != True:
65
+ b = symbols("b", cls=Dummy)
66
+ return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp), convert_matrix_to_array(base))
67
+ elif (expr.exp > 0) == True:
68
+ return convert_matrix_to_array(MatMul.fromiter(base for i in range(expr.exp)))
69
+ else:
70
+ return expr
71
+ elif isinstance(expr, HadamardProduct):
72
+ tp = _array_tensor_product(*[convert_matrix_to_array(arg) for arg in expr.args])
73
+ diag = [[2*i for i in range(len(expr.args))], [2*i+1 for i in range(len(expr.args))]]
74
+ return _array_diagonal(tp, *diag)
75
+ elif isinstance(expr, HadamardPower):
76
+ base, exp = expr.args
77
+ if isinstance(exp, Integer) and exp > 0:
78
+ return convert_matrix_to_array(HadamardProduct.fromiter(base for i in range(exp)))
79
+ else:
80
+ d = Dummy("d")
81
+ return ArrayElementwiseApplyFunc(Lambda(d, d**exp), base)
82
+ elif isinstance(expr, KroneckerProduct):
83
+ kp_args = [convert_matrix_to_array(arg) for arg in expr.args]
84
+ permutation = [2*i for i in range(len(kp_args))] + [2*i + 1 for i in range(len(kp_args))]
85
+ return Reshape(_permute_dims(_array_tensor_product(*kp_args), permutation), expr.shape)
86
+ else:
87
+ return expr
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (209 Bytes). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-311.pyc ADDED
Binary file (63.7 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-311.pyc ADDED
Binary file (5.03 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-311.pyc ADDED
Binary file (6.31 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-311.pyc ADDED
Binary file (5.87 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-311.pyc ADDED
Binary file (57.5 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-311.pyc ADDED
Binary file (22.1 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-311.pyc ADDED
Binary file (8.6 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-311.pyc ADDED
Binary file (2.65 kB). View file
 
.venv/lib/python3.11/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from sympy import tensordiagonal, eye, KroneckerDelta, Array
4
+ from sympy.core.symbol import symbols
5
+ from sympy.functions.elementary.trigonometric import (cos, sin)
6
+ from sympy.matrices.expressions.diagonal import DiagMatrix
7
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
8
+ from sympy.matrices.expressions.special import ZeroMatrix
9
+ from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensorproduct)
10
+ from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray
11
+ from sympy.combinatorics import Permutation
12
+ from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, ArraySymbol, ArrayElement, \
13
+ PermuteDims, ArrayContraction, ArrayTensorProduct, ArrayDiagonal, \
14
+ ArrayAdd, nest_permutation, ArrayElementwiseApplyFunc, _EditArrayContraction, _ArgE, _array_tensor_product, \
15
+ _array_contraction, _array_diagonal, _array_add, _permute_dims, Reshape
16
+ from sympy.testing.pytest import raises
17
+
18
+ i, j, k, l, m, n = symbols("i j k l m n")
19
+
20
+
21
+ M = ArraySymbol("M", (k, k))
22
+ N = ArraySymbol("N", (k, k))
23
+ P = ArraySymbol("P", (k, k))
24
+ Q = ArraySymbol("Q", (k, k))
25
+
26
+ A = ArraySymbol("A", (k, k))
27
+ B = ArraySymbol("B", (k, k))
28
+ C = ArraySymbol("C", (k, k))
29
+ D = ArraySymbol("D", (k, k))
30
+
31
+ X = ArraySymbol("X", (k, k))
32
+ Y = ArraySymbol("Y", (k, k))
33
+
34
+ a = ArraySymbol("a", (k, 1))
35
+ b = ArraySymbol("b", (k, 1))
36
+ c = ArraySymbol("c", (k, 1))
37
+ d = ArraySymbol("d", (k, 1))
38
+
39
+
40
+ def test_array_symbol_and_element():
41
+ A = ArraySymbol("A", (2,))
42
+ A0 = ArrayElement(A, (0,))
43
+ A1 = ArrayElement(A, (1,))
44
+ assert A[0] == A0
45
+ assert A[1] != A0
46
+ assert A.as_explicit() == ImmutableDenseNDimArray([A0, A1])
47
+
48
+ A2 = tensorproduct(A, A)
49
+ assert A2.shape == (2, 2)
50
+ # TODO: not yet supported:
51
+ # assert A2.as_explicit() == Array([[A[0]*A[0], A[1]*A[0]], [A[0]*A[1], A[1]*A[1]]])
52
+ A3 = tensorcontraction(A2, (0, 1))
53
+ assert A3.shape == ()
54
+ # TODO: not yet supported:
55
+ # assert A3.as_explicit() == Array([])
56
+
57
+ A = ArraySymbol("A", (2, 3, 4))
58
+ Ae = A.as_explicit()
59
+ assert Ae == ImmutableDenseNDimArray(
60
+ [[[ArrayElement(A, (i, j, k)) for k in range(4)] for j in range(3)] for i in range(2)])
61
+
62
+ p = _permute_dims(A, Permutation(0, 2, 1))
63
+ assert isinstance(p, PermuteDims)
64
+
65
+ A = ArraySymbol("A", (2,))
66
+ raises(IndexError, lambda: A[()])
67
+ raises(IndexError, lambda: A[0, 1])
68
+ raises(ValueError, lambda: A[-1])
69
+ raises(ValueError, lambda: A[2])
70
+
71
+ O = OneArray(3, 4)
72
+ Z = ZeroArray(m, n)
73
+
74
+ raises(IndexError, lambda: O[()])
75
+ raises(IndexError, lambda: O[1, 2, 3])
76
+ raises(ValueError, lambda: O[3, 0])
77
+ raises(ValueError, lambda: O[0, 4])
78
+
79
+ assert O[1, 2] == 1
80
+ assert Z[1, 2] == 0
81
+
82
+
83
+ def test_zero_array():
84
+ assert ZeroArray() == 0
85
+ assert ZeroArray().is_Integer
86
+
87
+ za = ZeroArray(3, 2, 4)
88
+ assert za.shape == (3, 2, 4)
89
+ za_e = za.as_explicit()
90
+ assert za_e.shape == (3, 2, 4)
91
+
92
+ m, n, k = symbols("m n k")
93
+ za = ZeroArray(m, n, k, 2)
94
+ assert za.shape == (m, n, k, 2)
95
+ raises(ValueError, lambda: za.as_explicit())
96
+
97
+
98
+ def test_one_array():
99
+ assert OneArray() == 1
100
+ assert OneArray().is_Integer
101
+
102
+ oa = OneArray(3, 2, 4)
103
+ assert oa.shape == (3, 2, 4)
104
+ oa_e = oa.as_explicit()
105
+ assert oa_e.shape == (3, 2, 4)
106
+
107
+ m, n, k = symbols("m n k")
108
+ oa = OneArray(m, n, k, 2)
109
+ assert oa.shape == (m, n, k, 2)
110
+ raises(ValueError, lambda: oa.as_explicit())
111
+
112
+
113
+ def test_arrayexpr_contraction_construction():
114
+
115
+ cg = _array_contraction(A)
116
+ assert cg == A
117
+
118
+ cg = _array_contraction(_array_tensor_product(A, B), (1, 0))
119
+ assert cg == _array_contraction(_array_tensor_product(A, B), (0, 1))
120
+
121
+ cg = _array_contraction(_array_tensor_product(M, N), (0, 1))
122
+ indtup = cg._get_contraction_tuples()
123
+ assert indtup == [[(0, 0), (0, 1)]]
124
+ assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 1)]
125
+
126
+ cg = _array_contraction(_array_tensor_product(M, N), (1, 2))
127
+ indtup = cg._get_contraction_tuples()
128
+ assert indtup == [[(0, 1), (1, 0)]]
129
+ assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(1, 2)]
130
+
131
+ cg = _array_contraction(_array_tensor_product(M, M, N), (1, 4), (2, 5))
132
+ indtup = cg._get_contraction_tuples()
133
+ assert indtup == [[(0, 0), (1, 1)], [(0, 1), (2, 0)]]
134
+ assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 3), (1, 4)]
135
+
136
+ # Test removal of trivial contraction:
137
+ assert _array_contraction(a, (1,)) == a
138
+ assert _array_contraction(
139
+ _array_tensor_product(a, b), (0, 2), (1,), (3,)) == _array_contraction(
140
+ _array_tensor_product(a, b), (0, 2))
141
+
142
+
143
+ def test_arrayexpr_array_flatten():
144
+
145
+ # Flatten nested ArrayTensorProduct objects:
146
+ expr1 = _array_tensor_product(M, N)
147
+ expr2 = _array_tensor_product(P, Q)
148
+ expr = _array_tensor_product(expr1, expr2)
149
+ assert expr == _array_tensor_product(M, N, P, Q)
150
+ assert expr.args == (M, N, P, Q)
151
+
152
+ # Flatten mixed ArrayTensorProduct and ArrayContraction objects:
153
+ cg1 = _array_contraction(expr1, (1, 2))
154
+ cg2 = _array_contraction(expr2, (0, 3))
155
+
156
+ expr = _array_tensor_product(cg1, cg2)
157
+ assert expr == _array_contraction(_array_tensor_product(M, N, P, Q), (1, 2), (4, 7))
158
+
159
+ expr = _array_tensor_product(M, cg1)
160
+ assert expr == _array_contraction(_array_tensor_product(M, M, N), (3, 4))
161
+
162
+ # Flatten nested ArrayContraction objects:
163
+ cgnested = _array_contraction(cg1, (0, 1))
164
+ assert cgnested == _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2))
165
+
166
+ cgnested = _array_contraction(_array_tensor_product(cg1, cg2), (0, 3))
167
+ assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 6), (1, 2), (4, 7))
168
+
169
+ cg3 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4))
170
+ cgnested = _array_contraction(cg3, (0, 1))
171
+ assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 5), (1, 3), (2, 4))
172
+
173
+ cgnested = _array_contraction(cg3, (0, 3), (1, 2))
174
+ assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 7), (1, 3), (2, 4), (5, 6))
175
+
176
+ cg4 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7))
177
+ cgnested = _array_contraction(cg4, (0, 1))
178
+ assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7))
179
+
180
+ cgnested = _array_contraction(cg4, (0, 1), (2, 3))
181
+ assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7), (4, 6))
182
+
183
+ cg = _array_diagonal(cg4)
184
+ assert cg == cg4
185
+ assert isinstance(cg, type(cg4))
186
+
187
+ # Flatten nested ArrayDiagonal objects:
188
+ cg1 = _array_diagonal(expr1, (1, 2))
189
+ cg2 = _array_diagonal(expr2, (0, 3))
190
+ cg3 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4))
191
+ cg4 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7))
192
+
193
+ cgnested = _array_diagonal(cg1, (0, 1))
194
+ assert cgnested == _array_diagonal(_array_tensor_product(M, N), (1, 2), (0, 3))
195
+
196
+ cgnested = _array_diagonal(cg3, (1, 2))
197
+ assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4), (5, 6))
198
+
199
+ cgnested = _array_diagonal(cg4, (1, 2))
200
+ assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7), (2, 4))
201
+
202
+ cg = _array_add(M, N)
203
+ cg2 = _array_add(cg, P)
204
+ assert isinstance(cg2, ArrayAdd)
205
+ assert cg2.args == (M, N, P)
206
+ assert cg2.shape == (k, k)
207
+
208
+ expr = _array_tensor_product(_array_diagonal(X, (0, 1)), _array_diagonal(A, (0, 1)))
209
+ assert expr == _array_diagonal(_array_tensor_product(X, A), (0, 1), (2, 3))
210
+
211
+ expr1 = _array_diagonal(_array_tensor_product(X, A), (1, 2))
212
+ expr2 = _array_tensor_product(expr1, a)
213
+ assert expr2 == _permute_dims(_array_diagonal(_array_tensor_product(X, A, a), (1, 2)), [0, 1, 4, 2, 3])
214
+
215
+ expr1 = _array_contraction(_array_tensor_product(X, A), (1, 2))
216
+ expr2 = _array_tensor_product(expr1, a)
217
+ assert isinstance(expr2, ArrayContraction)
218
+ assert isinstance(expr2.expr, ArrayTensorProduct)
219
+
220
+ cg = _array_tensor_product(_array_diagonal(_array_tensor_product(A, X, Y), (0, 3), (1, 5)), a, b)
221
+ assert cg == _permute_dims(_array_diagonal(_array_tensor_product(A, X, Y, a, b), (0, 3), (1, 5)), [0, 1, 6, 7, 2, 3, 4, 5])
222
+
223
+
224
+ def test_arrayexpr_array_diagonal():
225
+ cg = _array_diagonal(M, (1, 0))
226
+ assert cg == _array_diagonal(M, (0, 1))
227
+
228
+ cg = _array_diagonal(_array_tensor_product(M, N, P), (4, 1), (2, 0))
229
+ assert cg == _array_diagonal(_array_tensor_product(M, N, P), (1, 4), (0, 2))
230
+
231
+ cg = _array_diagonal(_array_tensor_product(M, N), (1, 2), (3,), allow_trivial_diags=True)
232
+ assert cg == _permute_dims(_array_diagonal(_array_tensor_product(M, N), (1, 2)), [0, 2, 1])
233
+
234
+ Ax = ArraySymbol("Ax", shape=(1, 2, 3, 4, 3, 5, 6, 2, 7))
235
+ cg = _array_diagonal(Ax, (1, 7), (3,), (2, 4), (6,), allow_trivial_diags=True)
236
+ assert cg == _permute_dims(_array_diagonal(Ax, (1, 7), (2, 4)), [0, 2, 4, 5, 1, 6, 3])
237
+
238
+ cg = _array_diagonal(M, (0,), allow_trivial_diags=True)
239
+ assert cg == _permute_dims(M, [1, 0])
240
+
241
+ raises(ValueError, lambda: _array_diagonal(M, (0, 0)))
242
+
243
+
244
+ def test_arrayexpr_array_shape():
245
+ expr = _array_tensor_product(M, N, P, Q)
246
+ assert expr.shape == (k, k, k, k, k, k, k, k)
247
+ Z = MatrixSymbol("Z", m, n)
248
+ expr = _array_tensor_product(M, Z)
249
+ assert expr.shape == (k, k, m, n)
250
+ expr2 = _array_contraction(expr, (0, 1))
251
+ assert expr2.shape == (m, n)
252
+ expr2 = _array_diagonal(expr, (0, 1))
253
+ assert expr2.shape == (m, n, k)
254
+ exprp = _permute_dims(expr, [2, 1, 3, 0])
255
+ assert exprp.shape == (m, k, n, k)
256
+ expr3 = _array_tensor_product(N, Z)
257
+ expr2 = _array_add(expr, expr3)
258
+ assert expr2.shape == (k, k, m, n)
259
+
260
+ # Contraction along axes with discordant dimensions:
261
+ raises(ValueError, lambda: _array_contraction(expr, (1, 2)))
262
+ # Also diagonal needs the same dimensions:
263
+ raises(ValueError, lambda: _array_diagonal(expr, (1, 2)))
264
+ # Diagonal requires at least to axes to compute the diagonal:
265
+ raises(ValueError, lambda: _array_diagonal(expr, (1,)))
266
+
267
+
268
+ def test_arrayexpr_permutedims_sink():
269
+
270
+ cg = _permute_dims(_array_tensor_product(M, N), [0, 1, 3, 2], nest_permutation=False)
271
+ sunk = nest_permutation(cg)
272
+ assert sunk == _array_tensor_product(M, _permute_dims(N, [1, 0]))
273
+
274
+ cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False)
275
+ sunk = nest_permutation(cg)
276
+ assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0]))
277
+
278
+ cg = _permute_dims(_array_tensor_product(M, N), [3, 2, 1, 0], nest_permutation=False)
279
+ sunk = nest_permutation(cg)
280
+ assert sunk == _array_tensor_product(_permute_dims(N, [1, 0]), _permute_dims(M, [1, 0]))
281
+
282
+ cg = _permute_dims(_array_contraction(_array_tensor_product(M, N), (1, 2)), [1, 0], nest_permutation=False)
283
+ sunk = nest_permutation(cg)
284
+ assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N), [[0, 3]]), (1, 2))
285
+
286
+ cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False)
287
+ sunk = nest_permutation(cg)
288
+ assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0]))
289
+
290
+ cg = _permute_dims(_array_contraction(_array_tensor_product(M, N, P), (1, 2), (3, 4)), [1, 0], nest_permutation=False)
291
+ sunk = nest_permutation(cg)
292
+ assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N, P), [[0, 5]]), (1, 2), (3, 4))
293
+
294
+
295
+ def test_arrayexpr_push_indices_up_and_down():
296
+
297
+ indices = list(range(12))
298
+
299
+ contr_diag_indices = [(0, 6), (2, 8)]
300
+ assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (1, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14, 15)
301
+ assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (None, 0, None, 1, 2, 3, None, 4, None, 5, 6, 7)
302
+
303
+ assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (1, 3, 4, 5, 7, 9, (0, 6), (2, 8), None, None, None, None)
304
+ assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (6, 0, 7, 1, 2, 3, 6, 4, 7, 5, None, None)
305
+
306
+ contr_diag_indices = [(1, 2), (7, 8)]
307
+ assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (0, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15)
308
+ assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (0, None, None, 1, 2, 3, 4, None, None, 5, 6, 7)
309
+
310
+ assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (0, 3, 4, 5, 6, 9, (1, 2), (7, 8), None, None, None, None)
311
+ assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (0, 6, 6, 1, 2, 3, 4, 7, 7, 5, None, None)
312
+
313
+
314
+ def test_arrayexpr_split_multiple_contractions():
315
+ a = MatrixSymbol("a", k, 1)
316
+ b = MatrixSymbol("b", k, 1)
317
+ A = MatrixSymbol("A", k, k)
318
+ B = MatrixSymbol("B", k, k)
319
+ C = MatrixSymbol("C", k, k)
320
+ X = MatrixSymbol("X", k, k)
321
+
322
+ cg = _array_contraction(_array_tensor_product(A.T, a, b, b.T, (A*X*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9))
323
+ expected = _array_contraction(_array_tensor_product(A.T, DiagMatrix(a), OneArray(1), b, b.T, (A*X*b).applyfunc(cos)), (1, 3), (2, 9), (6, 7, 10))
324
+ assert cg.split_multiple_contractions().dummy_eq(expected)
325
+
326
+ # Check no overlap of lines:
327
+
328
+ cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8), (3, 7))
329
+ assert cg.split_multiple_contractions() == cg
330
+
331
+ cg = _array_contraction(_array_tensor_product(a, b, A), (0, 2, 4), (1, 3))
332
+ assert cg.split_multiple_contractions() == cg
333
+
334
+
335
+ def test_arrayexpr_nested_permutations():
336
+
337
+ cg = _permute_dims(_permute_dims(M, (1, 0)), (1, 0))
338
+ assert cg == M
339
+
340
+ times = 3
341
+ plist1 = [list(range(6)) for i in range(times)]
342
+ plist2 = [list(range(6)) for i in range(times)]
343
+
344
+ for i in range(times):
345
+ random.shuffle(plist1[i])
346
+ random.shuffle(plist2[i])
347
+
348
+ plist1.append([2, 5, 4, 1, 0, 3])
349
+ plist2.append([3, 5, 0, 4, 1, 2])
350
+
351
+ plist1.append([2, 5, 4, 0, 3, 1])
352
+ plist2.append([3, 0, 5, 1, 2, 4])
353
+
354
+ plist1.append([5, 4, 2, 0, 3, 1])
355
+ plist2.append([4, 5, 0, 2, 3, 1])
356
+
357
+ Me = M.subs(k, 3).as_explicit()
358
+ Ne = N.subs(k, 3).as_explicit()
359
+ Pe = P.subs(k, 3).as_explicit()
360
+ cge = tensorproduct(Me, Ne, Pe)
361
+
362
+ for permutation_array1, permutation_array2 in zip(plist1, plist2):
363
+ p1 = Permutation(permutation_array1)
364
+ p2 = Permutation(permutation_array2)
365
+
366
+ cg = _permute_dims(
367
+ _permute_dims(
368
+ _array_tensor_product(M, N, P),
369
+ p1),
370
+ p2
371
+ )
372
+ result = _permute_dims(
373
+ _array_tensor_product(M, N, P),
374
+ p2*p1
375
+ )
376
+ assert cg == result
377
+
378
+ # Check that `permutedims` behaves the same way with explicit-component arrays:
379
+ result1 = _permute_dims(_permute_dims(cge, p1), p2)
380
+ result2 = _permute_dims(cge, p2*p1)
381
+ assert result1 == result2
382
+
383
+
384
+ def test_arrayexpr_contraction_permutation_mix():
385
+
386
+ Me = M.subs(k, 3).as_explicit()
387
+ Ne = N.subs(k, 3).as_explicit()
388
+
389
+ cg1 = _array_contraction(PermuteDims(_array_tensor_product(M, N), Permutation([0, 2, 1, 3])), (2, 3))
390
+ cg2 = _array_contraction(_array_tensor_product(M, N), (1, 3))
391
+ assert cg1 == cg2
392
+ cge1 = tensorcontraction(permutedims(tensorproduct(Me, Ne), Permutation([0, 2, 1, 3])), (2, 3))
393
+ cge2 = tensorcontraction(tensorproduct(Me, Ne), (1, 3))
394
+ assert cge1 == cge2
395
+
396
+ cg1 = _permute_dims(_array_tensor_product(M, N), Permutation([0, 1, 3, 2]))
397
+ cg2 = _array_tensor_product(M, _permute_dims(N, Permutation([1, 0])))
398
+ assert cg1 == cg2
399
+
400
+ cg1 = _array_contraction(
401
+ _permute_dims(
402
+ _array_tensor_product(M, N, P, Q), Permutation([0, 2, 3, 1, 4, 5, 7, 6])),
403
+ (1, 2), (3, 5)
404
+ )
405
+ cg2 = _array_contraction(
406
+ _array_tensor_product(M, N, P, _permute_dims(Q, Permutation([1, 0]))),
407
+ (1, 5), (2, 3)
408
+ )
409
+ assert cg1 == cg2
410
+
411
+ cg1 = _array_contraction(
412
+ _permute_dims(
413
+ _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 2, 7, 5, 3])),
414
+ (0, 1), (2, 6), (3, 7)
415
+ )
416
+ cg2 = _permute_dims(
417
+ _array_contraction(
418
+ _array_tensor_product(M, P, Q, N),
419
+ (0, 1), (2, 3), (4, 7)),
420
+ [1, 0]
421
+ )
422
+ assert cg1 == cg2
423
+
424
+ cg1 = _array_contraction(
425
+ _permute_dims(
426
+ _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 7, 2, 5, 3])),
427
+ (0, 1), (2, 6), (3, 7)
428
+ )
429
+ cg2 = _permute_dims(
430
+ _array_contraction(
431
+ _array_tensor_product(_permute_dims(M, [1, 0]), N, P, Q),
432
+ (0, 1), (3, 6), (4, 5)
433
+ ),
434
+ Permutation([1, 0])
435
+ )
436
+ assert cg1 == cg2
437
+
438
+
439
+ def test_arrayexpr_permute_tensor_product():
440
+ cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 1, 0, 5, 4, 6, 7]))
441
+ cg2 = _array_tensor_product(N, _permute_dims(M, [1, 0]),
442
+ _permute_dims(P, [1, 0]), Q)
443
+ assert cg1 == cg2
444
+
445
+ # TODO: reverse operation starting with `PermuteDims` and getting down to `bb`...
446
+ cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 5, 0, 1, 6, 7]))
447
+ cg2 = _array_tensor_product(N, P, M, Q)
448
+ assert cg1 == cg2
449
+
450
+ cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 6, 5, 7, 0, 1]))
451
+ assert cg1.expr == _array_tensor_product(N, P, Q, M)
452
+ assert cg1.permutation == Permutation([0, 1, 2, 4, 3, 5, 6, 7])
453
+
454
+ cg1 = _array_contraction(
455
+ _permute_dims(
456
+ _array_tensor_product(N, Q, Q, M),
457
+ [2, 1, 5, 4, 0, 3, 6, 7]),
458
+ [1, 2, 6])
459
+ cg2 = _permute_dims(_array_contraction(_array_tensor_product(Q, Q, N, M), (3, 5, 6)), [0, 2, 3, 1, 4])
460
+ assert cg1 == cg2
461
+
462
+ cg1 = _array_contraction(
463
+ _array_contraction(
464
+ _array_contraction(
465
+ _array_contraction(
466
+ _permute_dims(
467
+ _array_tensor_product(N, Q, Q, M),
468
+ [2, 1, 5, 4, 0, 3, 6, 7]),
469
+ [1, 2, 6]),
470
+ [1, 3, 4]),
471
+ [1]),
472
+ [0])
473
+ cg2 = _array_contraction(_array_tensor_product(M, N, Q, Q), (0, 3, 5), (1, 4, 7), (2,), (6,))
474
+ assert cg1 == cg2
475
+
476
+
477
+ def test_arrayexpr_canonicalize_diagonal__permute_dims():
478
+ tp = _array_tensor_product(M, Q, N, P)
479
+ expr = _array_diagonal(
480
+ _permute_dims(tp, [0, 1, 2, 4, 7, 6, 3, 5]), (2, 4, 5), (6, 7),
481
+ (0, 3))
482
+ result = _array_diagonal(tp, (2, 6, 7), (3, 5), (0, 4))
483
+ assert expr == result
484
+
485
+ tp = _array_tensor_product(M, N, P, Q)
486
+ expr = _array_diagonal(_permute_dims(tp, [0, 5, 2, 4, 1, 6, 3, 7]), (1, 2, 6), (3, 4))
487
+ result = _array_diagonal(_array_tensor_product(M, P, N, Q), (3, 4, 5), (1, 2))
488
+ assert expr == result
489
+
490
+
491
+ def test_arrayexpr_canonicalize_diagonal_contraction():
492
+ tp = _array_tensor_product(M, N, P, Q)
493
+ expr = _array_contraction(_array_diagonal(tp, (1, 3, 4)), (0, 3))
494
+ result = _array_diagonal(_array_contraction(_array_tensor_product(M, N, P, Q), (0, 6)), (0, 2, 3))
495
+ assert expr == result
496
+
497
+ expr = _array_contraction(_array_diagonal(tp, (0, 1, 2, 3, 7)), (1, 2, 3))
498
+ result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 2, 3, 5, 6, 7))
499
+ assert expr == result
500
+
501
+ expr = _array_contraction(_array_diagonal(tp, (0, 2, 6, 7)), (1, 2, 3))
502
+ result = _array_diagonal(_array_contraction(tp, (3, 4, 5)), (0, 2, 3, 4))
503
+ assert expr == result
504
+
505
+ td = _array_diagonal(_array_tensor_product(M, N, P, Q), (0, 3))
506
+ expr = _array_contraction(td, (2, 1), (0, 4, 6, 5, 3))
507
+ result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 3, 5, 6, 7), (2, 4))
508
+ assert expr == result
509
+
510
+
511
+ def test_arrayexpr_array_wrong_permutation_size():
512
+ cg = _array_tensor_product(M, N)
513
+ raises(ValueError, lambda: _permute_dims(cg, [1, 0]))
514
+ raises(ValueError, lambda: _permute_dims(cg, [1, 0, 2, 3, 5, 4]))
515
+
516
+
517
+ def test_arrayexpr_nested_array_elementwise_add():
518
+ cg = _array_contraction(_array_add(
519
+ _array_tensor_product(M, N),
520
+ _array_tensor_product(N, M)
521
+ ), (1, 2))
522
+ result = _array_add(
523
+ _array_contraction(_array_tensor_product(M, N), (1, 2)),
524
+ _array_contraction(_array_tensor_product(N, M), (1, 2))
525
+ )
526
+ assert cg == result
527
+
528
+ cg = _array_diagonal(_array_add(
529
+ _array_tensor_product(M, N),
530
+ _array_tensor_product(N, M)
531
+ ), (1, 2))
532
+ result = _array_add(
533
+ _array_diagonal(_array_tensor_product(M, N), (1, 2)),
534
+ _array_diagonal(_array_tensor_product(N, M), (1, 2))
535
+ )
536
+ assert cg == result
537
+
538
+
539
+ def test_arrayexpr_array_expr_zero_array():
540
+ za1 = ZeroArray(k, l, m, n)
541
+ zm1 = ZeroMatrix(m, n)
542
+
543
+ za2 = ZeroArray(k, m, m, n)
544
+ zm2 = ZeroMatrix(m, m)
545
+ zm3 = ZeroMatrix(k, k)
546
+
547
+ assert _array_tensor_product(M, N, za1) == ZeroArray(k, k, k, k, k, l, m, n)
548
+ assert _array_tensor_product(M, N, zm1) == ZeroArray(k, k, k, k, m, n)
549
+
550
+ assert _array_contraction(za1, (3,)) == ZeroArray(k, l, m)
551
+ assert _array_contraction(zm1, (1,)) == ZeroArray(m)
552
+ assert _array_contraction(za2, (1, 2)) == ZeroArray(k, n)
553
+ assert _array_contraction(zm2, (0, 1)) == 0
554
+
555
+ assert _array_diagonal(za2, (1, 2)) == ZeroArray(k, n, m)
556
+ assert _array_diagonal(zm2, (0, 1)) == ZeroArray(m)
557
+
558
+ assert _permute_dims(za1, [2, 1, 3, 0]) == ZeroArray(m, l, n, k)
559
+ assert _permute_dims(zm1, [1, 0]) == ZeroArray(n, m)
560
+
561
+ assert _array_add(za1) == za1
562
+ assert _array_add(zm1) == ZeroArray(m, n)
563
+ tp1 = _array_tensor_product(MatrixSymbol("A", k, l), MatrixSymbol("B", m, n))
564
+ assert _array_add(tp1, za1) == tp1
565
+ tp2 = _array_tensor_product(MatrixSymbol("C", k, l), MatrixSymbol("D", m, n))
566
+ assert _array_add(tp1, za1, tp2) == _array_add(tp1, tp2)
567
+ assert _array_add(M, zm3) == M
568
+ assert _array_add(M, N, zm3) == _array_add(M, N)
569
+
570
+
571
+ def test_arrayexpr_array_expr_applyfunc():
572
+
573
+ A = ArraySymbol("A", (3, k, 2))
574
+ aaf = ArrayElementwiseApplyFunc(sin, A)
575
+ assert aaf.shape == (3, k, 2)
576
+
577
+
578
+ def test_edit_array_contraction():
579
+ cg = _array_contraction(_array_tensor_product(A, B, C, D), (1, 2, 5))
580
+ ecg = _EditArrayContraction(cg)
581
+ assert ecg.to_array_contraction() == cg
582
+
583
+ ecg.args_with_ind[1], ecg.args_with_ind[2] = ecg.args_with_ind[2], ecg.args_with_ind[1]
584
+ assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, B, D), (1, 3, 4))
585
+
586
+ ci = ecg.get_new_contraction_index()
587
+ new_arg = _ArgE(X)
588
+ new_arg.indices = [ci, ci]
589
+ ecg.args_with_ind.insert(2, new_arg)
590
+ assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, X, B, D), (1, 3, 6), (4, 5))
591
+
592
+ assert ecg.get_contraction_indices() == [[1, 3, 6], [4, 5]]
593
+ assert [[tuple(j) for j in i] for i in ecg.get_contraction_indices_to_ind_rel_pos()] == [[(0, 1), (1, 1), (3, 0)], [(2, 0), (2, 1)]]
594
+ assert [list(i) for i in ecg.get_mapping_for_index(0)] == [[0, 1], [1, 1], [3, 0]]
595
+ assert [list(i) for i in ecg.get_mapping_for_index(1)] == [[2, 0], [2, 1]]
596
+ raises(ValueError, lambda: ecg.get_mapping_for_index(2))
597
+
598
+ ecg.args_with_ind.pop(1)
599
+ assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 4), (2, 3))
600
+
601
+ ecg.args_with_ind[0].indices[1] = ecg.args_with_ind[1].indices[0]
602
+ ecg.args_with_ind[1].indices[1] = ecg.args_with_ind[2].indices[0]
603
+ assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 2), (3, 4))
604
+
605
+ ecg.insert_after(ecg.args_with_ind[1], _ArgE(C))
606
+ assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, C, B, D), (1, 2), (3, 6))
607
+
608
+
609
+ def test_array_expressions_no_canonicalization():
610
+
611
+ tp = _array_tensor_product(M, N, P)
612
+
613
+ # ArrayTensorProduct:
614
+
615
+ expr = ArrayTensorProduct(tp, N)
616
+ assert str(expr) == "ArrayTensorProduct(ArrayTensorProduct(M, N, P), N)"
617
+ assert expr.doit() == ArrayTensorProduct(M, N, P, N)
618
+
619
+ expr = ArrayTensorProduct(ArrayContraction(M, (0, 1)), N)
620
+ assert str(expr) == "ArrayTensorProduct(ArrayContraction(M, (0, 1)), N)"
621
+ assert expr.doit() == ArrayContraction(ArrayTensorProduct(M, N), (0, 1))
622
+
623
+ expr = ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N)
624
+ assert str(expr) == "ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N)"
625
+ assert expr.doit() == PermuteDims(ArrayDiagonal(ArrayTensorProduct(M, N), (0, 1)), [2, 0, 1])
626
+
627
+ expr = ArrayTensorProduct(PermuteDims(M, [1, 0]), N)
628
+ assert str(expr) == "ArrayTensorProduct(PermuteDims(M, (0 1)), N)"
629
+ assert expr.doit() == PermuteDims(ArrayTensorProduct(M, N), [1, 0, 2, 3])
630
+
631
+ # ArrayContraction:
632
+
633
+ expr = ArrayContraction(_array_contraction(tp, (0, 2)), (0, 1))
634
+ assert isinstance(expr, ArrayContraction)
635
+ assert isinstance(expr.expr, ArrayContraction)
636
+ assert str(expr) == "ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))"
637
+ assert expr.doit() == ArrayContraction(tp, (0, 2), (1, 3))
638
+
639
+ expr = ArrayContraction(ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1)), (0, 1))
640
+ assert expr.doit() == ArrayContraction(tp, (0, 1), (2, 3), (4, 5))
641
+ # assert expr._canonicalize() == ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1), (2, 3))
642
+
643
+ expr = ArrayContraction(ArrayDiagonal(tp, (0, 1)), (0, 1))
644
+ assert str(expr) == "ArrayContraction(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))"
645
+ assert expr.doit() == ArrayDiagonal(ArrayContraction(ArrayTensorProduct(N, M, P), (0, 1)), (0, 1))
646
+
647
+ expr = ArrayContraction(PermuteDims(M, [1, 0]), (0, 1))
648
+ assert str(expr) == "ArrayContraction(PermuteDims(M, (0 1)), (0, 1))"
649
+ assert expr.doit() == ArrayContraction(M, (0, 1))
650
+
651
+ # ArrayDiagonal:
652
+
653
+ expr = ArrayDiagonal(ArrayDiagonal(tp, (0, 2)), (0, 1))
654
+ assert str(expr) == "ArrayDiagonal(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))"
655
+ assert expr.doit() == ArrayDiagonal(tp, (0, 2), (1, 3))
656
+
657
+ expr = ArrayDiagonal(ArrayDiagonal(ArrayDiagonal(tp, (0, 1)), (0, 1)), (0, 1))
658
+ assert expr.doit() == ArrayDiagonal(tp, (0, 1), (2, 3), (4, 5))
659
+ assert expr._canonicalize() == expr.doit()
660
+
661
+ expr = ArrayDiagonal(ArrayContraction(tp, (0, 1)), (0, 1))
662
+ assert str(expr) == "ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))"
663
+ assert expr.doit() == expr
664
+
665
+ expr = ArrayDiagonal(PermuteDims(M, [1, 0]), (0, 1))
666
+ assert str(expr) == "ArrayDiagonal(PermuteDims(M, (0 1)), (0, 1))"
667
+ assert expr.doit() == ArrayDiagonal(M, (0, 1))
668
+
669
+ # ArrayAdd:
670
+
671
+ expr = ArrayAdd(M)
672
+ assert isinstance(expr, ArrayAdd)
673
+ assert expr.doit() == M
674
+
675
+ expr = ArrayAdd(ArrayAdd(M, N), P)
676
+ assert str(expr) == "ArrayAdd(ArrayAdd(M, N), P)"
677
+ assert expr.doit() == ArrayAdd(M, N, P)
678
+
679
+ expr = ArrayAdd(M, ArrayAdd(N, ArrayAdd(P, M)))
680
+ assert expr.doit() == ArrayAdd(M, N, P, M)
681
+ assert expr._canonicalize() == ArrayAdd(M, N, ArrayAdd(P, M))
682
+
683
+ expr = ArrayAdd(M, ZeroArray(k, k), N)
684
+ assert str(expr) == "ArrayAdd(M, ZeroArray(k, k), N)"
685
+ assert expr.doit() == ArrayAdd(M, N)
686
+
687
+ # PermuteDims:
688
+
689
+ expr = PermuteDims(PermuteDims(M, [1, 0]), [1, 0])
690
+ assert str(expr) == "PermuteDims(PermuteDims(M, (0 1)), (0 1))"
691
+ assert expr.doit() == M
692
+
693
+ expr = PermuteDims(PermuteDims(PermuteDims(M, [1, 0]), [1, 0]), [1, 0])
694
+ assert expr.doit() == PermuteDims(M, [1, 0])
695
+ assert expr._canonicalize() == expr.doit()
696
+
697
+ # Reshape
698
+
699
+ expr = Reshape(A, (k**2,))
700
+ assert expr.shape == (k**2,)
701
+ assert isinstance(expr, Reshape)
702
+
703
+
704
+ def test_array_expr_construction_with_functions():
705
+
706
+ tp = tensorproduct(M, N)
707
+ assert tp == ArrayTensorProduct(M, N)
708
+
709
+ expr = tensorproduct(A, eye(2))
710
+ assert expr == ArrayTensorProduct(A, eye(2))
711
+
712
+ # Contraction:
713
+
714
+ expr = tensorcontraction(M, (0, 1))
715
+ assert expr == ArrayContraction(M, (0, 1))
716
+
717
+ expr = tensorcontraction(tp, (1, 2))
718
+ assert expr == ArrayContraction(tp, (1, 2))
719
+
720
+ expr = tensorcontraction(tensorcontraction(tp, (1, 2)), (0, 1))
721
+ assert expr == ArrayContraction(tp, (0, 3), (1, 2))
722
+
723
+ # Diagonalization:
724
+
725
+ expr = tensordiagonal(M, (0, 1))
726
+ assert expr == ArrayDiagonal(M, (0, 1))
727
+
728
+ expr = tensordiagonal(tensordiagonal(tp, (0, 1)), (0, 1))
729
+ assert expr == ArrayDiagonal(tp, (0, 1), (2, 3))
730
+
731
+ # Permutation of dimensions:
732
+
733
+ expr = permutedims(M, [1, 0])
734
+ assert expr == PermuteDims(M, [1, 0])
735
+
736
+ expr = permutedims(PermuteDims(tp, [1, 0, 2, 3]), [0, 1, 3, 2])
737
+ assert expr == PermuteDims(tp, [1, 0, 3, 2])
738
+
739
+ expr = PermuteDims(tp, index_order_new=["a", "b", "c", "d"], index_order_old=["d", "c", "b", "a"])
740
+ assert expr == PermuteDims(tp, [3, 2, 1, 0])
741
+
742
+ arr = Array(range(32)).reshape(2, 2, 2, 2, 2)
743
+ expr = PermuteDims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c'])
744
+ assert expr == PermuteDims(arr, [2, 0, 4, 3, 1])
745
+ assert expr.as_explicit() == permutedims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c'])
746
+
747
+
748
+ def test_array_element_expressions():
749
+ # Check commutative property:
750
+ assert M[0, 0]*N[0, 0] == N[0, 0]*M[0, 0]
751
+
752
+ # Check derivatives:
753
+ assert M[0, 0].diff(M[0, 0]) == 1
754
+ assert M[0, 0].diff(M[1, 0]) == 0
755
+ assert M[0, 0].diff(N[0, 0]) == 0
756
+ assert M[0, 1].diff(M[i, j]) == KroneckerDelta(i, 0)*KroneckerDelta(j, 1)
757
+ assert M[0, 1].diff(N[i, j]) == 0
758
+
759
+ K4 = ArraySymbol("K4", shape=(k, k, k, k))
760
+
761
+ assert K4[i, j, k, l].diff(K4[1, 2, 3, 4]) == (
762
+ KroneckerDelta(i, 1)*KroneckerDelta(j, 2)*KroneckerDelta(k, 3)*KroneckerDelta(l, 4)
763
+ )
764
+
765
+
766
+ def test_array_expr_reshape():
767
+
768
+ A = MatrixSymbol("A", 2, 2)
769
+ B = ArraySymbol("B", (2, 2, 2))
770
+ C = Array([1, 2, 3, 4])
771
+
772
+ expr = Reshape(A, (4,))
773
+ assert expr.expr == A
774
+ assert expr.shape == (4,)
775
+ assert expr.as_explicit() == Array([A[0, 0], A[0, 1], A[1, 0], A[1, 1]])
776
+
777
+ expr = Reshape(B, (2, 4))
778
+ assert expr.expr == B
779
+ assert expr.shape == (2, 4)
780
+ ee = expr.as_explicit()
781
+ assert isinstance(ee, ImmutableDenseNDimArray)
782
+ assert ee.shape == (2, 4)
783
+ assert ee == Array([[B[0, 0, 0], B[0, 0, 1], B[0, 1, 0], B[0, 1, 1]], [B[1, 0, 0], B[1, 0, 1], B[1, 1, 0], B[1, 1, 1]]])
784
+
785
+ expr = Reshape(A, (k, 2))
786
+ assert expr.shape == (k, 2)
787
+
788
+ raises(ValueError, lambda: Reshape(A, (2, 3)))
789
+ raises(ValueError, lambda: Reshape(A, (3,)))
790
+
791
+ expr = Reshape(C, (2, 2))
792
+ assert expr.expr == C
793
+ assert expr.shape == (2, 2)
794
+ assert expr.doit() == Array([[1, 2], [3, 4]])
795
+
796
+
797
+ def test_array_expr_as_explicit_with_explicit_component_arrays():
798
+ # Test if .as_explicit() works with explicit-component arrays
799
+ # nested in array expressions:
800
+ from sympy.abc import x, y, z, t
801
+ A = Array([[x, y], [z, t]])
802
+ assert ArrayTensorProduct(A, A).as_explicit() == tensorproduct(A, A)
803
+ assert ArrayDiagonal(A, (0, 1)).as_explicit() == tensordiagonal(A, (0, 1))
804
+ assert ArrayContraction(A, (0, 1)).as_explicit() == tensorcontraction(A, (0, 1))
805
+ assert ArrayAdd(A, A).as_explicit() == A + A
806
+ assert ArrayElementwiseApplyFunc(sin, A).as_explicit() == A.applyfunc(sin)
807
+ assert PermuteDims(A, [1, 0]).as_explicit() == permutedims(A, [1, 0])
808
+ assert Reshape(A, [4]).as_explicit() == A.reshape(4)