MTerryJack commited on
Commit
b80a530
·
verified ·
1 Parent(s): f95072e

chore: remove stray .venv files (1400-1482)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.13/site-packages/sympy/unify/tests/__init__.py +0 -0
  2. .venv/lib/python3.13/site-packages/sympy/unify/tests/test_rewrite.py +0 -74
  3. .venv/lib/python3.13/site-packages/sympy/unify/tests/test_sympy.py +0 -162
  4. .venv/lib/python3.13/site-packages/sympy/unify/tests/test_unify.py +0 -88
  5. .venv/lib/python3.13/site-packages/sympy/utilities/__init__.py +0 -30
  6. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/__init__.py +0 -22
  7. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/availability.py +0 -77
  8. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/compilation.py +0 -657
  9. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/runners.py +0 -301
  10. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/__init__.py +0 -0
  11. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/test_compilation.py +0 -104
  12. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/util.py +0 -312
  13. .venv/lib/python3.13/site-packages/sympy/utilities/autowrap.py +0 -1178
  14. .venv/lib/python3.13/site-packages/sympy/utilities/codegen.py +0 -2237
  15. .venv/lib/python3.13/site-packages/sympy/utilities/decorator.py +0 -339
  16. .venv/lib/python3.13/site-packages/sympy/utilities/enumerative.py +0 -1155
  17. .venv/lib/python3.13/site-packages/sympy/utilities/exceptions.py +0 -271
  18. .venv/lib/python3.13/site-packages/sympy/utilities/iterables.py +0 -3179
  19. .venv/lib/python3.13/site-packages/sympy/utilities/lambdify.py +0 -1592
  20. .venv/lib/python3.13/site-packages/sympy/utilities/magic.py +0 -12
  21. .venv/lib/python3.13/site-packages/sympy/utilities/matchpy_connector.py +0 -340
  22. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/__init__.py +0 -122
  23. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/__init__.py +0 -0
  24. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmlctop.xsl +0 -0
  25. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmltex.xsl +0 -0
  26. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl +0 -0
  27. .venv/lib/python3.13/site-packages/sympy/utilities/memoization.py +0 -76
  28. .venv/lib/python3.13/site-packages/sympy/utilities/misc.py +0 -564
  29. .venv/lib/python3.13/site-packages/sympy/utilities/pkgdata.py +0 -33
  30. .venv/lib/python3.13/site-packages/sympy/utilities/pytest.py +0 -12
  31. .venv/lib/python3.13/site-packages/sympy/utilities/randtest.py +0 -12
  32. .venv/lib/python3.13/site-packages/sympy/utilities/runtests.py +0 -13
  33. .venv/lib/python3.13/site-packages/sympy/utilities/source.py +0 -40
  34. .venv/lib/python3.13/site-packages/sympy/utilities/tests/__init__.py +0 -0
  35. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_autowrap.py +0 -467
  36. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen.py +0 -1632
  37. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_julia.py +0 -620
  38. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_octave.py +0 -589
  39. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_rust.py +0 -401
  40. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_decorator.py +0 -129
  41. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_deprecated.py +0 -13
  42. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_enumerative.py +0 -179
  43. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_exceptions.py +0 -12
  44. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_iterables.py +0 -945
  45. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_lambdify.py +0 -2263
  46. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_matchpy_connector.py +0 -164
  47. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_mathml.py +0 -33
  48. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_misc.py +0 -148
  49. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_pickling.py +0 -723
  50. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_source.py +0 -11
.venv/lib/python3.13/site-packages/sympy/unify/tests/__init__.py DELETED
File without changes
.venv/lib/python3.13/site-packages/sympy/unify/tests/test_rewrite.py DELETED
@@ -1,74 +0,0 @@
1
- from sympy.unify.rewrite import rewriterule
2
- from sympy.core.basic import Basic
3
- from sympy.core.singleton import S
4
- from sympy.core.symbol import Symbol
5
- from sympy.functions.elementary.trigonometric import sin
6
- from sympy.abc import x, y
7
- from sympy.strategies.rl import rebuild
8
- from sympy.assumptions import Q
9
-
10
- p, q = Symbol('p'), Symbol('q')
11
-
12
- def test_simple():
13
- rl = rewriterule(Basic(p, S(1)), Basic(p, S(2)), variables=(p,))
14
- assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
15
-
16
- p1 = p**2
17
- p2 = p**3
18
- rl = rewriterule(p1, p2, variables=(p,))
19
-
20
- expr = x**2
21
- assert list(rl(expr)) == [x**3]
22
-
23
- def test_simple_variables():
24
- rl = rewriterule(Basic(x, S(1)), Basic(x, S(2)), variables=(x,))
25
- assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
26
-
27
- rl = rewriterule(x**2, x**3, variables=(x,))
28
- assert list(rl(y**2)) == [y**3]
29
-
30
- def test_moderate():
31
- p1 = p**2 + q**3
32
- p2 = (p*q)**4
33
- rl = rewriterule(p1, p2, (p, q))
34
-
35
- expr = x**2 + y**3
36
- assert list(rl(expr)) == [(x*y)**4]
37
-
38
- def test_sincos():
39
- p1 = sin(p)**2 + sin(p)**2
40
- p2 = 1
41
- rl = rewriterule(p1, p2, (p, q))
42
-
43
- assert list(rl(sin(x)**2 + sin(x)**2)) == [1]
44
- assert list(rl(sin(y)**2 + sin(y)**2)) == [1]
45
-
46
- def test_Exprs_ok():
47
- rl = rewriterule(p+q, q+p, (p, q))
48
- next(rl(x+y)).is_commutative
49
- str(next(rl(x+y)))
50
-
51
- def test_condition_simple():
52
- rl = rewriterule(x, x+1, [x], lambda x: x < 10)
53
- assert not list(rl(S(15)))
54
- assert rebuild(next(rl(S(5)))) == 6
55
-
56
-
57
- def test_condition_multiple():
58
- rl = rewriterule(x + y, x**y, [x,y], lambda x, y: x.is_integer)
59
-
60
- a = Symbol('a')
61
- b = Symbol('b', integer=True)
62
- expr = a + b
63
- assert list(rl(expr)) == [b**a]
64
-
65
- c = Symbol('c', integer=True)
66
- d = Symbol('d', integer=True)
67
- assert set(rl(c + d)) == {c**d, d**c}
68
-
69
- def test_assumptions():
70
- rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x))
71
-
72
- a, b = map(Symbol, 'ab')
73
- expr = a + b
74
- assert list(rl(expr, Q.integer(b))) == [b**a]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/unify/tests/test_sympy.py DELETED
@@ -1,162 +0,0 @@
1
- from sympy.core.add import Add
2
- from sympy.core.basic import Basic
3
- from sympy.core.containers import Tuple
4
- from sympy.core.singleton import S
5
- from sympy.core.symbol import (Symbol, symbols)
6
- from sympy.logic.boolalg import And
7
- from sympy.core.symbol import Str
8
- from sympy.unify.core import Compound, Variable
9
- from sympy.unify.usympy import (deconstruct, construct, unify, is_associative,
10
- is_commutative)
11
- from sympy.abc import x, y, z, n
12
-
13
- def test_deconstruct():
14
- expr = Basic(S(1), S(2), S(3))
15
- expected = Compound(Basic, (1, 2, 3))
16
- assert deconstruct(expr) == expected
17
-
18
- assert deconstruct(1) == 1
19
- assert deconstruct(x) == x
20
- assert deconstruct(x, variables=(x,)) == Variable(x)
21
- assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x))
22
- assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \
23
- Compound(Add, (1, Variable(x)))
24
-
25
- def test_construct():
26
- expr = Compound(Basic, (S(1), S(2), S(3)))
27
- expected = Basic(S(1), S(2), S(3))
28
- assert construct(expr) == expected
29
-
30
- def test_nested():
31
- expr = Basic(S(1), Basic(S(2)), S(3))
32
- cmpd = Compound(Basic, (S(1), Compound(Basic, Tuple(2)), S(3)))
33
- assert deconstruct(expr) == cmpd
34
- assert construct(cmpd) == expr
35
-
36
- def test_unify():
37
- expr = Basic(S(1), S(2), S(3))
38
- a, b, c = map(Symbol, 'abc')
39
- pattern = Basic(a, b, c)
40
- assert list(unify(expr, pattern, {}, (a, b, c))) == [{a: 1, b: 2, c: 3}]
41
- assert list(unify(expr, pattern, variables=(a, b, c))) == \
42
- [{a: 1, b: 2, c: 3}]
43
-
44
- def test_unify_variables():
45
- assert list(unify(Basic(S(1), S(2)), Basic(S(1), x), {}, variables=(x,))) == [{x: 2}]
46
-
47
- def test_s_input():
48
- expr = Basic(S(1), S(2))
49
- a, b = map(Symbol, 'ab')
50
- pattern = Basic(a, b)
51
- assert list(unify(expr, pattern, {}, (a, b))) == [{a: 1, b: 2}]
52
- assert list(unify(expr, pattern, {a: 5}, (a, b))) == []
53
-
54
- def iterdicteq(a, b):
55
- a = tuple(a)
56
- b = tuple(b)
57
- return len(a) == len(b) and all(x in b for x in a)
58
-
59
- def test_unify_commutative():
60
- expr = Add(1, 2, 3, evaluate=False)
61
- a, b, c = map(Symbol, 'abc')
62
- pattern = Add(a, b, c, evaluate=False)
63
-
64
- result = tuple(unify(expr, pattern, {}, (a, b, c)))
65
- expected = ({a: 1, b: 2, c: 3},
66
- {a: 1, b: 3, c: 2},
67
- {a: 2, b: 1, c: 3},
68
- {a: 2, b: 3, c: 1},
69
- {a: 3, b: 1, c: 2},
70
- {a: 3, b: 2, c: 1})
71
-
72
- assert iterdicteq(result, expected)
73
-
74
- def test_unify_iter():
75
- expr = Add(1, 2, 3, evaluate=False)
76
- a, b, c = map(Symbol, 'abc')
77
- pattern = Add(a, c, evaluate=False)
78
- assert is_associative(deconstruct(pattern))
79
- assert is_commutative(deconstruct(pattern))
80
-
81
- result = list(unify(expr, pattern, {}, (a, c)))
82
- expected = [{a: 1, c: Add(2, 3, evaluate=False)},
83
- {a: 1, c: Add(3, 2, evaluate=False)},
84
- {a: 2, c: Add(1, 3, evaluate=False)},
85
- {a: 2, c: Add(3, 1, evaluate=False)},
86
- {a: 3, c: Add(1, 2, evaluate=False)},
87
- {a: 3, c: Add(2, 1, evaluate=False)},
88
- {a: Add(1, 2, evaluate=False), c: 3},
89
- {a: Add(2, 1, evaluate=False), c: 3},
90
- {a: Add(1, 3, evaluate=False), c: 2},
91
- {a: Add(3, 1, evaluate=False), c: 2},
92
- {a: Add(2, 3, evaluate=False), c: 1},
93
- {a: Add(3, 2, evaluate=False), c: 1}]
94
-
95
- assert iterdicteq(result, expected)
96
-
97
- def test_hard_match():
98
- from sympy.functions.elementary.trigonometric import (cos, sin)
99
- expr = sin(x) + cos(x)**2
100
- p, q = map(Symbol, 'pq')
101
- pattern = sin(p) + cos(p)**2
102
- assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
103
-
104
- def test_matrix():
105
- from sympy.matrices.expressions.matexpr import MatrixSymbol
106
- X = MatrixSymbol('X', n, n)
107
- Y = MatrixSymbol('Y', 2, 2)
108
- Z = MatrixSymbol('Z', 2, 3)
109
- assert list(unify(X, Y, {}, variables=[n, Str('X')])) == [{Str('X'): Str('Y'), n: 2}]
110
- assert list(unify(X, Z, {}, variables=[n, Str('X')])) == []
111
-
112
- def test_non_frankenAdds():
113
- # the is_commutative property used to fail because of Basic.__new__
114
- # This caused is_commutative and str calls to fail
115
- expr = x+y*2
116
- rebuilt = construct(deconstruct(expr))
117
- # Ensure that we can run these commands without causing an error
118
- str(rebuilt)
119
- rebuilt.is_commutative
120
-
121
- def test_FiniteSet_commutivity():
122
- from sympy.sets.sets import FiniteSet
123
- a, b, c, x, y = symbols('a,b,c,x,y')
124
- s = FiniteSet(a, b, c)
125
- t = FiniteSet(x, y)
126
- variables = (x, y)
127
- assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, t, variables=variables))
128
-
129
- def test_FiniteSet_complex():
130
- from sympy.sets.sets import FiniteSet
131
- a, b, c, x, y, z = symbols('a,b,c,x,y,z')
132
- expr = FiniteSet(Basic(S(1), x), y, Basic(x, z))
133
- pattern = FiniteSet(a, Basic(x, b))
134
- variables = a, b
135
- expected = ({b: 1, a: FiniteSet(y, Basic(x, z))},
136
- {b: z, a: FiniteSet(y, Basic(S(1), x))})
137
- assert iterdicteq(unify(expr, pattern, variables=variables), expected)
138
-
139
-
140
- def test_and():
141
- variables = x, y
142
- expected = ({x: z > 0, y: n < 3},)
143
- assert iterdicteq(unify((z>0) & (n<3), And(x, y), variables=variables),
144
- expected)
145
-
146
- def test_Union():
147
- from sympy.sets.sets import Interval
148
- assert list(unify(Interval(0, 1) + Interval(10, 11),
149
- Interval(0, 1) + Interval(12, 13),
150
- variables=(Interval(12, 13),)))
151
-
152
- def test_is_commutative():
153
- assert is_commutative(deconstruct(x+y))
154
- assert is_commutative(deconstruct(x*y))
155
- assert not is_commutative(deconstruct(x**y))
156
-
157
- def test_commutative_in_commutative():
158
- from sympy.abc import a,b,c,d
159
- from sympy.functions.elementary.trigonometric import (cos, sin)
160
- eq = sin(3)*sin(4)*sin(5) + 4*cos(3)*cos(4)
161
- pat = a*cos(b)*cos(c) + d*sin(b)*sin(c)
162
- assert next(unify(eq, pat, variables=(a,b,c,d)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/unify/tests/test_unify.py DELETED
@@ -1,88 +0,0 @@
1
- from sympy.unify.core import Compound, Variable, CondVariable, allcombinations
2
- from sympy.unify import core
3
-
4
- a,b,c = 'a', 'b', 'c'
5
- w,x,y,z = map(Variable, 'wxyz')
6
-
7
- C = Compound
8
-
9
- def is_associative(x):
10
- return isinstance(x, Compound) and (x.op in ('Add', 'Mul', 'CAdd', 'CMul'))
11
- def is_commutative(x):
12
- return isinstance(x, Compound) and (x.op in ('CAdd', 'CMul'))
13
-
14
-
15
- def unify(a, b, s={}):
16
- return core.unify(a, b, s=s, is_associative=is_associative,
17
- is_commutative=is_commutative)
18
-
19
- def test_basic():
20
- assert list(unify(a, x, {})) == [{x: a}]
21
- assert list(unify(a, x, {x: 10})) == []
22
- assert list(unify(1, x, {})) == [{x: 1}]
23
- assert list(unify(a, a, {})) == [{}]
24
- assert list(unify((w, x), (y, z), {})) == [{w: y, x: z}]
25
- assert list(unify(x, (a, b), {})) == [{x: (a, b)}]
26
-
27
- assert list(unify((a, b), (x, x), {})) == []
28
- assert list(unify((y, z), (x, x), {}))!= []
29
- assert list(unify((a, (b, c)), (a, (x, y)), {})) == [{x: b, y: c}]
30
-
31
- def test_ops():
32
- assert list(unify(C('Add', (a,b,c)), C('Add', (a,x,y)), {})) == \
33
- [{x:b, y:c}]
34
- assert list(unify(C('Add', (C('Mul', (1,2)), b,c)), C('Add', (x,y,c)), {})) == \
35
- [{x: C('Mul', (1,2)), y:b}]
36
-
37
- def test_associative():
38
- c1 = C('Add', (1,2,3))
39
- c2 = C('Add', (x,y))
40
- assert tuple(unify(c1, c2, {})) == ({x: 1, y: C('Add', (2, 3))},
41
- {x: C('Add', (1, 2)), y: 3})
42
-
43
- def test_commutative():
44
- c1 = C('CAdd', (1,2,3))
45
- c2 = C('CAdd', (x,y))
46
- result = list(unify(c1, c2, {}))
47
- assert {x: 1, y: C('CAdd', (2, 3))} in result
48
- assert ({x: 2, y: C('CAdd', (1, 3))} in result or
49
- {x: 2, y: C('CAdd', (3, 1))} in result)
50
-
51
- def _test_combinations_assoc():
52
- assert set(allcombinations((1,2,3), (a,b), True)) == \
53
- {(((1, 2), (3,)), (a, b)), (((1,), (2, 3)), (a, b))}
54
-
55
- def _test_combinations_comm():
56
- assert set(allcombinations((1,2,3), (a,b), None)) == \
57
- {(((1,), (2, 3)), ('a', 'b')), (((2,), (3, 1)), ('a', 'b')),
58
- (((3,), (1, 2)), ('a', 'b')), (((1, 2), (3,)), ('a', 'b')),
59
- (((2, 3), (1,)), ('a', 'b')), (((3, 1), (2,)), ('a', 'b'))}
60
-
61
- def test_allcombinations():
62
- assert set(allcombinations((1,2), (1,2), 'commutative')) ==\
63
- {(((1,),(2,)), ((1,),(2,))), (((1,),(2,)), ((2,),(1,)))}
64
-
65
-
66
- def test_commutativity():
67
- c1 = Compound('CAdd', (a, b))
68
- c2 = Compound('CAdd', (x, y))
69
- assert is_commutative(c1) and is_commutative(c2)
70
- assert len(list(unify(c1, c2, {}))) == 2
71
-
72
-
73
- def test_CondVariable():
74
- expr = C('CAdd', (1, 2))
75
- x = Variable('x')
76
- y = CondVariable('y', lambda a: a % 2 == 0)
77
- z = CondVariable('z', lambda a: a > 3)
78
- pattern = C('CAdd', (x, y))
79
- assert list(unify(expr, pattern, {})) == \
80
- [{x: 1, y: 2}]
81
-
82
- z = CondVariable('z', lambda a: a > 3)
83
- pattern = C('CAdd', (z, y))
84
-
85
- assert list(unify(expr, pattern, {})) == []
86
-
87
- def test_defaultdict():
88
- assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- """This module contains some general purpose utilities that are used across
2
- SymPy.
3
- """
4
- from .iterables import (flatten, group, take, subsets,
5
- variations, numbered_symbols, cartes, capture, dict_merge,
6
- prefixes, postfixes, sift, topological_sort, unflatten,
7
- has_dups, has_variety, reshape, rotations)
8
-
9
- from .misc import filldedent
10
-
11
- from .lambdify import lambdify
12
-
13
- from .decorator import threaded, xthreaded, public, memoize_property
14
-
15
- from .timeutils import timed
16
-
17
- __all__ = [
18
- 'flatten', 'group', 'take', 'subsets', 'variations', 'numbered_symbols',
19
- 'cartes', 'capture', 'dict_merge', 'prefixes', 'postfixes', 'sift',
20
- 'topological_sort', 'unflatten', 'has_dups', 'has_variety', 'reshape',
21
- 'rotations',
22
-
23
- 'filldedent',
24
-
25
- 'lambdify',
26
-
27
- 'threaded', 'xthreaded', 'public', 'memoize_property',
28
-
29
- 'timed',
30
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- """ This sub-module is private, i.e. external code should not depend on it.
2
-
3
- These functions are used by tests run as part of continuous integration.
4
- Once the implementation is mature (it should support the major
5
- platforms: Windows, OS X & Linux) it may become official API which
6
- may be relied upon by downstream libraries. Until then API may break
7
- without prior notice.
8
-
9
- TODO:
10
- - (optionally) clean up after tempfile.mkdtemp()
11
- - cross-platform testing
12
- - caching of compiler choice and intermediate files
13
-
14
- """
15
-
16
- from .compilation import compile_link_import_strings, compile_run_strings
17
- from .availability import has_fortran, has_c, has_cxx
18
-
19
- __all__ = [
20
- 'compile_link_import_strings', 'compile_run_strings',
21
- 'has_fortran', 'has_c', 'has_cxx',
22
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/availability.py DELETED
@@ -1,77 +0,0 @@
1
- import os
2
- from .compilation import compile_run_strings
3
- from .util import CompilerNotFoundError
4
-
5
- def has_fortran():
6
- if not hasattr(has_fortran, 'result'):
7
- try:
8
- (stdout, stderr), info = compile_run_strings(
9
- [('main.f90', (
10
- 'program foo\n'
11
- 'print *, "hello world"\n'
12
- 'end program'
13
- ))], clean=True
14
- )
15
- except CompilerNotFoundError:
16
- has_fortran.result = False
17
- if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
18
- raise
19
- else:
20
- if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
21
- if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
22
- raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
23
- has_fortran.result = False
24
- else:
25
- has_fortran.result = True
26
- return has_fortran.result
27
-
28
-
29
- def has_c():
30
- if not hasattr(has_c, 'result'):
31
- try:
32
- (stdout, stderr), info = compile_run_strings(
33
- [('main.c', (
34
- '#include <stdio.h>\n'
35
- 'int main(){\n'
36
- 'printf("hello world\\n");\n'
37
- 'return 0;\n'
38
- '}'
39
- ))], clean=True
40
- )
41
- except CompilerNotFoundError:
42
- has_c.result = False
43
- if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
44
- raise
45
- else:
46
- if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
47
- if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
48
- raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
49
- has_c.result = False
50
- else:
51
- has_c.result = True
52
- return has_c.result
53
-
54
-
55
- def has_cxx():
56
- if not hasattr(has_cxx, 'result'):
57
- try:
58
- (stdout, stderr), info = compile_run_strings(
59
- [('main.cxx', (
60
- '#include <iostream>\n'
61
- 'int main(){\n'
62
- 'std::cout << "hello world" << std::endl;\n'
63
- '}'
64
- ))], clean=True
65
- )
66
- except CompilerNotFoundError:
67
- has_cxx.result = False
68
- if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
69
- raise
70
- else:
71
- if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
72
- if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
73
- raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
74
- has_cxx.result = False
75
- else:
76
- has_cxx.result = True
77
- return has_cxx.result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/compilation.py DELETED
@@ -1,657 +0,0 @@
1
- import glob
2
- import os
3
- import shutil
4
- import subprocess
5
- import sys
6
- import tempfile
7
- import warnings
8
- from pathlib import Path
9
- from sysconfig import get_config_var, get_config_vars, get_path
10
-
11
- from .runners import (
12
- CCompilerRunner,
13
- CppCompilerRunner,
14
- FortranCompilerRunner
15
- )
16
- from .util import (
17
- get_abspath, make_dirs, copy, Glob, ArbitraryDepthGlob,
18
- glob_at_depth, import_module_from_file, pyx_is_cplus,
19
- sha256_of_string, sha256_of_file, CompileError
20
- )
21
-
22
- if os.name == 'posix':
23
- objext = '.o'
24
- elif os.name == 'nt':
25
- objext = '.obj'
26
- else:
27
- warnings.warn("Unknown os.name: {}".format(os.name))
28
- objext = '.o'
29
-
30
-
31
- def compile_sources(files, Runner=None, destdir=None, cwd=None, keep_dir_struct=False,
32
- per_file_kwargs=None, **kwargs):
33
- """ Compile source code files to object files.
34
-
35
- Parameters
36
- ==========
37
-
38
- files : iterable of str
39
- Paths to source files, if ``cwd`` is given, the paths are taken as relative.
40
- Runner: CompilerRunner subclass (optional)
41
- Could be e.g. ``FortranCompilerRunner``. Will be inferred from filename
42
- extensions if missing.
43
- destdir: str
44
- Output directory, if cwd is given, the path is taken as relative.
45
- cwd: str
46
- Working directory. Specify to have compiler run in other directory.
47
- also used as root of relative paths.
48
- keep_dir_struct: bool
49
- Reproduce directory structure in `destdir`. default: ``False``
50
- per_file_kwargs: dict
51
- Dict mapping instances in ``files`` to keyword arguments.
52
- \\*\\*kwargs: dict
53
- Default keyword arguments to pass to ``Runner``.
54
-
55
- Returns
56
- =======
57
- List of strings (paths of object files).
58
- """
59
- _per_file_kwargs = {}
60
-
61
- if per_file_kwargs is not None:
62
- for k, v in per_file_kwargs.items():
63
- if isinstance(k, Glob):
64
- for path in glob.glob(k.pathname):
65
- _per_file_kwargs[path] = v
66
- elif isinstance(k, ArbitraryDepthGlob):
67
- for path in glob_at_depth(k.filename, cwd):
68
- _per_file_kwargs[path] = v
69
- else:
70
- _per_file_kwargs[k] = v
71
-
72
- # Set up destination directory
73
- destdir = destdir or '.'
74
- if not os.path.isdir(destdir):
75
- if os.path.exists(destdir):
76
- raise OSError("{} is not a directory".format(destdir))
77
- else:
78
- make_dirs(destdir)
79
- if cwd is None:
80
- cwd = '.'
81
- for f in files:
82
- copy(f, destdir, only_update=True, dest_is_dir=True)
83
-
84
- # Compile files and return list of paths to the objects
85
- dstpaths = []
86
- for f in files:
87
- if keep_dir_struct:
88
- name, ext = os.path.splitext(f)
89
- else:
90
- name, ext = os.path.splitext(os.path.basename(f))
91
- file_kwargs = kwargs.copy()
92
- file_kwargs.update(_per_file_kwargs.get(f, {}))
93
- dstpaths.append(src2obj(f, Runner, cwd=cwd, **file_kwargs))
94
- return dstpaths
95
-
96
-
97
- def get_mixed_fort_c_linker(vendor=None, cplus=False, cwd=None):
98
- vendor = vendor or os.environ.get('SYMPY_COMPILER_VENDOR', 'gnu')
99
-
100
- if vendor.lower() == 'intel':
101
- if cplus:
102
- return (FortranCompilerRunner,
103
- {'flags': ['-nofor_main', '-cxxlib']}, vendor)
104
- else:
105
- return (FortranCompilerRunner,
106
- {'flags': ['-nofor_main']}, vendor)
107
- elif vendor.lower() == 'gnu' or 'llvm':
108
- if cplus:
109
- return (CppCompilerRunner,
110
- {'lib_options': ['fortran']}, vendor)
111
- else:
112
- return (FortranCompilerRunner,
113
- {}, vendor)
114
- else:
115
- raise ValueError("No vendor found.")
116
-
117
-
118
- def link(obj_files, out_file=None, shared=False, Runner=None,
119
- cwd=None, cplus=False, fort=False, extra_objs=None, **kwargs):
120
- """ Link object files.
121
-
122
- Parameters
123
- ==========
124
-
125
- obj_files: iterable of str
126
- Paths to object files.
127
- out_file: str (optional)
128
- Path to executable/shared library, if ``None`` it will be
129
- deduced from the last item in obj_files.
130
- shared: bool
131
- Generate a shared library?
132
- Runner: CompilerRunner subclass (optional)
133
- If not given the ``cplus`` and ``fort`` flags will be inspected
134
- (fallback is the C compiler).
135
- cwd: str
136
- Path to the root of relative paths and working directory for compiler.
137
- cplus: bool
138
- C++ objects? default: ``False``.
139
- fort: bool
140
- Fortran objects? default: ``False``.
141
- extra_objs: list
142
- List of paths to extra object files / static libraries.
143
- \\*\\*kwargs: dict
144
- Keyword arguments passed to ``Runner``.
145
-
146
- Returns
147
- =======
148
-
149
- The absolute path to the generated shared object / executable.
150
-
151
- """
152
- if out_file is None:
153
- out_file, ext = os.path.splitext(os.path.basename(obj_files[-1]))
154
- if shared:
155
- out_file += get_config_var('EXT_SUFFIX')
156
-
157
- if not Runner:
158
- if fort:
159
- Runner, extra_kwargs, vendor = \
160
- get_mixed_fort_c_linker(
161
- vendor=kwargs.get('vendor', None),
162
- cplus=cplus,
163
- cwd=cwd,
164
- )
165
- for k, v in extra_kwargs.items():
166
- if k in kwargs:
167
- kwargs[k].expand(v)
168
- else:
169
- kwargs[k] = v
170
- else:
171
- if cplus:
172
- Runner = CppCompilerRunner
173
- else:
174
- Runner = CCompilerRunner
175
-
176
- flags = kwargs.pop('flags', [])
177
- if shared:
178
- if '-shared' not in flags:
179
- flags.append('-shared')
180
- run_linker = kwargs.pop('run_linker', True)
181
- if not run_linker:
182
- raise ValueError("run_linker was set to False (nonsensical).")
183
-
184
- out_file = get_abspath(out_file, cwd=cwd)
185
- runner = Runner(obj_files+(extra_objs or []), out_file, flags, cwd=cwd, **kwargs)
186
- runner.run()
187
- return out_file
188
-
189
-
190
- def link_py_so(obj_files, so_file=None, cwd=None, libraries=None,
191
- cplus=False, fort=False, extra_objs=None, **kwargs):
192
- """ Link Python extension module (shared object) for importing
193
-
194
- Parameters
195
- ==========
196
-
197
- obj_files: iterable of str
198
- Paths to object files to be linked.
199
- so_file: str
200
- Name (path) of shared object file to create. If not specified it will
201
- have the basname of the last object file in `obj_files` but with the
202
- extension '.so' (Unix).
203
- cwd: path string
204
- Root of relative paths and working directory of linker.
205
- libraries: iterable of strings
206
- Libraries to link against, e.g. ['m'].
207
- cplus: bool
208
- Any C++ objects? default: ``False``.
209
- fort: bool
210
- Any Fortran objects? default: ``False``.
211
- extra_objs: list
212
- List of paths of extra object files / static libraries to link against.
213
- kwargs**: dict
214
- Keyword arguments passed to ``link(...)``.
215
-
216
- Returns
217
- =======
218
-
219
- Absolute path to the generate shared object.
220
- """
221
- libraries = libraries or []
222
-
223
- include_dirs = kwargs.pop('include_dirs', [])
224
- library_dirs = kwargs.pop('library_dirs', [])
225
-
226
- # Add Python include and library directories
227
- # PY_LDFLAGS does not available on all python implementations
228
- # e.g. when with pypy, so it's LDFLAGS we need to use
229
- if sys.platform == "win32":
230
- warnings.warn("Windows not yet supported.")
231
- elif sys.platform == 'darwin':
232
- cfgDict = get_config_vars()
233
- kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']]
234
- library_dirs += [cfgDict['LIBDIR']]
235
-
236
- # In macOS, linker needs to compile frameworks
237
- # e.g. "-framework CoreFoundation"
238
- is_framework = False
239
- for opt in cfgDict['LIBS'].split():
240
- if is_framework:
241
- kwargs['linkline'] = kwargs.get('linkline', []) + ['-framework', opt]
242
- is_framework = False
243
- elif opt.startswith('-l'):
244
- libraries.append(opt[2:])
245
- elif opt.startswith('-framework'):
246
- is_framework = True
247
- # The python library is not included in LIBS
248
- libfile = cfgDict['LIBRARY']
249
- libname = ".".join(libfile.split('.')[:-1])[3:]
250
- libraries.append(libname)
251
-
252
- elif sys.platform[:3] == 'aix':
253
- # Don't use the default code below
254
- pass
255
- else:
256
- if get_config_var('Py_ENABLE_SHARED'):
257
- cfgDict = get_config_vars()
258
- kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']]
259
- library_dirs += [cfgDict['LIBDIR']]
260
- for opt in cfgDict['BLDLIBRARY'].split():
261
- if opt.startswith('-l'):
262
- libraries += [opt[2:]]
263
- else:
264
- pass
265
-
266
- flags = kwargs.pop('flags', [])
267
- needed_flags = ('-pthread',)
268
- for flag in needed_flags:
269
- if flag not in flags:
270
- flags.append(flag)
271
-
272
- return link(obj_files, shared=True, flags=flags, cwd=cwd, cplus=cplus, fort=fort,
273
- include_dirs=include_dirs, libraries=libraries,
274
- library_dirs=library_dirs, extra_objs=extra_objs, **kwargs)
275
-
276
-
277
- def simple_cythonize(src, destdir=None, cwd=None, **cy_kwargs):
278
- """ Generates a C file from a Cython source file.
279
-
280
- Parameters
281
- ==========
282
-
283
- src: str
284
- Path to Cython source.
285
- destdir: str (optional)
286
- Path to output directory (default: '.').
287
- cwd: path string (optional)
288
- Root of relative paths (default: '.').
289
- **cy_kwargs:
290
- Second argument passed to cy_compile. Generates a .cpp file if ``cplus=True`` in ``cy_kwargs``,
291
- else a .c file.
292
- """
293
- from Cython.Compiler.Main import (
294
- default_options, CompilationOptions
295
- )
296
- from Cython.Compiler.Main import compile as cy_compile
297
-
298
- assert src.lower().endswith('.pyx') or src.lower().endswith('.py')
299
- cwd = cwd or '.'
300
- destdir = destdir or '.'
301
-
302
- ext = '.cpp' if cy_kwargs.get('cplus', False) else '.c'
303
- c_name = os.path.splitext(os.path.basename(src))[0] + ext
304
-
305
- dstfile = os.path.join(destdir, c_name)
306
-
307
- if cwd:
308
- ori_dir = os.getcwd()
309
- else:
310
- ori_dir = '.'
311
- os.chdir(cwd)
312
- try:
313
- cy_options = CompilationOptions(default_options)
314
- cy_options.__dict__.update(cy_kwargs)
315
- # Set language_level if not set by cy_kwargs
316
- # as not setting it is deprecated
317
- if 'language_level' not in cy_kwargs:
318
- cy_options.__dict__['language_level'] = 3
319
- cy_result = cy_compile([src], cy_options)
320
- if cy_result.num_errors > 0:
321
- raise ValueError("Cython compilation failed.")
322
-
323
- # Move generated C file to destination
324
- # In macOS, the generated C file is in the same directory as the source
325
- # but the /var is a symlink to /private/var, so we need to use realpath
326
- if os.path.realpath(os.path.dirname(src)) != os.path.realpath(destdir):
327
- if os.path.exists(dstfile):
328
- os.unlink(dstfile)
329
- shutil.move(os.path.join(os.path.dirname(src), c_name), destdir)
330
- finally:
331
- os.chdir(ori_dir)
332
- return dstfile
333
-
334
-
335
- extension_mapping = {
336
- '.c': (CCompilerRunner, None),
337
- '.cpp': (CppCompilerRunner, None),
338
- '.cxx': (CppCompilerRunner, None),
339
- '.f': (FortranCompilerRunner, None),
340
- '.for': (FortranCompilerRunner, None),
341
- '.ftn': (FortranCompilerRunner, None),
342
- '.f90': (FortranCompilerRunner, None), # ifort only knows about .f90
343
- '.f95': (FortranCompilerRunner, 'f95'),
344
- '.f03': (FortranCompilerRunner, 'f2003'),
345
- '.f08': (FortranCompilerRunner, 'f2008'),
346
- }
347
-
348
-
349
- def src2obj(srcpath, Runner=None, objpath=None, cwd=None, inc_py=False, **kwargs):
350
- """ Compiles a source code file to an object file.
351
-
352
- Files ending with '.pyx' assumed to be cython files and
353
- are dispatched to pyx2obj.
354
-
355
- Parameters
356
- ==========
357
-
358
- srcpath: str
359
- Path to source file.
360
- Runner: CompilerRunner subclass (optional)
361
- If ``None``: deduced from extension of srcpath.
362
- objpath : str (optional)
363
- Path to generated object. If ``None``: deduced from ``srcpath``.
364
- cwd: str (optional)
365
- Working directory and root of relative paths. If ``None``: current dir.
366
- inc_py: bool
367
- Add Python include path to kwarg "include_dirs". Default: False
368
- \\*\\*kwargs: dict
369
- keyword arguments passed to Runner or pyx2obj
370
-
371
- """
372
- name, ext = os.path.splitext(os.path.basename(srcpath))
373
- if objpath is None:
374
- if os.path.isabs(srcpath):
375
- objpath = '.'
376
- else:
377
- objpath = os.path.dirname(srcpath)
378
- objpath = objpath or '.' # avoid objpath == ''
379
-
380
- if os.path.isdir(objpath):
381
- objpath = os.path.join(objpath, name + objext)
382
-
383
- include_dirs = kwargs.pop('include_dirs', [])
384
- if inc_py:
385
- py_inc_dir = get_path('include')
386
- if py_inc_dir not in include_dirs:
387
- include_dirs.append(py_inc_dir)
388
-
389
- if ext.lower() == '.pyx':
390
- return pyx2obj(srcpath, objpath=objpath, include_dirs=include_dirs, cwd=cwd,
391
- **kwargs)
392
-
393
- if Runner is None:
394
- Runner, std = extension_mapping[ext.lower()]
395
- if 'std' not in kwargs:
396
- kwargs['std'] = std
397
-
398
- flags = kwargs.pop('flags', [])
399
- needed_flags = ('-fPIC',)
400
- for flag in needed_flags:
401
- if flag not in flags:
402
- flags.append(flag)
403
-
404
- # src2obj implies not running the linker...
405
- run_linker = kwargs.pop('run_linker', False)
406
- if run_linker:
407
- raise CompileError("src2obj called with run_linker=True")
408
-
409
- runner = Runner([srcpath], objpath, include_dirs=include_dirs,
410
- run_linker=run_linker, cwd=cwd, flags=flags, **kwargs)
411
- runner.run()
412
- return objpath
413
-
414
-
415
- def pyx2obj(pyxpath, objpath=None, destdir=None, cwd=None,
416
- include_dirs=None, cy_kwargs=None, cplus=None, **kwargs):
417
- """
418
- Convenience function
419
-
420
- If cwd is specified, pyxpath and dst are taken to be relative
421
- If only_update is set to `True` the modification time is checked
422
- and compilation is only run if the source is newer than the
423
- destination
424
-
425
- Parameters
426
- ==========
427
-
428
- pyxpath: str
429
- Path to Cython source file.
430
- objpath: str (optional)
431
- Path to object file to generate.
432
- destdir: str (optional)
433
- Directory to put generated C file. When ``None``: directory of ``objpath``.
434
- cwd: str (optional)
435
- Working directory and root of relative paths.
436
- include_dirs: iterable of path strings (optional)
437
- Passed onto src2obj and via cy_kwargs['include_path']
438
- to simple_cythonize.
439
- cy_kwargs: dict (optional)
440
- Keyword arguments passed onto `simple_cythonize`
441
- cplus: bool (optional)
442
- Indicate whether C++ is used. default: auto-detect using ``.util.pyx_is_cplus``.
443
- compile_kwargs: dict
444
- keyword arguments passed onto src2obj
445
-
446
- Returns
447
- =======
448
-
449
- Absolute path of generated object file.
450
-
451
- """
452
- assert pyxpath.endswith('.pyx')
453
- cwd = cwd or '.'
454
- objpath = objpath or '.'
455
- destdir = destdir or os.path.dirname(objpath)
456
-
457
- abs_objpath = get_abspath(objpath, cwd=cwd)
458
-
459
- if os.path.isdir(abs_objpath):
460
- pyx_fname = os.path.basename(pyxpath)
461
- name, ext = os.path.splitext(pyx_fname)
462
- objpath = os.path.join(objpath, name + objext)
463
-
464
- cy_kwargs = cy_kwargs or {}
465
- cy_kwargs['output_dir'] = cwd
466
- if cplus is None:
467
- cplus = pyx_is_cplus(pyxpath)
468
- cy_kwargs['cplus'] = cplus
469
-
470
- interm_c_file = simple_cythonize(pyxpath, destdir=destdir, cwd=cwd, **cy_kwargs)
471
-
472
- include_dirs = include_dirs or []
473
- flags = kwargs.pop('flags', [])
474
- needed_flags = ('-fwrapv', '-pthread', '-fPIC')
475
- for flag in needed_flags:
476
- if flag not in flags:
477
- flags.append(flag)
478
-
479
- options = kwargs.pop('options', [])
480
-
481
- if kwargs.pop('strict_aliasing', False):
482
- raise CompileError("Cython requires strict aliasing to be disabled.")
483
-
484
- # Let's be explicit about standard
485
- if cplus:
486
- std = kwargs.pop('std', 'c++98')
487
- else:
488
- std = kwargs.pop('std', 'c99')
489
-
490
- return src2obj(interm_c_file, objpath=objpath, cwd=cwd,
491
- include_dirs=include_dirs, flags=flags, std=std,
492
- options=options, inc_py=True, strict_aliasing=False,
493
- **kwargs)
494
-
495
-
496
- def _any_X(srcs, cls):
497
- for src in srcs:
498
- name, ext = os.path.splitext(src)
499
- key = ext.lower()
500
- if key in extension_mapping:
501
- if extension_mapping[key][0] == cls:
502
- return True
503
- return False
504
-
505
-
506
- def any_fortran_src(srcs):
507
- return _any_X(srcs, FortranCompilerRunner)
508
-
509
-
510
- def any_cplus_src(srcs):
511
- return _any_X(srcs, CppCompilerRunner)
512
-
513
-
514
- def compile_link_import_py_ext(sources, extname=None, build_dir='.', compile_kwargs=None,
515
- link_kwargs=None, extra_objs=None):
516
- """ Compiles sources to a shared object (Python extension) and imports it
517
-
518
- Sources in ``sources`` which is imported. If shared object is newer than the sources, they
519
- are not recompiled but instead it is imported.
520
-
521
- Parameters
522
- ==========
523
-
524
- sources : list of strings
525
- List of paths to sources.
526
- extname : string
527
- Name of extension (default: ``None``).
528
- If ``None``: taken from the last file in ``sources`` without extension.
529
- build_dir: str
530
- Path to directory in which objects files etc. are generated.
531
- compile_kwargs: dict
532
- keyword arguments passed to ``compile_sources``
533
- link_kwargs: dict
534
- keyword arguments passed to ``link_py_so``
535
- extra_objs: list
536
- List of paths to (prebuilt) object files / static libraries to link against.
537
-
538
- Returns
539
- =======
540
-
541
- The imported module from of the Python extension.
542
- """
543
- if extname is None:
544
- extname = os.path.splitext(os.path.basename(sources[-1]))[0]
545
-
546
- compile_kwargs = compile_kwargs or {}
547
- link_kwargs = link_kwargs or {}
548
-
549
- try:
550
- mod = import_module_from_file(os.path.join(build_dir, extname), sources)
551
- except ImportError:
552
- objs = compile_sources(list(map(get_abspath, sources)), destdir=build_dir,
553
- cwd=build_dir, **compile_kwargs)
554
- so = link_py_so(objs, cwd=build_dir, fort=any_fortran_src(sources),
555
- cplus=any_cplus_src(sources), extra_objs=extra_objs, **link_kwargs)
556
- mod = import_module_from_file(so)
557
- return mod
558
-
559
-
560
- def _write_sources_to_build_dir(sources, build_dir):
561
- build_dir = build_dir or tempfile.mkdtemp()
562
- if not os.path.isdir(build_dir):
563
- raise OSError("Non-existent directory: ", build_dir)
564
-
565
- source_files = []
566
- for name, src in sources:
567
- dest = os.path.join(build_dir, name)
568
- differs = True
569
- sha256_in_mem = sha256_of_string(src.encode('utf-8')).hexdigest()
570
- if os.path.exists(dest):
571
- if os.path.exists(dest + '.sha256'):
572
- sha256_on_disk = Path(dest + '.sha256').read_text()
573
- else:
574
- sha256_on_disk = sha256_of_file(dest).hexdigest()
575
-
576
- differs = sha256_on_disk != sha256_in_mem
577
- if differs:
578
- with open(dest, 'wt') as fh:
579
- fh.write(src)
580
- with open(dest + '.sha256', 'wt') as fh:
581
- fh.write(sha256_in_mem)
582
- source_files.append(dest)
583
- return source_files, build_dir
584
-
585
-
586
- def compile_link_import_strings(sources, build_dir=None, **kwargs):
587
- """ Compiles, links and imports extension module from source.
588
-
589
- Parameters
590
- ==========
591
-
592
- sources : iterable of name/source pair tuples
593
- build_dir : string (default: None)
594
- Path. ``None`` implies use a temporary directory.
595
- **kwargs:
596
- Keyword arguments passed onto `compile_link_import_py_ext`.
597
-
598
- Returns
599
- =======
600
-
601
- mod : module
602
- The compiled and imported extension module.
603
- info : dict
604
- Containing ``build_dir`` as 'build_dir'.
605
-
606
- """
607
- source_files, build_dir = _write_sources_to_build_dir(sources, build_dir)
608
- mod = compile_link_import_py_ext(source_files, build_dir=build_dir, **kwargs)
609
- info = {"build_dir": build_dir}
610
- return mod, info
611
-
612
-
613
- def compile_run_strings(sources, build_dir=None, clean=False, compile_kwargs=None, link_kwargs=None):
614
- """ Compiles, links and runs a program built from sources.
615
-
616
- Parameters
617
- ==========
618
-
619
- sources : iterable of name/source pair tuples
620
- build_dir : string (default: None)
621
- Path. ``None`` implies use a temporary directory.
622
- clean : bool
623
- Whether to remove build_dir after use. This will only have an
624
- effect if ``build_dir`` is ``None`` (which creates a temporary directory).
625
- Passing ``clean == True`` and ``build_dir != None`` raises a ``ValueError``.
626
- This will also set ``build_dir`` in returned info dictionary to ``None``.
627
- compile_kwargs: dict
628
- Keyword arguments passed onto ``compile_sources``
629
- link_kwargs: dict
630
- Keyword arguments passed onto ``link``
631
-
632
- Returns
633
- =======
634
-
635
- (stdout, stderr): pair of strings
636
- info: dict
637
- Containing exit status as 'exit_status' and ``build_dir`` as 'build_dir'
638
-
639
- """
640
- if clean and build_dir is not None:
641
- raise ValueError("Automatic removal of build_dir is only available for temporary directory.")
642
- try:
643
- source_files, build_dir = _write_sources_to_build_dir(sources, build_dir)
644
- objs = compile_sources(list(map(get_abspath, source_files)), destdir=build_dir,
645
- cwd=build_dir, **(compile_kwargs or {}))
646
- prog = link(objs, cwd=build_dir,
647
- fort=any_fortran_src(source_files),
648
- cplus=any_cplus_src(source_files), **(link_kwargs or {}))
649
- p = subprocess.Popen([prog], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
650
- exit_status = p.wait()
651
- stdout, stderr = [txt.decode('utf-8') for txt in p.communicate()]
652
- finally:
653
- if clean and os.path.isdir(build_dir):
654
- shutil.rmtree(build_dir)
655
- build_dir = None
656
- info = {"exit_status": exit_status, "build_dir": build_dir}
657
- return (stdout, stderr), info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/runners.py DELETED
@@ -1,301 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Callable, Optional
3
-
4
- from collections import OrderedDict
5
- import os
6
- import re
7
- import subprocess
8
- import warnings
9
-
10
- from .util import (
11
- find_binary_of_command, unique_list, CompileError
12
- )
13
-
14
-
15
- class CompilerRunner:
16
- """ CompilerRunner base class.
17
-
18
- Parameters
19
- ==========
20
-
21
- sources : list of str
22
- Paths to sources.
23
- out : str
24
- flags : iterable of str
25
- Compiler flags.
26
- run_linker : bool
27
- compiler_name_exe : (str, str) tuple
28
- Tuple of compiler name & command to call.
29
- cwd : str
30
- Path of root of relative paths.
31
- include_dirs : list of str
32
- Include directories.
33
- libraries : list of str
34
- Libraries to link against.
35
- library_dirs : list of str
36
- Paths to search for shared libraries.
37
- std : str
38
- Standard string, e.g. ``'c++11'``, ``'c99'``, ``'f2003'``.
39
- define: iterable of strings
40
- macros to define
41
- undef : iterable of strings
42
- macros to undefine
43
- preferred_vendor : string
44
- name of preferred vendor e.g. 'gnu' or 'intel'
45
-
46
- Methods
47
- =======
48
-
49
- run():
50
- Invoke compilation as a subprocess.
51
-
52
- """
53
-
54
- environ_key_compiler: str # e.g. 'CC', 'CXX', ...
55
- environ_key_flags: str # e.g. 'CFLAGS', 'CXXFLAGS', ...
56
- environ_key_ldflags: str = "LDFLAGS" # typically 'LDFLAGS'
57
-
58
- # Subclass to vendor/binary dict
59
- compiler_dict: dict[str, str]
60
-
61
- # Standards should be a tuple of supported standards
62
- # (first one will be the default)
63
- standards: tuple[None | str, ...]
64
-
65
- # Subclass to dict of binary/formater-callback
66
- std_formater: dict[str, Callable[[Optional[str]], str]]
67
-
68
- # subclass to be e.g. {'gcc': 'gnu', ...}
69
- compiler_name_vendor_mapping: dict[str, str]
70
-
71
- def __init__(self, sources, out, flags=None, run_linker=True, compiler=None, cwd='.',
72
- include_dirs=None, libraries=None, library_dirs=None, std=None, define=None,
73
- undef=None, strict_aliasing=None, preferred_vendor=None, linkline=None, **kwargs):
74
- if isinstance(sources, str):
75
- raise ValueError("Expected argument sources to be a list of strings.")
76
- self.sources = list(sources)
77
- self.out = out
78
- self.flags = flags or []
79
- if os.environ.get(self.environ_key_flags):
80
- self.flags += os.environ[self.environ_key_flags].split()
81
- self.cwd = cwd
82
- if compiler:
83
- self.compiler_name, self.compiler_binary = compiler
84
- elif os.environ.get(self.environ_key_compiler):
85
- self.compiler_binary = os.environ[self.environ_key_compiler]
86
- for k, v in self.compiler_dict.items():
87
- if k in self.compiler_binary:
88
- self.compiler_vendor = k
89
- self.compiler_name = v
90
- break
91
- else:
92
- self.compiler_vendor, self.compiler_name = list(self.compiler_dict.items())[0]
93
- warnings.warn("failed to determine what kind of compiler %s is, assuming %s" %
94
- (self.compiler_binary, self.compiler_name))
95
- else:
96
- # Find a compiler
97
- if preferred_vendor is None:
98
- preferred_vendor = os.environ.get('SYMPY_COMPILER_VENDOR', None)
99
- self.compiler_name, self.compiler_binary, self.compiler_vendor = self.find_compiler(preferred_vendor)
100
- if self.compiler_binary is None:
101
- raise ValueError("No compiler found (searched: {})".format(', '.join(self.compiler_dict.values())))
102
- self.define = define or []
103
- self.undef = undef or []
104
- self.include_dirs = include_dirs or []
105
- self.libraries = libraries or []
106
- self.library_dirs = library_dirs or []
107
- self.std = std or self.standards[0]
108
- self.run_linker = run_linker
109
- if self.run_linker:
110
- # both gnu and intel compilers use '-c' for disabling linker
111
- self.flags = list(filter(lambda x: x != '-c', self.flags))
112
- else:
113
- if '-c' not in self.flags:
114
- self.flags.append('-c')
115
-
116
- if self.std:
117
- self.flags.append(self.std_formater[
118
- self.compiler_name](self.std))
119
-
120
- self.linkline = (linkline or []) + [lf for lf in map(
121
- str.strip, os.environ.get(self.environ_key_ldflags, "").split()
122
- ) if lf != ""]
123
-
124
- if strict_aliasing is not None:
125
- nsa_re = re.compile("no-strict-aliasing$")
126
- sa_re = re.compile("strict-aliasing$")
127
- if strict_aliasing is True:
128
- if any(map(nsa_re.match, flags)):
129
- raise CompileError("Strict aliasing cannot be both enforced and disabled")
130
- elif any(map(sa_re.match, flags)):
131
- pass # already enforced
132
- else:
133
- flags.append('-fstrict-aliasing')
134
- elif strict_aliasing is False:
135
- if any(map(nsa_re.match, flags)):
136
- pass # already disabled
137
- else:
138
- if any(map(sa_re.match, flags)):
139
- raise CompileError("Strict aliasing cannot be both enforced and disabled")
140
- else:
141
- flags.append('-fno-strict-aliasing')
142
- else:
143
- msg = "Expected argument strict_aliasing to be True/False, got {}"
144
- raise ValueError(msg.format(strict_aliasing))
145
-
146
- @classmethod
147
- def find_compiler(cls, preferred_vendor=None):
148
- """ Identify a suitable C/fortran/other compiler. """
149
- candidates = list(cls.compiler_dict.keys())
150
- if preferred_vendor:
151
- if preferred_vendor in candidates:
152
- candidates = [preferred_vendor]+candidates
153
- else:
154
- raise ValueError("Unknown vendor {}".format(preferred_vendor))
155
- name, path = find_binary_of_command([cls.compiler_dict[x] for x in candidates])
156
- return name, path, cls.compiler_name_vendor_mapping[name]
157
-
158
- def cmd(self):
159
- """ List of arguments (str) to be passed to e.g. ``subprocess.Popen``. """
160
- cmd = (
161
- [self.compiler_binary] +
162
- self.flags +
163
- ['-U'+x for x in self.undef] +
164
- ['-D'+x for x in self.define] +
165
- ['-I'+x for x in self.include_dirs] +
166
- self.sources
167
- )
168
- if self.run_linker:
169
- cmd += (['-L'+x for x in self.library_dirs] +
170
- ['-l'+x for x in self.libraries] +
171
- self.linkline)
172
- counted = []
173
- for envvar in re.findall(r'\$\{(\w+)\}', ' '.join(cmd)):
174
- if os.getenv(envvar) is None:
175
- if envvar not in counted:
176
- counted.append(envvar)
177
- msg = "Environment variable '{}' undefined.".format(envvar)
178
- raise CompileError(msg)
179
- return cmd
180
-
181
- def run(self):
182
- self.flags = unique_list(self.flags)
183
-
184
- # Append output flag and name to tail of flags
185
- self.flags.extend(['-o', self.out])
186
- env = os.environ.copy()
187
- env['PWD'] = self.cwd
188
-
189
- # NOTE: intel compilers seems to need shell=True
190
- p = subprocess.Popen(' '.join(self.cmd()),
191
- shell=True,
192
- cwd=self.cwd,
193
- stdin=subprocess.PIPE,
194
- stdout=subprocess.PIPE,
195
- stderr=subprocess.STDOUT,
196
- env=env)
197
- comm = p.communicate()
198
- try:
199
- self.cmd_outerr = comm[0].decode('utf-8')
200
- except UnicodeDecodeError:
201
- self.cmd_outerr = comm[0].decode('iso-8859-1') # win32
202
- self.cmd_returncode = p.returncode
203
-
204
- # Error handling
205
- if self.cmd_returncode != 0:
206
- msg = "Error executing '{}' in {} (exited status {}):\n {}\n".format(
207
- ' '.join(self.cmd()), self.cwd, str(self.cmd_returncode), self.cmd_outerr
208
- )
209
- raise CompileError(msg)
210
-
211
- return self.cmd_outerr, self.cmd_returncode
212
-
213
-
214
- class CCompilerRunner(CompilerRunner):
215
-
216
- environ_key_compiler = 'CC'
217
- environ_key_flags = 'CFLAGS'
218
-
219
- compiler_dict = OrderedDict([
220
- ('gnu', 'gcc'),
221
- ('intel', 'icc'),
222
- ('llvm', 'clang'),
223
- ])
224
-
225
- standards = ('c89', 'c90', 'c99', 'c11') # First is default
226
-
227
- std_formater = {
228
- 'gcc': '-std={}'.format,
229
- 'icc': '-std={}'.format,
230
- 'clang': '-std={}'.format,
231
- }
232
-
233
- compiler_name_vendor_mapping = {
234
- 'gcc': 'gnu',
235
- 'icc': 'intel',
236
- 'clang': 'llvm'
237
- }
238
-
239
-
240
- def _mk_flag_filter(cmplr_name): # helper for class initialization
241
- not_welcome = {'g++': ("Wimplicit-interface",)} # "Wstrict-prototypes",)}
242
- if cmplr_name in not_welcome:
243
- def fltr(x):
244
- for nw in not_welcome[cmplr_name]:
245
- if nw in x:
246
- return False
247
- return True
248
- else:
249
- def fltr(x):
250
- return True
251
- return fltr
252
-
253
-
254
- class CppCompilerRunner(CompilerRunner):
255
-
256
- environ_key_compiler = 'CXX'
257
- environ_key_flags = 'CXXFLAGS'
258
-
259
- compiler_dict = OrderedDict([
260
- ('gnu', 'g++'),
261
- ('intel', 'icpc'),
262
- ('llvm', 'clang++'),
263
- ])
264
-
265
- # First is the default, c++0x == c++11
266
- standards = ('c++98', 'c++0x')
267
-
268
- std_formater = {
269
- 'g++': '-std={}'.format,
270
- 'icpc': '-std={}'.format,
271
- 'clang++': '-std={}'.format,
272
- }
273
-
274
- compiler_name_vendor_mapping = {
275
- 'g++': 'gnu',
276
- 'icpc': 'intel',
277
- 'clang++': 'llvm'
278
- }
279
-
280
-
281
- class FortranCompilerRunner(CompilerRunner):
282
-
283
- environ_key_compiler = 'FC'
284
- environ_key_flags = 'FFLAGS'
285
-
286
- standards = (None, 'f77', 'f95', 'f2003', 'f2008')
287
-
288
- std_formater = {
289
- 'gfortran': lambda x: '-std=gnu' if x is None else '-std=legacy' if x == 'f77' else '-std={}'.format(x),
290
- 'ifort': lambda x: '-stand f08' if x is None else '-stand f{}'.format(x[-2:]), # f2008 => f08
291
- }
292
-
293
- compiler_dict = OrderedDict([
294
- ('gnu', 'gfortran'),
295
- ('intel', 'ifort'),
296
- ])
297
-
298
- compiler_name_vendor_mapping = {
299
- 'gfortran': 'gnu',
300
- 'ifort': 'intel',
301
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/__init__.py DELETED
File without changes
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/test_compilation.py DELETED
@@ -1,104 +0,0 @@
1
- import shutil
2
- import os
3
- import subprocess
4
- import tempfile
5
- from sympy.external import import_module
6
- from sympy.testing.pytest import skip, skip_under_pyodide
7
-
8
- from sympy.utilities._compilation.compilation import compile_link_import_py_ext, compile_link_import_strings, compile_sources, get_abspath
9
-
10
- numpy = import_module('numpy')
11
- cython = import_module('cython')
12
-
13
- _sources1 = [
14
- ('sigmoid.c', r"""
15
- #include <math.h>
16
-
17
- void sigmoid(int n, const double * const restrict in,
18
- double * const restrict out, double lim){
19
- for (int i=0; i<n; ++i){
20
- const double x = in[i];
21
- out[i] = x*pow(pow(x/lim, 8)+1, -1./8.);
22
- }
23
- }
24
- """),
25
- ('_sigmoid.pyx', r"""
26
- import numpy as np
27
- cimport numpy as cnp
28
-
29
- cdef extern void c_sigmoid "sigmoid" (int, const double * const,
30
- double * const, double)
31
-
32
- def sigmoid(double [:] inp, double lim=350.0):
33
- cdef cnp.ndarray[cnp.float64_t, ndim=1] out = np.empty(
34
- inp.size, dtype=np.float64)
35
- c_sigmoid(inp.size, &inp[0], &out[0], lim)
36
- return out
37
- """)
38
- ]
39
-
40
-
41
- def npy(data, lim=350.0):
42
- return data/((data/lim)**8+1)**(1/8.)
43
-
44
-
45
- def test_compile_link_import_strings():
46
- if not numpy:
47
- skip("numpy not installed.")
48
- if not cython:
49
- skip("cython not installed.")
50
-
51
- from sympy.utilities._compilation import has_c
52
- if not has_c():
53
- skip("No C compiler found.")
54
-
55
- compile_kw = {"std": 'c99', "include_dirs": [numpy.get_include()]}
56
- info = None
57
- try:
58
- mod, info = compile_link_import_strings(_sources1, compile_kwargs=compile_kw)
59
- data = numpy.random.random(1024*1024*8) # 64 MB of RAM needed..
60
- res_mod = mod.sigmoid(data)
61
- res_npy = npy(data)
62
- assert numpy.allclose(res_mod, res_npy)
63
- finally:
64
- if info and info['build_dir']:
65
- shutil.rmtree(info['build_dir'])
66
-
67
-
68
- @skip_under_pyodide("Emscripten does not support subprocesses")
69
- def test_compile_sources():
70
- tmpdir = tempfile.mkdtemp()
71
-
72
- from sympy.utilities._compilation import has_c
73
- if not has_c():
74
- skip("No C compiler found.")
75
-
76
- build_dir = str(tmpdir)
77
- _handle, file_path = tempfile.mkstemp('.c', dir=build_dir)
78
- with open(file_path, 'wt') as ofh:
79
- ofh.write("""
80
- int foo(int bar) {
81
- return 2*bar;
82
- }
83
- """)
84
- obj, = compile_sources([file_path], cwd=build_dir)
85
- obj_path = get_abspath(obj, cwd=build_dir)
86
- assert os.path.exists(obj_path)
87
- try:
88
- _ = subprocess.check_output(["nm", "--help"])
89
- except subprocess.CalledProcessError:
90
- pass # we cannot test contents of object file
91
- else:
92
- nm_out = subprocess.check_output(["nm", obj_path])
93
- assert 'foo' in nm_out.decode('utf-8')
94
-
95
- if not cython:
96
- return # the final (optional) part of the test below requires Cython.
97
-
98
- _handle, pyx_path = tempfile.mkstemp('.pyx', dir=build_dir)
99
- with open(pyx_path, 'wt') as ofh:
100
- ofh.write(("cdef extern int foo(int)\n"
101
- "def _foo(arg):\n"
102
- " return foo(arg)"))
103
- mod = compile_link_import_py_ext([pyx_path], extra_objs=[obj_path], build_dir=build_dir)
104
- assert mod._foo(21) == 42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/util.py DELETED
@@ -1,312 +0,0 @@
1
- from collections import namedtuple
2
- from hashlib import sha256
3
- import os
4
- import shutil
5
- import sys
6
- import fnmatch
7
-
8
- from sympy.testing.pytest import XFAIL
9
-
10
-
11
- def may_xfail(func):
12
- if sys.platform.lower() == 'darwin' or os.name == 'nt':
13
- # sympy.utilities._compilation needs more testing on Windows and macOS
14
- # once those two platforms are reliably supported this xfail decorator
15
- # may be removed.
16
- return XFAIL(func)
17
- else:
18
- return func
19
-
20
-
21
- class CompilerNotFoundError(FileNotFoundError):
22
- pass
23
-
24
-
25
- class CompileError (Exception):
26
- """Failure to compile one or more C/C++ source files."""
27
-
28
-
29
- def get_abspath(path, cwd='.'):
30
- """ Returns the absolute path.
31
-
32
- Parameters
33
- ==========
34
-
35
- path : str
36
- (relative) path.
37
- cwd : str
38
- Path to root of relative path.
39
- """
40
- if os.path.isabs(path):
41
- return path
42
- else:
43
- if not os.path.isabs(cwd):
44
- cwd = os.path.abspath(cwd)
45
- return os.path.abspath(
46
- os.path.join(cwd, path)
47
- )
48
-
49
-
50
- def make_dirs(path):
51
- """ Create directories (equivalent of ``mkdir -p``). """
52
- if path[-1] == '/':
53
- parent = os.path.dirname(path[:-1])
54
- else:
55
- parent = os.path.dirname(path)
56
-
57
- if len(parent) > 0:
58
- if not os.path.exists(parent):
59
- make_dirs(parent)
60
-
61
- if not os.path.exists(path):
62
- os.mkdir(path, 0o777)
63
- else:
64
- assert os.path.isdir(path)
65
-
66
- def missing_or_other_newer(path, other_path, cwd=None):
67
- """
68
- Investigate if path is non-existent or older than provided reference
69
- path.
70
-
71
- Parameters
72
- ==========
73
- path: string
74
- path to path which might be missing or too old
75
- other_path: string
76
- reference path
77
- cwd: string
78
- working directory (root of relative paths)
79
-
80
- Returns
81
- =======
82
- True if path is older or missing.
83
- """
84
- cwd = cwd or '.'
85
- path = get_abspath(path, cwd=cwd)
86
- other_path = get_abspath(other_path, cwd=cwd)
87
- if not os.path.exists(path):
88
- return True
89
- if os.path.getmtime(other_path) - 1e-6 >= os.path.getmtime(path):
90
- # 1e-6 is needed because http://stackoverflow.com/questions/17086426/
91
- return True
92
- return False
93
-
94
- def copy(src, dst, only_update=False, copystat=True, cwd=None,
95
- dest_is_dir=False, create_dest_dirs=False):
96
- """ Variation of ``shutil.copy`` with extra options.
97
-
98
- Parameters
99
- ==========
100
-
101
- src : str
102
- Path to source file.
103
- dst : str
104
- Path to destination.
105
- only_update : bool
106
- Only copy if source is newer than destination
107
- (returns None if it was newer), default: ``False``.
108
- copystat : bool
109
- See ``shutil.copystat``. default: ``True``.
110
- cwd : str
111
- Path to working directory (root of relative paths).
112
- dest_is_dir : bool
113
- Ensures that dst is treated as a directory. default: ``False``
114
- create_dest_dirs : bool
115
- Creates directories if needed.
116
-
117
- Returns
118
- =======
119
-
120
- Path to the copied file.
121
-
122
- """
123
- if cwd: # Handle working directory
124
- if not os.path.isabs(src):
125
- src = os.path.join(cwd, src)
126
- if not os.path.isabs(dst):
127
- dst = os.path.join(cwd, dst)
128
-
129
- if not os.path.exists(src): # Make sure source file exists
130
- raise FileNotFoundError("Source: `{}` does not exist".format(src))
131
-
132
- # We accept both (re)naming destination file _or_
133
- # passing a (possible non-existent) destination directory
134
- if dest_is_dir:
135
- if not dst[-1] == '/':
136
- dst = dst+'/'
137
- else:
138
- if os.path.exists(dst) and os.path.isdir(dst):
139
- dest_is_dir = True
140
-
141
- if dest_is_dir:
142
- dest_dir = dst
143
- dest_fname = os.path.basename(src)
144
- dst = os.path.join(dest_dir, dest_fname)
145
- else:
146
- dest_dir = os.path.dirname(dst)
147
-
148
- if not os.path.exists(dest_dir):
149
- if create_dest_dirs:
150
- make_dirs(dest_dir)
151
- else:
152
- raise FileNotFoundError("You must create directory first.")
153
-
154
- if only_update:
155
- if not missing_or_other_newer(dst, src):
156
- return
157
-
158
- if os.path.islink(dst):
159
- dst = os.path.abspath(os.path.realpath(dst), cwd=cwd)
160
-
161
- shutil.copy(src, dst)
162
- if copystat:
163
- shutil.copystat(src, dst)
164
-
165
- return dst
166
-
167
- Glob = namedtuple('Glob', 'pathname')
168
- ArbitraryDepthGlob = namedtuple('ArbitraryDepthGlob', 'filename')
169
-
170
- def glob_at_depth(filename_glob, cwd=None):
171
- if cwd is not None:
172
- cwd = '.'
173
- globbed = []
174
- for root, dirs, filenames in os.walk(cwd):
175
- for fn in filenames:
176
- # This is not tested:
177
- if fnmatch.fnmatch(fn, filename_glob):
178
- globbed.append(os.path.join(root, fn))
179
- return globbed
180
-
181
- def sha256_of_file(path, nblocks=128):
182
- """ Computes the SHA256 hash of a file.
183
-
184
- Parameters
185
- ==========
186
-
187
- path : string
188
- Path to file to compute hash of.
189
- nblocks : int
190
- Number of blocks to read per iteration.
191
-
192
- Returns
193
- =======
194
-
195
- hashlib sha256 hash object. Use ``.digest()`` or ``.hexdigest()``
196
- on returned object to get binary or hex encoded string.
197
- """
198
- sh = sha256()
199
- with open(path, 'rb') as f:
200
- for chunk in iter(lambda: f.read(nblocks*sh.block_size), b''):
201
- sh.update(chunk)
202
- return sh
203
-
204
-
205
- def sha256_of_string(string):
206
- """ Computes the SHA256 hash of a string. """
207
- sh = sha256()
208
- sh.update(string)
209
- return sh
210
-
211
-
212
- def pyx_is_cplus(path):
213
- """
214
- Inspect a Cython source file (.pyx) and look for comment line like:
215
-
216
- # distutils: language = c++
217
-
218
- Returns True if such a file is present in the file, else False.
219
- """
220
- with open(path) as fh:
221
- for line in fh:
222
- if line.startswith('#') and '=' in line:
223
- splitted = line.split('=')
224
- if len(splitted) != 2:
225
- continue
226
- lhs, rhs = splitted
227
- if lhs.strip().split()[-1].lower() == 'language' and \
228
- rhs.strip().split()[0].lower() == 'c++':
229
- return True
230
- return False
231
-
232
- def import_module_from_file(filename, only_if_newer_than=None):
233
- """ Imports Python extension (from shared object file)
234
-
235
- Provide a list of paths in `only_if_newer_than` to check
236
- timestamps of dependencies. import_ raises an ImportError
237
- if any is newer.
238
-
239
- Word of warning: The OS may cache shared objects which makes
240
- reimporting same path of an shared object file very problematic.
241
-
242
- It will not detect the new time stamp, nor new checksum, but will
243
- instead silently use old module. Use unique names for this reason.
244
-
245
- Parameters
246
- ==========
247
-
248
- filename : str
249
- Path to shared object.
250
- only_if_newer_than : iterable of strings
251
- Paths to dependencies of the shared object.
252
-
253
- Raises
254
- ======
255
-
256
- ``ImportError`` if any of the files specified in ``only_if_newer_than`` are newer
257
- than the file given by filename.
258
- """
259
- path, name = os.path.split(filename)
260
- name, ext = os.path.splitext(name)
261
- name = name.split('.')[0]
262
- if sys.version_info[0] == 2:
263
- from imp import find_module, load_module
264
- fobj, filename, data = find_module(name, [path])
265
- if only_if_newer_than:
266
- for dep in only_if_newer_than:
267
- if os.path.getmtime(filename) < os.path.getmtime(dep):
268
- raise ImportError("{} is newer than {}".format(dep, filename))
269
- mod = load_module(name, fobj, filename, data)
270
- else:
271
- import importlib.util
272
- spec = importlib.util.spec_from_file_location(name, filename)
273
- if spec is None:
274
- raise ImportError("Failed to import: '%s'" % filename)
275
- mod = importlib.util.module_from_spec(spec)
276
- spec.loader.exec_module(mod)
277
- return mod
278
-
279
-
280
- def find_binary_of_command(candidates):
281
- """ Finds binary first matching name among candidates.
282
-
283
- Calls ``which`` from shutils for provided candidates and returns
284
- first hit.
285
-
286
- Parameters
287
- ==========
288
-
289
- candidates : iterable of str
290
- Names of candidate commands
291
-
292
- Raises
293
- ======
294
-
295
- CompilerNotFoundError if no candidates match.
296
- """
297
- from shutil import which
298
- for c in candidates:
299
- binary_path = which(c)
300
- if c and binary_path:
301
- return c, binary_path
302
-
303
- raise CompilerNotFoundError('No binary located for candidates: {}'.format(candidates))
304
-
305
-
306
- def unique_list(l):
307
- """ Uniquify a list (skip duplicate items). """
308
- result = []
309
- for x in l:
310
- if x not in result:
311
- result.append(x)
312
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/autowrap.py DELETED
@@ -1,1178 +0,0 @@
1
- """Module for compiling codegen output, and wrap the binary for use in
2
- python.
3
-
4
- .. note:: To use the autowrap module it must first be imported
5
-
6
- >>> from sympy.utilities.autowrap import autowrap
7
-
8
- This module provides a common interface for different external backends, such
9
- as f2py, fwrap, Cython, SWIG(?) etc. (Currently only f2py and Cython are
10
- implemented) The goal is to provide access to compiled binaries of acceptable
11
- performance with a one-button user interface, e.g.,
12
-
13
- >>> from sympy.abc import x,y
14
- >>> expr = (x - y)**25
15
- >>> flat = expr.expand()
16
- >>> binary_callable = autowrap(flat)
17
- >>> binary_callable(2, 3)
18
- -1.0
19
-
20
- Although a SymPy user might primarily be interested in working with
21
- mathematical expressions and not in the details of wrapping tools
22
- needed to evaluate such expressions efficiently in numerical form,
23
- the user cannot do so without some understanding of the
24
- limits in the target language. For example, the expanded expression
25
- contains large coefficients which result in loss of precision when
26
- computing the expression:
27
-
28
- >>> binary_callable(3, 2)
29
- 0.0
30
- >>> binary_callable(4, 5), binary_callable(5, 4)
31
- (-22925376.0, 25165824.0)
32
-
33
- Wrapping the unexpanded expression gives the expected behavior:
34
-
35
- >>> e = autowrap(expr)
36
- >>> e(4, 5), e(5, 4)
37
- (-1.0, 1.0)
38
-
39
- The callable returned from autowrap() is a binary Python function, not a
40
- SymPy object. If it is desired to use the compiled function in symbolic
41
- expressions, it is better to use binary_function() which returns a SymPy
42
- Function object. The binary callable is attached as the _imp_ attribute and
43
- invoked when a numerical evaluation is requested with evalf(), or with
44
- lambdify().
45
-
46
- >>> from sympy.utilities.autowrap import binary_function
47
- >>> f = binary_function('f', expr)
48
- >>> 2*f(x, y) + y
49
- y + 2*f(x, y)
50
- >>> (2*f(x, y) + y).evalf(2, subs={x: 1, y:2})
51
- 0.e-110
52
-
53
- When is this useful?
54
-
55
- 1) For computations on large arrays, Python iterations may be too slow,
56
- and depending on the mathematical expression, it may be difficult to
57
- exploit the advanced index operations provided by NumPy.
58
-
59
- 2) For *really* long expressions that will be called repeatedly, the
60
- compiled binary should be significantly faster than SymPy's .evalf()
61
-
62
- 3) If you are generating code with the codegen utility in order to use
63
- it in another project, the automatic Python wrappers let you test the
64
- binaries immediately from within SymPy.
65
-
66
- 4) To create customized ufuncs for use with numpy arrays.
67
- See *ufuncify*.
68
-
69
- When is this module NOT the best approach?
70
-
71
- 1) If you are really concerned about speed or memory optimizations,
72
- you will probably get better results by working directly with the
73
- wrapper tools and the low level code. However, the files generated
74
- by this utility may provide a useful starting point and reference
75
- code. Temporary files will be left intact if you supply the keyword
76
- tempdir="path/to/files/".
77
-
78
- 2) If the array computation can be handled easily by numpy, and you
79
- do not need the binaries for another project.
80
-
81
- """
82
-
83
- import sys
84
- import os
85
- import shutil
86
- import tempfile
87
- from pathlib import Path
88
- from subprocess import STDOUT, CalledProcessError, check_output
89
- from string import Template
90
- from warnings import warn
91
-
92
- from sympy.core.cache import cacheit
93
- from sympy.core.function import Lambda
94
- from sympy.core.relational import Eq
95
- from sympy.core.symbol import Dummy, Symbol
96
- from sympy.tensor.indexed import Idx, IndexedBase
97
- from sympy.utilities.codegen import (make_routine, get_code_generator,
98
- OutputArgument, InOutArgument,
99
- InputArgument, CodeGenArgumentListError,
100
- Result, ResultBase, C99CodeGen)
101
- from sympy.utilities.iterables import iterable
102
- from sympy.utilities.lambdify import implemented_function
103
- from sympy.utilities.decorator import doctest_depends_on
104
-
105
- _doctest_depends_on = {'exe': ('f2py', 'gfortran', 'gcc'),
106
- 'modules': ('numpy',)}
107
-
108
-
109
- class CodeWrapError(Exception):
110
- pass
111
-
112
-
113
- class CodeWrapper:
114
- """Base Class for code wrappers"""
115
- _filename = "wrapped_code"
116
- _module_basename = "wrapper_module"
117
- _module_counter = 0
118
-
119
- @property
120
- def filename(self):
121
- return "%s_%s" % (self._filename, CodeWrapper._module_counter)
122
-
123
- @property
124
- def module_name(self):
125
- return "%s_%s" % (self._module_basename, CodeWrapper._module_counter)
126
-
127
- def __init__(self, generator, filepath=None, flags=[], verbose=False):
128
- """
129
- generator -- the code generator to use
130
- """
131
- self.generator = generator
132
- self.filepath = filepath
133
- self.flags = flags
134
- self.quiet = not verbose
135
-
136
- @property
137
- def include_header(self):
138
- return bool(self.filepath)
139
-
140
- @property
141
- def include_empty(self):
142
- return bool(self.filepath)
143
-
144
- def _generate_code(self, main_routine, routines):
145
- routines.append(main_routine)
146
- self.generator.write(
147
- routines, self.filename, True, self.include_header,
148
- self.include_empty)
149
-
150
- def wrap_code(self, routine, helpers=None):
151
- helpers = helpers or []
152
- if self.filepath:
153
- workdir = os.path.abspath(self.filepath)
154
- else:
155
- workdir = tempfile.mkdtemp("_sympy_compile")
156
- if not os.access(workdir, os.F_OK):
157
- os.mkdir(workdir)
158
- oldwork = os.getcwd()
159
- os.chdir(workdir)
160
- try:
161
- sys.path.append(workdir)
162
- self._generate_code(routine, helpers)
163
- self._prepare_files(routine)
164
- self._process_files(routine)
165
- mod = __import__(self.module_name)
166
- finally:
167
- sys.path.remove(workdir)
168
- CodeWrapper._module_counter += 1
169
- os.chdir(oldwork)
170
- if not self.filepath:
171
- try:
172
- shutil.rmtree(workdir)
173
- except OSError:
174
- # Could be some issues on Windows
175
- pass
176
-
177
- return self._get_wrapped_function(mod, routine.name)
178
-
179
- def _process_files(self, routine):
180
- command = self.command
181
- command.extend(self.flags)
182
- try:
183
- retoutput = check_output(command, stderr=STDOUT)
184
- except CalledProcessError as e:
185
- raise CodeWrapError(
186
- "Error while executing command: %s. Command output is:\n%s" % (
187
- " ".join(command), e.output.decode('utf-8')))
188
- if not self.quiet:
189
- print(retoutput)
190
-
191
-
192
- class DummyWrapper(CodeWrapper):
193
- """Class used for testing independent of backends """
194
-
195
- template = """# dummy module for testing of SymPy
196
- def %(name)s():
197
- return "%(expr)s"
198
- %(name)s.args = "%(args)s"
199
- %(name)s.returns = "%(retvals)s"
200
- """
201
-
202
- def _prepare_files(self, routine):
203
- return
204
-
205
- def _generate_code(self, routine, helpers):
206
- with open('%s.py' % self.module_name, 'w') as f:
207
- printed = ", ".join(
208
- [str(res.expr) for res in routine.result_variables])
209
- # convert OutputArguments to return value like f2py
210
- args = filter(lambda x: not isinstance(
211
- x, OutputArgument), routine.arguments)
212
- retvals = []
213
- for val in routine.result_variables:
214
- if isinstance(val, Result):
215
- retvals.append('nameless')
216
- else:
217
- retvals.append(val.result_var)
218
-
219
- print(DummyWrapper.template % {
220
- 'name': routine.name,
221
- 'expr': printed,
222
- 'args': ", ".join([str(a.name) for a in args]),
223
- 'retvals': ", ".join([str(val) for val in retvals])
224
- }, end="", file=f)
225
-
226
- def _process_files(self, routine):
227
- return
228
-
229
- @classmethod
230
- def _get_wrapped_function(cls, mod, name):
231
- return getattr(mod, name)
232
-
233
-
234
- class CythonCodeWrapper(CodeWrapper):
235
- """Wrapper that uses Cython"""
236
-
237
- setup_template = """\
238
- from setuptools import setup
239
- from setuptools import Extension
240
- from Cython.Build import cythonize
241
- cy_opts = {cythonize_options}
242
- {np_import}
243
- ext_mods = [Extension(
244
- {ext_args},
245
- include_dirs={include_dirs},
246
- library_dirs={library_dirs},
247
- libraries={libraries},
248
- extra_compile_args={extra_compile_args},
249
- extra_link_args={extra_link_args}
250
- )]
251
- setup(ext_modules=cythonize(ext_mods, **cy_opts))
252
- """
253
-
254
- _cythonize_options = {'compiler_directives':{'language_level' : "3"}}
255
-
256
- pyx_imports = (
257
- "import numpy as np\n"
258
- "cimport numpy as np\n\n")
259
-
260
- pyx_header = (
261
- "cdef extern from '{header_file}.h':\n"
262
- " {prototype}\n\n")
263
-
264
- pyx_func = (
265
- "def {name}_c({arg_string}):\n"
266
- "\n"
267
- "{declarations}"
268
- "{body}")
269
-
270
- std_compile_flag = '-std=c99'
271
-
272
- def __init__(self, *args, **kwargs):
273
- """Instantiates a Cython code wrapper.
274
-
275
- The following optional parameters get passed to ``setuptools.Extension``
276
- for building the Python extension module. Read its documentation to
277
- learn more.
278
-
279
- Parameters
280
- ==========
281
- include_dirs : [list of strings]
282
- A list of directories to search for C/C++ header files (in Unix
283
- form for portability).
284
- library_dirs : [list of strings]
285
- A list of directories to search for C/C++ libraries at link time.
286
- libraries : [list of strings]
287
- A list of library names (not filenames or paths) to link against.
288
- extra_compile_args : [list of strings]
289
- Any extra platform- and compiler-specific information to use when
290
- compiling the source files in 'sources'. For platforms and
291
- compilers where "command line" makes sense, this is typically a
292
- list of command-line arguments, but for other platforms it could be
293
- anything. Note that the attribute ``std_compile_flag`` will be
294
- appended to this list.
295
- extra_link_args : [list of strings]
296
- Any extra platform- and compiler-specific information to use when
297
- linking object files together to create the extension (or to create
298
- a new static Python interpreter). Similar interpretation as for
299
- 'extra_compile_args'.
300
- cythonize_options : [dictionary]
301
- Keyword arguments passed on to cythonize.
302
-
303
- """
304
-
305
- self._include_dirs = kwargs.pop('include_dirs', [])
306
- self._library_dirs = kwargs.pop('library_dirs', [])
307
- self._libraries = kwargs.pop('libraries', [])
308
- self._extra_compile_args = kwargs.pop('extra_compile_args', [])
309
- self._extra_compile_args.append(self.std_compile_flag)
310
- self._extra_link_args = kwargs.pop('extra_link_args', [])
311
- self._cythonize_options = kwargs.pop('cythonize_options', self._cythonize_options)
312
-
313
- self._need_numpy = False
314
-
315
- super().__init__(*args, **kwargs)
316
-
317
- @property
318
- def command(self):
319
- command = [sys.executable, "setup.py", "build_ext", "--inplace"]
320
- return command
321
-
322
- def _prepare_files(self, routine, build_dir=os.curdir):
323
- # NOTE : build_dir is used for testing purposes.
324
- pyxfilename = self.module_name + '.pyx'
325
- codefilename = "%s.%s" % (self.filename, self.generator.code_extension)
326
-
327
- # pyx
328
- with open(os.path.join(build_dir, pyxfilename), 'w') as f:
329
- self.dump_pyx([routine], f, self.filename)
330
-
331
- # setup.py
332
- ext_args = [repr(self.module_name), repr([pyxfilename, codefilename])]
333
- if self._need_numpy:
334
- np_import = 'import numpy as np\n'
335
- self._include_dirs.append('np.get_include()')
336
- else:
337
- np_import = ''
338
-
339
- includes = str(self._include_dirs).replace("'np.get_include()'",
340
- 'np.get_include()')
341
- code = self.setup_template.format(
342
- ext_args=", ".join(ext_args),
343
- np_import=np_import,
344
- include_dirs=includes,
345
- library_dirs=self._library_dirs,
346
- libraries=self._libraries,
347
- extra_compile_args=self._extra_compile_args,
348
- extra_link_args=self._extra_link_args,
349
- cythonize_options=self._cythonize_options)
350
- Path(os.path.join(build_dir, 'setup.py')).write_text(code)
351
-
352
- @classmethod
353
- def _get_wrapped_function(cls, mod, name):
354
- return getattr(mod, name + '_c')
355
-
356
- def dump_pyx(self, routines, f, prefix):
357
- """Write a Cython file with Python wrappers
358
-
359
- This file contains all the definitions of the routines in c code and
360
- refers to the header file.
361
-
362
- Arguments
363
- ---------
364
- routines
365
- List of Routine instances
366
- f
367
- File-like object to write the file to
368
- prefix
369
- The filename prefix, used to refer to the proper header file.
370
- Only the basename of the prefix is used.
371
- """
372
- headers = []
373
- functions = []
374
- for routine in routines:
375
- prototype = self.generator.get_prototype(routine)
376
-
377
- # C Function Header Import
378
- headers.append(self.pyx_header.format(header_file=prefix,
379
- prototype=prototype))
380
-
381
- # Partition the C function arguments into categories
382
- py_rets, py_args, py_loc, py_inf = self._partition_args(routine.arguments)
383
-
384
- # Function prototype
385
- name = routine.name
386
- arg_string = ", ".join(self._prototype_arg(arg) for arg in py_args)
387
-
388
- # Local Declarations
389
- local_decs = []
390
- for arg, val in py_inf.items():
391
- proto = self._prototype_arg(arg)
392
- mat, ind = [self._string_var(v) for v in val]
393
- local_decs.append(" cdef {} = {}.shape[{}]".format(proto, mat, ind))
394
- local_decs.extend([" cdef {}".format(self._declare_arg(a)) for a in py_loc])
395
- declarations = "\n".join(local_decs)
396
- if declarations:
397
- declarations = declarations + "\n"
398
-
399
- # Function Body
400
- args_c = ", ".join([self._call_arg(a) for a in routine.arguments])
401
- rets = ", ".join([self._string_var(r.name) for r in py_rets])
402
- if routine.results:
403
- body = ' return %s(%s)' % (routine.name, args_c)
404
- if rets:
405
- body = body + ', ' + rets
406
- else:
407
- body = ' %s(%s)\n' % (routine.name, args_c)
408
- body = body + ' return ' + rets
409
-
410
- functions.append(self.pyx_func.format(name=name, arg_string=arg_string,
411
- declarations=declarations, body=body))
412
-
413
- # Write text to file
414
- if self._need_numpy:
415
- # Only import numpy if required
416
- f.write(self.pyx_imports)
417
- f.write('\n'.join(headers))
418
- f.write('\n'.join(functions))
419
-
420
- def _partition_args(self, args):
421
- """Group function arguments into categories."""
422
- py_args = []
423
- py_returns = []
424
- py_locals = []
425
- py_inferred = {}
426
- for arg in args:
427
- if isinstance(arg, OutputArgument):
428
- py_returns.append(arg)
429
- py_locals.append(arg)
430
- elif isinstance(arg, InOutArgument):
431
- py_returns.append(arg)
432
- py_args.append(arg)
433
- else:
434
- py_args.append(arg)
435
- # Find arguments that are array dimensions. These can be inferred
436
- # locally in the Cython code.
437
- if isinstance(arg, (InputArgument, InOutArgument)) and arg.dimensions:
438
- dims = [d[1] + 1 for d in arg.dimensions]
439
- sym_dims = [(i, d) for (i, d) in enumerate(dims) if
440
- isinstance(d, Symbol)]
441
- for (i, d) in sym_dims:
442
- py_inferred[d] = (arg.name, i)
443
- for arg in args:
444
- if arg.name in py_inferred:
445
- py_inferred[arg] = py_inferred.pop(arg.name)
446
- # Filter inferred arguments from py_args
447
- py_args = [a for a in py_args if a not in py_inferred]
448
- return py_returns, py_args, py_locals, py_inferred
449
-
450
- def _prototype_arg(self, arg):
451
- mat_dec = "np.ndarray[{mtype}, ndim={ndim}] {name}"
452
- np_types = {'double': 'np.double_t',
453
- 'int': 'np.int_t'}
454
- t = arg.get_datatype('c')
455
- if arg.dimensions:
456
- self._need_numpy = True
457
- ndim = len(arg.dimensions)
458
- mtype = np_types[t]
459
- return mat_dec.format(mtype=mtype, ndim=ndim, name=self._string_var(arg.name))
460
- else:
461
- return "%s %s" % (t, self._string_var(arg.name))
462
-
463
- def _declare_arg(self, arg):
464
- proto = self._prototype_arg(arg)
465
- if arg.dimensions:
466
- shape = '(' + ','.join(self._string_var(i[1] + 1) for i in arg.dimensions) + ')'
467
- return proto + " = np.empty({shape})".format(shape=shape)
468
- else:
469
- return proto + " = 0"
470
-
471
- def _call_arg(self, arg):
472
- if arg.dimensions:
473
- t = arg.get_datatype('c')
474
- return "<{}*> {}.data".format(t, self._string_var(arg.name))
475
- elif isinstance(arg, ResultBase):
476
- return "&{}".format(self._string_var(arg.name))
477
- else:
478
- return self._string_var(arg.name)
479
-
480
- def _string_var(self, var):
481
- printer = self.generator.printer.doprint
482
- return printer(var)
483
-
484
-
485
- class F2PyCodeWrapper(CodeWrapper):
486
- """Wrapper that uses f2py"""
487
-
488
- def __init__(self, *args, **kwargs):
489
-
490
- ext_keys = ['include_dirs', 'library_dirs', 'libraries',
491
- 'extra_compile_args', 'extra_link_args']
492
- msg = ('The compilation option kwarg {} is not supported with the f2py '
493
- 'backend.')
494
-
495
- for k in ext_keys:
496
- if k in kwargs.keys():
497
- warn(msg.format(k))
498
- kwargs.pop(k, None)
499
-
500
- super().__init__(*args, **kwargs)
501
-
502
- @property
503
- def command(self):
504
- filename = self.filename + '.' + self.generator.code_extension
505
- args = ['-c', '-m', self.module_name, filename]
506
- command = [sys.executable, "-c", "import numpy.f2py as f2py2e;f2py2e.main()"]+args
507
- return command
508
-
509
- def _prepare_files(self, routine):
510
- pass
511
-
512
- @classmethod
513
- def _get_wrapped_function(cls, mod, name):
514
- return getattr(mod, name)
515
-
516
-
517
- # Here we define a lookup of backends -> tuples of languages. For now, each
518
- # tuple is of length 1, but if a backend supports more than one language,
519
- # the most preferable language is listed first.
520
- _lang_lookup = {'CYTHON': ('C99', 'C89', 'C'),
521
- 'F2PY': ('F95',),
522
- 'NUMPY': ('C99', 'C89', 'C'),
523
- 'DUMMY': ('F95',)} # Dummy here just for testing
524
-
525
-
526
- def _infer_language(backend):
527
- """For a given backend, return the top choice of language"""
528
- langs = _lang_lookup.get(backend.upper(), False)
529
- if not langs:
530
- raise ValueError("Unrecognized backend: " + backend)
531
- return langs[0]
532
-
533
-
534
- def _validate_backend_language(backend, language):
535
- """Throws error if backend and language are incompatible"""
536
- langs = _lang_lookup.get(backend.upper(), False)
537
- if not langs:
538
- raise ValueError("Unrecognized backend: " + backend)
539
- if language.upper() not in langs:
540
- raise ValueError(("Backend {} and language {} are "
541
- "incompatible").format(backend, language))
542
-
543
-
544
- @cacheit
545
- @doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',))
546
- def autowrap(expr, language=None, backend='f2py', tempdir=None, args=None,
547
- flags=None, verbose=False, helpers=None, code_gen=None, **kwargs):
548
- """Generates Python callable binaries based on the math expression.
549
-
550
- Parameters
551
- ==========
552
-
553
- expr
554
- The SymPy expression that should be wrapped as a binary routine.
555
- language : string, optional
556
- If supplied, (options: 'C' or 'F95'), specifies the language of the
557
- generated code. If ``None`` [default], the language is inferred based
558
- upon the specified backend.
559
- backend : string, optional
560
- Backend used to wrap the generated code. Either 'f2py' [default],
561
- or 'cython'.
562
- tempdir : string, optional
563
- Path to directory for temporary files. If this argument is supplied,
564
- the generated code and the wrapper input files are left intact in the
565
- specified path.
566
- args : iterable, optional
567
- An ordered iterable of symbols. Specifies the argument sequence for the
568
- function.
569
- flags : iterable, optional
570
- Additional option flags that will be passed to the backend.
571
- verbose : bool, optional
572
- If True, autowrap will not mute the command line backends. This can be
573
- helpful for debugging.
574
- helpers : 3-tuple or iterable of 3-tuples, optional
575
- Used to define auxiliary functions needed for the main expression.
576
- Each tuple should be of the form (name, expr, args) where:
577
-
578
- - name : str, the function name
579
- - expr : sympy expression, the function
580
- - args : iterable, the function arguments (can be any iterable of symbols)
581
-
582
- code_gen : CodeGen instance
583
- An instance of a CodeGen subclass. Overrides ``language``.
584
- include_dirs : [string]
585
- A list of directories to search for C/C++ header files (in Unix form
586
- for portability).
587
- library_dirs : [string]
588
- A list of directories to search for C/C++ libraries at link time.
589
- libraries : [string]
590
- A list of library names (not filenames or paths) to link against.
591
- extra_compile_args : [string]
592
- Any extra platform- and compiler-specific information to use when
593
- compiling the source files in 'sources'. For platforms and compilers
594
- where "command line" makes sense, this is typically a list of
595
- command-line arguments, but for other platforms it could be anything.
596
- extra_link_args : [string]
597
- Any extra platform- and compiler-specific information to use when
598
- linking object files together to create the extension (or to create a
599
- new static Python interpreter). Similar interpretation as for
600
- 'extra_compile_args'.
601
-
602
- Examples
603
- ========
604
-
605
- Basic usage:
606
-
607
- >>> from sympy.abc import x, y, z
608
- >>> from sympy.utilities.autowrap import autowrap
609
- >>> expr = ((x - y + z)**(13)).expand()
610
- >>> binary_func = autowrap(expr)
611
- >>> binary_func(1, 4, 2)
612
- -1.0
613
-
614
- Using helper functions:
615
-
616
- >>> from sympy.abc import x, t
617
- >>> from sympy import Function
618
- >>> helper_func = Function('helper_func') # Define symbolic function
619
- >>> expr = 3*x + helper_func(t) # Main expression using helper function
620
- >>> # Define helper_func(x) = 4*x using f2py backend
621
- >>> binary_func = autowrap(expr, args=[x, t],
622
- ... helpers=('helper_func', 4*x, [x]))
623
- >>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20
624
- 26.0
625
- >>> # Same example using cython backend
626
- >>> binary_func = autowrap(expr, args=[x, t], backend='cython',
627
- ... helpers=[('helper_func', 4*x, [x])])
628
- >>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20
629
- 26.0
630
-
631
- Type handling example:
632
-
633
- >>> import numpy as np
634
- >>> expr = x + y
635
- >>> f_cython = autowrap(expr, backend='cython')
636
- >>> f_cython(1, 2) # doctest: +ELLIPSIS
637
- Traceback (most recent call last):
638
- ...
639
- TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int)
640
- >>> f_cython(np.array([1.0]), np.array([2.0]))
641
- array([ 3.])
642
-
643
- """
644
- if language:
645
- if not isinstance(language, type):
646
- _validate_backend_language(backend, language)
647
- else:
648
- language = _infer_language(backend)
649
-
650
- # two cases 1) helpers is an iterable of 3-tuples and 2) helpers is a
651
- # 3-tuple
652
- if iterable(helpers) and len(helpers) != 0 and iterable(helpers[0]):
653
- helpers = helpers if helpers else ()
654
- else:
655
- helpers = [helpers] if helpers else ()
656
- args = list(args) if iterable(args, exclude=set) else args
657
-
658
- if code_gen is None:
659
- code_gen = get_code_generator(language, "autowrap")
660
-
661
- CodeWrapperClass = {
662
- 'F2PY': F2PyCodeWrapper,
663
- 'CYTHON': CythonCodeWrapper,
664
- 'DUMMY': DummyWrapper
665
- }[backend.upper()]
666
- code_wrapper = CodeWrapperClass(code_gen, tempdir, flags if flags else (),
667
- verbose, **kwargs)
668
-
669
- helps = []
670
- for name_h, expr_h, args_h in helpers:
671
- helps.append(code_gen.routine(name_h, expr_h, args_h))
672
-
673
- for name_h, expr_h, args_h in helpers:
674
- if expr.has(expr_h):
675
- name_h = binary_function(name_h, expr_h, backend='dummy')
676
- expr = expr.subs(expr_h, name_h(*args_h))
677
- try:
678
- routine = code_gen.routine('autofunc', expr, args)
679
- except CodeGenArgumentListError as e:
680
- # if all missing arguments are for pure output, we simply attach them
681
- # at the end and try again, because the wrappers will silently convert
682
- # them to return values anyway.
683
- new_args = []
684
- for missing in e.missing_args:
685
- if not isinstance(missing, OutputArgument):
686
- raise
687
- new_args.append(missing.name)
688
- routine = code_gen.routine('autofunc', expr, args + new_args)
689
-
690
- return code_wrapper.wrap_code(routine, helpers=helps)
691
-
692
-
693
- @doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',))
694
- def binary_function(symfunc, expr, **kwargs):
695
- """Returns a SymPy function with expr as binary implementation
696
-
697
- This is a convenience function that automates the steps needed to
698
- autowrap the SymPy expression and attaching it to a Function object
699
- with implemented_function().
700
-
701
- Parameters
702
- ==========
703
-
704
- symfunc : SymPy Function
705
- The function to bind the callable to.
706
- expr : SymPy Expression
707
- The expression used to generate the function.
708
- kwargs : dict
709
- Any kwargs accepted by autowrap.
710
-
711
- Examples
712
- ========
713
-
714
- >>> from sympy.abc import x, y
715
- >>> from sympy.utilities.autowrap import binary_function
716
- >>> expr = ((x - y)**(25)).expand()
717
- >>> f = binary_function('f', expr)
718
- >>> type(f)
719
- <class 'sympy.core.function.UndefinedFunction'>
720
- >>> 2*f(x, y)
721
- 2*f(x, y)
722
- >>> f(x, y).evalf(2, subs={x: 1, y: 2})
723
- -1.0
724
-
725
- """
726
- binary = autowrap(expr, **kwargs)
727
- return implemented_function(symfunc, binary)
728
-
729
- #################################################################
730
- # UFUNCIFY #
731
- #################################################################
732
-
733
- _ufunc_top = Template("""\
734
- #include "Python.h"
735
- #include "math.h"
736
- #include "numpy/ndarraytypes.h"
737
- #include "numpy/ufuncobject.h"
738
- #include "numpy/halffloat.h"
739
- #include ${include_file}
740
-
741
- static PyMethodDef ${module}Methods[] = {
742
- {NULL, NULL, 0, NULL}
743
- };""")
744
-
745
- _ufunc_outcalls = Template("*((double *)out${outnum}) = ${funcname}(${call_args});")
746
-
747
- _ufunc_body = Template("""\
748
- #ifdef NPY_1_19_API_VERSION
749
- static void ${funcname}_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
750
- #else
751
- static void ${funcname}_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
752
- #endif
753
- {
754
- npy_intp i;
755
- npy_intp n = dimensions[0];
756
- ${declare_args}
757
- ${declare_steps}
758
- for (i = 0; i < n; i++) {
759
- ${outcalls}
760
- ${step_increments}
761
- }
762
- }
763
- PyUFuncGenericFunction ${funcname}_funcs[1] = {&${funcname}_ufunc};
764
- static char ${funcname}_types[${n_types}] = ${types}
765
- static void *${funcname}_data[1] = {NULL};""")
766
-
767
- _ufunc_bottom = Template("""\
768
- #if PY_VERSION_HEX >= 0x03000000
769
- static struct PyModuleDef moduledef = {
770
- PyModuleDef_HEAD_INIT,
771
- "${module}",
772
- NULL,
773
- -1,
774
- ${module}Methods,
775
- NULL,
776
- NULL,
777
- NULL,
778
- NULL
779
- };
780
-
781
- PyMODINIT_FUNC PyInit_${module}(void)
782
- {
783
- PyObject *m, *d;
784
- ${function_creation}
785
- m = PyModule_Create(&moduledef);
786
- if (!m) {
787
- return NULL;
788
- }
789
- import_array();
790
- import_umath();
791
- d = PyModule_GetDict(m);
792
- ${ufunc_init}
793
- return m;
794
- }
795
- #else
796
- PyMODINIT_FUNC init${module}(void)
797
- {
798
- PyObject *m, *d;
799
- ${function_creation}
800
- m = Py_InitModule("${module}", ${module}Methods);
801
- if (m == NULL) {
802
- return;
803
- }
804
- import_array();
805
- import_umath();
806
- d = PyModule_GetDict(m);
807
- ${ufunc_init}
808
- }
809
- #endif\
810
- """)
811
-
812
- _ufunc_init_form = Template("""\
813
- ufunc${ind} = PyUFunc_FromFuncAndData(${funcname}_funcs, ${funcname}_data, ${funcname}_types, 1, ${n_in}, ${n_out},
814
- PyUFunc_None, "${module}", ${docstring}, 0);
815
- PyDict_SetItemString(d, "${funcname}", ufunc${ind});
816
- Py_DECREF(ufunc${ind});""")
817
-
818
- _ufunc_setup = Template("""\
819
- from setuptools.extension import Extension
820
- from setuptools import setup
821
-
822
- from numpy import get_include
823
-
824
- if __name__ == "__main__":
825
- setup(ext_modules=[
826
- Extension('${module}',
827
- sources=['${module}.c', '${filename}.c'],
828
- include_dirs=[get_include()])])
829
- """)
830
-
831
-
832
- class UfuncifyCodeWrapper(CodeWrapper):
833
- """Wrapper for Ufuncify"""
834
-
835
- def __init__(self, *args, **kwargs):
836
-
837
- ext_keys = ['include_dirs', 'library_dirs', 'libraries',
838
- 'extra_compile_args', 'extra_link_args']
839
- msg = ('The compilation option kwarg {} is not supported with the numpy'
840
- ' backend.')
841
-
842
- for k in ext_keys:
843
- if k in kwargs.keys():
844
- warn(msg.format(k))
845
- kwargs.pop(k, None)
846
-
847
- super().__init__(*args, **kwargs)
848
-
849
- @property
850
- def command(self):
851
- command = [sys.executable, "setup.py", "build_ext", "--inplace"]
852
- return command
853
-
854
- def wrap_code(self, routines, helpers=None):
855
- # This routine overrides CodeWrapper because we can't assume funcname == routines[0].name
856
- # Therefore we have to break the CodeWrapper private API.
857
- # There isn't an obvious way to extend multi-expr support to
858
- # the other autowrap backends, so we limit this change to ufuncify.
859
- helpers = helpers if helpers is not None else []
860
- # We just need a consistent name
861
- funcname = 'wrapped_' + str(id(routines) + id(helpers))
862
-
863
- workdir = self.filepath or tempfile.mkdtemp("_sympy_compile")
864
- if not os.access(workdir, os.F_OK):
865
- os.mkdir(workdir)
866
- oldwork = os.getcwd()
867
- os.chdir(workdir)
868
- try:
869
- sys.path.append(workdir)
870
- self._generate_code(routines, helpers)
871
- self._prepare_files(routines, funcname)
872
- self._process_files(routines)
873
- mod = __import__(self.module_name)
874
- finally:
875
- sys.path.remove(workdir)
876
- CodeWrapper._module_counter += 1
877
- os.chdir(oldwork)
878
- if not self.filepath:
879
- try:
880
- shutil.rmtree(workdir)
881
- except OSError:
882
- # Could be some issues on Windows
883
- pass
884
-
885
- return self._get_wrapped_function(mod, funcname)
886
-
887
- def _generate_code(self, main_routines, helper_routines):
888
- all_routines = main_routines + helper_routines
889
- self.generator.write(
890
- all_routines, self.filename, True, self.include_header,
891
- self.include_empty)
892
-
893
- def _prepare_files(self, routines, funcname):
894
-
895
- # C
896
- codefilename = self.module_name + '.c'
897
- with open(codefilename, 'w') as f:
898
- self.dump_c(routines, f, self.filename, funcname=funcname)
899
-
900
- # setup.py
901
- with open('setup.py', 'w') as f:
902
- self.dump_setup(f)
903
-
904
- @classmethod
905
- def _get_wrapped_function(cls, mod, name):
906
- return getattr(mod, name)
907
-
908
- def dump_setup(self, f):
909
- setup = _ufunc_setup.substitute(module=self.module_name,
910
- filename=self.filename)
911
- f.write(setup)
912
-
913
- def dump_c(self, routines, f, prefix, funcname=None):
914
- """Write a C file with Python wrappers
915
-
916
- This file contains all the definitions of the routines in c code.
917
-
918
- Arguments
919
- ---------
920
- routines
921
- List of Routine instances
922
- f
923
- File-like object to write the file to
924
- prefix
925
- The filename prefix, used to name the imported module.
926
- funcname
927
- Name of the main function to be returned.
928
- """
929
- if funcname is None:
930
- if len(routines) == 1:
931
- funcname = routines[0].name
932
- else:
933
- msg = 'funcname must be specified for multiple output routines'
934
- raise ValueError(msg)
935
- functions = []
936
- function_creation = []
937
- ufunc_init = []
938
- module = self.module_name
939
- include_file = "\"{}.h\"".format(prefix)
940
- top = _ufunc_top.substitute(include_file=include_file, module=module)
941
-
942
- name = funcname
943
-
944
- # Partition the C function arguments into categories
945
- # Here we assume all routines accept the same arguments
946
- r_index = 0
947
- py_in, _ = self._partition_args(routines[0].arguments)
948
- n_in = len(py_in)
949
- n_out = len(routines)
950
-
951
- # Declare Args
952
- form = "char *{0}{1} = args[{2}];"
953
- arg_decs = [form.format('in', i, i) for i in range(n_in)]
954
- arg_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)])
955
- declare_args = '\n '.join(arg_decs)
956
-
957
- # Declare Steps
958
- form = "npy_intp {0}{1}_step = steps[{2}];"
959
- step_decs = [form.format('in', i, i) for i in range(n_in)]
960
- step_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)])
961
- declare_steps = '\n '.join(step_decs)
962
-
963
- # Call Args
964
- form = "*(double *)in{0}"
965
- call_args = ', '.join([form.format(a) for a in range(n_in)])
966
-
967
- # Step Increments
968
- form = "{0}{1} += {0}{1}_step;"
969
- step_incs = [form.format('in', i) for i in range(n_in)]
970
- step_incs.extend([form.format('out', i, i) for i in range(n_out)])
971
- step_increments = '\n '.join(step_incs)
972
-
973
- # Types
974
- n_types = n_in + n_out
975
- types = "{" + ', '.join(["NPY_DOUBLE"]*n_types) + "};"
976
-
977
- # Docstring
978
- docstring = '"Created in SymPy with Ufuncify"'
979
-
980
- # Function Creation
981
- function_creation.append("PyObject *ufunc{};".format(r_index))
982
-
983
- # Ufunc initialization
984
- init_form = _ufunc_init_form.substitute(module=module,
985
- funcname=name,
986
- docstring=docstring,
987
- n_in=n_in, n_out=n_out,
988
- ind=r_index)
989
- ufunc_init.append(init_form)
990
-
991
- outcalls = [_ufunc_outcalls.substitute(
992
- outnum=i, call_args=call_args, funcname=routines[i].name) for i in
993
- range(n_out)]
994
-
995
- body = _ufunc_body.substitute(module=module, funcname=name,
996
- declare_args=declare_args,
997
- declare_steps=declare_steps,
998
- call_args=call_args,
999
- step_increments=step_increments,
1000
- n_types=n_types, types=types,
1001
- outcalls='\n '.join(outcalls))
1002
- functions.append(body)
1003
-
1004
- body = '\n\n'.join(functions)
1005
- ufunc_init = '\n '.join(ufunc_init)
1006
- function_creation = '\n '.join(function_creation)
1007
- bottom = _ufunc_bottom.substitute(module=module,
1008
- ufunc_init=ufunc_init,
1009
- function_creation=function_creation)
1010
- text = [top, body, bottom]
1011
- f.write('\n\n'.join(text))
1012
-
1013
- def _partition_args(self, args):
1014
- """Group function arguments into categories."""
1015
- py_in = []
1016
- py_out = []
1017
- for arg in args:
1018
- if isinstance(arg, OutputArgument):
1019
- py_out.append(arg)
1020
- elif isinstance(arg, InOutArgument):
1021
- raise ValueError("Ufuncify doesn't support InOutArguments")
1022
- else:
1023
- py_in.append(arg)
1024
- return py_in, py_out
1025
-
1026
-
1027
- @cacheit
1028
- @doctest_depends_on(exe=('f2py', 'gfortran', 'gcc'), modules=('numpy',))
1029
- def ufuncify(args, expr, language=None, backend='numpy', tempdir=None,
1030
- flags=None, verbose=False, helpers=None, **kwargs):
1031
- """Generates a binary function that supports broadcasting on numpy arrays.
1032
-
1033
- Parameters
1034
- ==========
1035
-
1036
- args : iterable
1037
- Either a Symbol or an iterable of symbols. Specifies the argument
1038
- sequence for the function.
1039
- expr
1040
- A SymPy expression that defines the element wise operation.
1041
- language : string, optional
1042
- If supplied, (options: 'C' or 'F95'), specifies the language of the
1043
- generated code. If ``None`` [default], the language is inferred based
1044
- upon the specified backend.
1045
- backend : string, optional
1046
- Backend used to wrap the generated code. Either 'numpy' [default],
1047
- 'cython', or 'f2py'.
1048
- tempdir : string, optional
1049
- Path to directory for temporary files. If this argument is supplied,
1050
- the generated code and the wrapper input files are left intact in
1051
- the specified path.
1052
- flags : iterable, optional
1053
- Additional option flags that will be passed to the backend.
1054
- verbose : bool, optional
1055
- If True, autowrap will not mute the command line backends. This can
1056
- be helpful for debugging.
1057
- helpers : 3-tuple or iterable of 3-tuples, optional
1058
- Used to define auxiliary functions needed for the main expression.
1059
- Each tuple should be of the form (name, expr, args) where:
1060
-
1061
- - name : str, the function name
1062
- - expr : sympy expression, the function
1063
- - args : iterable, the function arguments (can be any iterable of symbols)
1064
-
1065
- kwargs : dict
1066
- These kwargs will be passed to autowrap if the `f2py` or `cython`
1067
- backend is used and ignored if the `numpy` backend is used.
1068
-
1069
- Notes
1070
- =====
1071
-
1072
- The default backend ('numpy') will create actual instances of
1073
- ``numpy.ufunc``. These support ndimensional broadcasting, and implicit type
1074
- conversion. Use of the other backends will result in a "ufunc-like"
1075
- function, which requires equal length 1-dimensional arrays for all
1076
- arguments, and will not perform any type conversions.
1077
-
1078
- References
1079
- ==========
1080
-
1081
- .. [1] https://numpy.org/doc/stable/reference/ufuncs.html
1082
-
1083
- Examples
1084
- ========
1085
-
1086
- Basic usage:
1087
-
1088
- >>> from sympy.utilities.autowrap import ufuncify
1089
- >>> from sympy.abc import x, y
1090
- >>> import numpy as np
1091
- >>> f = ufuncify((x, y), y + x**2)
1092
- >>> type(f)
1093
- <class 'numpy.ufunc'>
1094
- >>> f([1, 2, 3], 2)
1095
- array([ 3., 6., 11.])
1096
- >>> f(np.arange(5), 3)
1097
- array([ 3., 4., 7., 12., 19.])
1098
-
1099
- Using helper functions:
1100
-
1101
- >>> from sympy import Function
1102
- >>> helper_func = Function('helper_func') # Define symbolic function
1103
- >>> expr = x**2 + y*helper_func(x) # Main expression using helper function
1104
- >>> # Define helper_func(x) = x**3
1105
- >>> f = ufuncify((x, y), expr, helpers=[('helper_func', x**3, [x])])
1106
- >>> f([1, 2], [3, 4])
1107
- array([ 4., 36.])
1108
-
1109
- Type handling with different backends:
1110
-
1111
- For the 'f2py' and 'cython' backends, inputs are required to be equal length
1112
- 1-dimensional arrays. The 'f2py' backend will perform type conversion, but
1113
- the Cython backend will error if the inputs are not of the expected type.
1114
-
1115
- >>> f_fortran = ufuncify((x, y), y + x**2, backend='f2py')
1116
- >>> f_fortran(1, 2)
1117
- array([ 3.])
1118
- >>> f_fortran(np.array([1, 2, 3]), np.array([1.0, 2.0, 3.0]))
1119
- array([ 2., 6., 12.])
1120
- >>> f_cython = ufuncify((x, y), y + x**2, backend='Cython')
1121
- >>> f_cython(1, 2) # doctest: +ELLIPSIS
1122
- Traceback (most recent call last):
1123
- ...
1124
- TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int)
1125
- >>> f_cython(np.array([1.0]), np.array([2.0]))
1126
- array([ 3.])
1127
-
1128
- """
1129
-
1130
- if isinstance(args, Symbol):
1131
- args = (args,)
1132
- else:
1133
- args = tuple(args)
1134
-
1135
- if language:
1136
- _validate_backend_language(backend, language)
1137
- else:
1138
- language = _infer_language(backend)
1139
-
1140
- helpers = helpers if helpers else ()
1141
- flags = flags if flags else ()
1142
-
1143
- if backend.upper() == 'NUMPY':
1144
- # maxargs is set by numpy compile-time constant NPY_MAXARGS
1145
- # If a future version of numpy modifies or removes this restriction
1146
- # this variable should be changed or removed
1147
- maxargs = 32
1148
- helps = []
1149
- for name, expr, args in helpers:
1150
- helps.append(make_routine(name, expr, args))
1151
- code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"), tempdir,
1152
- flags, verbose)
1153
- if not isinstance(expr, (list, tuple)):
1154
- expr = [expr]
1155
- if len(expr) == 0:
1156
- raise ValueError('Expression iterable has zero length')
1157
- if len(expr) + len(args) > maxargs:
1158
- msg = ('Cannot create ufunc with more than {0} total arguments: '
1159
- 'got {1} in, {2} out')
1160
- raise ValueError(msg.format(maxargs, len(args), len(expr)))
1161
- routines = [make_routine('autofunc{}'.format(idx), exprx, args) for
1162
- idx, exprx in enumerate(expr)]
1163
- return code_wrapper.wrap_code(routines, helpers=helps)
1164
- else:
1165
- # Dummies are used for all added expressions to prevent name clashes
1166
- # within the original expression.
1167
- y = IndexedBase(Dummy('y'))
1168
- m = Dummy('m', integer=True)
1169
- i = Idx(Dummy('i', integer=True), m)
1170
- f_dummy = Dummy('f')
1171
- f = implemented_function('%s_%d' % (f_dummy.name, f_dummy.dummy_index), Lambda(args, expr))
1172
- # For each of the args create an indexed version.
1173
- indexed_args = [IndexedBase(Dummy(str(a))) for a in args]
1174
- # Order the arguments (out, args, dim)
1175
- args = [y] + indexed_args + [m]
1176
- args_with_indices = [a[i] for a in indexed_args]
1177
- return autowrap(Eq(y[i], f(*args_with_indices)), language, backend,
1178
- tempdir, args, flags, verbose, helpers, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/codegen.py DELETED
@@ -1,2237 +0,0 @@
1
- """
2
- module for generating C, C++, Fortran77, Fortran90, Julia, Rust
3
- and Octave/Matlab routines that evaluate SymPy expressions.
4
- This module is work in progress.
5
- Only the milestones with a '+' character in the list below have been completed.
6
-
7
- --- How is sympy.utilities.codegen different from sympy.printing.ccode? ---
8
-
9
- We considered the idea to extend the printing routines for SymPy functions in
10
- such a way that it prints complete compilable code, but this leads to a few
11
- unsurmountable issues that can only be tackled with dedicated code generator:
12
-
13
- - For C, one needs both a code and a header file, while the printing routines
14
- generate just one string. This code generator can be extended to support
15
- .pyf files for f2py.
16
-
17
- - SymPy functions are not concerned with programming-technical issues, such
18
- as input, output and input-output arguments. Other examples are contiguous
19
- or non-contiguous arrays, including headers of other libraries such as gsl
20
- or others.
21
-
22
- - It is highly interesting to evaluate several SymPy functions in one C
23
- routine, eventually sharing common intermediate results with the help
24
- of the cse routine. This is more than just printing.
25
-
26
- - From the programming perspective, expressions with constants should be
27
- evaluated in the code generator as much as possible. This is different
28
- for printing.
29
-
30
- --- Basic assumptions ---
31
-
32
- * A generic Routine data structure describes the routine that must be
33
- translated into C/Fortran/... code. This data structure covers all
34
- features present in one or more of the supported languages.
35
-
36
- * Descendants from the CodeGen class transform multiple Routine instances
37
- into compilable code. Each derived class translates into a specific
38
- language.
39
-
40
- * In many cases, one wants a simple workflow. The friendly functions in the
41
- last part are a simple api on top of the Routine/CodeGen stuff. They are
42
- easier to use, but are less powerful.
43
-
44
- --- Milestones ---
45
-
46
- + First working version with scalar input arguments, generating C code,
47
- tests
48
- + Friendly functions that are easier to use than the rigorous
49
- Routine/CodeGen workflow.
50
- + Integer and Real numbers as input and output
51
- + Output arguments
52
- + InputOutput arguments
53
- + Sort input/output arguments properly
54
- + Contiguous array arguments (numpy matrices)
55
- + Also generate .pyf code for f2py (in autowrap module)
56
- + Isolate constants and evaluate them beforehand in double precision
57
- + Fortran 90
58
- + Octave/Matlab
59
-
60
- - Common Subexpression Elimination
61
- - User defined comments in the generated code
62
- - Optional extra include lines for libraries/objects that can eval special
63
- functions
64
- - Test other C compilers and libraries: gcc, tcc, libtcc, gcc+gsl, ...
65
- - Contiguous array arguments (SymPy matrices)
66
- - Non-contiguous array arguments (SymPy matrices)
67
- - ccode must raise an error when it encounters something that cannot be
68
- translated into c. ccode(integrate(sin(x)/x, x)) does not make sense.
69
- - Complex numbers as input and output
70
- - A default complex datatype
71
- - Include extra information in the header: date, user, hostname, sha1
72
- hash, ...
73
- - Fortran 77
74
- - C++
75
- - Python
76
- - Julia
77
- - Rust
78
- - ...
79
-
80
- """
81
-
82
- import os
83
- import textwrap
84
- from io import StringIO
85
-
86
- from sympy import __version__ as sympy_version
87
- from sympy.core import Symbol, S, Tuple, Equality, Function, Basic
88
- from sympy.printing.c import c_code_printers
89
- from sympy.printing.codeprinter import AssignmentError
90
- from sympy.printing.fortran import FCodePrinter
91
- from sympy.printing.julia import JuliaCodePrinter
92
- from sympy.printing.octave import OctaveCodePrinter
93
- from sympy.printing.rust import RustCodePrinter
94
- from sympy.tensor import Idx, Indexed, IndexedBase
95
- from sympy.matrices import (MatrixSymbol, ImmutableMatrix, MatrixBase,
96
- MatrixExpr, MatrixSlice)
97
- from sympy.utilities.iterables import is_sequence
98
-
99
-
100
- __all__ = [
101
- # description of routines
102
- "Routine", "DataType", "default_datatypes", "get_default_datatype",
103
- "Argument", "InputArgument", "OutputArgument", "Result",
104
- # routines -> code
105
- "CodeGen", "CCodeGen", "FCodeGen", "JuliaCodeGen", "OctaveCodeGen",
106
- "RustCodeGen",
107
- # friendly functions
108
- "codegen", "make_routine",
109
- ]
110
-
111
-
112
- #
113
- # Description of routines
114
- #
115
-
116
-
117
- class Routine:
118
- """Generic description of evaluation routine for set of expressions.
119
-
120
- A CodeGen class can translate instances of this class into code in a
121
- particular language. The routine specification covers all the features
122
- present in these languages. The CodeGen part must raise an exception
123
- when certain features are not present in the target language. For
124
- example, multiple return values are possible in Python, but not in C or
125
- Fortran. Another example: Fortran and Python support complex numbers,
126
- while C does not.
127
-
128
- """
129
-
130
- def __init__(self, name, arguments, results, local_vars, global_vars):
131
- """Initialize a Routine instance.
132
-
133
- Parameters
134
- ==========
135
-
136
- name : string
137
- Name of the routine.
138
-
139
- arguments : list of Arguments
140
- These are things that appear in arguments of a routine, often
141
- appearing on the right-hand side of a function call. These are
142
- commonly InputArguments but in some languages, they can also be
143
- OutputArguments or InOutArguments (e.g., pass-by-reference in C
144
- code).
145
-
146
- results : list of Results
147
- These are the return values of the routine, often appearing on
148
- the left-hand side of a function call. The difference between
149
- Results and OutputArguments and when you should use each is
150
- language-specific.
151
-
152
- local_vars : list of Results
153
- These are variables that will be defined at the beginning of the
154
- function.
155
-
156
- global_vars : list of Symbols
157
- Variables which will not be passed into the function.
158
-
159
- """
160
-
161
- # extract all input symbols and all symbols appearing in an expression
162
- input_symbols = set()
163
- symbols = set()
164
- for arg in arguments:
165
- if isinstance(arg, OutputArgument):
166
- symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed))
167
- elif isinstance(arg, InputArgument):
168
- input_symbols.add(arg.name)
169
- elif isinstance(arg, InOutArgument):
170
- input_symbols.add(arg.name)
171
- symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed))
172
- else:
173
- raise ValueError("Unknown Routine argument: %s" % arg)
174
-
175
- for r in results:
176
- if not isinstance(r, Result):
177
- raise ValueError("Unknown Routine result: %s" % r)
178
- symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed))
179
-
180
- local_symbols = set()
181
- for r in local_vars:
182
- if isinstance(r, Result):
183
- symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed))
184
- local_symbols.add(r.name)
185
- else:
186
- local_symbols.add(r)
187
-
188
- symbols = {s.label if isinstance(s, Idx) else s for s in symbols}
189
-
190
- # Check that all symbols in the expressions are covered by
191
- # InputArguments/InOutArguments---subset because user could
192
- # specify additional (unused) InputArguments or local_vars.
193
- notcovered = symbols.difference(
194
- input_symbols.union(local_symbols).union(global_vars))
195
- if notcovered != set():
196
- raise ValueError("Symbols needed for output are not in input " +
197
- ", ".join([str(x) for x in notcovered]))
198
-
199
- self.name = name
200
- self.arguments = arguments
201
- self.results = results
202
- self.local_vars = local_vars
203
- self.global_vars = global_vars
204
-
205
- def __str__(self):
206
- return self.__class__.__name__ + "({name!r}, {arguments}, {results}, {local_vars}, {global_vars})".format(**self.__dict__)
207
-
208
- __repr__ = __str__
209
-
210
- @property
211
- def variables(self):
212
- """Returns a set of all variables possibly used in the routine.
213
-
214
- For routines with unnamed return values, the dummies that may or
215
- may not be used will be included in the set.
216
-
217
- """
218
- v = set(self.local_vars)
219
- v.update(arg.name for arg in self.arguments)
220
- v.update(res.result_var for res in self.results)
221
- return v
222
-
223
- @property
224
- def result_variables(self):
225
- """Returns a list of OutputArgument, InOutArgument and Result.
226
-
227
- If return values are present, they are at the end of the list.
228
- """
229
- args = [arg for arg in self.arguments if isinstance(
230
- arg, (OutputArgument, InOutArgument))]
231
- args.extend(self.results)
232
- return args
233
-
234
-
235
- class DataType:
236
- """Holds strings for a certain datatype in different languages."""
237
- def __init__(self, cname, fname, pyname, jlname, octname, rsname):
238
- self.cname = cname
239
- self.fname = fname
240
- self.pyname = pyname
241
- self.jlname = jlname
242
- self.octname = octname
243
- self.rsname = rsname
244
-
245
-
246
- default_datatypes = {
247
- "int": DataType("int", "INTEGER*4", "int", "", "", "i32"),
248
- "float": DataType("double", "REAL*8", "float", "", "", "f64"),
249
- "complex": DataType("double", "COMPLEX*16", "complex", "", "", "float") #FIXME:
250
- # complex is only supported in fortran, python, julia, and octave.
251
- # So to not break c or rust code generation, we stick with double or
252
- # float, respectively (but actually should raise an exception for
253
- # explicitly complex variables (x.is_complex==True))
254
- }
255
-
256
-
257
- COMPLEX_ALLOWED = False
258
- def get_default_datatype(expr, complex_allowed=None):
259
- """Derives an appropriate datatype based on the expression."""
260
- if complex_allowed is None:
261
- complex_allowed = COMPLEX_ALLOWED
262
- if complex_allowed:
263
- final_dtype = "complex"
264
- else:
265
- final_dtype = "float"
266
- if expr.is_integer:
267
- return default_datatypes["int"]
268
- elif expr.is_real:
269
- return default_datatypes["float"]
270
- elif isinstance(expr, MatrixBase):
271
- #check all entries
272
- dt = "int"
273
- for element in expr:
274
- if dt == "int" and not element.is_integer:
275
- dt = "float"
276
- if dt == "float" and not element.is_real:
277
- return default_datatypes[final_dtype]
278
- return default_datatypes[dt]
279
- else:
280
- return default_datatypes[final_dtype]
281
-
282
-
283
- class Variable:
284
- """Represents a typed variable."""
285
-
286
- def __init__(self, name, datatype=None, dimensions=None, precision=None):
287
- """Return a new variable.
288
-
289
- Parameters
290
- ==========
291
-
292
- name : Symbol or MatrixSymbol
293
-
294
- datatype : optional
295
- When not given, the data type will be guessed based on the
296
- assumptions on the symbol argument.
297
-
298
- dimensions : sequence containing tuples, optional
299
- If present, the argument is interpreted as an array, where this
300
- sequence of tuples specifies (lower, upper) bounds for each
301
- index of the array.
302
-
303
- precision : int, optional
304
- Controls the precision of floating point constants.
305
-
306
- """
307
- if not isinstance(name, (Symbol, MatrixSymbol)):
308
- raise TypeError("The first argument must be a SymPy symbol.")
309
- if datatype is None:
310
- datatype = get_default_datatype(name)
311
- elif not isinstance(datatype, DataType):
312
- raise TypeError("The (optional) `datatype' argument must be an "
313
- "instance of the DataType class.")
314
- if dimensions and not isinstance(dimensions, (tuple, list)):
315
- raise TypeError(
316
- "The dimensions argument must be a sequence of tuples")
317
-
318
- self._name = name
319
- self._datatype = {
320
- 'C': datatype.cname,
321
- 'FORTRAN': datatype.fname,
322
- 'JULIA': datatype.jlname,
323
- 'OCTAVE': datatype.octname,
324
- 'PYTHON': datatype.pyname,
325
- 'RUST': datatype.rsname,
326
- }
327
- self.dimensions = dimensions
328
- self.precision = precision
329
-
330
- def __str__(self):
331
- return "%s(%r)" % (self.__class__.__name__, self.name)
332
-
333
- __repr__ = __str__
334
-
335
- @property
336
- def name(self):
337
- return self._name
338
-
339
- def get_datatype(self, language):
340
- """Returns the datatype string for the requested language.
341
-
342
- Examples
343
- ========
344
-
345
- >>> from sympy import Symbol
346
- >>> from sympy.utilities.codegen import Variable
347
- >>> x = Variable(Symbol('x'))
348
- >>> x.get_datatype('c')
349
- 'double'
350
- >>> x.get_datatype('fortran')
351
- 'REAL*8'
352
-
353
- """
354
- try:
355
- return self._datatype[language.upper()]
356
- except KeyError:
357
- raise CodeGenError("Has datatypes for languages: %s" %
358
- ", ".join(self._datatype))
359
-
360
-
361
- class Argument(Variable):
362
- """An abstract Argument data structure: a name and a data type.
363
-
364
- This structure is refined in the descendants below.
365
-
366
- """
367
- pass
368
-
369
-
370
- class InputArgument(Argument):
371
- pass
372
-
373
-
374
- class ResultBase:
375
- """Base class for all "outgoing" information from a routine.
376
-
377
- Objects of this class stores a SymPy expression, and a SymPy object
378
- representing a result variable that will be used in the generated code
379
- only if necessary.
380
-
381
- """
382
- def __init__(self, expr, result_var):
383
- self.expr = expr
384
- self.result_var = result_var
385
-
386
- def __str__(self):
387
- return "%s(%r, %r)" % (self.__class__.__name__, self.expr,
388
- self.result_var)
389
-
390
- __repr__ = __str__
391
-
392
-
393
- class OutputArgument(Argument, ResultBase):
394
- """OutputArgument are always initialized in the routine."""
395
-
396
- def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None):
397
- """Return a new variable.
398
-
399
- Parameters
400
- ==========
401
-
402
- name : Symbol, MatrixSymbol
403
- The name of this variable. When used for code generation, this
404
- might appear, for example, in the prototype of function in the
405
- argument list.
406
-
407
- result_var : Symbol, Indexed
408
- Something that can be used to assign a value to this variable.
409
- Typically the same as `name` but for Indexed this should be e.g.,
410
- "y[i]" whereas `name` should be the Symbol "y".
411
-
412
- expr : object
413
- The expression that should be output, typically a SymPy
414
- expression.
415
-
416
- datatype : optional
417
- When not given, the data type will be guessed based on the
418
- assumptions on the symbol argument.
419
-
420
- dimensions : sequence containing tuples, optional
421
- If present, the argument is interpreted as an array, where this
422
- sequence of tuples specifies (lower, upper) bounds for each
423
- index of the array.
424
-
425
- precision : int, optional
426
- Controls the precision of floating point constants.
427
-
428
- """
429
-
430
- Argument.__init__(self, name, datatype, dimensions, precision)
431
- ResultBase.__init__(self, expr, result_var)
432
-
433
- def __str__(self):
434
- return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.result_var, self.expr)
435
-
436
- __repr__ = __str__
437
-
438
-
439
- class InOutArgument(Argument, ResultBase):
440
- """InOutArgument are never initialized in the routine."""
441
-
442
- def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None):
443
- if not datatype:
444
- datatype = get_default_datatype(expr)
445
- Argument.__init__(self, name, datatype, dimensions, precision)
446
- ResultBase.__init__(self, expr, result_var)
447
- __init__.__doc__ = OutputArgument.__init__.__doc__
448
-
449
-
450
- def __str__(self):
451
- return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.expr,
452
- self.result_var)
453
-
454
- __repr__ = __str__
455
-
456
-
457
- class Result(Variable, ResultBase):
458
- """An expression for a return value.
459
-
460
- The name result is used to avoid conflicts with the reserved word
461
- "return" in the Python language. It is also shorter than ReturnValue.
462
-
463
- These may or may not need a name in the destination (e.g., "return(x*y)"
464
- might return a value without ever naming it).
465
-
466
- """
467
-
468
- def __init__(self, expr, name=None, result_var=None, datatype=None,
469
- dimensions=None, precision=None):
470
- """Initialize a return value.
471
-
472
- Parameters
473
- ==========
474
-
475
- expr : SymPy expression
476
-
477
- name : Symbol, MatrixSymbol, optional
478
- The name of this return variable. When used for code generation,
479
- this might appear, for example, in the prototype of function in a
480
- list of return values. A dummy name is generated if omitted.
481
-
482
- result_var : Symbol, Indexed, optional
483
- Something that can be used to assign a value to this variable.
484
- Typically the same as `name` but for Indexed this should be e.g.,
485
- "y[i]" whereas `name` should be the Symbol "y". Defaults to
486
- `name` if omitted.
487
-
488
- datatype : optional
489
- When not given, the data type will be guessed based on the
490
- assumptions on the expr argument.
491
-
492
- dimensions : sequence containing tuples, optional
493
- If present, this variable is interpreted as an array,
494
- where this sequence of tuples specifies (lower, upper)
495
- bounds for each index of the array.
496
-
497
- precision : int, optional
498
- Controls the precision of floating point constants.
499
-
500
- """
501
- # Basic because it is the base class for all types of expressions
502
- if not isinstance(expr, (Basic, MatrixBase)):
503
- raise TypeError("The first argument must be a SymPy expression.")
504
-
505
- if name is None:
506
- name = 'result_%d' % abs(hash(expr))
507
-
508
- if datatype is None:
509
- #try to infer data type from the expression
510
- datatype = get_default_datatype(expr)
511
-
512
- if isinstance(name, str):
513
- if isinstance(expr, (MatrixBase, MatrixExpr)):
514
- name = MatrixSymbol(name, *expr.shape)
515
- else:
516
- name = Symbol(name)
517
-
518
- if result_var is None:
519
- result_var = name
520
-
521
- Variable.__init__(self, name, datatype=datatype,
522
- dimensions=dimensions, precision=precision)
523
- ResultBase.__init__(self, expr, result_var)
524
-
525
- def __str__(self):
526
- return "%s(%r, %r, %r)" % (self.__class__.__name__, self.expr, self.name,
527
- self.result_var)
528
-
529
- __repr__ = __str__
530
-
531
-
532
- #
533
- # Transformation of routine objects into code
534
- #
535
-
536
- class CodeGen:
537
- """Abstract class for the code generators."""
538
-
539
- printer = None # will be set to an instance of a CodePrinter subclass
540
-
541
- def _indent_code(self, codelines):
542
- return self.printer.indent_code(codelines)
543
-
544
- def _printer_method_with_settings(self, method, settings=None, *args, **kwargs):
545
- settings = settings or {}
546
- ori = {k: self.printer._settings[k] for k in settings}
547
- for k, v in settings.items():
548
- self.printer._settings[k] = v
549
- result = getattr(self.printer, method)(*args, **kwargs)
550
- for k, v in ori.items():
551
- self.printer._settings[k] = v
552
- return result
553
-
554
- def _get_symbol(self, s):
555
- """Returns the symbol as fcode prints it."""
556
- if self.printer._settings['human']:
557
- expr_str = self.printer.doprint(s)
558
- else:
559
- constants, not_supported, expr_str = self.printer.doprint(s)
560
- if constants or not_supported:
561
- raise ValueError("Failed to print %s" % str(s))
562
- return expr_str.strip()
563
-
564
- def __init__(self, project="project", cse=False):
565
- """Initialize a code generator.
566
-
567
- Derived classes will offer more options that affect the generated
568
- code.
569
-
570
- """
571
- self.project = project
572
- self.cse = cse
573
-
574
- def routine(self, name, expr, argument_sequence=None, global_vars=None):
575
- """Creates an Routine object that is appropriate for this language.
576
-
577
- This implementation is appropriate for at least C/Fortran. Subclasses
578
- can override this if necessary.
579
-
580
- Here, we assume at most one return value (the l-value) which must be
581
- scalar. Additional outputs are OutputArguments (e.g., pointers on
582
- right-hand-side or pass-by-reference). Matrices are always returned
583
- via OutputArguments. If ``argument_sequence`` is None, arguments will
584
- be ordered alphabetically, but with all InputArguments first, and then
585
- OutputArgument and InOutArguments.
586
-
587
- """
588
-
589
- if self.cse:
590
- from sympy.simplify.cse_main import cse
591
-
592
- if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
593
- if not expr:
594
- raise ValueError("No expression given")
595
- for e in expr:
596
- if not e.is_Equality:
597
- raise CodeGenError("Lists of expressions must all be Equalities. {} is not.".format(e))
598
-
599
- # create a list of right hand sides and simplify them
600
- rhs = [e.rhs for e in expr]
601
- common, simplified = cse(rhs)
602
-
603
- # pack the simplified expressions back up with their left hand sides
604
- expr = [Equality(e.lhs, rhs) for e, rhs in zip(expr, simplified)]
605
- else:
606
- if isinstance(expr, Equality):
607
- common, simplified = cse(expr.rhs) #, ignore=in_out_args)
608
- expr = Equality(expr.lhs, simplified[0])
609
- else:
610
- common, simplified = cse(expr)
611
- expr = simplified
612
-
613
- local_vars = [Result(b,a) for a,b in common]
614
- local_symbols = {a for a,_ in common}
615
- local_expressions = Tuple(*[b for _,b in common])
616
- else:
617
- local_expressions = Tuple()
618
-
619
- if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
620
- if not expr:
621
- raise ValueError("No expression given")
622
- expressions = Tuple(*expr)
623
- else:
624
- expressions = Tuple(expr)
625
-
626
- if self.cse:
627
- if {i.label for i in expressions.atoms(Idx)} != set():
628
- raise CodeGenError("CSE and Indexed expressions do not play well together yet")
629
- else:
630
- # local variables for indexed expressions
631
- local_vars = {i.label for i in expressions.atoms(Idx)}
632
- local_symbols = local_vars
633
-
634
- # global variables
635
- global_vars = set() if global_vars is None else set(global_vars)
636
-
637
- # symbols that should be arguments
638
- symbols = (expressions.free_symbols | local_expressions.free_symbols) - local_symbols - global_vars
639
- new_symbols = set()
640
- new_symbols.update(symbols)
641
-
642
- for symbol in symbols:
643
- if isinstance(symbol, Idx):
644
- new_symbols.remove(symbol)
645
- new_symbols.update(symbol.args[1].free_symbols)
646
- if isinstance(symbol, Indexed):
647
- new_symbols.remove(symbol)
648
- symbols = new_symbols
649
-
650
- # Decide whether to use output argument or return value
651
- return_val = []
652
- output_args = []
653
- for expr in expressions:
654
- if isinstance(expr, Equality):
655
- out_arg = expr.lhs
656
- expr = expr.rhs
657
- if isinstance(out_arg, Indexed):
658
- dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape])
659
- symbol = out_arg.base.label
660
- elif isinstance(out_arg, Symbol):
661
- dims = []
662
- symbol = out_arg
663
- elif isinstance(out_arg, MatrixSymbol):
664
- dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape])
665
- symbol = out_arg
666
- else:
667
- raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
668
- "can define output arguments.")
669
-
670
- if expr.has(symbol):
671
- output_args.append(
672
- InOutArgument(symbol, out_arg, expr, dimensions=dims))
673
- else:
674
- output_args.append(
675
- OutputArgument(symbol, out_arg, expr, dimensions=dims))
676
-
677
- # remove duplicate arguments when they are not local variables
678
- if symbol not in local_vars:
679
- # avoid duplicate arguments
680
- symbols.remove(symbol)
681
- elif isinstance(expr, (ImmutableMatrix, MatrixSlice)):
682
- # Create a "dummy" MatrixSymbol to use as the Output arg
683
- out_arg = MatrixSymbol('out_%s' % abs(hash(expr)), *expr.shape)
684
- dims = tuple([(S.Zero, dim - 1) for dim in out_arg.shape])
685
- output_args.append(
686
- OutputArgument(out_arg, out_arg, expr, dimensions=dims))
687
- else:
688
- return_val.append(Result(expr))
689
-
690
- arg_list = []
691
-
692
- # setup input argument list
693
-
694
- # helper to get dimensions for data for array-like args
695
- def dimensions(s):
696
- return [(S.Zero, dim - 1) for dim in s.shape]
697
-
698
- array_symbols = {}
699
- for array in expressions.atoms(Indexed) | local_expressions.atoms(Indexed):
700
- array_symbols[array.base.label] = array
701
- for array in expressions.atoms(MatrixSymbol) | local_expressions.atoms(MatrixSymbol):
702
- array_symbols[array] = array
703
-
704
- for symbol in sorted(symbols, key=str):
705
- if symbol in array_symbols:
706
- array = array_symbols[symbol]
707
- metadata = {'dimensions': dimensions(array)}
708
- else:
709
- metadata = {}
710
-
711
- arg_list.append(InputArgument(symbol, **metadata))
712
-
713
- output_args.sort(key=lambda x: str(x.name))
714
- arg_list.extend(output_args)
715
-
716
- if argument_sequence is not None:
717
- # if the user has supplied IndexedBase instances, we'll accept that
718
- new_sequence = []
719
- for arg in argument_sequence:
720
- if isinstance(arg, IndexedBase):
721
- new_sequence.append(arg.label)
722
- else:
723
- new_sequence.append(arg)
724
- argument_sequence = new_sequence
725
-
726
- missing = [x for x in arg_list if x.name not in argument_sequence]
727
- if missing:
728
- msg = "Argument list didn't specify: {0} "
729
- msg = msg.format(", ".join([str(m.name) for m in missing]))
730
- raise CodeGenArgumentListError(msg, missing)
731
-
732
- # create redundant arguments to produce the requested sequence
733
- name_arg_dict = {x.name: x for x in arg_list}
734
- new_args = []
735
- for symbol in argument_sequence:
736
- try:
737
- new_args.append(name_arg_dict[symbol])
738
- except KeyError:
739
- if isinstance(symbol, (IndexedBase, MatrixSymbol)):
740
- metadata = {'dimensions': dimensions(symbol)}
741
- else:
742
- metadata = {}
743
- new_args.append(InputArgument(symbol, **metadata))
744
- arg_list = new_args
745
-
746
- return Routine(name, arg_list, return_val, local_vars, global_vars)
747
-
748
- def write(self, routines, prefix, to_files=False, header=True, empty=True):
749
- """Writes all the source code files for the given routines.
750
-
751
- The generated source is returned as a list of (filename, contents)
752
- tuples, or is written to files (see below). Each filename consists
753
- of the given prefix, appended with an appropriate extension.
754
-
755
- Parameters
756
- ==========
757
-
758
- routines : list
759
- A list of Routine instances to be written
760
-
761
- prefix : string
762
- The prefix for the output files
763
-
764
- to_files : bool, optional
765
- When True, the output is written to files. Otherwise, a list
766
- of (filename, contents) tuples is returned. [default: False]
767
-
768
- header : bool, optional
769
- When True, a header comment is included on top of each source
770
- file. [default: True]
771
-
772
- empty : bool, optional
773
- When True, empty lines are included to structure the source
774
- files. [default: True]
775
-
776
- """
777
- if to_files:
778
- for dump_fn in self.dump_fns:
779
- filename = "%s.%s" % (prefix, dump_fn.extension)
780
- with open(filename, "w") as f:
781
- dump_fn(self, routines, f, prefix, header, empty)
782
- else:
783
- result = []
784
- for dump_fn in self.dump_fns:
785
- filename = "%s.%s" % (prefix, dump_fn.extension)
786
- contents = StringIO()
787
- dump_fn(self, routines, contents, prefix, header, empty)
788
- result.append((filename, contents.getvalue()))
789
- return result
790
-
791
- def dump_code(self, routines, f, prefix, header=True, empty=True):
792
- """Write the code by calling language specific methods.
793
-
794
- The generated file contains all the definitions of the routines in
795
- low-level code and refers to the header file if appropriate.
796
-
797
- Parameters
798
- ==========
799
-
800
- routines : list
801
- A list of Routine instances.
802
-
803
- f : file-like
804
- Where to write the file.
805
-
806
- prefix : string
807
- The filename prefix, used to refer to the proper header file.
808
- Only the basename of the prefix is used.
809
-
810
- header : bool, optional
811
- When True, a header comment is included on top of each source
812
- file. [default : True]
813
-
814
- empty : bool, optional
815
- When True, empty lines are included to structure the source
816
- files. [default : True]
817
-
818
- """
819
-
820
- code_lines = self._preprocessor_statements(prefix)
821
-
822
- for routine in routines:
823
- if empty:
824
- code_lines.append("\n")
825
- code_lines.extend(self._get_routine_opening(routine))
826
- code_lines.extend(self._declare_arguments(routine))
827
- code_lines.extend(self._declare_globals(routine))
828
- code_lines.extend(self._declare_locals(routine))
829
- if empty:
830
- code_lines.append("\n")
831
- code_lines.extend(self._call_printer(routine))
832
- if empty:
833
- code_lines.append("\n")
834
- code_lines.extend(self._get_routine_ending(routine))
835
-
836
- code_lines = self._indent_code(''.join(code_lines))
837
-
838
- if header:
839
- code_lines = ''.join(self._get_header() + [code_lines])
840
-
841
- if code_lines:
842
- f.write(code_lines)
843
-
844
-
845
- class CodeGenError(Exception):
846
- pass
847
-
848
-
849
- class CodeGenArgumentListError(Exception):
850
- @property
851
- def missing_args(self):
852
- return self.args[1]
853
-
854
-
855
- header_comment = """Code generated with SymPy %(version)s
856
-
857
- See http://www.sympy.org/ for more information.
858
-
859
- This file is part of '%(project)s'
860
- """
861
-
862
-
863
- class CCodeGen(CodeGen):
864
- """Generator for C code.
865
-
866
- The .write() method inherited from CodeGen will output a code file and
867
- an interface file, <prefix>.c and <prefix>.h respectively.
868
-
869
- """
870
-
871
- code_extension = "c"
872
- interface_extension = "h"
873
- standard = 'c99'
874
-
875
- def __init__(self, project="project", printer=None,
876
- preprocessor_statements=None, cse=False):
877
- super().__init__(project=project, cse=cse)
878
- self.printer = printer or c_code_printers[self.standard.lower()]()
879
-
880
- self.preprocessor_statements = preprocessor_statements
881
- if preprocessor_statements is None:
882
- self.preprocessor_statements = ['#include <math.h>']
883
-
884
- def _get_header(self):
885
- """Writes a common header for the generated files."""
886
- code_lines = []
887
- code_lines.append("/" + "*"*78 + '\n')
888
- tmp = header_comment % {"version": sympy_version,
889
- "project": self.project}
890
- for line in tmp.splitlines():
891
- code_lines.append(" *%s*\n" % line.center(76))
892
- code_lines.append(" " + "*"*78 + "/\n")
893
- return code_lines
894
-
895
- def get_prototype(self, routine):
896
- """Returns a string for the function prototype of the routine.
897
-
898
- If the routine has multiple result objects, an CodeGenError is
899
- raised.
900
-
901
- See: https://en.wikipedia.org/wiki/Function_prototype
902
-
903
- """
904
- if len(routine.results) > 1:
905
- raise CodeGenError("C only supports a single or no return value.")
906
- elif len(routine.results) == 1:
907
- ctype = routine.results[0].get_datatype('C')
908
- else:
909
- ctype = "void"
910
-
911
- type_args = []
912
- for arg in routine.arguments:
913
- name = self.printer.doprint(arg.name)
914
- if arg.dimensions or isinstance(arg, ResultBase):
915
- type_args.append((arg.get_datatype('C'), "*%s" % name))
916
- else:
917
- type_args.append((arg.get_datatype('C'), name))
918
- arguments = ", ".join([ "%s %s" % t for t in type_args])
919
- return "%s %s(%s)" % (ctype, routine.name, arguments)
920
-
921
- def _preprocessor_statements(self, prefix):
922
- code_lines = []
923
- code_lines.append('#include "{}.h"'.format(os.path.basename(prefix)))
924
- code_lines.extend(self.preprocessor_statements)
925
- code_lines = ['{}\n'.format(l) for l in code_lines]
926
- return code_lines
927
-
928
- def _get_routine_opening(self, routine):
929
- prototype = self.get_prototype(routine)
930
- return ["%s {\n" % prototype]
931
-
932
- def _declare_arguments(self, routine):
933
- # arguments are declared in prototype
934
- return []
935
-
936
- def _declare_globals(self, routine):
937
- # global variables are not explicitly declared within C functions
938
- return []
939
-
940
- def _declare_locals(self, routine):
941
-
942
- # Compose a list of symbols to be dereferenced in the function
943
- # body. These are the arguments that were passed by a reference
944
- # pointer, excluding arrays.
945
- dereference = []
946
- for arg in routine.arguments:
947
- if isinstance(arg, ResultBase) and not arg.dimensions:
948
- dereference.append(arg.name)
949
-
950
- code_lines = []
951
- for result in routine.local_vars:
952
-
953
- # local variables that are simple symbols such as those used as indices into
954
- # for loops are defined declared elsewhere.
955
- if not isinstance(result, Result):
956
- continue
957
-
958
- if result.name != result.result_var:
959
- raise CodeGen("Result variable and name should match: {}".format(result))
960
- assign_to = result.name
961
- t = result.get_datatype('c')
962
- if isinstance(result.expr, (MatrixBase, MatrixExpr)):
963
- dims = result.expr.shape
964
- code_lines.append("{} {}[{}];\n".format(t, str(assign_to), dims[0]*dims[1]))
965
- prefix = ""
966
- else:
967
- prefix = "const {} ".format(t)
968
-
969
- constants, not_c, c_expr = self._printer_method_with_settings(
970
- 'doprint', {"human": False, "dereference": dereference, "strict": False},
971
- result.expr, assign_to=assign_to)
972
-
973
- for name, value in sorted(constants, key=str):
974
- code_lines.append("double const %s = %s;\n" % (name, value))
975
-
976
- code_lines.append("{}{}\n".format(prefix, c_expr))
977
-
978
- return code_lines
979
-
980
- def _call_printer(self, routine):
981
- code_lines = []
982
-
983
- # Compose a list of symbols to be dereferenced in the function
984
- # body. These are the arguments that were passed by a reference
985
- # pointer, excluding arrays.
986
- dereference = []
987
- for arg in routine.arguments:
988
- if isinstance(arg, ResultBase) and not arg.dimensions:
989
- dereference.append(arg.name)
990
-
991
- return_val = None
992
- for result in routine.result_variables:
993
- if isinstance(result, Result):
994
- assign_to = routine.name + "_result"
995
- t = result.get_datatype('c')
996
- code_lines.append("{} {};\n".format(t, str(assign_to)))
997
- return_val = assign_to
998
- else:
999
- assign_to = result.result_var
1000
-
1001
- try:
1002
- constants, not_c, c_expr = self._printer_method_with_settings(
1003
- 'doprint', {"human": False, "dereference": dereference, "strict": False},
1004
- result.expr, assign_to=assign_to)
1005
- except AssignmentError:
1006
- assign_to = result.result_var
1007
- code_lines.append(
1008
- "%s %s;\n" % (result.get_datatype('c'), str(assign_to)))
1009
- constants, not_c, c_expr = self._printer_method_with_settings(
1010
- 'doprint', {"human": False, "dereference": dereference, "strict": False},
1011
- result.expr, assign_to=assign_to)
1012
-
1013
- for name, value in sorted(constants, key=str):
1014
- code_lines.append("double const %s = %s;\n" % (name, value))
1015
- code_lines.append("%s\n" % c_expr)
1016
-
1017
- if return_val:
1018
- code_lines.append(" return %s;\n" % return_val)
1019
- return code_lines
1020
-
1021
- def _get_routine_ending(self, routine):
1022
- return ["}\n"]
1023
-
1024
- def dump_c(self, routines, f, prefix, header=True, empty=True):
1025
- self.dump_code(routines, f, prefix, header, empty)
1026
- dump_c.extension = code_extension # type: ignore
1027
- dump_c.__doc__ = CodeGen.dump_code.__doc__
1028
-
1029
- def dump_h(self, routines, f, prefix, header=True, empty=True):
1030
- """Writes the C header file.
1031
-
1032
- This file contains all the function declarations.
1033
-
1034
- Parameters
1035
- ==========
1036
-
1037
- routines : list
1038
- A list of Routine instances.
1039
-
1040
- f : file-like
1041
- Where to write the file.
1042
-
1043
- prefix : string
1044
- The filename prefix, used to construct the include guards.
1045
- Only the basename of the prefix is used.
1046
-
1047
- header : bool, optional
1048
- When True, a header comment is included on top of each source
1049
- file. [default : True]
1050
-
1051
- empty : bool, optional
1052
- When True, empty lines are included to structure the source
1053
- files. [default : True]
1054
-
1055
- """
1056
- if header:
1057
- print(''.join(self._get_header()), file=f)
1058
- guard_name = "%s__%s__H" % (self.project.replace(
1059
- " ", "_").upper(), prefix.replace("/", "_").upper())
1060
- # include guards
1061
- if empty:
1062
- print(file=f)
1063
- print("#ifndef %s" % guard_name, file=f)
1064
- print("#define %s" % guard_name, file=f)
1065
- if empty:
1066
- print(file=f)
1067
- # declaration of the function prototypes
1068
- for routine in routines:
1069
- prototype = self.get_prototype(routine)
1070
- print("%s;" % prototype, file=f)
1071
- # end if include guards
1072
- if empty:
1073
- print(file=f)
1074
- print("#endif", file=f)
1075
- if empty:
1076
- print(file=f)
1077
- dump_h.extension = interface_extension # type: ignore
1078
-
1079
- # This list of dump functions is used by CodeGen.write to know which dump
1080
- # functions it has to call.
1081
- dump_fns = [dump_c, dump_h]
1082
-
1083
- class C89CodeGen(CCodeGen):
1084
- standard = 'C89'
1085
-
1086
- class C99CodeGen(CCodeGen):
1087
- standard = 'C99'
1088
-
1089
- class FCodeGen(CodeGen):
1090
- """Generator for Fortran 95 code
1091
-
1092
- The .write() method inherited from CodeGen will output a code file and
1093
- an interface file, <prefix>.f90 and <prefix>.h respectively.
1094
-
1095
- """
1096
-
1097
- code_extension = "f90"
1098
- interface_extension = "h"
1099
-
1100
- def __init__(self, project='project', printer=None):
1101
- super().__init__(project)
1102
- self.printer = printer or FCodePrinter()
1103
-
1104
- def _get_header(self):
1105
- """Writes a common header for the generated files."""
1106
- code_lines = []
1107
- code_lines.append("!" + "*"*78 + '\n')
1108
- tmp = header_comment % {"version": sympy_version,
1109
- "project": self.project}
1110
- for line in tmp.splitlines():
1111
- code_lines.append("!*%s*\n" % line.center(76))
1112
- code_lines.append("!" + "*"*78 + '\n')
1113
- return code_lines
1114
-
1115
- def _preprocessor_statements(self, prefix):
1116
- return []
1117
-
1118
- def _get_routine_opening(self, routine):
1119
- """Returns the opening statements of the fortran routine."""
1120
- code_list = []
1121
- if len(routine.results) > 1:
1122
- raise CodeGenError(
1123
- "Fortran only supports a single or no return value.")
1124
- elif len(routine.results) == 1:
1125
- result = routine.results[0]
1126
- code_list.append(result.get_datatype('fortran'))
1127
- code_list.append("function")
1128
- else:
1129
- code_list.append("subroutine")
1130
-
1131
- args = ", ".join("%s" % self._get_symbol(arg.name)
1132
- for arg in routine.arguments)
1133
-
1134
- call_sig = "{}({})\n".format(routine.name, args)
1135
- # Fortran 95 requires all lines be less than 132 characters, so wrap
1136
- # this line before appending.
1137
- call_sig = ' &\n'.join(textwrap.wrap(call_sig,
1138
- width=60,
1139
- break_long_words=False)) + '\n'
1140
- code_list.append(call_sig)
1141
- code_list = [' '.join(code_list)]
1142
- code_list.append('implicit none\n')
1143
- return code_list
1144
-
1145
- def _declare_arguments(self, routine):
1146
- # argument type declarations
1147
- code_list = []
1148
- array_list = []
1149
- scalar_list = []
1150
- for arg in routine.arguments:
1151
-
1152
- if isinstance(arg, InputArgument):
1153
- typeinfo = "%s, intent(in)" % arg.get_datatype('fortran')
1154
- elif isinstance(arg, InOutArgument):
1155
- typeinfo = "%s, intent(inout)" % arg.get_datatype('fortran')
1156
- elif isinstance(arg, OutputArgument):
1157
- typeinfo = "%s, intent(out)" % arg.get_datatype('fortran')
1158
- else:
1159
- raise CodeGenError("Unknown Argument type: %s" % type(arg))
1160
-
1161
- fprint = self._get_symbol
1162
-
1163
- if arg.dimensions:
1164
- # fortran arrays start at 1
1165
- dimstr = ", ".join(["%s:%s" % (
1166
- fprint(dim[0] + 1), fprint(dim[1] + 1))
1167
- for dim in arg.dimensions])
1168
- typeinfo += ", dimension(%s)" % dimstr
1169
- array_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name)))
1170
- else:
1171
- scalar_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name)))
1172
-
1173
- # scalars first, because they can be used in array declarations
1174
- code_list.extend(scalar_list)
1175
- code_list.extend(array_list)
1176
-
1177
- return code_list
1178
-
1179
- def _declare_globals(self, routine):
1180
- # Global variables not explicitly declared within Fortran 90 functions.
1181
- # Note: a future F77 mode may need to generate "common" blocks.
1182
- return []
1183
-
1184
- def _declare_locals(self, routine):
1185
- code_list = []
1186
- for var in sorted(routine.local_vars, key=str):
1187
- typeinfo = get_default_datatype(var)
1188
- code_list.append("%s :: %s\n" % (
1189
- typeinfo.fname, self._get_symbol(var)))
1190
- return code_list
1191
-
1192
- def _get_routine_ending(self, routine):
1193
- """Returns the closing statements of the fortran routine."""
1194
- if len(routine.results) == 1:
1195
- return ["end function\n"]
1196
- else:
1197
- return ["end subroutine\n"]
1198
-
1199
- def get_interface(self, routine):
1200
- """Returns a string for the function interface.
1201
-
1202
- The routine should have a single result object, which can be None.
1203
- If the routine has multiple result objects, a CodeGenError is
1204
- raised.
1205
-
1206
- See: https://en.wikipedia.org/wiki/Function_prototype
1207
-
1208
- """
1209
- prototype = [ "interface\n" ]
1210
- prototype.extend(self._get_routine_opening(routine))
1211
- prototype.extend(self._declare_arguments(routine))
1212
- prototype.extend(self._get_routine_ending(routine))
1213
- prototype.append("end interface\n")
1214
-
1215
- return "".join(prototype)
1216
-
1217
- def _call_printer(self, routine):
1218
- declarations = []
1219
- code_lines = []
1220
- for result in routine.result_variables:
1221
- if isinstance(result, Result):
1222
- assign_to = routine.name
1223
- elif isinstance(result, (OutputArgument, InOutArgument)):
1224
- assign_to = result.result_var
1225
-
1226
- constants, not_fortran, f_expr = self._printer_method_with_settings(
1227
- 'doprint', {"human": False, "source_format": 'free', "standard": 95, "strict": False},
1228
- result.expr, assign_to=assign_to)
1229
-
1230
- for obj, v in sorted(constants, key=str):
1231
- t = get_default_datatype(obj)
1232
- declarations.append(
1233
- "%s, parameter :: %s = %s\n" % (t.fname, obj, v))
1234
- for obj in sorted(not_fortran, key=str):
1235
- t = get_default_datatype(obj)
1236
- if isinstance(obj, Function):
1237
- name = obj.func
1238
- else:
1239
- name = obj
1240
- declarations.append("%s :: %s\n" % (t.fname, name))
1241
-
1242
- code_lines.append("%s\n" % f_expr)
1243
- return declarations + code_lines
1244
-
1245
- def _indent_code(self, codelines):
1246
- return self._printer_method_with_settings(
1247
- 'indent_code', {"human": False, "source_format": 'free', "strict": False}, codelines)
1248
-
1249
- def dump_f95(self, routines, f, prefix, header=True, empty=True):
1250
- # check that symbols are unique with ignorecase
1251
- for r in routines:
1252
- lowercase = {str(x).lower() for x in r.variables}
1253
- orig_case = {str(x) for x in r.variables}
1254
- if len(lowercase) < len(orig_case):
1255
- raise CodeGenError("Fortran ignores case. Got symbols: %s" %
1256
- (", ".join([str(var) for var in r.variables])))
1257
- self.dump_code(routines, f, prefix, header, empty)
1258
- dump_f95.extension = code_extension # type: ignore
1259
- dump_f95.__doc__ = CodeGen.dump_code.__doc__
1260
-
1261
- def dump_h(self, routines, f, prefix, header=True, empty=True):
1262
- """Writes the interface to a header file.
1263
-
1264
- This file contains all the function declarations.
1265
-
1266
- Parameters
1267
- ==========
1268
-
1269
- routines : list
1270
- A list of Routine instances.
1271
-
1272
- f : file-like
1273
- Where to write the file.
1274
-
1275
- prefix : string
1276
- The filename prefix.
1277
-
1278
- header : bool, optional
1279
- When True, a header comment is included on top of each source
1280
- file. [default : True]
1281
-
1282
- empty : bool, optional
1283
- When True, empty lines are included to structure the source
1284
- files. [default : True]
1285
-
1286
- """
1287
- if header:
1288
- print(''.join(self._get_header()), file=f)
1289
- if empty:
1290
- print(file=f)
1291
- # declaration of the function prototypes
1292
- for routine in routines:
1293
- prototype = self.get_interface(routine)
1294
- f.write(prototype)
1295
- if empty:
1296
- print(file=f)
1297
- dump_h.extension = interface_extension # type: ignore
1298
-
1299
- # This list of dump functions is used by CodeGen.write to know which dump
1300
- # functions it has to call.
1301
- dump_fns = [dump_f95, dump_h]
1302
-
1303
-
1304
- class JuliaCodeGen(CodeGen):
1305
- """Generator for Julia code.
1306
-
1307
- The .write() method inherited from CodeGen will output a code file
1308
- <prefix>.jl.
1309
-
1310
- """
1311
-
1312
- code_extension = "jl"
1313
-
1314
- def __init__(self, project='project', printer=None):
1315
- super().__init__(project)
1316
- self.printer = printer or JuliaCodePrinter()
1317
-
1318
- def routine(self, name, expr, argument_sequence, global_vars):
1319
- """Specialized Routine creation for Julia."""
1320
-
1321
- if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
1322
- if not expr:
1323
- raise ValueError("No expression given")
1324
- expressions = Tuple(*expr)
1325
- else:
1326
- expressions = Tuple(expr)
1327
-
1328
- # local variables
1329
- local_vars = {i.label for i in expressions.atoms(Idx)}
1330
-
1331
- # global variables
1332
- global_vars = set() if global_vars is None else set(global_vars)
1333
-
1334
- # symbols that should be arguments
1335
- old_symbols = expressions.free_symbols - local_vars - global_vars
1336
- symbols = set()
1337
- for s in old_symbols:
1338
- if isinstance(s, Idx):
1339
- symbols.update(s.args[1].free_symbols)
1340
- elif not isinstance(s, Indexed):
1341
- symbols.add(s)
1342
-
1343
- # Julia supports multiple return values
1344
- return_vals = []
1345
- output_args = []
1346
- for (i, expr) in enumerate(expressions):
1347
- if isinstance(expr, Equality):
1348
- out_arg = expr.lhs
1349
- expr = expr.rhs
1350
- symbol = out_arg
1351
- if isinstance(out_arg, Indexed):
1352
- dims = tuple([ (S.One, dim) for dim in out_arg.shape])
1353
- symbol = out_arg.base.label
1354
- output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims))
1355
- if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
1356
- raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
1357
- "can define output arguments.")
1358
-
1359
- return_vals.append(Result(expr, name=symbol, result_var=out_arg))
1360
- if not expr.has(symbol):
1361
- # this is a pure output: remove from the symbols list, so
1362
- # it doesn't become an input.
1363
- symbols.remove(symbol)
1364
-
1365
- else:
1366
- # we have no name for this output
1367
- return_vals.append(Result(expr, name='out%d' % (i+1)))
1368
-
1369
- # setup input argument list
1370
- output_args.sort(key=lambda x: str(x.name))
1371
- arg_list = list(output_args)
1372
- array_symbols = {}
1373
- for array in expressions.atoms(Indexed):
1374
- array_symbols[array.base.label] = array
1375
- for array in expressions.atoms(MatrixSymbol):
1376
- array_symbols[array] = array
1377
-
1378
- for symbol in sorted(symbols, key=str):
1379
- arg_list.append(InputArgument(symbol))
1380
-
1381
- if argument_sequence is not None:
1382
- # if the user has supplied IndexedBase instances, we'll accept that
1383
- new_sequence = []
1384
- for arg in argument_sequence:
1385
- if isinstance(arg, IndexedBase):
1386
- new_sequence.append(arg.label)
1387
- else:
1388
- new_sequence.append(arg)
1389
- argument_sequence = new_sequence
1390
-
1391
- missing = [x for x in arg_list if x.name not in argument_sequence]
1392
- if missing:
1393
- msg = "Argument list didn't specify: {0} "
1394
- msg = msg.format(", ".join([str(m.name) for m in missing]))
1395
- raise CodeGenArgumentListError(msg, missing)
1396
-
1397
- # create redundant arguments to produce the requested sequence
1398
- name_arg_dict = {x.name: x for x in arg_list}
1399
- new_args = []
1400
- for symbol in argument_sequence:
1401
- try:
1402
- new_args.append(name_arg_dict[symbol])
1403
- except KeyError:
1404
- new_args.append(InputArgument(symbol))
1405
- arg_list = new_args
1406
-
1407
- return Routine(name, arg_list, return_vals, local_vars, global_vars)
1408
-
1409
- def _get_header(self):
1410
- """Writes a common header for the generated files."""
1411
- code_lines = []
1412
- tmp = header_comment % {"version": sympy_version,
1413
- "project": self.project}
1414
- for line in tmp.splitlines():
1415
- if line == '':
1416
- code_lines.append("#\n")
1417
- else:
1418
- code_lines.append("# %s\n" % line)
1419
- return code_lines
1420
-
1421
- def _preprocessor_statements(self, prefix):
1422
- return []
1423
-
1424
- def _get_routine_opening(self, routine):
1425
- """Returns the opening statements of the routine."""
1426
- code_list = []
1427
- code_list.append("function ")
1428
-
1429
- # Inputs
1430
- args = []
1431
- for arg in routine.arguments:
1432
- if isinstance(arg, OutputArgument):
1433
- raise CodeGenError("Julia: invalid argument of type %s" %
1434
- str(type(arg)))
1435
- if isinstance(arg, (InputArgument, InOutArgument)):
1436
- args.append("%s" % self._get_symbol(arg.name))
1437
- args = ", ".join(args)
1438
- code_list.append("%s(%s)\n" % (routine.name, args))
1439
- code_list = [ "".join(code_list) ]
1440
-
1441
- return code_list
1442
-
1443
- def _declare_arguments(self, routine):
1444
- return []
1445
-
1446
- def _declare_globals(self, routine):
1447
- return []
1448
-
1449
- def _declare_locals(self, routine):
1450
- return []
1451
-
1452
- def _get_routine_ending(self, routine):
1453
- outs = []
1454
- for result in routine.results:
1455
- if isinstance(result, Result):
1456
- # Note: name not result_var; want `y` not `y[i]` for Indexed
1457
- s = self._get_symbol(result.name)
1458
- else:
1459
- raise CodeGenError("unexpected object in Routine results")
1460
- outs.append(s)
1461
- return ["return " + ", ".join(outs) + "\nend\n"]
1462
-
1463
- def _call_printer(self, routine):
1464
- declarations = []
1465
- code_lines = []
1466
- for result in routine.results:
1467
- if isinstance(result, Result):
1468
- assign_to = result.result_var
1469
- else:
1470
- raise CodeGenError("unexpected object in Routine results")
1471
-
1472
- constants, not_supported, jl_expr = self._printer_method_with_settings(
1473
- 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to)
1474
-
1475
- for obj, v in sorted(constants, key=str):
1476
- declarations.append(
1477
- "%s = %s\n" % (obj, v))
1478
- for obj in sorted(not_supported, key=str):
1479
- if isinstance(obj, Function):
1480
- name = obj.func
1481
- else:
1482
- name = obj
1483
- declarations.append(
1484
- "# unsupported: %s\n" % (name))
1485
- code_lines.append("%s\n" % (jl_expr))
1486
- return declarations + code_lines
1487
-
1488
- def _indent_code(self, codelines):
1489
- # Note that indenting seems to happen twice, first
1490
- # statement-by-statement by JuliaPrinter then again here.
1491
- p = JuliaCodePrinter({'human': False, "strict": False})
1492
- return p.indent_code(codelines)
1493
-
1494
- def dump_jl(self, routines, f, prefix, header=True, empty=True):
1495
- self.dump_code(routines, f, prefix, header, empty)
1496
-
1497
- dump_jl.extension = code_extension # type: ignore
1498
- dump_jl.__doc__ = CodeGen.dump_code.__doc__
1499
-
1500
- # This list of dump functions is used by CodeGen.write to know which dump
1501
- # functions it has to call.
1502
- dump_fns = [dump_jl]
1503
-
1504
-
1505
- class OctaveCodeGen(CodeGen):
1506
- """Generator for Octave code.
1507
-
1508
- The .write() method inherited from CodeGen will output a code file
1509
- <prefix>.m.
1510
-
1511
- Octave .m files usually contain one function. That function name should
1512
- match the filename (``prefix``). If you pass multiple ``name_expr`` pairs,
1513
- the latter ones are presumed to be private functions accessed by the
1514
- primary function.
1515
-
1516
- You should only pass inputs to ``argument_sequence``: outputs are ordered
1517
- according to their order in ``name_expr``.
1518
-
1519
- """
1520
-
1521
- code_extension = "m"
1522
-
1523
- def __init__(self, project='project', printer=None):
1524
- super().__init__(project)
1525
- self.printer = printer or OctaveCodePrinter()
1526
-
1527
- def routine(self, name, expr, argument_sequence, global_vars):
1528
- """Specialized Routine creation for Octave."""
1529
-
1530
- # FIXME: this is probably general enough for other high-level
1531
- # languages, perhaps its the C/Fortran one that is specialized!
1532
-
1533
- if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
1534
- if not expr:
1535
- raise ValueError("No expression given")
1536
- expressions = Tuple(*expr)
1537
- else:
1538
- expressions = Tuple(expr)
1539
-
1540
- # local variables
1541
- local_vars = {i.label for i in expressions.atoms(Idx)}
1542
-
1543
- # global variables
1544
- global_vars = set() if global_vars is None else set(global_vars)
1545
-
1546
- # symbols that should be arguments
1547
- old_symbols = expressions.free_symbols - local_vars - global_vars
1548
- symbols = set()
1549
- for s in old_symbols:
1550
- if isinstance(s, Idx):
1551
- symbols.update(s.args[1].free_symbols)
1552
- elif not isinstance(s, Indexed):
1553
- symbols.add(s)
1554
-
1555
- # Octave supports multiple return values
1556
- return_vals = []
1557
- for (i, expr) in enumerate(expressions):
1558
- if isinstance(expr, Equality):
1559
- out_arg = expr.lhs
1560
- expr = expr.rhs
1561
- symbol = out_arg
1562
- if isinstance(out_arg, Indexed):
1563
- symbol = out_arg.base.label
1564
- if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
1565
- raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
1566
- "can define output arguments.")
1567
-
1568
- return_vals.append(Result(expr, name=symbol, result_var=out_arg))
1569
- if not expr.has(symbol):
1570
- # this is a pure output: remove from the symbols list, so
1571
- # it doesn't become an input.
1572
- symbols.remove(symbol)
1573
-
1574
- else:
1575
- # we have no name for this output
1576
- return_vals.append(Result(expr, name='out%d' % (i+1)))
1577
-
1578
- # setup input argument list
1579
- arg_list = []
1580
- array_symbols = {}
1581
- for array in expressions.atoms(Indexed):
1582
- array_symbols[array.base.label] = array
1583
- for array in expressions.atoms(MatrixSymbol):
1584
- array_symbols[array] = array
1585
-
1586
- for symbol in sorted(symbols, key=str):
1587
- arg_list.append(InputArgument(symbol))
1588
-
1589
- if argument_sequence is not None:
1590
- # if the user has supplied IndexedBase instances, we'll accept that
1591
- new_sequence = []
1592
- for arg in argument_sequence:
1593
- if isinstance(arg, IndexedBase):
1594
- new_sequence.append(arg.label)
1595
- else:
1596
- new_sequence.append(arg)
1597
- argument_sequence = new_sequence
1598
-
1599
- missing = [x for x in arg_list if x.name not in argument_sequence]
1600
- if missing:
1601
- msg = "Argument list didn't specify: {0} "
1602
- msg = msg.format(", ".join([str(m.name) for m in missing]))
1603
- raise CodeGenArgumentListError(msg, missing)
1604
-
1605
- # create redundant arguments to produce the requested sequence
1606
- name_arg_dict = {x.name: x for x in arg_list}
1607
- new_args = []
1608
- for symbol in argument_sequence:
1609
- try:
1610
- new_args.append(name_arg_dict[symbol])
1611
- except KeyError:
1612
- new_args.append(InputArgument(symbol))
1613
- arg_list = new_args
1614
-
1615
- return Routine(name, arg_list, return_vals, local_vars, global_vars)
1616
-
1617
- def _get_header(self):
1618
- """Writes a common header for the generated files."""
1619
- code_lines = []
1620
- tmp = header_comment % {"version": sympy_version,
1621
- "project": self.project}
1622
- for line in tmp.splitlines():
1623
- if line == '':
1624
- code_lines.append("%\n")
1625
- else:
1626
- code_lines.append("%% %s\n" % line)
1627
- return code_lines
1628
-
1629
- def _preprocessor_statements(self, prefix):
1630
- return []
1631
-
1632
- def _get_routine_opening(self, routine):
1633
- """Returns the opening statements of the routine."""
1634
- code_list = []
1635
- code_list.append("function ")
1636
-
1637
- # Outputs
1638
- outs = []
1639
- for result in routine.results:
1640
- if isinstance(result, Result):
1641
- # Note: name not result_var; want `y` not `y(i)` for Indexed
1642
- s = self._get_symbol(result.name)
1643
- else:
1644
- raise CodeGenError("unexpected object in Routine results")
1645
- outs.append(s)
1646
- if len(outs) > 1:
1647
- code_list.append("[" + (", ".join(outs)) + "]")
1648
- else:
1649
- code_list.append("".join(outs))
1650
- code_list.append(" = ")
1651
-
1652
- # Inputs
1653
- args = []
1654
- for arg in routine.arguments:
1655
- if isinstance(arg, (OutputArgument, InOutArgument)):
1656
- raise CodeGenError("Octave: invalid argument of type %s" %
1657
- str(type(arg)))
1658
- if isinstance(arg, InputArgument):
1659
- args.append("%s" % self._get_symbol(arg.name))
1660
- args = ", ".join(args)
1661
- code_list.append("%s(%s)\n" % (routine.name, args))
1662
- code_list = [ "".join(code_list) ]
1663
-
1664
- return code_list
1665
-
1666
- def _declare_arguments(self, routine):
1667
- return []
1668
-
1669
- def _declare_globals(self, routine):
1670
- if not routine.global_vars:
1671
- return []
1672
- s = " ".join(sorted([self._get_symbol(g) for g in routine.global_vars]))
1673
- return ["global " + s + "\n"]
1674
-
1675
- def _declare_locals(self, routine):
1676
- return []
1677
-
1678
- def _get_routine_ending(self, routine):
1679
- return ["end\n"]
1680
-
1681
- def _call_printer(self, routine):
1682
- declarations = []
1683
- code_lines = []
1684
- for result in routine.results:
1685
- if isinstance(result, Result):
1686
- assign_to = result.result_var
1687
- else:
1688
- raise CodeGenError("unexpected object in Routine results")
1689
-
1690
- constants, not_supported, oct_expr = self._printer_method_with_settings(
1691
- 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to)
1692
-
1693
- for obj, v in sorted(constants, key=str):
1694
- declarations.append(
1695
- " %s = %s; %% constant\n" % (obj, v))
1696
- for obj in sorted(not_supported, key=str):
1697
- if isinstance(obj, Function):
1698
- name = obj.func
1699
- else:
1700
- name = obj
1701
- declarations.append(
1702
- " %% unsupported: %s\n" % (name))
1703
- code_lines.append("%s\n" % (oct_expr))
1704
- return declarations + code_lines
1705
-
1706
- def _indent_code(self, codelines):
1707
- return self._printer_method_with_settings(
1708
- 'indent_code', {"human": False, "strict": False}, codelines)
1709
-
1710
- def dump_m(self, routines, f, prefix, header=True, empty=True, inline=True):
1711
- # Note used to call self.dump_code() but we need more control for header
1712
-
1713
- code_lines = self._preprocessor_statements(prefix)
1714
-
1715
- for i, routine in enumerate(routines):
1716
- if i > 0:
1717
- if empty:
1718
- code_lines.append("\n")
1719
- code_lines.extend(self._get_routine_opening(routine))
1720
- if i == 0:
1721
- if routine.name != prefix:
1722
- raise ValueError('Octave function name should match prefix')
1723
- if header:
1724
- code_lines.append("%" + prefix.upper() +
1725
- " Autogenerated by SymPy\n")
1726
- code_lines.append(''.join(self._get_header()))
1727
- code_lines.extend(self._declare_arguments(routine))
1728
- code_lines.extend(self._declare_globals(routine))
1729
- code_lines.extend(self._declare_locals(routine))
1730
- if empty:
1731
- code_lines.append("\n")
1732
- code_lines.extend(self._call_printer(routine))
1733
- if empty:
1734
- code_lines.append("\n")
1735
- code_lines.extend(self._get_routine_ending(routine))
1736
-
1737
- code_lines = self._indent_code(''.join(code_lines))
1738
-
1739
- if code_lines:
1740
- f.write(code_lines)
1741
-
1742
- dump_m.extension = code_extension # type: ignore
1743
- dump_m.__doc__ = CodeGen.dump_code.__doc__
1744
-
1745
- # This list of dump functions is used by CodeGen.write to know which dump
1746
- # functions it has to call.
1747
- dump_fns = [dump_m]
1748
-
1749
- class RustCodeGen(CodeGen):
1750
- """Generator for Rust code.
1751
-
1752
- The .write() method inherited from CodeGen will output a code file
1753
- <prefix>.rs
1754
-
1755
- """
1756
-
1757
- code_extension = "rs"
1758
-
1759
- def __init__(self, project="project", printer=None):
1760
- super().__init__(project=project)
1761
- self.printer = printer or RustCodePrinter()
1762
-
1763
- def routine(self, name, expr, argument_sequence, global_vars):
1764
- """Specialized Routine creation for Rust."""
1765
-
1766
- if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
1767
- if not expr:
1768
- raise ValueError("No expression given")
1769
- expressions = Tuple(*expr)
1770
- else:
1771
- expressions = Tuple(expr)
1772
-
1773
- # local variables
1774
- local_vars = {i.label for i in expressions.atoms(Idx)}
1775
-
1776
- # global variables
1777
- global_vars = set() if global_vars is None else set(global_vars)
1778
-
1779
- # symbols that should be arguments
1780
- symbols = expressions.free_symbols - local_vars - global_vars - expressions.atoms(Indexed)
1781
-
1782
- # Rust supports multiple return values
1783
- return_vals = []
1784
- output_args = []
1785
- for (i, expr) in enumerate(expressions):
1786
- if isinstance(expr, Equality):
1787
- out_arg = expr.lhs
1788
- expr = expr.rhs
1789
- symbol = out_arg
1790
- if isinstance(out_arg, Indexed):
1791
- dims = tuple([ (S.One, dim) for dim in out_arg.shape])
1792
- symbol = out_arg.base.label
1793
- output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims))
1794
- if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
1795
- raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
1796
- "can define output arguments.")
1797
-
1798
- return_vals.append(Result(expr, name=symbol, result_var=out_arg))
1799
- if not expr.has(symbol):
1800
- # this is a pure output: remove from the symbols list, so
1801
- # it doesn't become an input.
1802
- symbols.remove(symbol)
1803
-
1804
- else:
1805
- # we have no name for this output
1806
- return_vals.append(Result(expr, name='out%d' % (i+1)))
1807
-
1808
- # setup input argument list
1809
- output_args.sort(key=lambda x: str(x.name))
1810
- arg_list = list(output_args)
1811
- array_symbols = {}
1812
- for array in expressions.atoms(Indexed):
1813
- array_symbols[array.base.label] = array
1814
- for array in expressions.atoms(MatrixSymbol):
1815
- array_symbols[array] = array
1816
-
1817
- for symbol in sorted(symbols, key=str):
1818
- arg_list.append(InputArgument(symbol))
1819
-
1820
- if argument_sequence is not None:
1821
- # if the user has supplied IndexedBase instances, we'll accept that
1822
- new_sequence = []
1823
- for arg in argument_sequence:
1824
- if isinstance(arg, IndexedBase):
1825
- new_sequence.append(arg.label)
1826
- else:
1827
- new_sequence.append(arg)
1828
- argument_sequence = new_sequence
1829
-
1830
- missing = [x for x in arg_list if x.name not in argument_sequence]
1831
- if missing:
1832
- msg = "Argument list didn't specify: {0} "
1833
- msg = msg.format(", ".join([str(m.name) for m in missing]))
1834
- raise CodeGenArgumentListError(msg, missing)
1835
-
1836
- # create redundant arguments to produce the requested sequence
1837
- name_arg_dict = {x.name: x for x in arg_list}
1838
- new_args = []
1839
- for symbol in argument_sequence:
1840
- try:
1841
- new_args.append(name_arg_dict[symbol])
1842
- except KeyError:
1843
- new_args.append(InputArgument(symbol))
1844
- arg_list = new_args
1845
-
1846
- return Routine(name, arg_list, return_vals, local_vars, global_vars)
1847
-
1848
-
1849
- def _get_header(self):
1850
- """Writes a common header for the generated files."""
1851
- code_lines = []
1852
- code_lines.append("/*\n")
1853
- tmp = header_comment % {"version": sympy_version,
1854
- "project": self.project}
1855
- for line in tmp.splitlines():
1856
- code_lines.append((" *%s" % line.center(76)).rstrip() + "\n")
1857
- code_lines.append(" */\n")
1858
- return code_lines
1859
-
1860
- def get_prototype(self, routine):
1861
- """Returns a string for the function prototype of the routine.
1862
-
1863
- If the routine has multiple result objects, an CodeGenError is
1864
- raised.
1865
-
1866
- See: https://en.wikipedia.org/wiki/Function_prototype
1867
-
1868
- """
1869
- results = [i.get_datatype('Rust') for i in routine.results]
1870
-
1871
- if len(results) == 1:
1872
- rstype = " -> " + results[0]
1873
- elif len(routine.results) > 1:
1874
- rstype = " -> (" + ", ".join(results) + ")"
1875
- else:
1876
- rstype = ""
1877
-
1878
- type_args = []
1879
- for arg in routine.arguments:
1880
- name = self.printer.doprint(arg.name)
1881
- if arg.dimensions or isinstance(arg, ResultBase):
1882
- type_args.append(("*%s" % name, arg.get_datatype('Rust')))
1883
- else:
1884
- type_args.append((name, arg.get_datatype('Rust')))
1885
- arguments = ", ".join([ "%s: %s" % t for t in type_args])
1886
- return "fn %s(%s)%s" % (routine.name, arguments, rstype)
1887
-
1888
- def _preprocessor_statements(self, prefix):
1889
- code_lines = []
1890
- # code_lines.append("use std::f64::consts::*;\n")
1891
- return code_lines
1892
-
1893
- def _get_routine_opening(self, routine):
1894
- prototype = self.get_prototype(routine)
1895
- return ["%s {\n" % prototype]
1896
-
1897
- def _declare_arguments(self, routine):
1898
- # arguments are declared in prototype
1899
- return []
1900
-
1901
- def _declare_globals(self, routine):
1902
- # global variables are not explicitly declared within C functions
1903
- return []
1904
-
1905
- def _declare_locals(self, routine):
1906
- # loop variables are declared in loop statement
1907
- return []
1908
-
1909
- def _call_printer(self, routine):
1910
-
1911
- code_lines = []
1912
- declarations = []
1913
- returns = []
1914
-
1915
- # Compose a list of symbols to be dereferenced in the function
1916
- # body. These are the arguments that were passed by a reference
1917
- # pointer, excluding arrays.
1918
- dereference = []
1919
- for arg in routine.arguments:
1920
- if isinstance(arg, ResultBase) and not arg.dimensions:
1921
- dereference.append(arg.name)
1922
-
1923
- for result in routine.results:
1924
- if isinstance(result, Result):
1925
- assign_to = result.result_var
1926
- returns.append(str(result.result_var))
1927
- else:
1928
- raise CodeGenError("unexpected object in Routine results")
1929
-
1930
- constants, not_supported, rs_expr = self._printer_method_with_settings(
1931
- 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to)
1932
-
1933
- for name, value in sorted(constants, key=str):
1934
- declarations.append("const %s: f64 = %s;\n" % (name, value))
1935
-
1936
- for obj in sorted(not_supported, key=str):
1937
- if isinstance(obj, Function):
1938
- name = obj.func
1939
- else:
1940
- name = obj
1941
- declarations.append("// unsupported: %s\n" % (name))
1942
-
1943
- code_lines.append("let %s\n" % rs_expr)
1944
-
1945
- if len(returns) > 1:
1946
- returns = ['(' + ', '.join(returns) + ')']
1947
-
1948
- returns.append('\n')
1949
-
1950
- return declarations + code_lines + returns
1951
-
1952
- def _get_routine_ending(self, routine):
1953
- return ["}\n"]
1954
-
1955
- def dump_rs(self, routines, f, prefix, header=True, empty=True):
1956
- self.dump_code(routines, f, prefix, header, empty)
1957
-
1958
- dump_rs.extension = code_extension # type: ignore
1959
- dump_rs.__doc__ = CodeGen.dump_code.__doc__
1960
-
1961
- # This list of dump functions is used by CodeGen.write to know which dump
1962
- # functions it has to call.
1963
- dump_fns = [dump_rs]
1964
-
1965
-
1966
-
1967
-
1968
- def get_code_generator(language, project=None, standard=None, printer = None):
1969
- if language == 'C':
1970
- if standard is None:
1971
- pass
1972
- elif standard.lower() == 'c89':
1973
- language = 'C89'
1974
- elif standard.lower() == 'c99':
1975
- language = 'C99'
1976
- CodeGenClass = {"C": CCodeGen, "C89": C89CodeGen, "C99": C99CodeGen,
1977
- "F95": FCodeGen, "JULIA": JuliaCodeGen,
1978
- "OCTAVE": OctaveCodeGen,
1979
- "RUST": RustCodeGen}.get(language.upper())
1980
- if CodeGenClass is None:
1981
- raise ValueError("Language '%s' is not supported." % language)
1982
- return CodeGenClass(project, printer)
1983
-
1984
-
1985
- #
1986
- # Friendly functions
1987
- #
1988
-
1989
-
1990
- def codegen(name_expr, language=None, prefix=None, project="project",
1991
- to_files=False, header=True, empty=True, argument_sequence=None,
1992
- global_vars=None, standard=None, code_gen=None, printer=None):
1993
- """Generate source code for expressions in a given language.
1994
-
1995
- Parameters
1996
- ==========
1997
-
1998
- name_expr : tuple, or list of tuples
1999
- A single (name, expression) tuple or a list of (name, expression)
2000
- tuples. Each tuple corresponds to a routine. If the expression is
2001
- an equality (an instance of class Equality) the left hand side is
2002
- considered an output argument. If expression is an iterable, then
2003
- the routine will have multiple outputs.
2004
-
2005
- language : string,
2006
- A string that indicates the source code language. This is case
2007
- insensitive. Currently, 'C', 'F95' and 'Octave' are supported.
2008
- 'Octave' generates code compatible with both Octave and Matlab.
2009
-
2010
- prefix : string, optional
2011
- A prefix for the names of the files that contain the source code.
2012
- Language-dependent suffixes will be appended. If omitted, the name
2013
- of the first name_expr tuple is used.
2014
-
2015
- project : string, optional
2016
- A project name, used for making unique preprocessor instructions.
2017
- [default: "project"]
2018
-
2019
- to_files : bool, optional
2020
- When True, the code will be written to one or more files with the
2021
- given prefix, otherwise strings with the names and contents of
2022
- these files are returned. [default: False]
2023
-
2024
- header : bool, optional
2025
- When True, a header is written on top of each source file.
2026
- [default: True]
2027
-
2028
- empty : bool, optional
2029
- When True, empty lines are used to structure the code.
2030
- [default: True]
2031
-
2032
- argument_sequence : iterable, optional
2033
- Sequence of arguments for the routine in a preferred order. A
2034
- CodeGenError is raised if required arguments are missing.
2035
- Redundant arguments are used without warning. If omitted,
2036
- arguments will be ordered alphabetically, but with all input
2037
- arguments first, and then output or in-out arguments.
2038
-
2039
- global_vars : iterable, optional
2040
- Sequence of global variables used by the routine. Variables
2041
- listed here will not show up as function arguments.
2042
-
2043
- standard : string, optional
2044
-
2045
- code_gen : CodeGen instance, optional
2046
- An instance of a CodeGen subclass. Overrides ``language``.
2047
-
2048
- printer : Printer instance, optional
2049
- An instance of a Printer subclass.
2050
-
2051
- Examples
2052
- ========
2053
-
2054
- >>> from sympy.utilities.codegen import codegen
2055
- >>> from sympy.abc import x, y, z
2056
- >>> [(c_name, c_code), (h_name, c_header)] = codegen(
2057
- ... ("f", x+y*z), "C89", "test", header=False, empty=False)
2058
- >>> print(c_name)
2059
- test.c
2060
- >>> print(c_code)
2061
- #include "test.h"
2062
- #include <math.h>
2063
- double f(double x, double y, double z) {
2064
- double f_result;
2065
- f_result = x + y*z;
2066
- return f_result;
2067
- }
2068
- <BLANKLINE>
2069
- >>> print(h_name)
2070
- test.h
2071
- >>> print(c_header)
2072
- #ifndef PROJECT__TEST__H
2073
- #define PROJECT__TEST__H
2074
- double f(double x, double y, double z);
2075
- #endif
2076
- <BLANKLINE>
2077
-
2078
- Another example using Equality objects to give named outputs. Here the
2079
- filename (prefix) is taken from the first (name, expr) pair.
2080
-
2081
- >>> from sympy.abc import f, g
2082
- >>> from sympy import Eq
2083
- >>> [(c_name, c_code), (h_name, c_header)] = codegen(
2084
- ... [("myfcn", x + y), ("fcn2", [Eq(f, 2*x), Eq(g, y)])],
2085
- ... "C99", header=False, empty=False)
2086
- >>> print(c_name)
2087
- myfcn.c
2088
- >>> print(c_code)
2089
- #include "myfcn.h"
2090
- #include <math.h>
2091
- double myfcn(double x, double y) {
2092
- double myfcn_result;
2093
- myfcn_result = x + y;
2094
- return myfcn_result;
2095
- }
2096
- void fcn2(double x, double y, double *f, double *g) {
2097
- (*f) = 2*x;
2098
- (*g) = y;
2099
- }
2100
- <BLANKLINE>
2101
-
2102
- If the generated function(s) will be part of a larger project where various
2103
- global variables have been defined, the 'global_vars' option can be used
2104
- to remove the specified variables from the function signature
2105
-
2106
- >>> from sympy.utilities.codegen import codegen
2107
- >>> from sympy.abc import x, y, z
2108
- >>> [(f_name, f_code), header] = codegen(
2109
- ... ("f", x+y*z), "F95", header=False, empty=False,
2110
- ... argument_sequence=(x, y), global_vars=(z,))
2111
- >>> print(f_code)
2112
- REAL*8 function f(x, y)
2113
- implicit none
2114
- REAL*8, intent(in) :: x
2115
- REAL*8, intent(in) :: y
2116
- f = x + y*z
2117
- end function
2118
- <BLANKLINE>
2119
-
2120
- """
2121
-
2122
- # Initialize the code generator.
2123
- if language is None:
2124
- if code_gen is None:
2125
- raise ValueError("Need either language or code_gen")
2126
- else:
2127
- if code_gen is not None:
2128
- raise ValueError("You cannot specify both language and code_gen.")
2129
- code_gen = get_code_generator(language, project, standard, printer)
2130
-
2131
- if isinstance(name_expr[0], str):
2132
- # single tuple is given, turn it into a singleton list with a tuple.
2133
- name_expr = [name_expr]
2134
-
2135
- if prefix is None:
2136
- prefix = name_expr[0][0]
2137
-
2138
- # Construct Routines appropriate for this code_gen from (name, expr) pairs.
2139
- routines = []
2140
- for name, expr in name_expr:
2141
- routines.append(code_gen.routine(name, expr, argument_sequence,
2142
- global_vars))
2143
-
2144
- # Write the code.
2145
- return code_gen.write(routines, prefix, to_files, header, empty)
2146
-
2147
-
2148
- def make_routine(name, expr, argument_sequence=None,
2149
- global_vars=None, language="F95"):
2150
- """A factory that makes an appropriate Routine from an expression.
2151
-
2152
- Parameters
2153
- ==========
2154
-
2155
- name : string
2156
- The name of this routine in the generated code.
2157
-
2158
- expr : expression or list/tuple of expressions
2159
- A SymPy expression that the Routine instance will represent. If
2160
- given a list or tuple of expressions, the routine will be
2161
- considered to have multiple return values and/or output arguments.
2162
-
2163
- argument_sequence : list or tuple, optional
2164
- List arguments for the routine in a preferred order. If omitted,
2165
- the results are language dependent, for example, alphabetical order
2166
- or in the same order as the given expressions.
2167
-
2168
- global_vars : iterable, optional
2169
- Sequence of global variables used by the routine. Variables
2170
- listed here will not show up as function arguments.
2171
-
2172
- language : string, optional
2173
- Specify a target language. The Routine itself should be
2174
- language-agnostic but the precise way one is created, error
2175
- checking, etc depend on the language. [default: "F95"].
2176
-
2177
- Notes
2178
- =====
2179
-
2180
- A decision about whether to use output arguments or return values is made
2181
- depending on both the language and the particular mathematical expressions.
2182
- For an expression of type Equality, the left hand side is typically made
2183
- into an OutputArgument (or perhaps an InOutArgument if appropriate).
2184
- Otherwise, typically, the calculated expression is made a return values of
2185
- the routine.
2186
-
2187
- Examples
2188
- ========
2189
-
2190
- >>> from sympy.utilities.codegen import make_routine
2191
- >>> from sympy.abc import x, y, f, g
2192
- >>> from sympy import Eq
2193
- >>> r = make_routine('test', [Eq(f, 2*x), Eq(g, x + y)])
2194
- >>> [arg.result_var for arg in r.results]
2195
- []
2196
- >>> [arg.name for arg in r.arguments]
2197
- [x, y, f, g]
2198
- >>> [arg.name for arg in r.result_variables]
2199
- [f, g]
2200
- >>> r.local_vars
2201
- set()
2202
-
2203
- Another more complicated example with a mixture of specified and
2204
- automatically-assigned names. Also has Matrix output.
2205
-
2206
- >>> from sympy import Matrix
2207
- >>> r = make_routine('fcn', [x*y, Eq(f, 1), Eq(g, x + g), Matrix([[x, 2]])])
2208
- >>> [arg.result_var for arg in r.results] # doctest: +SKIP
2209
- [result_5397460570204848505]
2210
- >>> [arg.expr for arg in r.results]
2211
- [x*y]
2212
- >>> [arg.name for arg in r.arguments] # doctest: +SKIP
2213
- [x, y, f, g, out_8598435338387848786]
2214
-
2215
- We can examine the various arguments more closely:
2216
-
2217
- >>> from sympy.utilities.codegen import (InputArgument, OutputArgument,
2218
- ... InOutArgument)
2219
- >>> [a.name for a in r.arguments if isinstance(a, InputArgument)]
2220
- [x, y]
2221
-
2222
- >>> [a.name for a in r.arguments if isinstance(a, OutputArgument)] # doctest: +SKIP
2223
- [f, out_8598435338387848786]
2224
- >>> [a.expr for a in r.arguments if isinstance(a, OutputArgument)]
2225
- [1, Matrix([[x, 2]])]
2226
-
2227
- >>> [a.name for a in r.arguments if isinstance(a, InOutArgument)]
2228
- [g]
2229
- >>> [a.expr for a in r.arguments if isinstance(a, InOutArgument)]
2230
- [g + x]
2231
-
2232
- """
2233
-
2234
- # initialize a new code generator
2235
- code_gen = get_code_generator(language)
2236
-
2237
- return code_gen.routine(name, expr, argument_sequence, global_vars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/decorator.py DELETED
@@ -1,339 +0,0 @@
1
- """Useful utility decorators. """
2
-
3
- from typing import TypeVar
4
- import sys
5
- import types
6
- import inspect
7
- from functools import wraps, update_wrapper
8
-
9
- from sympy.utilities.exceptions import sympy_deprecation_warning
10
-
11
-
12
- T = TypeVar('T')
13
- """A generic type"""
14
-
15
-
16
- def threaded_factory(func, use_add):
17
- """A factory for ``threaded`` decorators. """
18
- from sympy.core import sympify
19
- from sympy.matrices import MatrixBase
20
- from sympy.utilities.iterables import iterable
21
-
22
- @wraps(func)
23
- def threaded_func(expr, *args, **kwargs):
24
- if isinstance(expr, MatrixBase):
25
- return expr.applyfunc(lambda f: func(f, *args, **kwargs))
26
- elif iterable(expr):
27
- try:
28
- return expr.__class__([func(f, *args, **kwargs) for f in expr])
29
- except TypeError:
30
- return expr
31
- else:
32
- expr = sympify(expr)
33
-
34
- if use_add and expr.is_Add:
35
- return expr.__class__(*[ func(f, *args, **kwargs) for f in expr.args ])
36
- elif expr.is_Relational:
37
- return expr.__class__(func(expr.lhs, *args, **kwargs),
38
- func(expr.rhs, *args, **kwargs))
39
- else:
40
- return func(expr, *args, **kwargs)
41
-
42
- return threaded_func
43
-
44
-
45
- def threaded(func):
46
- """Apply ``func`` to sub--elements of an object, including :class:`~.Add`.
47
-
48
- This decorator is intended to make it uniformly possible to apply a
49
- function to all elements of composite objects, e.g. matrices, lists, tuples
50
- and other iterable containers, or just expressions.
51
-
52
- This version of :func:`threaded` decorator allows threading over
53
- elements of :class:`~.Add` class. If this behavior is not desirable
54
- use :func:`xthreaded` decorator.
55
-
56
- Functions using this decorator must have the following signature::
57
-
58
- @threaded
59
- def function(expr, *args, **kwargs):
60
-
61
- """
62
- return threaded_factory(func, True)
63
-
64
-
65
- def xthreaded(func):
66
- """Apply ``func`` to sub--elements of an object, excluding :class:`~.Add`.
67
-
68
- This decorator is intended to make it uniformly possible to apply a
69
- function to all elements of composite objects, e.g. matrices, lists, tuples
70
- and other iterable containers, or just expressions.
71
-
72
- This version of :func:`threaded` decorator disallows threading over
73
- elements of :class:`~.Add` class. If this behavior is not desirable
74
- use :func:`threaded` decorator.
75
-
76
- Functions using this decorator must have the following signature::
77
-
78
- @xthreaded
79
- def function(expr, *args, **kwargs):
80
-
81
- """
82
- return threaded_factory(func, False)
83
-
84
-
85
- def conserve_mpmath_dps(func):
86
- """After the function finishes, resets the value of ``mpmath.mp.dps`` to
87
- the value it had before the function was run."""
88
- import mpmath
89
-
90
- def func_wrapper(*args, **kwargs):
91
- dps = mpmath.mp.dps
92
- try:
93
- return func(*args, **kwargs)
94
- finally:
95
- mpmath.mp.dps = dps
96
-
97
- func_wrapper = update_wrapper(func_wrapper, func)
98
- return func_wrapper
99
-
100
-
101
- class no_attrs_in_subclass:
102
- """Don't 'inherit' certain attributes from a base class
103
-
104
- >>> from sympy.utilities.decorator import no_attrs_in_subclass
105
-
106
- >>> class A(object):
107
- ... x = 'test'
108
-
109
- >>> A.x = no_attrs_in_subclass(A, A.x)
110
-
111
- >>> class B(A):
112
- ... pass
113
-
114
- >>> hasattr(A, 'x')
115
- True
116
- >>> hasattr(B, 'x')
117
- False
118
-
119
- """
120
- def __init__(self, cls, f):
121
- self.cls = cls
122
- self.f = f
123
-
124
- def __get__(self, instance, owner=None):
125
- if owner == self.cls:
126
- if hasattr(self.f, '__get__'):
127
- return self.f.__get__(instance, owner)
128
- return self.f
129
- raise AttributeError
130
-
131
-
132
- def doctest_depends_on(exe=None, modules=None, disable_viewers=None,
133
- python_version=None, ground_types=None):
134
- """
135
- Adds metadata about the dependencies which need to be met for doctesting
136
- the docstrings of the decorated objects.
137
-
138
- ``exe`` should be a list of executables
139
-
140
- ``modules`` should be a list of modules
141
-
142
- ``disable_viewers`` should be a list of viewers for :func:`~sympy.printing.preview.preview` to disable
143
-
144
- ``python_version`` should be the minimum Python version required, as a tuple
145
- (like ``(3, 0)``)
146
- """
147
- dependencies = {}
148
- if exe is not None:
149
- dependencies['executables'] = exe
150
- if modules is not None:
151
- dependencies['modules'] = modules
152
- if disable_viewers is not None:
153
- dependencies['disable_viewers'] = disable_viewers
154
- if python_version is not None:
155
- dependencies['python_version'] = python_version
156
- if ground_types is not None:
157
- dependencies['ground_types'] = ground_types
158
-
159
- def skiptests():
160
- from sympy.testing.runtests import DependencyError, SymPyDocTests, PyTestReporter # lazy import
161
- r = PyTestReporter()
162
- t = SymPyDocTests(r, None)
163
- try:
164
- t._check_dependencies(**dependencies)
165
- except DependencyError:
166
- return True # Skip doctests
167
- else:
168
- return False # Run doctests
169
-
170
- def depends_on_deco(fn):
171
- fn._doctest_depends_on = dependencies
172
- fn.__doctest_skip__ = skiptests
173
-
174
- if inspect.isclass(fn):
175
- fn._doctest_depdends_on = no_attrs_in_subclass(
176
- fn, fn._doctest_depends_on)
177
- fn.__doctest_skip__ = no_attrs_in_subclass(
178
- fn, fn.__doctest_skip__)
179
- return fn
180
-
181
- return depends_on_deco
182
-
183
-
184
- def public(obj: T) -> T:
185
- """
186
- Append ``obj``'s name to global ``__all__`` variable (call site).
187
-
188
- By using this decorator on functions or classes you achieve the same goal
189
- as by filling ``__all__`` variables manually, you just do not have to repeat
190
- yourself (object's name). You also know if object is public at definition
191
- site, not at some random location (where ``__all__`` was set).
192
-
193
- Note that in multiple decorator setup (in almost all cases) ``@public``
194
- decorator must be applied before any other decorators, because it relies
195
- on the pointer to object's global namespace. If you apply other decorators
196
- first, ``@public`` may end up modifying the wrong namespace.
197
-
198
- Examples
199
- ========
200
-
201
- >>> from sympy.utilities.decorator import public
202
-
203
- >>> __all__ # noqa: F821
204
- Traceback (most recent call last):
205
- ...
206
- NameError: name '__all__' is not defined
207
-
208
- >>> @public
209
- ... def some_function():
210
- ... pass
211
-
212
- >>> __all__ # noqa: F821
213
- ['some_function']
214
-
215
- """
216
- if isinstance(obj, types.FunctionType):
217
- ns = obj.__globals__
218
- name = obj.__name__
219
- elif isinstance(obj, (type(type), type)):
220
- ns = sys.modules[obj.__module__].__dict__
221
- name = obj.__name__
222
- else:
223
- raise TypeError("expected a function or a class, got %s" % obj)
224
-
225
- if "__all__" not in ns:
226
- ns["__all__"] = [name]
227
- else:
228
- ns["__all__"].append(name)
229
-
230
- return obj
231
-
232
-
233
- def memoize_property(propfunc):
234
- """Property decorator that caches the value of potentially expensive
235
- ``propfunc`` after the first evaluation. The cached value is stored in
236
- the corresponding property name with an attached underscore."""
237
- attrname = '_' + propfunc.__name__
238
- sentinel = object()
239
-
240
- @wraps(propfunc)
241
- def accessor(self):
242
- val = getattr(self, attrname, sentinel)
243
- if val is sentinel:
244
- val = propfunc(self)
245
- setattr(self, attrname, val)
246
- return val
247
-
248
- return property(accessor)
249
-
250
-
251
- def deprecated(message, *, deprecated_since_version,
252
- active_deprecations_target, stacklevel=3):
253
- '''
254
- Mark a function as deprecated.
255
-
256
- This decorator should be used if an entire function or class is
257
- deprecated. If only a certain functionality is deprecated, you should use
258
- :func:`~.warns_deprecated_sympy` directly. This decorator is just a
259
- convenience. There is no functional difference between using this
260
- decorator and calling ``warns_deprecated_sympy()`` at the top of the
261
- function.
262
-
263
- The decorator takes the same arguments as
264
- :func:`~.warns_deprecated_sympy`. See its
265
- documentation for details on what the keywords to this decorator do.
266
-
267
- See the :ref:`deprecation-policy` document for details on when and how
268
- things should be deprecated in SymPy.
269
-
270
- Examples
271
- ========
272
-
273
- >>> from sympy.utilities.decorator import deprecated
274
- >>> from sympy import simplify
275
- >>> @deprecated("""\
276
- ... The simplify_this(expr) function is deprecated. Use simplify(expr)
277
- ... instead.""", deprecated_since_version="1.1",
278
- ... active_deprecations_target='simplify-this-deprecation')
279
- ... def simplify_this(expr):
280
- ... """
281
- ... Simplify ``expr``.
282
- ...
283
- ... .. deprecated:: 1.1
284
- ...
285
- ... The ``simplify_this`` function is deprecated. Use :func:`simplify`
286
- ... instead. See its documentation for more information. See
287
- ... :ref:`simplify-this-deprecation` for details.
288
- ...
289
- ... """
290
- ... return simplify(expr)
291
- >>> from sympy.abc import x
292
- >>> simplify_this(x*(x + 1) - x**2) # doctest: +SKIP
293
- <stdin>:1: SymPyDeprecationWarning:
294
- <BLANKLINE>
295
- The simplify_this(expr) function is deprecated. Use simplify(expr)
296
- instead.
297
- <BLANKLINE>
298
- See https://docs.sympy.org/latest/explanation/active-deprecations.html#simplify-this-deprecation
299
- for details.
300
- <BLANKLINE>
301
- This has been deprecated since SymPy version 1.1. It
302
- will be removed in a future version of SymPy.
303
- <BLANKLINE>
304
- simplify_this(x)
305
- x
306
-
307
- See Also
308
- ========
309
- sympy.utilities.exceptions.SymPyDeprecationWarning
310
- sympy.utilities.exceptions.sympy_deprecation_warning
311
- sympy.utilities.exceptions.ignore_warnings
312
- sympy.testing.pytest.warns_deprecated_sympy
313
-
314
- '''
315
- decorator_kwargs = {"deprecated_since_version": deprecated_since_version,
316
- "active_deprecations_target": active_deprecations_target}
317
- def deprecated_decorator(wrapped):
318
- if hasattr(wrapped, '__mro__'): # wrapped is actually a class
319
- class wrapper(wrapped):
320
- __doc__ = wrapped.__doc__
321
- __module__ = wrapped.__module__
322
- _sympy_deprecated_func = wrapped
323
- if '__new__' in wrapped.__dict__:
324
- def __new__(cls, *args, **kwargs):
325
- sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel)
326
- return super().__new__(cls, *args, **kwargs)
327
- else:
328
- def __init__(self, *args, **kwargs):
329
- sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel)
330
- super().__init__(*args, **kwargs)
331
- wrapper.__name__ = wrapped.__name__
332
- else:
333
- @wraps(wrapped)
334
- def wrapper(*args, **kwargs):
335
- sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel)
336
- return wrapped(*args, **kwargs)
337
- wrapper._sympy_deprecated_func = wrapped
338
- return wrapper
339
- return deprecated_decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/enumerative.py DELETED
@@ -1,1155 +0,0 @@
1
- """
2
- Algorithms and classes to support enumerative combinatorics.
3
-
4
- Currently just multiset partitions, but more could be added.
5
-
6
- Terminology (following Knuth, algorithm 7.1.2.5M TAOCP)
7
- *multiset* aaabbcccc has a *partition* aaabc | bccc
8
-
9
- The submultisets, aaabc and bccc of the partition are called
10
- *parts*, or sometimes *vectors*. (Knuth notes that multiset
11
- partitions can be thought of as partitions of vectors of integers,
12
- where the ith element of the vector gives the multiplicity of
13
- element i.)
14
-
15
- The values a, b and c are *components* of the multiset. These
16
- correspond to elements of a set, but in a multiset can be present
17
- with a multiplicity greater than 1.
18
-
19
- The algorithm deserves some explanation.
20
-
21
- Think of the part aaabc from the multiset above. If we impose an
22
- ordering on the components of the multiset, we can represent a part
23
- with a vector, in which the value of the first element of the vector
24
- corresponds to the multiplicity of the first component in that
25
- part. Thus, aaabc can be represented by the vector [3, 1, 1]. We
26
- can also define an ordering on parts, based on the lexicographic
27
- ordering of the vector (leftmost vector element, i.e., the element
28
- with the smallest component number, is the most significant), so
29
- that [3, 1, 1] > [3, 1, 0] and [3, 1, 1] > [2, 1, 4]. The ordering
30
- on parts can be extended to an ordering on partitions: First, sort
31
- the parts in each partition, left-to-right in decreasing order. Then
32
- partition A is greater than partition B if A's leftmost/greatest
33
- part is greater than B's leftmost part. If the leftmost parts are
34
- equal, compare the second parts, and so on.
35
-
36
- In this ordering, the greatest partition of a given multiset has only
37
- one part. The least partition is the one in which the components
38
- are spread out, one per part.
39
-
40
- The enumeration algorithms in this file yield the partitions of the
41
- argument multiset in decreasing order. The main data structure is a
42
- stack of parts, corresponding to the current partition. An
43
- important invariant is that the parts on the stack are themselves in
44
- decreasing order. This data structure is decremented to find the
45
- next smaller partition. Most often, decrementing the partition will
46
- only involve adjustments to the smallest parts at the top of the
47
- stack, much as adjacent integers *usually* differ only in their last
48
- few digits.
49
-
50
- Knuth's algorithm uses two main operations on parts:
51
-
52
- Decrement - change the part so that it is smaller in the
53
- (vector) lexicographic order, but reduced by the smallest amount possible.
54
- For example, if the multiset has vector [5,
55
- 3, 1], and the bottom/greatest part is [4, 2, 1], this part would
56
- decrement to [4, 2, 0], while [4, 0, 0] would decrement to [3, 3,
57
- 1]. A singleton part is never decremented -- [1, 0, 0] is not
58
- decremented to [0, 3, 1]. Instead, the decrement operator needs
59
- to fail for this case. In Knuth's pseudocode, the decrement
60
- operator is step m5.
61
-
62
- Spread unallocated multiplicity - Once a part has been decremented,
63
- it cannot be the rightmost part in the partition. There is some
64
- multiplicity that has not been allocated, and new parts must be
65
- created above it in the stack to use up this multiplicity. To
66
- maintain the invariant that the parts on the stack are in
67
- decreasing order, these new parts must be less than or equal to
68
- the decremented part.
69
- For example, if the multiset is [5, 3, 1], and its most
70
- significant part has just been decremented to [5, 3, 0], the
71
- spread operation will add a new part so that the stack becomes
72
- [[5, 3, 0], [0, 0, 1]]. If the most significant part (for the
73
- same multiset) has been decremented to [2, 0, 0] the stack becomes
74
- [[2, 0, 0], [2, 0, 0], [1, 3, 1]]. In the pseudocode, the spread
75
- operation for one part is step m2. The complete spread operation
76
- is a loop of steps m2 and m3.
77
-
78
- In order to facilitate the spread operation, Knuth stores, for each
79
- component of each part, not just the multiplicity of that component
80
- in the part, but also the total multiplicity available for this
81
- component in this part or any lesser part above it on the stack.
82
-
83
- One added twist is that Knuth does not represent the part vectors as
84
- arrays. Instead, he uses a sparse representation, in which a
85
- component of a part is represented as a component number (c), plus
86
- the multiplicity of the component in that part (v) as well as the
87
- total multiplicity available for that component (u). This saves
88
- time that would be spent skipping over zeros.
89
-
90
- """
91
-
92
- class PartComponent:
93
- """Internal class used in support of the multiset partitions
94
- enumerators and the associated visitor functions.
95
-
96
- Represents one component of one part of the current partition.
97
-
98
- A stack of these, plus an auxiliary frame array, f, represents a
99
- partition of the multiset.
100
-
101
- Knuth's pseudocode makes c, u, and v separate arrays.
102
- """
103
-
104
- __slots__ = ('c', 'u', 'v')
105
-
106
- def __init__(self):
107
- self.c = 0 # Component number
108
- self.u = 0 # The as yet unpartitioned amount in component c
109
- # *before* it is allocated by this triple
110
- self.v = 0 # Amount of c component in the current part
111
- # (v<=u). An invariant of the representation is
112
- # that the next higher triple for this component
113
- # (if there is one) will have a value of u-v in
114
- # its u attribute.
115
-
116
- def __repr__(self):
117
- "for debug/algorithm animation purposes"
118
- return 'c:%d u:%d v:%d' % (self.c, self.u, self.v)
119
-
120
- def __eq__(self, other):
121
- """Define value oriented equality, which is useful for testers"""
122
- return (isinstance(other, self.__class__) and
123
- self.c == other.c and
124
- self.u == other.u and
125
- self.v == other.v)
126
-
127
- def __ne__(self, other):
128
- """Defined for consistency with __eq__"""
129
- return not self == other
130
-
131
-
132
- # This function tries to be a faithful implementation of algorithm
133
- # 7.1.2.5M in Volume 4A, Combinatoral Algorithms, Part 1, of The Art
134
- # of Computer Programming, by Donald Knuth. This includes using
135
- # (mostly) the same variable names, etc. This makes for rather
136
- # low-level Python.
137
-
138
- # Changes from Knuth's pseudocode include
139
- # - use PartComponent struct/object instead of 3 arrays
140
- # - make the function a generator
141
- # - map (with some difficulty) the GOTOs to Python control structures.
142
- # - Knuth uses 1-based numbering for components, this code is 0-based
143
- # - renamed variable l to lpart.
144
- # - flag variable x takes on values True/False instead of 1/0
145
- #
146
- def multiset_partitions_taocp(multiplicities):
147
- """Enumerates partitions of a multiset.
148
-
149
- Parameters
150
- ==========
151
-
152
- multiplicities
153
- list of integer multiplicities of the components of the multiset.
154
-
155
- Yields
156
- ======
157
-
158
- state
159
- Internal data structure which encodes a particular partition.
160
- This output is then usually processed by a visitor function
161
- which combines the information from this data structure with
162
- the components themselves to produce an actual partition.
163
-
164
- Unless they wish to create their own visitor function, users will
165
- have little need to look inside this data structure. But, for
166
- reference, it is a 3-element list with components:
167
-
168
- f
169
- is a frame array, which is used to divide pstack into parts.
170
-
171
- lpart
172
- points to the base of the topmost part.
173
-
174
- pstack
175
- is an array of PartComponent objects.
176
-
177
- The ``state`` output offers a peek into the internal data
178
- structures of the enumeration function. The client should
179
- treat this as read-only; any modification of the data
180
- structure will cause unpredictable (and almost certainly
181
- incorrect) results. Also, the components of ``state`` are
182
- modified in place at each iteration. Hence, the visitor must
183
- be called at each loop iteration. Accumulating the ``state``
184
- instances and processing them later will not work.
185
-
186
- Examples
187
- ========
188
-
189
- >>> from sympy.utilities.enumerative import list_visitor
190
- >>> from sympy.utilities.enumerative import multiset_partitions_taocp
191
- >>> # variables components and multiplicities represent the multiset 'abb'
192
- >>> components = 'ab'
193
- >>> multiplicities = [1, 2]
194
- >>> states = multiset_partitions_taocp(multiplicities)
195
- >>> list(list_visitor(state, components) for state in states)
196
- [[['a', 'b', 'b']],
197
- [['a', 'b'], ['b']],
198
- [['a'], ['b', 'b']],
199
- [['a'], ['b'], ['b']]]
200
-
201
- See Also
202
- ========
203
-
204
- sympy.utilities.iterables.multiset_partitions: Takes a multiset
205
- as input and directly yields multiset partitions. It
206
- dispatches to a number of functions, including this one, for
207
- implementation. Most users will find it more convenient to
208
- use than multiset_partitions_taocp.
209
-
210
- """
211
-
212
- # Important variables.
213
- # m is the number of components, i.e., number of distinct elements
214
- m = len(multiplicities)
215
- # n is the cardinality, total number of elements whether or not distinct
216
- n = sum(multiplicities)
217
-
218
- # The main data structure, f segments pstack into parts. See
219
- # list_visitor() for example code indicating how this internal
220
- # state corresponds to a partition.
221
-
222
- # Note: allocation of space for stack is conservative. Knuth's
223
- # exercise 7.2.1.5.68 gives some indication of how to tighten this
224
- # bound, but this is not implemented.
225
- pstack = [PartComponent() for i in range(n * m + 1)]
226
- f = [0] * (n + 1)
227
-
228
- # Step M1 in Knuth (Initialize)
229
- # Initial state - entire multiset in one part.
230
- for j in range(m):
231
- ps = pstack[j]
232
- ps.c = j
233
- ps.u = multiplicities[j]
234
- ps.v = multiplicities[j]
235
-
236
- # Other variables
237
- f[0] = 0
238
- a = 0
239
- lpart = 0
240
- f[1] = m
241
- b = m # in general, current stack frame is from a to b - 1
242
-
243
- while True:
244
- while True:
245
- # Step M2 (Subtract v from u)
246
- k = b
247
- x = False
248
- for j in range(a, b):
249
- pstack[k].u = pstack[j].u - pstack[j].v
250
- if pstack[k].u == 0:
251
- x = True
252
- elif not x:
253
- pstack[k].c = pstack[j].c
254
- pstack[k].v = min(pstack[j].v, pstack[k].u)
255
- x = pstack[k].u < pstack[j].v
256
- k = k + 1
257
- else: # x is True
258
- pstack[k].c = pstack[j].c
259
- pstack[k].v = pstack[k].u
260
- k = k + 1
261
- # Note: x is True iff v has changed
262
-
263
- # Step M3 (Push if nonzero.)
264
- if k > b:
265
- a = b
266
- b = k
267
- lpart = lpart + 1
268
- f[lpart + 1] = b
269
- # Return to M2
270
- else:
271
- break # Continue to M4
272
-
273
- # M4 Visit a partition
274
- state = [f, lpart, pstack]
275
- yield state
276
-
277
- # M5 (Decrease v)
278
- while True:
279
- j = b-1
280
- while (pstack[j].v == 0):
281
- j = j - 1
282
- if j == a and pstack[j].v == 1:
283
- # M6 (Backtrack)
284
- if lpart == 0:
285
- return
286
- lpart = lpart - 1
287
- b = a
288
- a = f[lpart]
289
- # Return to M5
290
- else:
291
- pstack[j].v = pstack[j].v - 1
292
- for k in range(j + 1, b):
293
- pstack[k].v = pstack[k].u
294
- break # GOTO M2
295
-
296
- # --------------- Visitor functions for multiset partitions ---------------
297
- # A visitor takes the partition state generated by
298
- # multiset_partitions_taocp or other enumerator, and produces useful
299
- # output (such as the actual partition).
300
-
301
-
302
- def factoring_visitor(state, primes):
303
- """Use with multiset_partitions_taocp to enumerate the ways a
304
- number can be expressed as a product of factors. For this usage,
305
- the exponents of the prime factors of a number are arguments to
306
- the partition enumerator, while the corresponding prime factors
307
- are input here.
308
-
309
- Examples
310
- ========
311
-
312
- To enumerate the factorings of a number we can think of the elements of the
313
- partition as being the prime factors and the multiplicities as being their
314
- exponents.
315
-
316
- >>> from sympy.utilities.enumerative import factoring_visitor
317
- >>> from sympy.utilities.enumerative import multiset_partitions_taocp
318
- >>> from sympy import factorint
319
- >>> primes, multiplicities = zip(*factorint(24).items())
320
- >>> primes
321
- (2, 3)
322
- >>> multiplicities
323
- (3, 1)
324
- >>> states = multiset_partitions_taocp(multiplicities)
325
- >>> list(factoring_visitor(state, primes) for state in states)
326
- [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3], [6, 2, 2], [2, 2, 2, 3]]
327
- """
328
- f, lpart, pstack = state
329
- factoring = []
330
- for i in range(lpart + 1):
331
- factor = 1
332
- for ps in pstack[f[i]: f[i + 1]]:
333
- if ps.v > 0:
334
- factor *= primes[ps.c] ** ps.v
335
- factoring.append(factor)
336
- return factoring
337
-
338
-
339
- def list_visitor(state, components):
340
- """Return a list of lists to represent the partition.
341
-
342
- Examples
343
- ========
344
-
345
- >>> from sympy.utilities.enumerative import list_visitor
346
- >>> from sympy.utilities.enumerative import multiset_partitions_taocp
347
- >>> states = multiset_partitions_taocp([1, 2, 1])
348
- >>> s = next(states)
349
- >>> list_visitor(s, 'abc') # for multiset 'a b b c'
350
- [['a', 'b', 'b', 'c']]
351
- >>> s = next(states)
352
- >>> list_visitor(s, [1, 2, 3]) # for multiset '1 2 2 3
353
- [[1, 2, 2], [3]]
354
- """
355
- f, lpart, pstack = state
356
-
357
- partition = []
358
- for i in range(lpart+1):
359
- part = []
360
- for ps in pstack[f[i]:f[i+1]]:
361
- if ps.v > 0:
362
- part.extend([components[ps.c]] * ps.v)
363
- partition.append(part)
364
-
365
- return partition
366
-
367
-
368
- class MultisetPartitionTraverser():
369
- """
370
- Has methods to ``enumerate`` and ``count`` the partitions of a multiset.
371
-
372
- This implements a refactored and extended version of Knuth's algorithm
373
- 7.1.2.5M [AOCP]_."
374
-
375
- The enumeration methods of this class are generators and return
376
- data structures which can be interpreted by the same visitor
377
- functions used for the output of ``multiset_partitions_taocp``.
378
-
379
- Examples
380
- ========
381
-
382
- >>> from sympy.utilities.enumerative import MultisetPartitionTraverser
383
- >>> m = MultisetPartitionTraverser()
384
- >>> m.count_partitions([4,4,4,2])
385
- 127750
386
- >>> m.count_partitions([3,3,3])
387
- 686
388
-
389
- See Also
390
- ========
391
-
392
- multiset_partitions_taocp
393
- sympy.utilities.iterables.multiset_partitions
394
-
395
- References
396
- ==========
397
-
398
- .. [AOCP] Algorithm 7.1.2.5M in Volume 4A, Combinatoral Algorithms,
399
- Part 1, of The Art of Computer Programming, by Donald Knuth.
400
-
401
- .. [Factorisatio] On a Problem of Oppenheim concerning
402
- "Factorisatio Numerorum" E. R. Canfield, Paul Erdos, Carl
403
- Pomerance, JOURNAL OF NUMBER THEORY, Vol. 17, No. 1. August
404
- 1983. See section 7 for a description of an algorithm
405
- similar to Knuth's.
406
-
407
- .. [Yorgey] Generating Multiset Partitions, Brent Yorgey, The
408
- Monad.Reader, Issue 8, September 2007.
409
-
410
- """
411
-
412
- def __init__(self):
413
- self.debug = False
414
- # TRACING variables. These are useful for gathering
415
- # statistics on the algorithm itself, but have no particular
416
- # benefit to a user of the code.
417
- self.k1 = 0
418
- self.k2 = 0
419
- self.p1 = 0
420
- self.pstack = None
421
- self.f = None
422
- self.lpart = 0
423
- self.discarded = 0
424
- # dp_stack is list of lists of (part_key, start_count) pairs
425
- self.dp_stack = []
426
-
427
- # dp_map is map part_key-> count, where count represents the
428
- # number of multiset which are descendants of a part with this
429
- # key, **or any of its decrements**
430
-
431
- # Thus, when we find a part in the map, we add its count
432
- # value to the running total, cut off the enumeration, and
433
- # backtrack
434
-
435
- if not hasattr(self, 'dp_map'):
436
- self.dp_map = {}
437
-
438
- def db_trace(self, msg):
439
- """Useful for understanding/debugging the algorithms. Not
440
- generally activated in end-user code."""
441
- if self.debug:
442
- # XXX: animation_visitor is undefined... Clearly this does not
443
- # work and was not tested. Previous code in comments below.
444
- raise RuntimeError
445
- #letters = 'abcdefghijklmnopqrstuvwxyz'
446
- #state = [self.f, self.lpart, self.pstack]
447
- #print("DBG:", msg,
448
- # ["".join(part) for part in list_visitor(state, letters)],
449
- # animation_visitor(state))
450
-
451
- #
452
- # Helper methods for enumeration
453
- #
454
- def _initialize_enumeration(self, multiplicities):
455
- """Allocates and initializes the partition stack.
456
-
457
- This is called from the enumeration/counting routines, so
458
- there is no need to call it separately."""
459
-
460
- num_components = len(multiplicities)
461
- # cardinality is the total number of elements, whether or not distinct
462
- cardinality = sum(multiplicities)
463
-
464
- # pstack is the partition stack, which is segmented by
465
- # f into parts.
466
- self.pstack = [PartComponent() for i in
467
- range(num_components * cardinality + 1)]
468
- self.f = [0] * (cardinality + 1)
469
-
470
- # Initial state - entire multiset in one part.
471
- for j in range(num_components):
472
- ps = self.pstack[j]
473
- ps.c = j
474
- ps.u = multiplicities[j]
475
- ps.v = multiplicities[j]
476
-
477
- self.f[0] = 0
478
- self.f[1] = num_components
479
- self.lpart = 0
480
-
481
- # The decrement_part() method corresponds to step M5 in Knuth's
482
- # algorithm. This is the base version for enum_all(). Modified
483
- # versions of this method are needed if we want to restrict
484
- # sizes of the partitions produced.
485
- def decrement_part(self, part):
486
- """Decrements part (a subrange of pstack), if possible, returning
487
- True iff the part was successfully decremented.
488
-
489
- If you think of the v values in the part as a multi-digit
490
- integer (least significant digit on the right) this is
491
- basically decrementing that integer, but with the extra
492
- constraint that the leftmost digit cannot be decremented to 0.
493
-
494
- Parameters
495
- ==========
496
-
497
- part
498
- The part, represented as a list of PartComponent objects,
499
- which is to be decremented.
500
-
501
- """
502
- plen = len(part)
503
- for j in range(plen - 1, -1, -1):
504
- if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0:
505
- # found val to decrement
506
- part[j].v -= 1
507
- # Reset trailing parts back to maximum
508
- for k in range(j + 1, plen):
509
- part[k].v = part[k].u
510
- return True
511
- return False
512
-
513
- # Version to allow number of parts to be bounded from above.
514
- # Corresponds to (a modified) step M5.
515
- def decrement_part_small(self, part, ub):
516
- """Decrements part (a subrange of pstack), if possible, returning
517
- True iff the part was successfully decremented.
518
-
519
- Parameters
520
- ==========
521
-
522
- part
523
- part to be decremented (topmost part on the stack)
524
-
525
- ub
526
- the maximum number of parts allowed in a partition
527
- returned by the calling traversal.
528
-
529
- Notes
530
- =====
531
-
532
- The goal of this modification of the ordinary decrement method
533
- is to fail (meaning that the subtree rooted at this part is to
534
- be skipped) when it can be proved that this part can only have
535
- child partitions which are larger than allowed by ``ub``. If a
536
- decision is made to fail, it must be accurate, otherwise the
537
- enumeration will miss some partitions. But, it is OK not to
538
- capture all the possible failures -- if a part is passed that
539
- should not be, the resulting too-large partitions are filtered
540
- by the enumeration one level up. However, as is usual in
541
- constrained enumerations, failing early is advantageous.
542
-
543
- The tests used by this method catch the most common cases,
544
- although this implementation is by no means the last word on
545
- this problem. The tests include:
546
-
547
- 1) ``lpart`` must be less than ``ub`` by at least 2. This is because
548
- once a part has been decremented, the partition
549
- will gain at least one child in the spread step.
550
-
551
- 2) If the leading component of the part is about to be
552
- decremented, check for how many parts will be added in
553
- order to use up the unallocated multiplicity in that
554
- leading component, and fail if this number is greater than
555
- allowed by ``ub``. (See code for the exact expression.) This
556
- test is given in the answer to Knuth's problem 7.2.1.5.69.
557
-
558
- 3) If there is *exactly* enough room to expand the leading
559
- component by the above test, check the next component (if
560
- it exists) once decrementing has finished. If this has
561
- ``v == 0``, this next component will push the expansion over the
562
- limit by 1, so fail.
563
- """
564
- if self.lpart >= ub - 1:
565
- self.p1 += 1 # increment to keep track of usefulness of tests
566
- return False
567
- plen = len(part)
568
- for j in range(plen - 1, -1, -1):
569
- # Knuth's mod, (answer to problem 7.2.1.5.69)
570
- if j == 0 and (part[0].v - 1)*(ub - self.lpart) < part[0].u:
571
- self.k1 += 1
572
- return False
573
-
574
- if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0:
575
- # found val to decrement
576
- part[j].v -= 1
577
- # Reset trailing parts back to maximum
578
- for k in range(j + 1, plen):
579
- part[k].v = part[k].u
580
-
581
- # Have now decremented part, but are we doomed to
582
- # failure when it is expanded? Check one oddball case
583
- # that turns out to be surprisingly common - exactly
584
- # enough room to expand the leading component, but no
585
- # room for the second component, which has v=0.
586
- if (plen > 1 and part[1].v == 0 and
587
- (part[0].u - part[0].v) ==
588
- ((ub - self.lpart - 1) * part[0].v)):
589
- self.k2 += 1
590
- self.db_trace("Decrement fails test 3")
591
- return False
592
- return True
593
- return False
594
-
595
- def decrement_part_large(self, part, amt, lb):
596
- """Decrements part, while respecting size constraint.
597
-
598
- A part can have no children which are of sufficient size (as
599
- indicated by ``lb``) unless that part has sufficient
600
- unallocated multiplicity. When enforcing the size constraint,
601
- this method will decrement the part (if necessary) by an
602
- amount needed to ensure sufficient unallocated multiplicity.
603
-
604
- Returns True iff the part was successfully decremented.
605
-
606
- Parameters
607
- ==========
608
-
609
- part
610
- part to be decremented (topmost part on the stack)
611
-
612
- amt
613
- Can only take values 0 or 1. A value of 1 means that the
614
- part must be decremented, and then the size constraint is
615
- enforced. A value of 0 means just to enforce the ``lb``
616
- size constraint.
617
-
618
- lb
619
- The partitions produced by the calling enumeration must
620
- have more parts than this value.
621
-
622
- """
623
-
624
- if amt == 1:
625
- # In this case we always need to decrement, *before*
626
- # enforcing the "sufficient unallocated multiplicity"
627
- # constraint. Easiest for this is just to call the
628
- # regular decrement method.
629
- if not self.decrement_part(part):
630
- return False
631
-
632
- # Next, perform any needed additional decrementing to respect
633
- # "sufficient unallocated multiplicity" (or fail if this is
634
- # not possible).
635
- min_unalloc = lb - self.lpart
636
- if min_unalloc <= 0:
637
- return True
638
- total_mult = sum(pc.u for pc in part)
639
- total_alloc = sum(pc.v for pc in part)
640
- if total_mult <= min_unalloc:
641
- return False
642
-
643
- deficit = min_unalloc - (total_mult - total_alloc)
644
- if deficit <= 0:
645
- return True
646
-
647
- for i in range(len(part) - 1, -1, -1):
648
- if i == 0:
649
- if part[0].v > deficit:
650
- part[0].v -= deficit
651
- return True
652
- else:
653
- return False # This shouldn't happen, due to above check
654
- else:
655
- if part[i].v >= deficit:
656
- part[i].v -= deficit
657
- return True
658
- else:
659
- deficit -= part[i].v
660
- part[i].v = 0
661
-
662
- def decrement_part_range(self, part, lb, ub):
663
- """Decrements part (a subrange of pstack), if possible, returning
664
- True iff the part was successfully decremented.
665
-
666
- Parameters
667
- ==========
668
-
669
- part
670
- part to be decremented (topmost part on the stack)
671
-
672
- ub
673
- the maximum number of parts allowed in a partition
674
- returned by the calling traversal.
675
-
676
- lb
677
- The partitions produced by the calling enumeration must
678
- have more parts than this value.
679
-
680
- Notes
681
- =====
682
-
683
- Combines the constraints of _small and _large decrement
684
- methods. If returns success, part has been decremented at
685
- least once, but perhaps by quite a bit more if needed to meet
686
- the lb constraint.
687
- """
688
-
689
- # Constraint in the range case is just enforcing both the
690
- # constraints from _small and _large cases. Note the 0 as the
691
- # second argument to the _large call -- this is the signal to
692
- # decrement only as needed to for constraint enforcement. The
693
- # short circuiting and left-to-right order of the 'and'
694
- # operator is important for this to work correctly.
695
- return self.decrement_part_small(part, ub) and \
696
- self.decrement_part_large(part, 0, lb)
697
-
698
- def spread_part_multiplicity(self):
699
- """Returns True if a new part has been created, and
700
- adjusts pstack, f and lpart as needed.
701
-
702
- Notes
703
- =====
704
-
705
- Spreads unallocated multiplicity from the current top part
706
- into a new part created above the current on the stack. This
707
- new part is constrained to be less than or equal to the old in
708
- terms of the part ordering.
709
-
710
- This call does nothing (and returns False) if the current top
711
- part has no unallocated multiplicity.
712
-
713
- """
714
- j = self.f[self.lpart] # base of current top part
715
- k = self.f[self.lpart + 1] # ub of current; potential base of next
716
- base = k # save for later comparison
717
-
718
- changed = False # Set to true when the new part (so far) is
719
- # strictly less than (as opposed to less than
720
- # or equal) to the old.
721
- for j in range(self.f[self.lpart], self.f[self.lpart + 1]):
722
- self.pstack[k].u = self.pstack[j].u - self.pstack[j].v
723
- if self.pstack[k].u == 0:
724
- changed = True
725
- else:
726
- self.pstack[k].c = self.pstack[j].c
727
- if changed: # Put all available multiplicity in this part
728
- self.pstack[k].v = self.pstack[k].u
729
- else: # Still maintaining ordering constraint
730
- if self.pstack[k].u < self.pstack[j].v:
731
- self.pstack[k].v = self.pstack[k].u
732
- changed = True
733
- else:
734
- self.pstack[k].v = self.pstack[j].v
735
- k = k + 1
736
- if k > base:
737
- # Adjust for the new part on stack
738
- self.lpart = self.lpart + 1
739
- self.f[self.lpart + 1] = k
740
- return True
741
- return False
742
-
743
- def top_part(self):
744
- """Return current top part on the stack, as a slice of pstack.
745
-
746
- """
747
- return self.pstack[self.f[self.lpart]:self.f[self.lpart + 1]]
748
-
749
- # Same interface and functionality as multiset_partitions_taocp(),
750
- # but some might find this refactored version easier to follow.
751
- def enum_all(self, multiplicities):
752
- """Enumerate the partitions of a multiset.
753
-
754
- Examples
755
- ========
756
-
757
- >>> from sympy.utilities.enumerative import list_visitor
758
- >>> from sympy.utilities.enumerative import MultisetPartitionTraverser
759
- >>> m = MultisetPartitionTraverser()
760
- >>> states = m.enum_all([2,2])
761
- >>> list(list_visitor(state, 'ab') for state in states)
762
- [[['a', 'a', 'b', 'b']],
763
- [['a', 'a', 'b'], ['b']],
764
- [['a', 'a'], ['b', 'b']],
765
- [['a', 'a'], ['b'], ['b']],
766
- [['a', 'b', 'b'], ['a']],
767
- [['a', 'b'], ['a', 'b']],
768
- [['a', 'b'], ['a'], ['b']],
769
- [['a'], ['a'], ['b', 'b']],
770
- [['a'], ['a'], ['b'], ['b']]]
771
-
772
- See Also
773
- ========
774
-
775
- multiset_partitions_taocp:
776
- which provides the same result as this method, but is
777
- about twice as fast. Hence, enum_all is primarily useful
778
- for testing. Also see the function for a discussion of
779
- states and visitors.
780
-
781
- """
782
- self._initialize_enumeration(multiplicities)
783
- while True:
784
- while self.spread_part_multiplicity():
785
- pass
786
-
787
- # M4 Visit a partition
788
- state = [self.f, self.lpart, self.pstack]
789
- yield state
790
-
791
- # M5 (Decrease v)
792
- while not self.decrement_part(self.top_part()):
793
- # M6 (Backtrack)
794
- if self.lpart == 0:
795
- return
796
- self.lpart -= 1
797
-
798
- def enum_small(self, multiplicities, ub):
799
- """Enumerate multiset partitions with no more than ``ub`` parts.
800
-
801
- Equivalent to enum_range(multiplicities, 0, ub)
802
-
803
- Parameters
804
- ==========
805
-
806
- multiplicities
807
- list of multiplicities of the components of the multiset.
808
-
809
- ub
810
- Maximum number of parts
811
-
812
- Examples
813
- ========
814
-
815
- >>> from sympy.utilities.enumerative import list_visitor
816
- >>> from sympy.utilities.enumerative import MultisetPartitionTraverser
817
- >>> m = MultisetPartitionTraverser()
818
- >>> states = m.enum_small([2,2], 2)
819
- >>> list(list_visitor(state, 'ab') for state in states)
820
- [[['a', 'a', 'b', 'b']],
821
- [['a', 'a', 'b'], ['b']],
822
- [['a', 'a'], ['b', 'b']],
823
- [['a', 'b', 'b'], ['a']],
824
- [['a', 'b'], ['a', 'b']]]
825
-
826
- The implementation is based, in part, on the answer given to
827
- exercise 69, in Knuth [AOCP]_.
828
-
829
- See Also
830
- ========
831
-
832
- enum_all, enum_large, enum_range
833
-
834
- """
835
-
836
- # Keep track of iterations which do not yield a partition.
837
- # Clearly, we would like to keep this number small.
838
- self.discarded = 0
839
- if ub <= 0:
840
- return
841
- self._initialize_enumeration(multiplicities)
842
- while True:
843
- while self.spread_part_multiplicity():
844
- self.db_trace('spread 1')
845
- if self.lpart >= ub:
846
- self.discarded += 1
847
- self.db_trace(' Discarding')
848
- self.lpart = ub - 2
849
- break
850
- else:
851
- # M4 Visit a partition
852
- state = [self.f, self.lpart, self.pstack]
853
- yield state
854
-
855
- # M5 (Decrease v)
856
- while not self.decrement_part_small(self.top_part(), ub):
857
- self.db_trace("Failed decrement, going to backtrack")
858
- # M6 (Backtrack)
859
- if self.lpart == 0:
860
- return
861
- self.lpart -= 1
862
- self.db_trace("Backtracked to")
863
- self.db_trace("decrement ok, about to expand")
864
-
865
- def enum_large(self, multiplicities, lb):
866
- """Enumerate the partitions of a multiset with lb < num(parts)
867
-
868
- Equivalent to enum_range(multiplicities, lb, sum(multiplicities))
869
-
870
- Parameters
871
- ==========
872
-
873
- multiplicities
874
- list of multiplicities of the components of the multiset.
875
-
876
- lb
877
- Number of parts in the partition must be greater than
878
- this lower bound.
879
-
880
-
881
- Examples
882
- ========
883
-
884
- >>> from sympy.utilities.enumerative import list_visitor
885
- >>> from sympy.utilities.enumerative import MultisetPartitionTraverser
886
- >>> m = MultisetPartitionTraverser()
887
- >>> states = m.enum_large([2,2], 2)
888
- >>> list(list_visitor(state, 'ab') for state in states)
889
- [[['a', 'a'], ['b'], ['b']],
890
- [['a', 'b'], ['a'], ['b']],
891
- [['a'], ['a'], ['b', 'b']],
892
- [['a'], ['a'], ['b'], ['b']]]
893
-
894
- See Also
895
- ========
896
-
897
- enum_all, enum_small, enum_range
898
-
899
- """
900
- self.discarded = 0
901
- if lb >= sum(multiplicities):
902
- return
903
- self._initialize_enumeration(multiplicities)
904
- self.decrement_part_large(self.top_part(), 0, lb)
905
- while True:
906
- good_partition = True
907
- while self.spread_part_multiplicity():
908
- if not self.decrement_part_large(self.top_part(), 0, lb):
909
- # Failure here should be rare/impossible
910
- self.discarded += 1
911
- good_partition = False
912
- break
913
-
914
- # M4 Visit a partition
915
- if good_partition:
916
- state = [self.f, self.lpart, self.pstack]
917
- yield state
918
-
919
- # M5 (Decrease v)
920
- while not self.decrement_part_large(self.top_part(), 1, lb):
921
- # M6 (Backtrack)
922
- if self.lpart == 0:
923
- return
924
- self.lpart -= 1
925
-
926
- def enum_range(self, multiplicities, lb, ub):
927
-
928
- """Enumerate the partitions of a multiset with
929
- ``lb < num(parts) <= ub``.
930
-
931
- In particular, if partitions with exactly ``k`` parts are
932
- desired, call with ``(multiplicities, k - 1, k)``. This
933
- method generalizes enum_all, enum_small, and enum_large.
934
-
935
- Examples
936
- ========
937
-
938
- >>> from sympy.utilities.enumerative import list_visitor
939
- >>> from sympy.utilities.enumerative import MultisetPartitionTraverser
940
- >>> m = MultisetPartitionTraverser()
941
- >>> states = m.enum_range([2,2], 1, 2)
942
- >>> list(list_visitor(state, 'ab') for state in states)
943
- [[['a', 'a', 'b'], ['b']],
944
- [['a', 'a'], ['b', 'b']],
945
- [['a', 'b', 'b'], ['a']],
946
- [['a', 'b'], ['a', 'b']]]
947
-
948
- """
949
- # combine the constraints of the _large and _small
950
- # enumerations.
951
- self.discarded = 0
952
- if ub <= 0 or lb >= sum(multiplicities):
953
- return
954
- self._initialize_enumeration(multiplicities)
955
- self.decrement_part_large(self.top_part(), 0, lb)
956
- while True:
957
- good_partition = True
958
- while self.spread_part_multiplicity():
959
- self.db_trace("spread 1")
960
- if not self.decrement_part_large(self.top_part(), 0, lb):
961
- # Failure here - possible in range case?
962
- self.db_trace(" Discarding (large cons)")
963
- self.discarded += 1
964
- good_partition = False
965
- break
966
- elif self.lpart >= ub:
967
- self.discarded += 1
968
- good_partition = False
969
- self.db_trace(" Discarding small cons")
970
- self.lpart = ub - 2
971
- break
972
-
973
- # M4 Visit a partition
974
- if good_partition:
975
- state = [self.f, self.lpart, self.pstack]
976
- yield state
977
-
978
- # M5 (Decrease v)
979
- while not self.decrement_part_range(self.top_part(), lb, ub):
980
- self.db_trace("Failed decrement, going to backtrack")
981
- # M6 (Backtrack)
982
- if self.lpart == 0:
983
- return
984
- self.lpart -= 1
985
- self.db_trace("Backtracked to")
986
- self.db_trace("decrement ok, about to expand")
987
-
988
- def count_partitions_slow(self, multiplicities):
989
- """Returns the number of partitions of a multiset whose elements
990
- have the multiplicities given in ``multiplicities``.
991
-
992
- Primarily for comparison purposes. It follows the same path as
993
- enumerate, and counts, rather than generates, the partitions.
994
-
995
- See Also
996
- ========
997
-
998
- count_partitions
999
- Has the same calling interface, but is much faster.
1000
-
1001
- """
1002
- # number of partitions so far in the enumeration
1003
- self.pcount = 0
1004
- self._initialize_enumeration(multiplicities)
1005
- while True:
1006
- while self.spread_part_multiplicity():
1007
- pass
1008
-
1009
- # M4 Visit (count) a partition
1010
- self.pcount += 1
1011
-
1012
- # M5 (Decrease v)
1013
- while not self.decrement_part(self.top_part()):
1014
- # M6 (Backtrack)
1015
- if self.lpart == 0:
1016
- return self.pcount
1017
- self.lpart -= 1
1018
-
1019
- def count_partitions(self, multiplicities):
1020
- """Returns the number of partitions of a multiset whose components
1021
- have the multiplicities given in ``multiplicities``.
1022
-
1023
- For larger counts, this method is much faster than calling one
1024
- of the enumerators and counting the result. Uses dynamic
1025
- programming to cut down on the number of nodes actually
1026
- explored. The dictionary used in order to accelerate the
1027
- counting process is stored in the ``MultisetPartitionTraverser``
1028
- object and persists across calls. If the user does not
1029
- expect to call ``count_partitions`` for any additional
1030
- multisets, the object should be cleared to save memory. On
1031
- the other hand, the cache built up from one count run can
1032
- significantly speed up subsequent calls to ``count_partitions``,
1033
- so it may be advantageous not to clear the object.
1034
-
1035
- Examples
1036
- ========
1037
-
1038
- >>> from sympy.utilities.enumerative import MultisetPartitionTraverser
1039
- >>> m = MultisetPartitionTraverser()
1040
- >>> m.count_partitions([9,8,2])
1041
- 288716
1042
- >>> m.count_partitions([2,2])
1043
- 9
1044
- >>> del m
1045
-
1046
- Notes
1047
- =====
1048
-
1049
- If one looks at the workings of Knuth's algorithm M [AOCP]_, it
1050
- can be viewed as a traversal of a binary tree of parts. A
1051
- part has (up to) two children, the left child resulting from
1052
- the spread operation, and the right child from the decrement
1053
- operation. The ordinary enumeration of multiset partitions is
1054
- an in-order traversal of this tree, and with the partitions
1055
- corresponding to paths from the root to the leaves. The
1056
- mapping from paths to partitions is a little complicated,
1057
- since the partition would contain only those parts which are
1058
- leaves or the parents of a spread link, not those which are
1059
- parents of a decrement link.
1060
-
1061
- For counting purposes, it is sufficient to count leaves, and
1062
- this can be done with a recursive in-order traversal. The
1063
- number of leaves of a subtree rooted at a particular part is a
1064
- function only of that part itself, so memoizing has the
1065
- potential to speed up the counting dramatically.
1066
-
1067
- This method follows a computational approach which is similar
1068
- to the hypothetical memoized recursive function, but with two
1069
- differences:
1070
-
1071
- 1) This method is iterative, borrowing its structure from the
1072
- other enumerations and maintaining an explicit stack of
1073
- parts which are in the process of being counted. (There
1074
- may be multisets which can be counted reasonably quickly by
1075
- this implementation, but which would overflow the default
1076
- Python recursion limit with a recursive implementation.)
1077
-
1078
- 2) Instead of using the part data structure directly, a more
1079
- compact key is constructed. This saves space, but more
1080
- importantly coalesces some parts which would remain
1081
- separate with physical keys.
1082
-
1083
- Unlike the enumeration functions, there is currently no _range
1084
- version of count_partitions. If someone wants to stretch
1085
- their brain, it should be possible to construct one by
1086
- memoizing with a histogram of counts rather than a single
1087
- count, and combining the histograms.
1088
- """
1089
- # number of partitions so far in the enumeration
1090
- self.pcount = 0
1091
-
1092
- # dp_stack is list of lists of (part_key, start_count) pairs
1093
- self.dp_stack = []
1094
-
1095
- self._initialize_enumeration(multiplicities)
1096
- pkey = part_key(self.top_part())
1097
- self.dp_stack.append([(pkey, 0), ])
1098
- while True:
1099
- while self.spread_part_multiplicity():
1100
- pkey = part_key(self.top_part())
1101
- if pkey in self.dp_map:
1102
- # Already have a cached value for the count of the
1103
- # subtree rooted at this part. Add it to the
1104
- # running counter, and break out of the spread
1105
- # loop. The -1 below is to compensate for the
1106
- # leaf that this code path would otherwise find,
1107
- # and which gets incremented for below.
1108
-
1109
- self.pcount += (self.dp_map[pkey] - 1)
1110
- self.lpart -= 1
1111
- break
1112
- else:
1113
- self.dp_stack.append([(pkey, self.pcount), ])
1114
-
1115
- # M4 count a leaf partition
1116
- self.pcount += 1
1117
-
1118
- # M5 (Decrease v)
1119
- while not self.decrement_part(self.top_part()):
1120
- # M6 (Backtrack)
1121
- for key, oldcount in self.dp_stack.pop():
1122
- self.dp_map[key] = self.pcount - oldcount
1123
- if self.lpart == 0:
1124
- return self.pcount
1125
- self.lpart -= 1
1126
-
1127
- # At this point have successfully decremented the part on
1128
- # the stack and it does not appear in the cache. It needs
1129
- # to be added to the list at the top of dp_stack
1130
- pkey = part_key(self.top_part())
1131
- self.dp_stack[-1].append((pkey, self.pcount),)
1132
-
1133
-
1134
- def part_key(part):
1135
- """Helper for MultisetPartitionTraverser.count_partitions that
1136
- creates a key for ``part``, that only includes information which can
1137
- affect the count for that part. (Any irrelevant information just
1138
- reduces the effectiveness of dynamic programming.)
1139
-
1140
- Notes
1141
- =====
1142
-
1143
- This member function is a candidate for future exploration. There
1144
- are likely symmetries that can be exploited to coalesce some
1145
- ``part_key`` values, and thereby save space and improve
1146
- performance.
1147
-
1148
- """
1149
- # The component number is irrelevant for counting partitions, so
1150
- # leave it out of the memo key.
1151
- rval = []
1152
- for ps in part:
1153
- rval.append(ps.u)
1154
- rval.append(ps.v)
1155
- return tuple(rval)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/exceptions.py DELETED
@@ -1,271 +0,0 @@
1
- """
2
- General SymPy exceptions and warnings.
3
- """
4
-
5
- import warnings
6
- import contextlib
7
-
8
- from textwrap import dedent
9
-
10
-
11
- class SymPyDeprecationWarning(DeprecationWarning):
12
- r"""
13
- A warning for deprecated features of SymPy.
14
-
15
- See the :ref:`deprecation-policy` document for details on when and how
16
- things should be deprecated in SymPy.
17
-
18
- Note that simply constructing this class will not cause a warning to be
19
- issued. To do that, you must call the :func`sympy_deprecation_warning`
20
- function. For this reason, it is not recommended to ever construct this
21
- class directly.
22
-
23
- Explanation
24
- ===========
25
-
26
- The ``SymPyDeprecationWarning`` class is a subclass of
27
- ``DeprecationWarning`` that is used for all deprecations in SymPy. A
28
- special subclass is used so that we can automatically augment the warning
29
- message with additional metadata about the version the deprecation was
30
- introduced in and a link to the documentation. This also allows users to
31
- explicitly filter deprecation warnings from SymPy using ``warnings``
32
- filters (see :ref:`silencing-sympy-deprecation-warnings`).
33
-
34
- Additionally, ``SymPyDeprecationWarning`` is enabled to be shown by
35
- default, unlike normal ``DeprecationWarning``\s, which are only shown by
36
- default in interactive sessions. This ensures that deprecation warnings in
37
- SymPy will actually be seen by users.
38
-
39
- See the documentation of :func:`sympy_deprecation_warning` for a
40
- description of the parameters to this function.
41
-
42
- To mark a function as deprecated, you can use the :func:`@deprecated
43
- <sympy.utilities.decorator.deprecated>` decorator.
44
-
45
- See Also
46
- ========
47
- sympy.utilities.exceptions.sympy_deprecation_warning
48
- sympy.utilities.exceptions.ignore_warnings
49
- sympy.utilities.decorator.deprecated
50
- sympy.testing.pytest.warns_deprecated_sympy
51
-
52
- """
53
- def __init__(self, message, *, deprecated_since_version, active_deprecations_target):
54
-
55
- super().__init__(message, deprecated_since_version,
56
- active_deprecations_target)
57
- self.message = message
58
- if not isinstance(deprecated_since_version, str):
59
- raise TypeError(f"'deprecated_since_version' should be a string, got {deprecated_since_version!r}")
60
- self.deprecated_since_version = deprecated_since_version
61
- self.active_deprecations_target = active_deprecations_target
62
- if any(i in active_deprecations_target for i in '()='):
63
- raise ValueError("active_deprecations_target be the part inside of the '(...)='")
64
-
65
- self.full_message = f"""
66
-
67
- {dedent(message).strip()}
68
-
69
- See https://docs.sympy.org/latest/explanation/active-deprecations.html#{active_deprecations_target}
70
- for details.
71
-
72
- This has been deprecated since SymPy version {deprecated_since_version}. It
73
- will be removed in a future version of SymPy.
74
- """
75
-
76
- def __str__(self):
77
- return self.full_message
78
-
79
- def __repr__(self):
80
- return f"{self.__class__.__name__}({self.message!r}, deprecated_since_version={self.deprecated_since_version!r}, active_deprecations_target={self.active_deprecations_target!r})"
81
-
82
- def __eq__(self, other):
83
- return isinstance(other, SymPyDeprecationWarning) and self.args == other.args
84
-
85
- # Make pickling work. The by default, it tries to recreate the expression
86
- # from its args, but this doesn't work because of our keyword-only
87
- # arguments.
88
- @classmethod
89
- def _new(cls, message, deprecated_since_version,
90
- active_deprecations_target):
91
- return cls(message, deprecated_since_version=deprecated_since_version, active_deprecations_target=active_deprecations_target)
92
-
93
- def __reduce__(self):
94
- return (self._new, (self.message, self.deprecated_since_version, self.active_deprecations_target))
95
-
96
- # Python by default hides DeprecationWarnings, which we do not want.
97
- warnings.simplefilter("once", SymPyDeprecationWarning)
98
-
99
- def sympy_deprecation_warning(message, *, deprecated_since_version,
100
- active_deprecations_target, stacklevel=3):
101
- r'''
102
- Warn that a feature is deprecated in SymPy.
103
-
104
- See the :ref:`deprecation-policy` document for details on when and how
105
- things should be deprecated in SymPy.
106
-
107
- To mark an entire function or class as deprecated, you can use the
108
- :func:`@deprecated <sympy.utilities.decorator.deprecated>` decorator.
109
-
110
- Parameters
111
- ==========
112
-
113
- message : str
114
- The deprecation message. This may span multiple lines and contain
115
- code examples. Messages should be wrapped to 80 characters. The
116
- message is automatically dedented and leading and trailing whitespace
117
- stripped. Messages may include dynamic content based on the user
118
- input, but avoid using ``str(expression)`` if an expression can be
119
- arbitrary, as it might be huge and make the warning message
120
- unreadable.
121
-
122
- deprecated_since_version : str
123
- The version of SymPy the feature has been deprecated since. For new
124
- deprecations, this should be the version in `sympy/release.py
125
- <https://github.com/sympy/sympy/blob/master/sympy/release.py>`_
126
- without the ``.dev``. If the next SymPy version ends up being
127
- different from this, the release manager will need to update any
128
- ``SymPyDeprecationWarning``\s using the incorrect version. This
129
- argument is required and must be passed as a keyword argument.
130
- (example: ``deprecated_since_version="1.10"``).
131
-
132
- active_deprecations_target : str
133
- The Sphinx target corresponding to the section for the deprecation in
134
- the :ref:`active-deprecations` document (see
135
- ``doc/src/explanation/active-deprecations.md``). This is used to
136
- automatically generate a URL to the page in the warning message. This
137
- argument is required and must be passed as a keyword argument.
138
- (example: ``active_deprecations_target="deprecated-feature-abc"``)
139
-
140
- stacklevel : int, default: 3
141
- The ``stacklevel`` parameter that is passed to ``warnings.warn``. If
142
- you create a wrapper that calls this function, this should be
143
- increased so that the warning message shows the user line of code that
144
- produced the warning. Note that in some cases there will be multiple
145
- possible different user code paths that could result in the warning.
146
- In that case, just choose the smallest common stacklevel.
147
-
148
- Examples
149
- ========
150
-
151
- >>> from sympy.utilities.exceptions import sympy_deprecation_warning
152
- >>> def is_this_zero(x, y=0):
153
- ... """
154
- ... Determine if x = 0.
155
- ...
156
- ... Parameters
157
- ... ==========
158
- ...
159
- ... x : Expr
160
- ... The expression to check.
161
- ...
162
- ... y : Expr, optional
163
- ... If provided, check if x = y.
164
- ...
165
- ... .. deprecated:: 1.1
166
- ...
167
- ... The ``y`` argument to ``is_this_zero`` is deprecated. Use
168
- ... ``is_this_zero(x - y)`` instead.
169
- ...
170
- ... """
171
- ... from sympy import simplify
172
- ...
173
- ... if y != 0:
174
- ... sympy_deprecation_warning("""
175
- ... The y argument to is_zero() is deprecated. Use is_zero(x - y) instead.""",
176
- ... deprecated_since_version="1.1",
177
- ... active_deprecations_target='is-this-zero-y-deprecation')
178
- ... return simplify(x - y) == 0
179
- >>> is_this_zero(0)
180
- True
181
- >>> is_this_zero(1, 1) # doctest: +SKIP
182
- <stdin>:1: SymPyDeprecationWarning:
183
- <BLANKLINE>
184
- The y argument to is_zero() is deprecated. Use is_zero(x - y) instead.
185
- <BLANKLINE>
186
- See https://docs.sympy.org/latest/explanation/active-deprecations.html#is-this-zero-y-deprecation
187
- for details.
188
- <BLANKLINE>
189
- This has been deprecated since SymPy version 1.1. It
190
- will be removed in a future version of SymPy.
191
- <BLANKLINE>
192
- is_this_zero(1, 1)
193
- True
194
-
195
- See Also
196
- ========
197
-
198
- sympy.utilities.exceptions.SymPyDeprecationWarning
199
- sympy.utilities.exceptions.ignore_warnings
200
- sympy.utilities.decorator.deprecated
201
- sympy.testing.pytest.warns_deprecated_sympy
202
-
203
- '''
204
- w = SymPyDeprecationWarning(message,
205
- deprecated_since_version=deprecated_since_version,
206
- active_deprecations_target=active_deprecations_target)
207
- warnings.warn(w, stacklevel=stacklevel)
208
-
209
-
210
- @contextlib.contextmanager
211
- def ignore_warnings(warningcls):
212
- '''
213
- Context manager to suppress warnings during tests.
214
-
215
- .. note::
216
-
217
- Do not use this with SymPyDeprecationWarning in the tests.
218
- warns_deprecated_sympy() should be used instead.
219
-
220
- This function is useful for suppressing warnings during tests. The warns
221
- function should be used to assert that a warning is raised. The
222
- ignore_warnings function is useful in situation when the warning is not
223
- guaranteed to be raised (e.g. on importing a module) or if the warning
224
- comes from third-party code.
225
-
226
- This function is also useful to prevent the same or similar warnings from
227
- being issue twice due to recursive calls.
228
-
229
- When the warning is coming (reliably) from SymPy the warns function should
230
- be preferred to ignore_warnings.
231
-
232
- >>> from sympy.utilities.exceptions import ignore_warnings
233
- >>> import warnings
234
-
235
- Here's a warning:
236
-
237
- >>> with warnings.catch_warnings(): # reset warnings in doctest
238
- ... warnings.simplefilter('error')
239
- ... warnings.warn('deprecated', UserWarning)
240
- Traceback (most recent call last):
241
- ...
242
- UserWarning: deprecated
243
-
244
- Let's suppress it with ignore_warnings:
245
-
246
- >>> with warnings.catch_warnings(): # reset warnings in doctest
247
- ... warnings.simplefilter('error')
248
- ... with ignore_warnings(UserWarning):
249
- ... warnings.warn('deprecated', UserWarning)
250
-
251
- (No warning emitted)
252
-
253
- See Also
254
- ========
255
- sympy.utilities.exceptions.SymPyDeprecationWarning
256
- sympy.utilities.exceptions.sympy_deprecation_warning
257
- sympy.utilities.decorator.deprecated
258
- sympy.testing.pytest.warns_deprecated_sympy
259
-
260
- '''
261
- # Absorbs all warnings in warnrec
262
- with warnings.catch_warnings(record=True) as warnrec:
263
- # Make sure our warning doesn't get filtered
264
- warnings.simplefilter("always", warningcls)
265
- # Now run the test
266
- yield
267
-
268
- # Reissue any warnings that we aren't testing for
269
- for w in warnrec:
270
- if not issubclass(w.category, warningcls):
271
- warnings.warn_explicit(w.message, w.category, w.filename, w.lineno)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/iterables.py DELETED
@@ -1,3179 +0,0 @@
1
- from collections import Counter, defaultdict, OrderedDict
2
- from itertools import (
3
- chain, combinations, combinations_with_replacement, cycle, islice,
4
- permutations, product, groupby
5
- )
6
- # For backwards compatibility
7
- from itertools import product as cartes # noqa: F401
8
- from operator import gt
9
-
10
-
11
-
12
- # this is the logical location of these functions
13
- from sympy.utilities.enumerative import (
14
- multiset_partitions_taocp, list_visitor, MultisetPartitionTraverser)
15
-
16
- from sympy.utilities.misc import as_int
17
- from sympy.utilities.decorator import deprecated
18
-
19
-
20
- def is_palindromic(s, i=0, j=None):
21
- """
22
- Return True if the sequence is the same from left to right as it
23
- is from right to left in the whole sequence (default) or in the
24
- Python slice ``s[i: j]``; else False.
25
-
26
- Examples
27
- ========
28
-
29
- >>> from sympy.utilities.iterables import is_palindromic
30
- >>> is_palindromic([1, 0, 1])
31
- True
32
- >>> is_palindromic('abcbb')
33
- False
34
- >>> is_palindromic('abcbb', 1)
35
- False
36
-
37
- Normal Python slicing is performed in place so there is no need to
38
- create a slice of the sequence for testing:
39
-
40
- >>> is_palindromic('abcbb', 1, -1)
41
- True
42
- >>> is_palindromic('abcbb', -4, -1)
43
- True
44
-
45
- See Also
46
- ========
47
-
48
- sympy.ntheory.digits.is_palindromic: tests integers
49
-
50
- """
51
- i, j, _ = slice(i, j).indices(len(s))
52
- m = (j - i)//2
53
- # if length is odd, middle element will be ignored
54
- return all(s[i + k] == s[j - 1 - k] for k in range(m))
55
-
56
-
57
- def flatten(iterable, levels=None, cls=None): # noqa: F811
58
- """
59
- Recursively denest iterable containers.
60
-
61
- >>> from sympy import flatten
62
-
63
- >>> flatten([1, 2, 3])
64
- [1, 2, 3]
65
- >>> flatten([1, 2, [3]])
66
- [1, 2, 3]
67
- >>> flatten([1, [2, 3], [4, 5]])
68
- [1, 2, 3, 4, 5]
69
- >>> flatten([1.0, 2, (1, None)])
70
- [1.0, 2, 1, None]
71
-
72
- If you want to denest only a specified number of levels of
73
- nested containers, then set ``levels`` flag to the desired
74
- number of levels::
75
-
76
- >>> ls = [[(-2, -1), (1, 2)], [(0, 0)]]
77
-
78
- >>> flatten(ls, levels=1)
79
- [(-2, -1), (1, 2), (0, 0)]
80
-
81
- If cls argument is specified, it will only flatten instances of that
82
- class, for example:
83
-
84
- >>> from sympy import Basic, S
85
- >>> class MyOp(Basic):
86
- ... pass
87
- ...
88
- >>> flatten([MyOp(S(1), MyOp(S(2), S(3)))], cls=MyOp)
89
- [1, 2, 3]
90
-
91
- adapted from https://kogs-www.informatik.uni-hamburg.de/~meine/python_tricks
92
- """
93
- from sympy.tensor.array import NDimArray
94
- if levels is not None:
95
- if not levels:
96
- return iterable
97
- elif levels > 0:
98
- levels -= 1
99
- else:
100
- raise ValueError(
101
- "expected non-negative number of levels, got %s" % levels)
102
-
103
- if cls is None:
104
- def reducible(x):
105
- return is_sequence(x, set)
106
- else:
107
- def reducible(x):
108
- return isinstance(x, cls)
109
-
110
- result = []
111
-
112
- for el in iterable:
113
- if reducible(el):
114
- if hasattr(el, 'args') and not isinstance(el, NDimArray):
115
- el = el.args
116
- result.extend(flatten(el, levels=levels, cls=cls))
117
- else:
118
- result.append(el)
119
-
120
- return result
121
-
122
-
123
- def unflatten(iter, n=2):
124
- """Group ``iter`` into tuples of length ``n``. Raise an error if
125
- the length of ``iter`` is not a multiple of ``n``.
126
- """
127
- if n < 1 or len(iter) % n:
128
- raise ValueError('iter length is not a multiple of %i' % n)
129
- return list(zip(*(iter[i::n] for i in range(n))))
130
-
131
-
132
- def reshape(seq, how):
133
- """Reshape the sequence according to the template in ``how``.
134
-
135
- Examples
136
- ========
137
-
138
- >>> from sympy.utilities import reshape
139
- >>> seq = list(range(1, 9))
140
-
141
- >>> reshape(seq, [4]) # lists of 4
142
- [[1, 2, 3, 4], [5, 6, 7, 8]]
143
-
144
- >>> reshape(seq, (4,)) # tuples of 4
145
- [(1, 2, 3, 4), (5, 6, 7, 8)]
146
-
147
- >>> reshape(seq, (2, 2)) # tuples of 4
148
- [(1, 2, 3, 4), (5, 6, 7, 8)]
149
-
150
- >>> reshape(seq, (2, [2])) # (i, i, [i, i])
151
- [(1, 2, [3, 4]), (5, 6, [7, 8])]
152
-
153
- >>> reshape(seq, ((2,), [2])) # etc....
154
- [((1, 2), [3, 4]), ((5, 6), [7, 8])]
155
-
156
- >>> reshape(seq, (1, [2], 1))
157
- [(1, [2, 3], 4), (5, [6, 7], 8)]
158
-
159
- >>> reshape(tuple(seq), ([[1], 1, (2,)],))
160
- (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],))
161
-
162
- >>> reshape(tuple(seq), ([1], 1, (2,)))
163
- (([1], 2, (3, 4)), ([5], 6, (7, 8)))
164
-
165
- >>> reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)])
166
- [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]]
167
-
168
- """
169
- m = sum(flatten(how))
170
- n, rem = divmod(len(seq), m)
171
- if m < 0 or rem:
172
- raise ValueError('template must sum to positive number '
173
- 'that divides the length of the sequence')
174
- i = 0
175
- container = type(how)
176
- rv = [None]*n
177
- for k in range(len(rv)):
178
- _rv = []
179
- for hi in how:
180
- if isinstance(hi, int):
181
- _rv.extend(seq[i: i + hi])
182
- i += hi
183
- else:
184
- n = sum(flatten(hi))
185
- hi_type = type(hi)
186
- _rv.append(hi_type(reshape(seq[i: i + n], hi)[0]))
187
- i += n
188
- rv[k] = container(_rv)
189
- return type(seq)(rv)
190
-
191
-
192
- def group(seq, multiple=True):
193
- """
194
- Splits a sequence into a list of lists of equal, adjacent elements.
195
-
196
- Examples
197
- ========
198
-
199
- >>> from sympy import group
200
-
201
- >>> group([1, 1, 1, 2, 2, 3])
202
- [[1, 1, 1], [2, 2], [3]]
203
- >>> group([1, 1, 1, 2, 2, 3], multiple=False)
204
- [(1, 3), (2, 2), (3, 1)]
205
- >>> group([1, 1, 3, 2, 2, 1], multiple=False)
206
- [(1, 2), (3, 1), (2, 2), (1, 1)]
207
-
208
- See Also
209
- ========
210
-
211
- multiset
212
-
213
- """
214
- if multiple:
215
- return [(list(g)) for _, g in groupby(seq)]
216
- return [(k, len(list(g))) for k, g in groupby(seq)]
217
-
218
-
219
- def _iproduct2(iterable1, iterable2):
220
- '''Cartesian product of two possibly infinite iterables'''
221
-
222
- it1 = iter(iterable1)
223
- it2 = iter(iterable2)
224
-
225
- elems1 = []
226
- elems2 = []
227
-
228
- sentinel = object()
229
- def append(it, elems):
230
- e = next(it, sentinel)
231
- if e is not sentinel:
232
- elems.append(e)
233
-
234
- n = 0
235
- append(it1, elems1)
236
- append(it2, elems2)
237
-
238
- while n <= len(elems1) + len(elems2):
239
- for m in range(n-len(elems1)+1, len(elems2)):
240
- yield (elems1[n-m], elems2[m])
241
- n += 1
242
- append(it1, elems1)
243
- append(it2, elems2)
244
-
245
-
246
- def iproduct(*iterables):
247
- '''
248
- Cartesian product of iterables.
249
-
250
- Generator of the Cartesian product of iterables. This is analogous to
251
- itertools.product except that it works with infinite iterables and will
252
- yield any item from the infinite product eventually.
253
-
254
- Examples
255
- ========
256
-
257
- >>> from sympy.utilities.iterables import iproduct
258
- >>> sorted(iproduct([1,2], [3,4]))
259
- [(1, 3), (1, 4), (2, 3), (2, 4)]
260
-
261
- With an infinite iterator:
262
-
263
- >>> from sympy import S
264
- >>> (3,) in iproduct(S.Integers)
265
- True
266
- >>> (3, 4) in iproduct(S.Integers, S.Integers)
267
- True
268
-
269
- .. seealso::
270
-
271
- `itertools.product
272
- <https://docs.python.org/3/library/itertools.html#itertools.product>`_
273
- '''
274
- if len(iterables) == 0:
275
- yield ()
276
- return
277
- elif len(iterables) == 1:
278
- for e in iterables[0]:
279
- yield (e,)
280
- elif len(iterables) == 2:
281
- yield from _iproduct2(*iterables)
282
- else:
283
- first, others = iterables[0], iterables[1:]
284
- for ef, eo in _iproduct2(first, iproduct(*others)):
285
- yield (ef,) + eo
286
-
287
-
288
- def multiset(seq):
289
- """Return the hashable sequence in multiset form with values being the
290
- multiplicity of the item in the sequence.
291
-
292
- Examples
293
- ========
294
-
295
- >>> from sympy.utilities.iterables import multiset
296
- >>> multiset('mississippi')
297
- {'i': 4, 'm': 1, 'p': 2, 's': 4}
298
-
299
- See Also
300
- ========
301
-
302
- group
303
-
304
- """
305
- return dict(Counter(seq).items())
306
-
307
-
308
-
309
-
310
- def ibin(n, bits=None, str=False):
311
- """Return a list of length ``bits`` corresponding to the binary value
312
- of ``n`` with small bits to the right (last). If bits is omitted, the
313
- length will be the number required to represent ``n``. If the bits are
314
- desired in reversed order, use the ``[::-1]`` slice of the returned list.
315
-
316
- If a sequence of all bits-length lists starting from ``[0, 0,..., 0]``
317
- through ``[1, 1, ..., 1]`` are desired, pass a non-integer for bits, e.g.
318
- ``'all'``.
319
-
320
- If the bit *string* is desired pass ``str=True``.
321
-
322
- Examples
323
- ========
324
-
325
- >>> from sympy.utilities.iterables import ibin
326
- >>> ibin(2)
327
- [1, 0]
328
- >>> ibin(2, 4)
329
- [0, 0, 1, 0]
330
-
331
- If all lists corresponding to 0 to 2**n - 1, pass a non-integer
332
- for bits:
333
-
334
- >>> bits = 2
335
- >>> for i in ibin(2, 'all'):
336
- ... print(i)
337
- (0, 0)
338
- (0, 1)
339
- (1, 0)
340
- (1, 1)
341
-
342
- If a bit string is desired of a given length, use str=True:
343
-
344
- >>> n = 123
345
- >>> bits = 10
346
- >>> ibin(n, bits, str=True)
347
- '0001111011'
348
- >>> ibin(n, bits, str=True)[::-1] # small bits left
349
- '1101111000'
350
- >>> list(ibin(3, 'all', str=True))
351
- ['000', '001', '010', '011', '100', '101', '110', '111']
352
-
353
- """
354
- if n < 0:
355
- raise ValueError("negative numbers are not allowed")
356
- n = as_int(n)
357
-
358
- if bits is None:
359
- bits = 0
360
- else:
361
- try:
362
- bits = as_int(bits)
363
- except ValueError:
364
- bits = -1
365
- else:
366
- if n.bit_length() > bits:
367
- raise ValueError(
368
- "`bits` must be >= {}".format(n.bit_length()))
369
-
370
- if not str:
371
- if bits >= 0:
372
- return [1 if i == "1" else 0 for i in bin(n)[2:].rjust(bits, "0")]
373
- else:
374
- return variations(range(2), n, repetition=True)
375
- else:
376
- if bits >= 0:
377
- return bin(n)[2:].rjust(bits, "0")
378
- else:
379
- return (bin(i)[2:].rjust(n, "0") for i in range(2**n))
380
-
381
-
382
- def variations(seq, n, repetition=False):
383
- r"""Returns an iterator over the n-sized variations of ``seq`` (size N).
384
- ``repetition`` controls whether items in ``seq`` can appear more than once;
385
-
386
- Examples
387
- ========
388
-
389
- ``variations(seq, n)`` will return `\frac{N!}{(N - n)!}` permutations without
390
- repetition of ``seq``'s elements:
391
-
392
- >>> from sympy import variations
393
- >>> list(variations([1, 2], 2))
394
- [(1, 2), (2, 1)]
395
-
396
- ``variations(seq, n, True)`` will return the `N^n` permutations obtained
397
- by allowing repetition of elements:
398
-
399
- >>> list(variations([1, 2], 2, repetition=True))
400
- [(1, 1), (1, 2), (2, 1), (2, 2)]
401
-
402
- If you ask for more items than are in the set you get the empty set unless
403
- you allow repetitions:
404
-
405
- >>> list(variations([0, 1], 3, repetition=False))
406
- []
407
- >>> list(variations([0, 1], 3, repetition=True))[:4]
408
- [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1)]
409
-
410
- .. seealso::
411
-
412
- `itertools.permutations
413
- <https://docs.python.org/3/library/itertools.html#itertools.permutations>`_,
414
- `itertools.product
415
- <https://docs.python.org/3/library/itertools.html#itertools.product>`_
416
- """
417
- if not repetition:
418
- seq = tuple(seq)
419
- if len(seq) < n:
420
- return iter(()) # 0 length iterator
421
- return permutations(seq, n)
422
- else:
423
- if n == 0:
424
- return iter(((),)) # yields 1 empty tuple
425
- else:
426
- return product(seq, repeat=n)
427
-
428
-
429
- def subsets(seq, k=None, repetition=False):
430
- r"""Generates all `k`-subsets (combinations) from an `n`-element set, ``seq``.
431
-
432
- A `k`-subset of an `n`-element set is any subset of length exactly `k`. The
433
- number of `k`-subsets of an `n`-element set is given by ``binomial(n, k)``,
434
- whereas there are `2^n` subsets all together. If `k` is ``None`` then all
435
- `2^n` subsets will be returned from shortest to longest.
436
-
437
- Examples
438
- ========
439
-
440
- >>> from sympy import subsets
441
-
442
- ``subsets(seq, k)`` will return the
443
- `\frac{n!}{k!(n - k)!}` `k`-subsets (combinations)
444
- without repetition, i.e. once an item has been removed, it can no
445
- longer be "taken":
446
-
447
- >>> list(subsets([1, 2], 2))
448
- [(1, 2)]
449
- >>> list(subsets([1, 2]))
450
- [(), (1,), (2,), (1, 2)]
451
- >>> list(subsets([1, 2, 3], 2))
452
- [(1, 2), (1, 3), (2, 3)]
453
-
454
-
455
- ``subsets(seq, k, repetition=True)`` will return the
456
- `\frac{(n - 1 + k)!}{k!(n - 1)!}`
457
- combinations *with* repetition:
458
-
459
- >>> list(subsets([1, 2], 2, repetition=True))
460
- [(1, 1), (1, 2), (2, 2)]
461
-
462
- If you ask for more items than are in the set you get the empty set unless
463
- you allow repetitions:
464
-
465
- >>> list(subsets([0, 1], 3, repetition=False))
466
- []
467
- >>> list(subsets([0, 1], 3, repetition=True))
468
- [(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)]
469
-
470
- """
471
- if k is None:
472
- if not repetition:
473
- return chain.from_iterable((combinations(seq, k)
474
- for k in range(len(seq) + 1)))
475
- else:
476
- return chain.from_iterable((combinations_with_replacement(seq, k)
477
- for k in range(len(seq) + 1)))
478
- else:
479
- if not repetition:
480
- return combinations(seq, k)
481
- else:
482
- return combinations_with_replacement(seq, k)
483
-
484
-
485
- def filter_symbols(iterator, exclude):
486
- """
487
- Only yield elements from `iterator` that do not occur in `exclude`.
488
-
489
- Parameters
490
- ==========
491
-
492
- iterator : iterable
493
- iterator to take elements from
494
-
495
- exclude : iterable
496
- elements to exclude
497
-
498
- Returns
499
- =======
500
-
501
- iterator : iterator
502
- filtered iterator
503
- """
504
- exclude = set(exclude)
505
- for s in iterator:
506
- if s not in exclude:
507
- yield s
508
-
509
- def numbered_symbols(prefix='x', cls=None, start=0, exclude=(), *args, **assumptions):
510
- """
511
- Generate an infinite stream of Symbols consisting of a prefix and
512
- increasing subscripts provided that they do not occur in ``exclude``.
513
-
514
- Parameters
515
- ==========
516
-
517
- prefix : str, optional
518
- The prefix to use. By default, this function will generate symbols of
519
- the form "x0", "x1", etc.
520
-
521
- cls : class, optional
522
- The class to use. By default, it uses ``Symbol``, but you can also use ``Wild``
523
- or ``Dummy``.
524
-
525
- start : int, optional
526
- The start number. By default, it is 0.
527
-
528
- exclude : list, tuple, set of cls, optional
529
- Symbols to be excluded.
530
-
531
- *args, **kwargs
532
- Additional positional and keyword arguments are passed to the *cls* class.
533
-
534
- Returns
535
- =======
536
-
537
- sym : Symbol
538
- The subscripted symbols.
539
- """
540
- exclude = set(exclude or [])
541
- if cls is None:
542
- # We can't just make the default cls=Symbol because it isn't
543
- # imported yet.
544
- from sympy.core import Symbol
545
- cls = Symbol
546
-
547
- while True:
548
- name = '%s%s' % (prefix, start)
549
- s = cls(name, *args, **assumptions)
550
- if s not in exclude:
551
- yield s
552
- start += 1
553
-
554
-
555
- def capture(func):
556
- """Return the printed output of func().
557
-
558
- ``func`` should be a function without arguments that produces output with
559
- print statements.
560
-
561
- >>> from sympy.utilities.iterables import capture
562
- >>> from sympy import pprint
563
- >>> from sympy.abc import x
564
- >>> def foo():
565
- ... print('hello world!')
566
- ...
567
- >>> 'hello' in capture(foo) # foo, not foo()
568
- True
569
- >>> capture(lambda: pprint(2/x))
570
- '2\\n-\\nx\\n'
571
-
572
- """
573
- from io import StringIO
574
- import sys
575
-
576
- stdout = sys.stdout
577
- sys.stdout = file = StringIO()
578
- try:
579
- func()
580
- finally:
581
- sys.stdout = stdout
582
- return file.getvalue()
583
-
584
-
585
- def sift(seq, keyfunc, binary=False):
586
- """
587
- Sift the sequence, ``seq`` according to ``keyfunc``.
588
-
589
- Returns
590
- =======
591
-
592
- When ``binary`` is ``False`` (default), the output is a dictionary
593
- where elements of ``seq`` are stored in a list keyed to the value
594
- of keyfunc for that element. If ``binary`` is True then a tuple
595
- with lists ``T`` and ``F`` are returned where ``T`` is a list
596
- containing elements of seq for which ``keyfunc`` was ``True`` and
597
- ``F`` containing those elements for which ``keyfunc`` was ``False``;
598
- a ValueError is raised if the ``keyfunc`` is not binary.
599
-
600
- Examples
601
- ========
602
-
603
- >>> from sympy.utilities import sift
604
- >>> from sympy.abc import x, y
605
- >>> from sympy import sqrt, exp, pi, Tuple
606
-
607
- >>> sift(range(5), lambda x: x % 2)
608
- {0: [0, 2, 4], 1: [1, 3]}
609
-
610
- sift() returns a defaultdict() object, so any key that has no matches will
611
- give [].
612
-
613
- >>> sift([x], lambda x: x.is_commutative)
614
- {True: [x]}
615
- >>> _[False]
616
- []
617
-
618
- Sometimes you will not know how many keys you will get:
619
-
620
- >>> sift([sqrt(x), exp(x), (y**x)**2],
621
- ... lambda x: x.as_base_exp()[0])
622
- {E: [exp(x)], x: [sqrt(x)], y: [y**(2*x)]}
623
-
624
- Sometimes you expect the results to be binary; the
625
- results can be unpacked by setting ``binary`` to True:
626
-
627
- >>> sift(range(4), lambda x: x % 2, binary=True)
628
- ([1, 3], [0, 2])
629
- >>> sift(Tuple(1, pi), lambda x: x.is_rational, binary=True)
630
- ([1], [pi])
631
-
632
- A ValueError is raised if the predicate was not actually binary
633
- (which is a good test for the logic where sifting is used and
634
- binary results were expected):
635
-
636
- >>> unknown = exp(1) - pi # the rationality of this is unknown
637
- >>> args = Tuple(1, pi, unknown)
638
- >>> sift(args, lambda x: x.is_rational, binary=True)
639
- Traceback (most recent call last):
640
- ...
641
- ValueError: keyfunc gave non-binary output
642
-
643
- The non-binary sifting shows that there were 3 keys generated:
644
-
645
- >>> set(sift(args, lambda x: x.is_rational).keys())
646
- {None, False, True}
647
-
648
- If you need to sort the sifted items it might be better to use
649
- ``ordered`` which can economically apply multiple sort keys
650
- to a sequence while sorting.
651
-
652
- See Also
653
- ========
654
-
655
- ordered
656
-
657
- """
658
- if not binary:
659
- m = defaultdict(list)
660
- for i in seq:
661
- m[keyfunc(i)].append(i)
662
- return m
663
- sift = F, T = [], []
664
- for i in seq:
665
- try:
666
- sift[keyfunc(i)].append(i)
667
- except (IndexError, TypeError):
668
- raise ValueError('keyfunc gave non-binary output')
669
- return T, F
670
-
671
-
672
- def take(iter, n):
673
- """Return ``n`` items from ``iter`` iterator. """
674
- return [ value for _, value in zip(range(n), iter) ]
675
-
676
-
677
- def dict_merge(*dicts):
678
- """Merge dictionaries into a single dictionary. """
679
- merged = {}
680
-
681
- for dict in dicts:
682
- merged.update(dict)
683
-
684
- return merged
685
-
686
-
687
- def common_prefix(*seqs):
688
- """Return the subsequence that is a common start of sequences in ``seqs``.
689
-
690
- >>> from sympy.utilities.iterables import common_prefix
691
- >>> common_prefix(list(range(3)))
692
- [0, 1, 2]
693
- >>> common_prefix(list(range(3)), list(range(4)))
694
- [0, 1, 2]
695
- >>> common_prefix([1, 2, 3], [1, 2, 5])
696
- [1, 2]
697
- >>> common_prefix([1, 2, 3], [1, 3, 5])
698
- [1]
699
- """
700
- if not all(seqs):
701
- return []
702
- elif len(seqs) == 1:
703
- return seqs[0]
704
- i = 0
705
- for i in range(min(len(s) for s in seqs)):
706
- if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))):
707
- break
708
- else:
709
- i += 1
710
- return seqs[0][:i]
711
-
712
-
713
- def common_suffix(*seqs):
714
- """Return the subsequence that is a common ending of sequences in ``seqs``.
715
-
716
- >>> from sympy.utilities.iterables import common_suffix
717
- >>> common_suffix(list(range(3)))
718
- [0, 1, 2]
719
- >>> common_suffix(list(range(3)), list(range(4)))
720
- []
721
- >>> common_suffix([1, 2, 3], [9, 2, 3])
722
- [2, 3]
723
- >>> common_suffix([1, 2, 3], [9, 7, 3])
724
- [3]
725
- """
726
-
727
- if not all(seqs):
728
- return []
729
- elif len(seqs) == 1:
730
- return seqs[0]
731
- i = 0
732
- for i in range(-1, -min(len(s) for s in seqs) - 1, -1):
733
- if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))):
734
- break
735
- else:
736
- i -= 1
737
- if i == -1:
738
- return []
739
- else:
740
- return seqs[0][i + 1:]
741
-
742
-
743
- def prefixes(seq):
744
- """
745
- Generate all prefixes of a sequence.
746
-
747
- Examples
748
- ========
749
-
750
- >>> from sympy.utilities.iterables import prefixes
751
-
752
- >>> list(prefixes([1,2,3,4]))
753
- [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]]
754
-
755
- """
756
- n = len(seq)
757
-
758
- for i in range(n):
759
- yield seq[:i + 1]
760
-
761
-
762
- def postfixes(seq):
763
- """
764
- Generate all postfixes of a sequence.
765
-
766
- Examples
767
- ========
768
-
769
- >>> from sympy.utilities.iterables import postfixes
770
-
771
- >>> list(postfixes([1,2,3,4]))
772
- [[4], [3, 4], [2, 3, 4], [1, 2, 3, 4]]
773
-
774
- """
775
- n = len(seq)
776
-
777
- for i in range(n):
778
- yield seq[n - i - 1:]
779
-
780
-
781
- def topological_sort(graph, key=None):
782
- r"""
783
- Topological sort of graph's vertices.
784
-
785
- Parameters
786
- ==========
787
-
788
- graph : tuple[list, list[tuple[T, T]]
789
- A tuple consisting of a list of vertices and a list of edges of
790
- a graph to be sorted topologically.
791
-
792
- key : callable[T] (optional)
793
- Ordering key for vertices on the same level. By default the natural
794
- (e.g. lexicographic) ordering is used (in this case the base type
795
- must implement ordering relations).
796
-
797
- Examples
798
- ========
799
-
800
- Consider a graph::
801
-
802
- +---+ +---+ +---+
803
- | 7 |\ | 5 | | 3 |
804
- +---+ \ +---+ +---+
805
- | _\___/ ____ _/ |
806
- | / \___/ \ / |
807
- V V V V |
808
- +----+ +---+ |
809
- | 11 | | 8 | |
810
- +----+ +---+ |
811
- | | \____ ___/ _ |
812
- | \ \ / / \ |
813
- V \ V V / V V
814
- +---+ \ +---+ | +----+
815
- | 2 | | | 9 | | | 10 |
816
- +---+ | +---+ | +----+
817
- \________/
818
-
819
- where vertices are integers. This graph can be encoded using
820
- elementary Python's data structures as follows::
821
-
822
- >>> V = [2, 3, 5, 7, 8, 9, 10, 11]
823
- >>> E = [(7, 11), (7, 8), (5, 11), (3, 8), (3, 10),
824
- ... (11, 2), (11, 9), (11, 10), (8, 9)]
825
-
826
- To compute a topological sort for graph ``(V, E)`` issue::
827
-
828
- >>> from sympy.utilities.iterables import topological_sort
829
-
830
- >>> topological_sort((V, E))
831
- [3, 5, 7, 8, 11, 2, 9, 10]
832
-
833
- If specific tie breaking approach is needed, use ``key`` parameter::
834
-
835
- >>> topological_sort((V, E), key=lambda v: -v)
836
- [7, 5, 11, 3, 10, 8, 9, 2]
837
-
838
- Only acyclic graphs can be sorted. If the input graph has a cycle,
839
- then ``ValueError`` will be raised::
840
-
841
- >>> topological_sort((V, E + [(10, 7)]))
842
- Traceback (most recent call last):
843
- ...
844
- ValueError: cycle detected
845
-
846
- References
847
- ==========
848
-
849
- .. [1] https://en.wikipedia.org/wiki/Topological_sorting
850
-
851
- """
852
- V, E = graph
853
-
854
- L = []
855
- S = set(V)
856
- E = list(E)
857
-
858
- S.difference_update(u for v, u in E)
859
-
860
- if key is None:
861
- def key(value):
862
- return value
863
-
864
- S = sorted(S, key=key, reverse=True)
865
-
866
- while S:
867
- node = S.pop()
868
- L.append(node)
869
-
870
- for u, v in list(E):
871
- if u == node:
872
- E.remove((u, v))
873
-
874
- for _u, _v in E:
875
- if v == _v:
876
- break
877
- else:
878
- kv = key(v)
879
-
880
- for i, s in enumerate(S):
881
- ks = key(s)
882
-
883
- if kv > ks:
884
- S.insert(i, v)
885
- break
886
- else:
887
- S.append(v)
888
-
889
- if E:
890
- raise ValueError("cycle detected")
891
- else:
892
- return L
893
-
894
-
895
- def strongly_connected_components(G):
896
- r"""
897
- Strongly connected components of a directed graph in reverse topological
898
- order.
899
-
900
-
901
- Parameters
902
- ==========
903
-
904
- G : tuple[list, list[tuple[T, T]]
905
- A tuple consisting of a list of vertices and a list of edges of
906
- a graph whose strongly connected components are to be found.
907
-
908
-
909
- Examples
910
- ========
911
-
912
- Consider a directed graph (in dot notation)::
913
-
914
- digraph {
915
- A -> B
916
- A -> C
917
- B -> C
918
- C -> B
919
- B -> D
920
- }
921
-
922
- .. graphviz::
923
-
924
- digraph {
925
- A -> B
926
- A -> C
927
- B -> C
928
- C -> B
929
- B -> D
930
- }
931
-
932
- where vertices are the letters A, B, C and D. This graph can be encoded
933
- using Python's elementary data structures as follows::
934
-
935
- >>> V = ['A', 'B', 'C', 'D']
936
- >>> E = [('A', 'B'), ('A', 'C'), ('B', 'C'), ('C', 'B'), ('B', 'D')]
937
-
938
- The strongly connected components of this graph can be computed as
939
-
940
- >>> from sympy.utilities.iterables import strongly_connected_components
941
-
942
- >>> strongly_connected_components((V, E))
943
- [['D'], ['B', 'C'], ['A']]
944
-
945
- This also gives the components in reverse topological order.
946
-
947
- Since the subgraph containing B and C has a cycle they must be together in
948
- a strongly connected component. A and D are connected to the rest of the
949
- graph but not in a cyclic manner so they appear as their own strongly
950
- connected components.
951
-
952
-
953
- Notes
954
- =====
955
-
956
- The vertices of the graph must be hashable for the data structures used.
957
- If the vertices are unhashable replace them with integer indices.
958
-
959
- This function uses Tarjan's algorithm to compute the strongly connected
960
- components in `O(|V|+|E|)` (linear) time.
961
-
962
-
963
- References
964
- ==========
965
-
966
- .. [1] https://en.wikipedia.org/wiki/Strongly_connected_component
967
- .. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
968
-
969
-
970
- See Also
971
- ========
972
-
973
- sympy.utilities.iterables.connected_components
974
-
975
- """
976
- # Map from a vertex to its neighbours
977
- V, E = G
978
- Gmap = {vi: [] for vi in V}
979
- for v1, v2 in E:
980
- Gmap[v1].append(v2)
981
- return _strongly_connected_components(V, Gmap)
982
-
983
-
984
- def _strongly_connected_components(V, Gmap):
985
- """More efficient internal routine for strongly_connected_components"""
986
- #
987
- # Here V is an iterable of vertices and Gmap is a dict mapping each vertex
988
- # to a list of neighbours e.g.:
989
- #
990
- # V = [0, 1, 2, 3]
991
- # Gmap = {0: [2, 3], 1: [0]}
992
- #
993
- # For a large graph these data structures can often be created more
994
- # efficiently then those expected by strongly_connected_components() which
995
- # in this case would be
996
- #
997
- # V = [0, 1, 2, 3]
998
- # Gmap = [(0, 2), (0, 3), (1, 0)]
999
- #
1000
- # XXX: Maybe this should be the recommended function to use instead...
1001
- #
1002
-
1003
- # Non-recursive Tarjan's algorithm:
1004
- lowlink = {}
1005
- indices = {}
1006
- stack = OrderedDict()
1007
- callstack = []
1008
- components = []
1009
- nomore = object()
1010
-
1011
- def start(v):
1012
- index = len(stack)
1013
- indices[v] = lowlink[v] = index
1014
- stack[v] = None
1015
- callstack.append((v, iter(Gmap[v])))
1016
-
1017
- def finish(v1):
1018
- # Finished a component?
1019
- if lowlink[v1] == indices[v1]:
1020
- component = [stack.popitem()[0]]
1021
- while component[-1] is not v1:
1022
- component.append(stack.popitem()[0])
1023
- components.append(component[::-1])
1024
- v2, _ = callstack.pop()
1025
- if callstack:
1026
- v1, _ = callstack[-1]
1027
- lowlink[v1] = min(lowlink[v1], lowlink[v2])
1028
-
1029
- for v in V:
1030
- if v in indices:
1031
- continue
1032
- start(v)
1033
- while callstack:
1034
- v1, it1 = callstack[-1]
1035
- v2 = next(it1, nomore)
1036
- # Finished children of v1?
1037
- if v2 is nomore:
1038
- finish(v1)
1039
- # Recurse on v2
1040
- elif v2 not in indices:
1041
- start(v2)
1042
- elif v2 in stack:
1043
- lowlink[v1] = min(lowlink[v1], indices[v2])
1044
-
1045
- # Reverse topological sort order:
1046
- return components
1047
-
1048
-
1049
- def connected_components(G):
1050
- r"""
1051
- Connected components of an undirected graph or weakly connected components
1052
- of a directed graph.
1053
-
1054
-
1055
- Parameters
1056
- ==========
1057
-
1058
- G : tuple[list, list[tuple[T, T]]
1059
- A tuple consisting of a list of vertices and a list of edges of
1060
- a graph whose connected components are to be found.
1061
-
1062
-
1063
- Examples
1064
- ========
1065
-
1066
-
1067
- Given an undirected graph::
1068
-
1069
- graph {
1070
- A -- B
1071
- C -- D
1072
- }
1073
-
1074
- .. graphviz::
1075
-
1076
- graph {
1077
- A -- B
1078
- C -- D
1079
- }
1080
-
1081
- We can find the connected components using this function if we include
1082
- each edge in both directions::
1083
-
1084
- >>> from sympy.utilities.iterables import connected_components
1085
-
1086
- >>> V = ['A', 'B', 'C', 'D']
1087
- >>> E = [('A', 'B'), ('B', 'A'), ('C', 'D'), ('D', 'C')]
1088
- >>> connected_components((V, E))
1089
- [['A', 'B'], ['C', 'D']]
1090
-
1091
- The weakly connected components of a directed graph can found the same
1092
- way.
1093
-
1094
-
1095
- Notes
1096
- =====
1097
-
1098
- The vertices of the graph must be hashable for the data structures used.
1099
- If the vertices are unhashable replace them with integer indices.
1100
-
1101
- This function uses Tarjan's algorithm to compute the connected components
1102
- in `O(|V|+|E|)` (linear) time.
1103
-
1104
-
1105
- References
1106
- ==========
1107
-
1108
- .. [1] https://en.wikipedia.org/wiki/Component_%28graph_theory%29
1109
- .. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
1110
-
1111
-
1112
- See Also
1113
- ========
1114
-
1115
- sympy.utilities.iterables.strongly_connected_components
1116
-
1117
- """
1118
- # Duplicate edges both ways so that the graph is effectively undirected
1119
- # and return the strongly connected components:
1120
- V, E = G
1121
- E_undirected = []
1122
- for v1, v2 in E:
1123
- E_undirected.extend([(v1, v2), (v2, v1)])
1124
- return strongly_connected_components((V, E_undirected))
1125
-
1126
-
1127
- def rotate_left(x, y):
1128
- """
1129
- Left rotates a list x by the number of steps specified
1130
- in y.
1131
-
1132
- Examples
1133
- ========
1134
-
1135
- >>> from sympy.utilities.iterables import rotate_left
1136
- >>> a = [0, 1, 2]
1137
- >>> rotate_left(a, 1)
1138
- [1, 2, 0]
1139
- """
1140
- if len(x) == 0:
1141
- return []
1142
- y = y % len(x)
1143
- return x[y:] + x[:y]
1144
-
1145
-
1146
- def rotate_right(x, y):
1147
- """
1148
- Right rotates a list x by the number of steps specified
1149
- in y.
1150
-
1151
- Examples
1152
- ========
1153
-
1154
- >>> from sympy.utilities.iterables import rotate_right
1155
- >>> a = [0, 1, 2]
1156
- >>> rotate_right(a, 1)
1157
- [2, 0, 1]
1158
- """
1159
- if len(x) == 0:
1160
- return []
1161
- y = len(x) - y % len(x)
1162
- return x[y:] + x[:y]
1163
-
1164
-
1165
- def least_rotation(x, key=None):
1166
- '''
1167
- Returns the number of steps of left rotation required to
1168
- obtain lexicographically minimal string/list/tuple, etc.
1169
-
1170
- Examples
1171
- ========
1172
-
1173
- >>> from sympy.utilities.iterables import least_rotation, rotate_left
1174
- >>> a = [3, 1, 5, 1, 2]
1175
- >>> least_rotation(a)
1176
- 3
1177
- >>> rotate_left(a, _)
1178
- [1, 2, 3, 1, 5]
1179
-
1180
- References
1181
- ==========
1182
-
1183
- .. [1] https://en.wikipedia.org/wiki/Lexicographically_minimal_string_rotation
1184
-
1185
- '''
1186
- from sympy.functions.elementary.miscellaneous import Id
1187
- if key is None: key = Id
1188
- S = x + x # Concatenate string to it self to avoid modular arithmetic
1189
- f = [-1] * len(S) # Failure function
1190
- k = 0 # Least rotation of string found so far
1191
- for j in range(1,len(S)):
1192
- sj = S[j]
1193
- i = f[j-k-1]
1194
- while i != -1 and sj != S[k+i+1]:
1195
- if key(sj) < key(S[k+i+1]):
1196
- k = j-i-1
1197
- i = f[i]
1198
- if sj != S[k+i+1]:
1199
- if key(sj) < key(S[k]):
1200
- k = j
1201
- f[j-k] = -1
1202
- else:
1203
- f[j-k] = i+1
1204
- return k
1205
-
1206
-
1207
- def multiset_combinations(m, n, g=None):
1208
- """
1209
- Return the unique combinations of size ``n`` from multiset ``m``.
1210
-
1211
- Examples
1212
- ========
1213
-
1214
- >>> from sympy.utilities.iterables import multiset_combinations
1215
- >>> from itertools import combinations
1216
- >>> [''.join(i) for i in multiset_combinations('baby', 3)]
1217
- ['abb', 'aby', 'bby']
1218
-
1219
- >>> def count(f, s): return len(list(f(s, 3)))
1220
-
1221
- The number of combinations depends on the number of letters; the
1222
- number of unique combinations depends on how the letters are
1223
- repeated.
1224
-
1225
- >>> s1 = 'abracadabra'
1226
- >>> s2 = 'banana tree'
1227
- >>> count(combinations, s1), count(multiset_combinations, s1)
1228
- (165, 23)
1229
- >>> count(combinations, s2), count(multiset_combinations, s2)
1230
- (165, 54)
1231
-
1232
- """
1233
- from sympy.core.sorting import ordered
1234
- if g is None:
1235
- if isinstance(m, dict):
1236
- if any(as_int(v) < 0 for v in m.values()):
1237
- raise ValueError('counts cannot be negative')
1238
- N = sum(m.values())
1239
- if n > N:
1240
- return
1241
- g = [[k, m[k]] for k in ordered(m)]
1242
- else:
1243
- m = list(m)
1244
- N = len(m)
1245
- if n > N:
1246
- return
1247
- try:
1248
- m = multiset(m)
1249
- g = [(k, m[k]) for k in ordered(m)]
1250
- except TypeError:
1251
- m = list(ordered(m))
1252
- g = [list(i) for i in group(m, multiple=False)]
1253
- del m
1254
- else:
1255
- # not checking counts since g is intended for internal use
1256
- N = sum(v for k, v in g)
1257
- if n > N or not n:
1258
- yield []
1259
- else:
1260
- for i, (k, v) in enumerate(g):
1261
- if v >= n:
1262
- yield [k]*n
1263
- v = n - 1
1264
- for v in range(min(n, v), 0, -1):
1265
- for j in multiset_combinations(None, n - v, g[i + 1:]):
1266
- rv = [k]*v + j
1267
- if len(rv) == n:
1268
- yield rv
1269
-
1270
- def multiset_permutations(m, size=None, g=None):
1271
- """
1272
- Return the unique permutations of multiset ``m``.
1273
-
1274
- Examples
1275
- ========
1276
-
1277
- >>> from sympy.utilities.iterables import multiset_permutations
1278
- >>> from sympy import factorial
1279
- >>> [''.join(i) for i in multiset_permutations('aab')]
1280
- ['aab', 'aba', 'baa']
1281
- >>> factorial(len('banana'))
1282
- 720
1283
- >>> len(list(multiset_permutations('banana')))
1284
- 60
1285
- """
1286
- from sympy.core.sorting import ordered
1287
- if g is None:
1288
- if isinstance(m, dict):
1289
- if any(as_int(v) < 0 for v in m.values()):
1290
- raise ValueError('counts cannot be negative')
1291
- g = [[k, m[k]] for k in ordered(m)]
1292
- else:
1293
- m = list(ordered(m))
1294
- g = [list(i) for i in group(m, multiple=False)]
1295
- del m
1296
- do = [gi for gi in g if gi[1] > 0]
1297
- SUM = sum(gi[1] for gi in do)
1298
- if not do or size is not None and (size > SUM or size < 1):
1299
- if not do and size is None or size == 0:
1300
- yield []
1301
- return
1302
- elif size == 1:
1303
- for k, v in do:
1304
- yield [k]
1305
- elif len(do) == 1:
1306
- k, v = do[0]
1307
- v = v if size is None else (size if size <= v else 0)
1308
- yield [k for i in range(v)]
1309
- elif all(v == 1 for k, v in do):
1310
- for p in permutations([k for k, v in do], size):
1311
- yield list(p)
1312
- else:
1313
- size = size if size is not None else SUM
1314
- for i, (k, v) in enumerate(do):
1315
- do[i][1] -= 1
1316
- for j in multiset_permutations(None, size - 1, do):
1317
- if j:
1318
- yield [k] + j
1319
- do[i][1] += 1
1320
-
1321
-
1322
- def _partition(seq, vector, m=None):
1323
- """
1324
- Return the partition of seq as specified by the partition vector.
1325
-
1326
- Examples
1327
- ========
1328
-
1329
- >>> from sympy.utilities.iterables import _partition
1330
- >>> _partition('abcde', [1, 0, 1, 2, 0])
1331
- [['b', 'e'], ['a', 'c'], ['d']]
1332
-
1333
- Specifying the number of bins in the partition is optional:
1334
-
1335
- >>> _partition('abcde', [1, 0, 1, 2, 0], 3)
1336
- [['b', 'e'], ['a', 'c'], ['d']]
1337
-
1338
- The output of _set_partitions can be passed as follows:
1339
-
1340
- >>> output = (3, [1, 0, 1, 2, 0])
1341
- >>> _partition('abcde', *output)
1342
- [['b', 'e'], ['a', 'c'], ['d']]
1343
-
1344
- See Also
1345
- ========
1346
-
1347
- combinatorics.partitions.Partition.from_rgs
1348
-
1349
- """
1350
- if m is None:
1351
- m = max(vector) + 1
1352
- elif isinstance(vector, int): # entered as m, vector
1353
- vector, m = m, vector
1354
- p = [[] for i in range(m)]
1355
- for i, v in enumerate(vector):
1356
- p[v].append(seq[i])
1357
- return p
1358
-
1359
-
1360
- def _set_partitions(n):
1361
- """Cycle through all partitions of n elements, yielding the
1362
- current number of partitions, ``m``, and a mutable list, ``q``
1363
- such that ``element[i]`` is in part ``q[i]`` of the partition.
1364
-
1365
- NOTE: ``q`` is modified in place and generally should not be changed
1366
- between function calls.
1367
-
1368
- Examples
1369
- ========
1370
-
1371
- >>> from sympy.utilities.iterables import _set_partitions, _partition
1372
- >>> for m, q in _set_partitions(3):
1373
- ... print('%s %s %s' % (m, q, _partition('abc', q, m)))
1374
- 1 [0, 0, 0] [['a', 'b', 'c']]
1375
- 2 [0, 0, 1] [['a', 'b'], ['c']]
1376
- 2 [0, 1, 0] [['a', 'c'], ['b']]
1377
- 2 [0, 1, 1] [['a'], ['b', 'c']]
1378
- 3 [0, 1, 2] [['a'], ['b'], ['c']]
1379
-
1380
- Notes
1381
- =====
1382
-
1383
- This algorithm is similar to, and solves the same problem as,
1384
- Algorithm 7.2.1.5H, from volume 4A of Knuth's The Art of Computer
1385
- Programming. Knuth uses the term "restricted growth string" where
1386
- this code refers to a "partition vector". In each case, the meaning is
1387
- the same: the value in the ith element of the vector specifies to
1388
- which part the ith set element is to be assigned.
1389
-
1390
- At the lowest level, this code implements an n-digit big-endian
1391
- counter (stored in the array q) which is incremented (with carries) to
1392
- get the next partition in the sequence. A special twist is that a
1393
- digit is constrained to be at most one greater than the maximum of all
1394
- the digits to the left of it. The array p maintains this maximum, so
1395
- that the code can efficiently decide when a digit can be incremented
1396
- in place or whether it needs to be reset to 0 and trigger a carry to
1397
- the next digit. The enumeration starts with all the digits 0 (which
1398
- corresponds to all the set elements being assigned to the same 0th
1399
- part), and ends with 0123...n, which corresponds to each set element
1400
- being assigned to a different, singleton, part.
1401
-
1402
- This routine was rewritten to use 0-based lists while trying to
1403
- preserve the beauty and efficiency of the original algorithm.
1404
-
1405
- References
1406
- ==========
1407
-
1408
- .. [1] Nijenhuis, Albert and Wilf, Herbert. (1978) Combinatorial Algorithms,
1409
- 2nd Ed, p 91, algorithm "nexequ". Available online from
1410
- https://www.math.upenn.edu/~wilf/website/CombAlgDownld.html (viewed
1411
- November 17, 2012).
1412
-
1413
- """
1414
- p = [0]*n
1415
- q = [0]*n
1416
- nc = 1
1417
- yield nc, q
1418
- while nc != n:
1419
- m = n
1420
- while 1:
1421
- m -= 1
1422
- i = q[m]
1423
- if p[i] != 1:
1424
- break
1425
- q[m] = 0
1426
- i += 1
1427
- q[m] = i
1428
- m += 1
1429
- nc += m - n
1430
- p[0] += n - m
1431
- if i == nc:
1432
- p[nc] = 0
1433
- nc += 1
1434
- p[i - 1] -= 1
1435
- p[i] += 1
1436
- yield nc, q
1437
-
1438
-
1439
- def multiset_partitions(multiset, m=None):
1440
- """
1441
- Return unique partitions of the given multiset (in list form).
1442
- If ``m`` is None, all multisets will be returned, otherwise only
1443
- partitions with ``m`` parts will be returned.
1444
-
1445
- If ``multiset`` is an integer, a range [0, 1, ..., multiset - 1]
1446
- will be supplied.
1447
-
1448
- Examples
1449
- ========
1450
-
1451
- >>> from sympy.utilities.iterables import multiset_partitions
1452
- >>> list(multiset_partitions([1, 2, 3, 4], 2))
1453
- [[[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]],
1454
- [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]],
1455
- [[1], [2, 3, 4]]]
1456
- >>> list(multiset_partitions([1, 2, 3, 4], 1))
1457
- [[[1, 2, 3, 4]]]
1458
-
1459
- Only unique partitions are returned and these will be returned in a
1460
- canonical order regardless of the order of the input:
1461
-
1462
- >>> a = [1, 2, 2, 1]
1463
- >>> ans = list(multiset_partitions(a, 2))
1464
- >>> a.sort()
1465
- >>> list(multiset_partitions(a, 2)) == ans
1466
- True
1467
- >>> a = range(3, 1, -1)
1468
- >>> (list(multiset_partitions(a)) ==
1469
- ... list(multiset_partitions(sorted(a))))
1470
- True
1471
-
1472
- If m is omitted then all partitions will be returned:
1473
-
1474
- >>> list(multiset_partitions([1, 1, 2]))
1475
- [[[1, 1, 2]], [[1, 1], [2]], [[1, 2], [1]], [[1], [1], [2]]]
1476
- >>> list(multiset_partitions([1]*3))
1477
- [[[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]]
1478
-
1479
- Counting
1480
- ========
1481
-
1482
- The number of partitions of a set is given by the bell number:
1483
-
1484
- >>> from sympy import bell
1485
- >>> len(list(multiset_partitions(5))) == bell(5) == 52
1486
- True
1487
-
1488
- The number of partitions of length k from a set of size n is given by the
1489
- Stirling Number of the 2nd kind:
1490
-
1491
- >>> from sympy.functions.combinatorial.numbers import stirling
1492
- >>> stirling(5, 2) == len(list(multiset_partitions(5, 2))) == 15
1493
- True
1494
-
1495
- These comments on counting apply to *sets*, not multisets.
1496
-
1497
- Notes
1498
- =====
1499
-
1500
- When all the elements are the same in the multiset, the order
1501
- of the returned partitions is determined by the ``partitions``
1502
- routine. If one is counting partitions then it is better to use
1503
- the ``nT`` function.
1504
-
1505
- See Also
1506
- ========
1507
-
1508
- partitions
1509
- sympy.combinatorics.partitions.Partition
1510
- sympy.combinatorics.partitions.IntegerPartition
1511
- sympy.functions.combinatorial.numbers.nT
1512
-
1513
- """
1514
- # This function looks at the supplied input and dispatches to
1515
- # several special-case routines as they apply.
1516
- if isinstance(multiset, int):
1517
- n = multiset
1518
- if m and m > n:
1519
- return
1520
- multiset = list(range(n))
1521
- if m == 1:
1522
- yield [multiset[:]]
1523
- return
1524
-
1525
- # If m is not None, it can sometimes be faster to use
1526
- # MultisetPartitionTraverser.enum_range() even for inputs
1527
- # which are sets. Since the _set_partitions code is quite
1528
- # fast, this is only advantageous when the overall set
1529
- # partitions outnumber those with the desired number of parts
1530
- # by a large factor. (At least 60.) Such a switch is not
1531
- # currently implemented.
1532
- for nc, q in _set_partitions(n):
1533
- if m is None or nc == m:
1534
- rv = [[] for i in range(nc)]
1535
- for i in range(n):
1536
- rv[q[i]].append(multiset[i])
1537
- yield rv
1538
- return
1539
-
1540
- if len(multiset) == 1 and isinstance(multiset, str):
1541
- multiset = [multiset]
1542
-
1543
- if not has_variety(multiset):
1544
- # Only one component, repeated n times. The resulting
1545
- # partitions correspond to partitions of integer n.
1546
- n = len(multiset)
1547
- if m and m > n:
1548
- return
1549
- if m == 1:
1550
- yield [multiset[:]]
1551
- return
1552
- x = multiset[:1]
1553
- for size, p in partitions(n, m, size=True):
1554
- if m is None or size == m:
1555
- rv = []
1556
- for k in sorted(p):
1557
- rv.extend([x*k]*p[k])
1558
- yield rv
1559
- else:
1560
- from sympy.core.sorting import ordered
1561
- multiset = list(ordered(multiset))
1562
- n = len(multiset)
1563
- if m and m > n:
1564
- return
1565
- if m == 1:
1566
- yield [multiset[:]]
1567
- return
1568
-
1569
- # Split the information of the multiset into two lists -
1570
- # one of the elements themselves, and one (of the same length)
1571
- # giving the number of repeats for the corresponding element.
1572
- elements, multiplicities = zip(*group(multiset, False))
1573
-
1574
- if len(elements) < len(multiset):
1575
- # General case - multiset with more than one distinct element
1576
- # and at least one element repeated more than once.
1577
- if m:
1578
- mpt = MultisetPartitionTraverser()
1579
- for state in mpt.enum_range(multiplicities, m-1, m):
1580
- yield list_visitor(state, elements)
1581
- else:
1582
- for state in multiset_partitions_taocp(multiplicities):
1583
- yield list_visitor(state, elements)
1584
- else:
1585
- # Set partitions case - no repeated elements. Pretty much
1586
- # same as int argument case above, with same possible, but
1587
- # currently unimplemented optimization for some cases when
1588
- # m is not None
1589
- for nc, q in _set_partitions(n):
1590
- if m is None or nc == m:
1591
- rv = [[] for i in range(nc)]
1592
- for i in range(n):
1593
- rv[q[i]].append(i)
1594
- yield [[multiset[j] for j in i] for i in rv]
1595
-
1596
-
1597
- def partitions(n, m=None, k=None, size=False):
1598
- """Generate all partitions of positive integer, n.
1599
-
1600
- Each partition is represented as a dictionary, mapping an integer
1601
- to the number of copies of that integer in the partition. For example,
1602
- the first partition of 4 returned is {4: 1}, "4: one of them".
1603
-
1604
- Parameters
1605
- ==========
1606
- n : int
1607
- m : int, optional
1608
- limits number of parts in partition (mnemonic: m, maximum parts)
1609
- k : int, optional
1610
- limits the numbers that are kept in the partition (mnemonic: k, keys)
1611
- size : bool, default: False
1612
- If ``True``, (M, P) is returned where M is the sum of the
1613
- multiplicities and P is the generated partition.
1614
- If ``False``, only the generated partition is returned.
1615
-
1616
- Examples
1617
- ========
1618
-
1619
- >>> from sympy.utilities.iterables import partitions
1620
-
1621
- The numbers appearing in the partition (the key of the returned dict)
1622
- are limited with k:
1623
-
1624
- >>> for p in partitions(6, k=2): # doctest: +SKIP
1625
- ... print(p)
1626
- {2: 3}
1627
- {1: 2, 2: 2}
1628
- {1: 4, 2: 1}
1629
- {1: 6}
1630
-
1631
- The maximum number of parts in the partition (the sum of the values in
1632
- the returned dict) are limited with m (default value, None, gives
1633
- partitions from 1 through n):
1634
-
1635
- >>> for p in partitions(6, m=2): # doctest: +SKIP
1636
- ... print(p)
1637
- ...
1638
- {6: 1}
1639
- {1: 1, 5: 1}
1640
- {2: 1, 4: 1}
1641
- {3: 2}
1642
-
1643
- References
1644
- ==========
1645
-
1646
- .. [1] modified from Tim Peter's version to allow for k and m values:
1647
- https://code.activestate.com/recipes/218332-generator-for-integer-partitions/
1648
-
1649
- See Also
1650
- ========
1651
-
1652
- sympy.combinatorics.partitions.Partition
1653
- sympy.combinatorics.partitions.IntegerPartition
1654
-
1655
- """
1656
- if (n <= 0 or
1657
- m is not None and m < 1 or
1658
- k is not None and k < 1 or
1659
- m and k and m*k < n):
1660
- # the empty set is the only way to handle these inputs
1661
- # and returning {} to represent it is consistent with
1662
- # the counting convention, e.g. nT(0) == 1.
1663
- if size:
1664
- yield 0, {}
1665
- else:
1666
- yield {}
1667
- return
1668
-
1669
- if m is None:
1670
- m = n
1671
- else:
1672
- m = min(m, n)
1673
- k = min(k or n, n)
1674
-
1675
- n, m, k = as_int(n), as_int(m), as_int(k)
1676
- q, r = divmod(n, k)
1677
- ms = {k: q}
1678
- keys = [k] # ms.keys(), from largest to smallest
1679
- if r:
1680
- ms[r] = 1
1681
- keys.append(r)
1682
- room = m - q - bool(r)
1683
- if size:
1684
- yield sum(ms.values()), ms.copy()
1685
- else:
1686
- yield ms.copy()
1687
-
1688
- while keys != [1]:
1689
- # Reuse any 1's.
1690
- if keys[-1] == 1:
1691
- del keys[-1]
1692
- reuse = ms.pop(1)
1693
- room += reuse
1694
- else:
1695
- reuse = 0
1696
-
1697
- while 1:
1698
- # Let i be the smallest key larger than 1. Reuse one
1699
- # instance of i.
1700
- i = keys[-1]
1701
- newcount = ms[i] = ms[i] - 1
1702
- reuse += i
1703
- if newcount == 0:
1704
- del keys[-1], ms[i]
1705
- room += 1
1706
-
1707
- # Break the remainder into pieces of size i-1.
1708
- i -= 1
1709
- q, r = divmod(reuse, i)
1710
- need = q + bool(r)
1711
- if need > room:
1712
- if not keys:
1713
- return
1714
- continue
1715
-
1716
- ms[i] = q
1717
- keys.append(i)
1718
- if r:
1719
- ms[r] = 1
1720
- keys.append(r)
1721
- break
1722
- room -= need
1723
- if size:
1724
- yield sum(ms.values()), ms.copy()
1725
- else:
1726
- yield ms.copy()
1727
-
1728
-
1729
- def ordered_partitions(n, m=None, sort=True):
1730
- """Generates ordered partitions of integer *n*.
1731
-
1732
- Parameters
1733
- ==========
1734
- n : int
1735
- m : int, optional
1736
- The default value gives partitions of all sizes else only
1737
- those with size m. In addition, if *m* is not None then
1738
- partitions are generated *in place* (see examples).
1739
- sort : bool, default: True
1740
- Controls whether partitions are
1741
- returned in sorted order when *m* is not None; when False,
1742
- the partitions are returned as fast as possible with elements
1743
- sorted, but when m|n the partitions will not be in
1744
- ascending lexicographical order.
1745
-
1746
- Examples
1747
- ========
1748
-
1749
- >>> from sympy.utilities.iterables import ordered_partitions
1750
-
1751
- All partitions of 5 in ascending lexicographical:
1752
-
1753
- >>> for p in ordered_partitions(5):
1754
- ... print(p)
1755
- [1, 1, 1, 1, 1]
1756
- [1, 1, 1, 2]
1757
- [1, 1, 3]
1758
- [1, 2, 2]
1759
- [1, 4]
1760
- [2, 3]
1761
- [5]
1762
-
1763
- Only partitions of 5 with two parts:
1764
-
1765
- >>> for p in ordered_partitions(5, 2):
1766
- ... print(p)
1767
- [1, 4]
1768
- [2, 3]
1769
-
1770
- When ``m`` is given, a given list objects will be used more than
1771
- once for speed reasons so you will not see the correct partitions
1772
- unless you make a copy of each as it is generated:
1773
-
1774
- >>> [p for p in ordered_partitions(7, 3)]
1775
- [[1, 1, 1], [1, 1, 1], [1, 1, 1], [2, 2, 2]]
1776
- >>> [list(p) for p in ordered_partitions(7, 3)]
1777
- [[1, 1, 5], [1, 2, 4], [1, 3, 3], [2, 2, 3]]
1778
-
1779
- When ``n`` is a multiple of ``m``, the elements are still sorted
1780
- but the partitions themselves will be *unordered* if sort is False;
1781
- the default is to return them in ascending lexicographical order.
1782
-
1783
- >>> for p in ordered_partitions(6, 2):
1784
- ... print(p)
1785
- [1, 5]
1786
- [2, 4]
1787
- [3, 3]
1788
-
1789
- But if speed is more important than ordering, sort can be set to
1790
- False:
1791
-
1792
- >>> for p in ordered_partitions(6, 2, sort=False):
1793
- ... print(p)
1794
- [1, 5]
1795
- [3, 3]
1796
- [2, 4]
1797
-
1798
- References
1799
- ==========
1800
-
1801
- .. [1] Generating Integer Partitions, [online],
1802
- Available: https://jeromekelleher.net/generating-integer-partitions.html
1803
- .. [2] Jerome Kelleher and Barry O'Sullivan, "Generating All
1804
- Partitions: A Comparison Of Two Encodings", [online],
1805
- Available: https://arxiv.org/pdf/0909.2331v2.pdf
1806
- """
1807
- if n < 1 or m is not None and m < 1:
1808
- # the empty set is the only way to handle these inputs
1809
- # and returning {} to represent it is consistent with
1810
- # the counting convention, e.g. nT(0) == 1.
1811
- yield []
1812
- return
1813
-
1814
- if m is None:
1815
- # The list `a`'s leading elements contain the partition in which
1816
- # y is the biggest element and x is either the same as y or the
1817
- # 2nd largest element; v and w are adjacent element indices
1818
- # to which x and y are being assigned, respectively.
1819
- a = [1]*n
1820
- y = -1
1821
- v = n
1822
- while v > 0:
1823
- v -= 1
1824
- x = a[v] + 1
1825
- while y >= 2 * x:
1826
- a[v] = x
1827
- y -= x
1828
- v += 1
1829
- w = v + 1
1830
- while x <= y:
1831
- a[v] = x
1832
- a[w] = y
1833
- yield a[:w + 1]
1834
- x += 1
1835
- y -= 1
1836
- a[v] = x + y
1837
- y = a[v] - 1
1838
- yield a[:w]
1839
- elif m == 1:
1840
- yield [n]
1841
- elif n == m:
1842
- yield [1]*n
1843
- else:
1844
- # recursively generate partitions of size m
1845
- for b in range(1, n//m + 1):
1846
- a = [b]*m
1847
- x = n - b*m
1848
- if not x:
1849
- if sort:
1850
- yield a
1851
- elif not sort and x <= m:
1852
- for ax in ordered_partitions(x, sort=False):
1853
- mi = len(ax)
1854
- a[-mi:] = [i + b for i in ax]
1855
- yield a
1856
- a[-mi:] = [b]*mi
1857
- else:
1858
- for mi in range(1, m):
1859
- for ax in ordered_partitions(x, mi, sort=True):
1860
- a[-mi:] = [i + b for i in ax]
1861
- yield a
1862
- a[-mi:] = [b]*mi
1863
-
1864
-
1865
- def binary_partitions(n):
1866
- """
1867
- Generates the binary partition of *n*.
1868
-
1869
- A binary partition consists only of numbers that are
1870
- powers of two. Each step reduces a `2^{k+1}` to `2^k` and
1871
- `2^k`. Thus 16 is converted to 8 and 8.
1872
-
1873
- Examples
1874
- ========
1875
-
1876
- >>> from sympy.utilities.iterables import binary_partitions
1877
- >>> for i in binary_partitions(5):
1878
- ... print(i)
1879
- ...
1880
- [4, 1]
1881
- [2, 2, 1]
1882
- [2, 1, 1, 1]
1883
- [1, 1, 1, 1, 1]
1884
-
1885
- References
1886
- ==========
1887
-
1888
- .. [1] TAOCP 4, section 7.2.1.5, problem 64
1889
-
1890
- """
1891
- from math import ceil, log2
1892
- power = int(2**(ceil(log2(n))))
1893
- acc = 0
1894
- partition = []
1895
- while power:
1896
- if acc + power <= n:
1897
- partition.append(power)
1898
- acc += power
1899
- power >>= 1
1900
-
1901
- last_num = len(partition) - 1 - (n & 1)
1902
- while last_num >= 0:
1903
- yield partition
1904
- if partition[last_num] == 2:
1905
- partition[last_num] = 1
1906
- partition.append(1)
1907
- last_num -= 1
1908
- continue
1909
- partition.append(1)
1910
- partition[last_num] >>= 1
1911
- x = partition[last_num + 1] = partition[last_num]
1912
- last_num += 1
1913
- while x > 1:
1914
- if x <= len(partition) - last_num - 1:
1915
- del partition[-x + 1:]
1916
- last_num += 1
1917
- partition[last_num] = x
1918
- else:
1919
- x >>= 1
1920
- yield [1]*n
1921
-
1922
-
1923
- def has_dups(seq):
1924
- """Return True if there are any duplicate elements in ``seq``.
1925
-
1926
- Examples
1927
- ========
1928
-
1929
- >>> from sympy import has_dups, Dict, Set
1930
- >>> has_dups((1, 2, 1))
1931
- True
1932
- >>> has_dups(range(3))
1933
- False
1934
- >>> all(has_dups(c) is False for c in (set(), Set(), dict(), Dict()))
1935
- True
1936
- """
1937
- from sympy.core.containers import Dict
1938
- from sympy.sets.sets import Set
1939
- if isinstance(seq, (dict, set, Dict, Set)):
1940
- return False
1941
- unique = set()
1942
- try:
1943
- return any(True for s in seq if s in unique or unique.add(s))
1944
- except TypeError:
1945
- return len(seq) != len(list(uniq(seq)))
1946
-
1947
-
1948
- def has_variety(seq):
1949
- """Return True if there are any different elements in ``seq``.
1950
-
1951
- Examples
1952
- ========
1953
-
1954
- >>> from sympy import has_variety
1955
-
1956
- >>> has_variety((1, 2, 1))
1957
- True
1958
- >>> has_variety((1, 1, 1))
1959
- False
1960
- """
1961
- for i, s in enumerate(seq):
1962
- if i == 0:
1963
- sentinel = s
1964
- else:
1965
- if s != sentinel:
1966
- return True
1967
- return False
1968
-
1969
-
1970
- def uniq(seq, result=None):
1971
- """
1972
- Yield unique elements from ``seq`` as an iterator. The second
1973
- parameter ``result`` is used internally; it is not necessary
1974
- to pass anything for this.
1975
-
1976
- Note: changing the sequence during iteration will raise a
1977
- RuntimeError if the size of the sequence is known; if you pass
1978
- an iterator and advance the iterator you will change the
1979
- output of this routine but there will be no warning.
1980
-
1981
- Examples
1982
- ========
1983
-
1984
- >>> from sympy.utilities.iterables import uniq
1985
- >>> dat = [1, 4, 1, 5, 4, 2, 1, 2]
1986
- >>> type(uniq(dat)) in (list, tuple)
1987
- False
1988
-
1989
- >>> list(uniq(dat))
1990
- [1, 4, 5, 2]
1991
- >>> list(uniq(x for x in dat))
1992
- [1, 4, 5, 2]
1993
- >>> list(uniq([[1], [2, 1], [1]]))
1994
- [[1], [2, 1]]
1995
- """
1996
- try:
1997
- n = len(seq)
1998
- except TypeError:
1999
- n = None
2000
- def check():
2001
- # check that size of seq did not change during iteration;
2002
- # if n == None the object won't support size changing, e.g.
2003
- # an iterator can't be changed
2004
- if n is not None and len(seq) != n:
2005
- raise RuntimeError('sequence changed size during iteration')
2006
- try:
2007
- seen = set()
2008
- result = result or []
2009
- for i, s in enumerate(seq):
2010
- if not (s in seen or seen.add(s)):
2011
- yield s
2012
- check()
2013
- except TypeError:
2014
- if s not in result:
2015
- yield s
2016
- check()
2017
- result.append(s)
2018
- if hasattr(seq, '__getitem__'):
2019
- yield from uniq(seq[i + 1:], result)
2020
- else:
2021
- yield from uniq(seq, result)
2022
-
2023
-
2024
- def generate_bell(n):
2025
- """Return permutations of [0, 1, ..., n - 1] such that each permutation
2026
- differs from the last by the exchange of a single pair of neighbors.
2027
- The ``n!`` permutations are returned as an iterator. In order to obtain
2028
- the next permutation from a random starting permutation, use the
2029
- ``next_trotterjohnson`` method of the Permutation class (which generates
2030
- the same sequence in a different manner).
2031
-
2032
- Examples
2033
- ========
2034
-
2035
- >>> from itertools import permutations
2036
- >>> from sympy.utilities.iterables import generate_bell
2037
- >>> from sympy import zeros, Matrix
2038
-
2039
- This is the sort of permutation used in the ringing of physical bells,
2040
- and does not produce permutations in lexicographical order. Rather, the
2041
- permutations differ from each other by exactly one inversion, and the
2042
- position at which the swapping occurs varies periodically in a simple
2043
- fashion. Consider the first few permutations of 4 elements generated
2044
- by ``permutations`` and ``generate_bell``:
2045
-
2046
- >>> list(permutations(range(4)))[:5]
2047
- [(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), (0, 3, 1, 2)]
2048
- >>> list(generate_bell(4))[:5]
2049
- [(0, 1, 2, 3), (0, 1, 3, 2), (0, 3, 1, 2), (3, 0, 1, 2), (3, 0, 2, 1)]
2050
-
2051
- Notice how the 2nd and 3rd lexicographical permutations have 3 elements
2052
- out of place whereas each "bell" permutation always has only two
2053
- elements out of place relative to the previous permutation (and so the
2054
- signature (+/-1) of a permutation is opposite of the signature of the
2055
- previous permutation).
2056
-
2057
- How the position of inversion varies across the elements can be seen
2058
- by tracing out where the largest number appears in the permutations:
2059
-
2060
- >>> m = zeros(4, 24)
2061
- >>> for i, p in enumerate(generate_bell(4)):
2062
- ... m[:, i] = Matrix([j - 3 for j in list(p)]) # make largest zero
2063
- >>> m.print_nonzero('X')
2064
- [XXX XXXXXX XXXXXX XXX]
2065
- [XX XX XXXX XX XXXX XX XX]
2066
- [X XXXX XX XXXX XX XXXX X]
2067
- [ XXXXXX XXXXXX XXXXXX ]
2068
-
2069
- See Also
2070
- ========
2071
-
2072
- sympy.combinatorics.permutations.Permutation.next_trotterjohnson
2073
-
2074
- References
2075
- ==========
2076
-
2077
- .. [1] https://en.wikipedia.org/wiki/Method_ringing
2078
-
2079
- .. [2] https://stackoverflow.com/questions/4856615/recursive-permutation/4857018
2080
-
2081
- .. [3] https://web.archive.org/web/20160313023044/http://programminggeeks.com/bell-algorithm-for-permutation/
2082
-
2083
- .. [4] https://en.wikipedia.org/wiki/Steinhaus%E2%80%93Johnson%E2%80%93Trotter_algorithm
2084
-
2085
- .. [5] Generating involutions, derangements, and relatives by ECO
2086
- Vincent Vajnovszki, DMTCS vol 1 issue 12, 2010
2087
-
2088
- """
2089
- n = as_int(n)
2090
- if n < 1:
2091
- raise ValueError('n must be a positive integer')
2092
- if n == 1:
2093
- yield (0,)
2094
- elif n == 2:
2095
- yield (0, 1)
2096
- yield (1, 0)
2097
- elif n == 3:
2098
- yield from [(0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)]
2099
- else:
2100
- m = n - 1
2101
- op = [0] + [-1]*m
2102
- l = list(range(n))
2103
- while True:
2104
- yield tuple(l)
2105
- # find biggest element with op
2106
- big = None, -1 # idx, value
2107
- for i in range(n):
2108
- if op[i] and l[i] > big[1]:
2109
- big = i, l[i]
2110
- i, _ = big
2111
- if i is None:
2112
- break # there are no ops left
2113
- # swap it with neighbor in the indicated direction
2114
- j = i + op[i]
2115
- l[i], l[j] = l[j], l[i]
2116
- op[i], op[j] = op[j], op[i]
2117
- # if it landed at the end or if the neighbor in the same
2118
- # direction is bigger then turn off op
2119
- if j == 0 or j == m or l[j + op[j]] > l[j]:
2120
- op[j] = 0
2121
- # any element bigger to the left gets +1 op
2122
- for i in range(j):
2123
- if l[i] > l[j]:
2124
- op[i] = 1
2125
- # any element bigger to the right gets -1 op
2126
- for i in range(j + 1, n):
2127
- if l[i] > l[j]:
2128
- op[i] = -1
2129
-
2130
-
2131
- def generate_involutions(n):
2132
- """
2133
- Generates involutions.
2134
-
2135
- An involution is a permutation that when multiplied
2136
- by itself equals the identity permutation. In this
2137
- implementation the involutions are generated using
2138
- Fixed Points.
2139
-
2140
- Alternatively, an involution can be considered as
2141
- a permutation that does not contain any cycles with
2142
- a length that is greater than two.
2143
-
2144
- Examples
2145
- ========
2146
-
2147
- >>> from sympy.utilities.iterables import generate_involutions
2148
- >>> list(generate_involutions(3))
2149
- [(0, 1, 2), (0, 2, 1), (1, 0, 2), (2, 1, 0)]
2150
- >>> len(list(generate_involutions(4)))
2151
- 10
2152
-
2153
- References
2154
- ==========
2155
-
2156
- .. [1] https://mathworld.wolfram.com/PermutationInvolution.html
2157
-
2158
- """
2159
- idx = list(range(n))
2160
- for p in permutations(idx):
2161
- for i in idx:
2162
- if p[p[i]] != i:
2163
- break
2164
- else:
2165
- yield p
2166
-
2167
-
2168
- def multiset_derangements(s):
2169
- """Generate derangements of the elements of s *in place*.
2170
-
2171
- Examples
2172
- ========
2173
-
2174
- >>> from sympy.utilities.iterables import multiset_derangements, uniq
2175
-
2176
- Because the derangements of multisets (not sets) are generated
2177
- in place, copies of the return value must be made if a collection
2178
- of derangements is desired or else all values will be the same:
2179
-
2180
- >>> list(uniq([i for i in multiset_derangements('1233')]))
2181
- [[None, None, None, None]]
2182
- >>> [i.copy() for i in multiset_derangements('1233')]
2183
- [['3', '3', '1', '2'], ['3', '3', '2', '1']]
2184
- >>> [''.join(i) for i in multiset_derangements('1233')]
2185
- ['3312', '3321']
2186
- """
2187
- from sympy.core.sorting import ordered
2188
- # create multiset dictionary of hashable elements or else
2189
- # remap elements to integers
2190
- try:
2191
- ms = multiset(s)
2192
- except TypeError:
2193
- # give each element a canonical integer value
2194
- key = dict(enumerate(ordered(uniq(s))))
2195
- h = []
2196
- for si in s:
2197
- for k in key:
2198
- if key[k] == si:
2199
- h.append(k)
2200
- break
2201
- for i in multiset_derangements(h):
2202
- yield [key[j] for j in i]
2203
- return
2204
-
2205
- mx = max(ms.values()) # max repetition of any element
2206
- n = len(s) # the number of elements
2207
-
2208
- ## special cases
2209
-
2210
- # 1) one element has more than half the total cardinality of s: no
2211
- # derangements are possible.
2212
- if mx*2 > n:
2213
- return
2214
-
2215
- # 2) all elements appear once: singletons
2216
- if len(ms) == n:
2217
- yield from _set_derangements(s)
2218
- return
2219
-
2220
- # find the first element that is repeated the most to place
2221
- # in the following two special cases where the selection
2222
- # is unambiguous: either there are two elements with multiplicity
2223
- # of mx or else there is only one with multiplicity mx
2224
- for M in ms:
2225
- if ms[M] == mx:
2226
- break
2227
-
2228
- inonM = [i for i in range(n) if s[i] != M] # location of non-M
2229
- iM = [i for i in range(n) if s[i] == M] # locations of M
2230
- rv = [None]*n
2231
-
2232
- # 3) half are the same
2233
- if 2*mx == n:
2234
- # M goes into non-M locations
2235
- for i in inonM:
2236
- rv[i] = M
2237
- # permutations of non-M go to M locations
2238
- for p in multiset_permutations([s[i] for i in inonM]):
2239
- for i, pi in zip(iM, p):
2240
- rv[i] = pi
2241
- yield rv
2242
- # clean-up (and encourages proper use of routine)
2243
- rv[:] = [None]*n
2244
- return
2245
-
2246
- # 4) single repeat covers all but 1 of the non-repeats:
2247
- # if there is one repeat then the multiset of the values
2248
- # of ms would be {mx: 1, 1: n - mx}, i.e. there would
2249
- # be n - mx + 1 values with the condition that n - 2*mx = 1
2250
- if n - 2*mx == 1 and len(ms.values()) == n - mx + 1:
2251
- for i, i1 in enumerate(inonM):
2252
- ifill = inonM[:i] + inonM[i+1:]
2253
- for j in ifill:
2254
- rv[j] = M
2255
- for p in permutations([s[j] for j in ifill]):
2256
- rv[i1] = s[i1]
2257
- for j, pi in zip(iM, p):
2258
- rv[j] = pi
2259
- k = i1
2260
- for j in iM:
2261
- rv[j], rv[k] = rv[k], rv[j]
2262
- yield rv
2263
- k = j
2264
- # clean-up (and encourages proper use of routine)
2265
- rv[:] = [None]*n
2266
- return
2267
-
2268
- ## general case is handled with 3 helpers:
2269
- # 1) `finish_derangements` will place the last two elements
2270
- # which have arbitrary multiplicities, e.g. for multiset
2271
- # {c: 3, a: 2, b: 2}, the last two elements are a and b
2272
- # 2) `iopen` will tell where a given element can be placed
2273
- # 3) `do` will recursively place elements into subsets of
2274
- # valid locations
2275
-
2276
- def finish_derangements():
2277
- """Place the last two elements into the partially completed
2278
- derangement, and yield the results.
2279
- """
2280
-
2281
- a = take[1][0] # penultimate element
2282
- a_ct = take[1][1]
2283
- b = take[0][0] # last element to be placed
2284
- b_ct = take[0][1]
2285
-
2286
- # split the indexes of the not-already-assigned elements of rv into
2287
- # three categories
2288
- forced_a = [] # positions which must have an a
2289
- forced_b = [] # positions which must have a b
2290
- open_free = [] # positions which could take either
2291
- for i in range(len(s)):
2292
- if rv[i] is None:
2293
- if s[i] == a:
2294
- forced_b.append(i)
2295
- elif s[i] == b:
2296
- forced_a.append(i)
2297
- else:
2298
- open_free.append(i)
2299
-
2300
- if len(forced_a) > a_ct or len(forced_b) > b_ct:
2301
- # No derangement possible
2302
- return
2303
-
2304
- for i in forced_a:
2305
- rv[i] = a
2306
- for i in forced_b:
2307
- rv[i] = b
2308
- for a_place in combinations(open_free, a_ct - len(forced_a)):
2309
- for a_pos in a_place:
2310
- rv[a_pos] = a
2311
- for i in open_free:
2312
- if rv[i] is None: # anything not in the subset is set to b
2313
- rv[i] = b
2314
- yield rv
2315
- # Clean up/undo the final placements
2316
- for i in open_free:
2317
- rv[i] = None
2318
-
2319
- # additional cleanup - clear forced_a, forced_b
2320
- for i in forced_a:
2321
- rv[i] = None
2322
- for i in forced_b:
2323
- rv[i] = None
2324
-
2325
- def iopen(v):
2326
- # return indices at which element v can be placed in rv:
2327
- # locations which are not already occupied if that location
2328
- # does not already contain v in the same location of s
2329
- return [i for i in range(n) if rv[i] is None and s[i] != v]
2330
-
2331
- def do(j):
2332
- if j == 1:
2333
- # handle the last two elements (regardless of multiplicity)
2334
- # with a special method
2335
- yield from finish_derangements()
2336
- else:
2337
- # place the mx elements of M into a subset of places
2338
- # into which it can be replaced
2339
- M, mx = take[j]
2340
- for i in combinations(iopen(M), mx):
2341
- # place M
2342
- for ii in i:
2343
- rv[ii] = M
2344
- # recursively place the next element
2345
- yield from do(j - 1)
2346
- # mark positions where M was placed as once again
2347
- # open for placement of other elements
2348
- for ii in i:
2349
- rv[ii] = None
2350
-
2351
- # process elements in order of canonically decreasing multiplicity
2352
- take = sorted(ms.items(), key=lambda x:(x[1], x[0]))
2353
- yield from do(len(take) - 1)
2354
- rv[:] = [None]*n
2355
-
2356
-
2357
- def random_derangement(t, choice=None, strict=True):
2358
- """Return a list of elements in which none are in the same positions
2359
- as they were originally. If an element fills more than half of the positions
2360
- then an error will be raised since no derangement is possible. To obtain
2361
- a derangement of as many items as possible--with some of the most numerous
2362
- remaining in their original positions--pass `strict=False`. To produce a
2363
- pseudorandom derangment, pass a pseudorandom selector like `choice` (see
2364
- below).
2365
-
2366
- Examples
2367
- ========
2368
-
2369
- >>> from sympy.utilities.iterables import random_derangement
2370
- >>> t = 'SymPy: a CAS in pure Python'
2371
- >>> d = random_derangement(t)
2372
- >>> all(i != j for i, j in zip(d, t))
2373
- True
2374
-
2375
- A predictable result can be obtained by using a pseudorandom
2376
- generator for the choice:
2377
-
2378
- >>> from sympy.core.random import seed, choice as c
2379
- >>> seed(1)
2380
- >>> d = [''.join(random_derangement(t, c)) for i in range(5)]
2381
- >>> assert len(set(d)) != 1 # we got different values
2382
-
2383
- By reseeding, the same sequence can be obtained:
2384
-
2385
- >>> seed(1)
2386
- >>> d2 = [''.join(random_derangement(t, c)) for i in range(5)]
2387
- >>> assert d == d2
2388
- """
2389
- if choice is None:
2390
- import secrets
2391
- choice = secrets.choice
2392
- def shuffle(rv):
2393
- '''Knuth shuffle'''
2394
- for i in range(len(rv) - 1, 0, -1):
2395
- x = choice(rv[:i + 1])
2396
- j = rv.index(x)
2397
- rv[i], rv[j] = rv[j], rv[i]
2398
- def pick(rv, n):
2399
- '''shuffle rv and return the first n values
2400
- '''
2401
- shuffle(rv)
2402
- return rv[:n]
2403
- ms = multiset(t)
2404
- tot = len(t)
2405
- ms = sorted(ms.items(), key=lambda x: x[1])
2406
- # if there are not enough spaces for the most
2407
- # plentiful element to move to then some of them
2408
- # will have to stay in place
2409
- M, mx = ms[-1]
2410
- n = len(t)
2411
- xs = 2*mx - tot
2412
- if xs > 0:
2413
- if strict:
2414
- raise ValueError('no derangement possible')
2415
- opts = [i for (i, c) in enumerate(t) if c == ms[-1][0]]
2416
- pick(opts, xs)
2417
- stay = sorted(opts[:xs])
2418
- rv = list(t)
2419
- for i in reversed(stay):
2420
- rv.pop(i)
2421
- rv = random_derangement(rv, choice)
2422
- for i in stay:
2423
- rv.insert(i, ms[-1][0])
2424
- return ''.join(rv) if type(t) is str else rv
2425
- # the normal derangement calculated from here
2426
- if n == len(ms):
2427
- # approx 1/3 will succeed
2428
- rv = list(t)
2429
- while True:
2430
- shuffle(rv)
2431
- if all(i != j for i,j in zip(rv, t)):
2432
- break
2433
- else:
2434
- # general case
2435
- rv = [None]*n
2436
- while True:
2437
- j = 0
2438
- while j > -len(ms): # do most numerous first
2439
- j -= 1
2440
- e, c = ms[j]
2441
- opts = [i for i in range(n) if rv[i] is None and t[i] != e]
2442
- if len(opts) < c:
2443
- for i in range(n):
2444
- rv[i] = None
2445
- break # try again
2446
- pick(opts, c)
2447
- for i in range(c):
2448
- rv[opts[i]] = e
2449
- else:
2450
- return rv
2451
- return rv
2452
-
2453
-
2454
- def _set_derangements(s):
2455
- """
2456
- yield derangements of items in ``s`` which are assumed to contain
2457
- no repeated elements
2458
- """
2459
- if len(s) < 2:
2460
- return
2461
- if len(s) == 2:
2462
- yield [s[1], s[0]]
2463
- return
2464
- if len(s) == 3:
2465
- yield [s[1], s[2], s[0]]
2466
- yield [s[2], s[0], s[1]]
2467
- return
2468
- for p in permutations(s):
2469
- if not any(i == j for i, j in zip(p, s)):
2470
- yield list(p)
2471
-
2472
-
2473
- def generate_derangements(s):
2474
- """
2475
- Return unique derangements of the elements of iterable ``s``.
2476
-
2477
- Examples
2478
- ========
2479
-
2480
- >>> from sympy.utilities.iterables import generate_derangements
2481
- >>> list(generate_derangements([0, 1, 2]))
2482
- [[1, 2, 0], [2, 0, 1]]
2483
- >>> list(generate_derangements([0, 1, 2, 2]))
2484
- [[2, 2, 0, 1], [2, 2, 1, 0]]
2485
- >>> list(generate_derangements([0, 1, 1]))
2486
- []
2487
-
2488
- See Also
2489
- ========
2490
-
2491
- sympy.functions.combinatorial.factorials.subfactorial
2492
-
2493
- """
2494
- if not has_dups(s):
2495
- yield from _set_derangements(s)
2496
- else:
2497
- for p in multiset_derangements(s):
2498
- yield list(p)
2499
-
2500
-
2501
- def necklaces(n, k, free=False):
2502
- """
2503
- A routine to generate necklaces that may (free=True) or may not
2504
- (free=False) be turned over to be viewed. The "necklaces" returned
2505
- are comprised of ``n`` integers (beads) with ``k`` different
2506
- values (colors). Only unique necklaces are returned.
2507
-
2508
- Examples
2509
- ========
2510
-
2511
- >>> from sympy.utilities.iterables import necklaces, bracelets
2512
- >>> def show(s, i):
2513
- ... return ''.join(s[j] for j in i)
2514
-
2515
- The "unrestricted necklace" is sometimes also referred to as a
2516
- "bracelet" (an object that can be turned over, a sequence that can
2517
- be reversed) and the term "necklace" is used to imply a sequence
2518
- that cannot be reversed. So ACB == ABC for a bracelet (rotate and
2519
- reverse) while the two are different for a necklace since rotation
2520
- alone cannot make the two sequences the same.
2521
-
2522
- (mnemonic: Bracelets can be viewed Backwards, but Not Necklaces.)
2523
-
2524
- >>> B = [show('ABC', i) for i in bracelets(3, 3)]
2525
- >>> N = [show('ABC', i) for i in necklaces(3, 3)]
2526
- >>> set(N) - set(B)
2527
- {'ACB'}
2528
-
2529
- >>> list(necklaces(4, 2))
2530
- [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 1),
2531
- (0, 1, 0, 1), (0, 1, 1, 1), (1, 1, 1, 1)]
2532
-
2533
- >>> [show('.o', i) for i in bracelets(4, 2)]
2534
- ['....', '...o', '..oo', '.o.o', '.ooo', 'oooo']
2535
-
2536
- References
2537
- ==========
2538
-
2539
- .. [1] https://mathworld.wolfram.com/Necklace.html
2540
-
2541
- .. [2] Frank Ruskey, Carla Savage, and Terry Min Yih Wang,
2542
- Generating necklaces, Journal of Algorithms 13 (1992), 414-430;
2543
- https://doi.org/10.1016/0196-6774(92)90047-G
2544
-
2545
- """
2546
- # The FKM algorithm
2547
- if k == 0 and n > 0:
2548
- return
2549
- a = [0]*n
2550
- yield tuple(a)
2551
- if n == 0:
2552
- return
2553
- while True:
2554
- i = n - 1
2555
- while a[i] == k - 1:
2556
- i -= 1
2557
- if i == -1:
2558
- return
2559
- a[i] += 1
2560
- for j in range(n - i - 1):
2561
- a[j + i + 1] = a[j]
2562
- if n % (i + 1) == 0 and (not free or all(a <= a[j::-1] + a[-1:j:-1] for j in range(n - 1))):
2563
- # No need to test j = n - 1.
2564
- yield tuple(a)
2565
-
2566
-
2567
- def bracelets(n, k):
2568
- """Wrapper to necklaces to return a free (unrestricted) necklace."""
2569
- return necklaces(n, k, free=True)
2570
-
2571
-
2572
- def generate_oriented_forest(n):
2573
- """
2574
- This algorithm generates oriented forests.
2575
-
2576
- An oriented graph is a directed graph having no symmetric pair of directed
2577
- edges. A forest is an acyclic graph, i.e., it has no cycles. A forest can
2578
- also be described as a disjoint union of trees, which are graphs in which
2579
- any two vertices are connected by exactly one simple path.
2580
-
2581
- Examples
2582
- ========
2583
-
2584
- >>> from sympy.utilities.iterables import generate_oriented_forest
2585
- >>> list(generate_oriented_forest(4))
2586
- [[0, 1, 2, 3], [0, 1, 2, 2], [0, 1, 2, 1], [0, 1, 2, 0], \
2587
- [0, 1, 1, 1], [0, 1, 1, 0], [0, 1, 0, 1], [0, 1, 0, 0], [0, 0, 0, 0]]
2588
-
2589
- References
2590
- ==========
2591
-
2592
- .. [1] T. Beyer and S.M. Hedetniemi: constant time generation of
2593
- rooted trees, SIAM J. Computing Vol. 9, No. 4, November 1980
2594
-
2595
- .. [2] https://stackoverflow.com/questions/1633833/oriented-forest-taocp-algorithm-in-python
2596
-
2597
- """
2598
- P = list(range(-1, n))
2599
- while True:
2600
- yield P[1:]
2601
- if P[n] > 0:
2602
- P[n] = P[P[n]]
2603
- else:
2604
- for p in range(n - 1, 0, -1):
2605
- if P[p] != 0:
2606
- target = P[p] - 1
2607
- for q in range(p - 1, 0, -1):
2608
- if P[q] == target:
2609
- break
2610
- offset = p - q
2611
- for i in range(p, n + 1):
2612
- P[i] = P[i - offset]
2613
- break
2614
- else:
2615
- break
2616
-
2617
-
2618
- def minlex(seq, directed=True, key=None):
2619
- r"""
2620
- Return the rotation of the sequence in which the lexically smallest
2621
- elements appear first, e.g. `cba \rightarrow acb`.
2622
-
2623
- The sequence returned is a tuple, unless the input sequence is a string
2624
- in which case a string is returned.
2625
-
2626
- If ``directed`` is False then the smaller of the sequence and the
2627
- reversed sequence is returned, e.g. `cba \rightarrow abc`.
2628
-
2629
- If ``key`` is not None then it is used to extract a comparison key from each element in iterable.
2630
-
2631
- Examples
2632
- ========
2633
-
2634
- >>> from sympy.combinatorics.polyhedron import minlex
2635
- >>> minlex((1, 2, 0))
2636
- (0, 1, 2)
2637
- >>> minlex((1, 0, 2))
2638
- (0, 2, 1)
2639
- >>> minlex((1, 0, 2), directed=False)
2640
- (0, 1, 2)
2641
-
2642
- >>> minlex('11010011000', directed=True)
2643
- '00011010011'
2644
- >>> minlex('11010011000', directed=False)
2645
- '00011001011'
2646
-
2647
- >>> minlex(('bb', 'aaa', 'c', 'a'))
2648
- ('a', 'bb', 'aaa', 'c')
2649
- >>> minlex(('bb', 'aaa', 'c', 'a'), key=len)
2650
- ('c', 'a', 'bb', 'aaa')
2651
-
2652
- """
2653
- from sympy.functions.elementary.miscellaneous import Id
2654
- if key is None: key = Id
2655
- best = rotate_left(seq, least_rotation(seq, key=key))
2656
- if not directed:
2657
- rseq = seq[::-1]
2658
- rbest = rotate_left(rseq, least_rotation(rseq, key=key))
2659
- best = min(best, rbest, key=key)
2660
-
2661
- # Convert to tuple, unless we started with a string.
2662
- return tuple(best) if not isinstance(seq, str) else best
2663
-
2664
-
2665
- def runs(seq, op=gt):
2666
- """Group the sequence into lists in which successive elements
2667
- all compare the same with the comparison operator, ``op``:
2668
- op(seq[i + 1], seq[i]) is True from all elements in a run.
2669
-
2670
- Examples
2671
- ========
2672
-
2673
- >>> from sympy.utilities.iterables import runs
2674
- >>> from operator import ge
2675
- >>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2])
2676
- [[0, 1, 2], [2], [1, 4], [3], [2], [2]]
2677
- >>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2], op=ge)
2678
- [[0, 1, 2, 2], [1, 4], [3], [2, 2]]
2679
- """
2680
- cycles = []
2681
- seq = iter(seq)
2682
- try:
2683
- run = [next(seq)]
2684
- except StopIteration:
2685
- return []
2686
- while True:
2687
- try:
2688
- ei = next(seq)
2689
- except StopIteration:
2690
- break
2691
- if op(ei, run[-1]):
2692
- run.append(ei)
2693
- continue
2694
- else:
2695
- cycles.append(run)
2696
- run = [ei]
2697
- if run:
2698
- cycles.append(run)
2699
- return cycles
2700
-
2701
-
2702
- def sequence_partitions(l, n, /):
2703
- r"""Returns the partition of sequence $l$ into $n$ bins
2704
-
2705
- Explanation
2706
- ===========
2707
-
2708
- Given the sequence $l_1 \cdots l_m \in V^+$ where
2709
- $V^+$ is the Kleene plus of $V$
2710
-
2711
- The set of $n$ partitions of $l$ is defined as:
2712
-
2713
- .. math::
2714
- \{(s_1, \cdots, s_n) | s_1 \in V^+, \cdots, s_n \in V^+,
2715
- s_1 \cdots s_n = l_1 \cdots l_m\}
2716
-
2717
- Parameters
2718
- ==========
2719
-
2720
- l : Sequence[T]
2721
- A nonempty sequence of any Python objects
2722
-
2723
- n : int
2724
- A positive integer
2725
-
2726
- Yields
2727
- ======
2728
-
2729
- out : list[Sequence[T]]
2730
- A list of sequences with concatenation equals $l$.
2731
- This should conform with the type of $l$.
2732
-
2733
- Examples
2734
- ========
2735
-
2736
- >>> from sympy.utilities.iterables import sequence_partitions
2737
- >>> for out in sequence_partitions([1, 2, 3, 4], 2):
2738
- ... print(out)
2739
- [[1], [2, 3, 4]]
2740
- [[1, 2], [3, 4]]
2741
- [[1, 2, 3], [4]]
2742
-
2743
- Notes
2744
- =====
2745
-
2746
- This is modified version of EnricoGiampieri's partition generator
2747
- from https://stackoverflow.com/questions/13131491/partition-n-items-into-k-bins-in-python-lazily
2748
-
2749
- See Also
2750
- ========
2751
-
2752
- sequence_partitions_empty
2753
- """
2754
- # Asserting l is nonempty is done only for sanity check
2755
- if n == 1 and l:
2756
- yield [l]
2757
- return
2758
- for i in range(1, len(l)):
2759
- for part in sequence_partitions(l[i:], n - 1):
2760
- yield [l[:i]] + part
2761
-
2762
-
2763
- def sequence_partitions_empty(l, n, /):
2764
- r"""Returns the partition of sequence $l$ into $n$ bins with
2765
- empty sequence
2766
-
2767
- Explanation
2768
- ===========
2769
-
2770
- Given the sequence $l_1 \cdots l_m \in V^*$ where
2771
- $V^*$ is the Kleene star of $V$
2772
-
2773
- The set of $n$ partitions of $l$ is defined as:
2774
-
2775
- .. math::
2776
- \{(s_1, \cdots, s_n) | s_1 \in V^*, \cdots, s_n \in V^*,
2777
- s_1 \cdots s_n = l_1 \cdots l_m\}
2778
-
2779
- There are more combinations than :func:`sequence_partitions` because
2780
- empty sequence can fill everywhere, so we try to provide different
2781
- utility for this.
2782
-
2783
- Parameters
2784
- ==========
2785
-
2786
- l : Sequence[T]
2787
- A sequence of any Python objects (can be possibly empty)
2788
-
2789
- n : int
2790
- A positive integer
2791
-
2792
- Yields
2793
- ======
2794
-
2795
- out : list[Sequence[T]]
2796
- A list of sequences with concatenation equals $l$.
2797
- This should conform with the type of $l$.
2798
-
2799
- Examples
2800
- ========
2801
-
2802
- >>> from sympy.utilities.iterables import sequence_partitions_empty
2803
- >>> for out in sequence_partitions_empty([1, 2, 3, 4], 2):
2804
- ... print(out)
2805
- [[], [1, 2, 3, 4]]
2806
- [[1], [2, 3, 4]]
2807
- [[1, 2], [3, 4]]
2808
- [[1, 2, 3], [4]]
2809
- [[1, 2, 3, 4], []]
2810
-
2811
- See Also
2812
- ========
2813
-
2814
- sequence_partitions
2815
- """
2816
- if n < 1:
2817
- return
2818
- if n == 1:
2819
- yield [l]
2820
- return
2821
- for i in range(0, len(l) + 1):
2822
- for part in sequence_partitions_empty(l[i:], n - 1):
2823
- yield [l[:i]] + part
2824
-
2825
-
2826
- def kbins(l, k, ordered=None):
2827
- """
2828
- Return sequence ``l`` partitioned into ``k`` bins.
2829
-
2830
- Examples
2831
- ========
2832
-
2833
- The default is to give the items in the same order, but grouped
2834
- into k partitions without any reordering:
2835
-
2836
- >>> from sympy.utilities.iterables import kbins
2837
- >>> for p in kbins(list(range(5)), 2):
2838
- ... print(p)
2839
- ...
2840
- [[0], [1, 2, 3, 4]]
2841
- [[0, 1], [2, 3, 4]]
2842
- [[0, 1, 2], [3, 4]]
2843
- [[0, 1, 2, 3], [4]]
2844
-
2845
- The ``ordered`` flag is either None (to give the simple partition
2846
- of the elements) or is a 2 digit integer indicating whether the order of
2847
- the bins and the order of the items in the bins matters. Given::
2848
-
2849
- A = [[0], [1, 2]]
2850
- B = [[1, 2], [0]]
2851
- C = [[2, 1], [0]]
2852
- D = [[0], [2, 1]]
2853
-
2854
- the following values for ``ordered`` have the shown meanings::
2855
-
2856
- 00 means A == B == C == D
2857
- 01 means A == B
2858
- 10 means A == D
2859
- 11 means A == A
2860
-
2861
- >>> for ordered_flag in [None, 0, 1, 10, 11]:
2862
- ... print('ordered = %s' % ordered_flag)
2863
- ... for p in kbins(list(range(3)), 2, ordered=ordered_flag):
2864
- ... print(' %s' % p)
2865
- ...
2866
- ordered = None
2867
- [[0], [1, 2]]
2868
- [[0, 1], [2]]
2869
- ordered = 0
2870
- [[0, 1], [2]]
2871
- [[0, 2], [1]]
2872
- [[0], [1, 2]]
2873
- ordered = 1
2874
- [[0], [1, 2]]
2875
- [[0], [2, 1]]
2876
- [[1], [0, 2]]
2877
- [[1], [2, 0]]
2878
- [[2], [0, 1]]
2879
- [[2], [1, 0]]
2880
- ordered = 10
2881
- [[0, 1], [2]]
2882
- [[2], [0, 1]]
2883
- [[0, 2], [1]]
2884
- [[1], [0, 2]]
2885
- [[0], [1, 2]]
2886
- [[1, 2], [0]]
2887
- ordered = 11
2888
- [[0], [1, 2]]
2889
- [[0, 1], [2]]
2890
- [[0], [2, 1]]
2891
- [[0, 2], [1]]
2892
- [[1], [0, 2]]
2893
- [[1, 0], [2]]
2894
- [[1], [2, 0]]
2895
- [[1, 2], [0]]
2896
- [[2], [0, 1]]
2897
- [[2, 0], [1]]
2898
- [[2], [1, 0]]
2899
- [[2, 1], [0]]
2900
-
2901
- See Also
2902
- ========
2903
-
2904
- partitions, multiset_partitions
2905
-
2906
- """
2907
- if ordered is None:
2908
- yield from sequence_partitions(l, k)
2909
- elif ordered == 11:
2910
- for pl in multiset_permutations(l):
2911
- pl = list(pl)
2912
- yield from sequence_partitions(pl, k)
2913
- elif ordered == 00:
2914
- yield from multiset_partitions(l, k)
2915
- elif ordered == 10:
2916
- for p in multiset_partitions(l, k):
2917
- for perm in permutations(p):
2918
- yield list(perm)
2919
- elif ordered == 1:
2920
- for kgot, p in partitions(len(l), k, size=True):
2921
- if kgot != k:
2922
- continue
2923
- for li in multiset_permutations(l):
2924
- rv = []
2925
- i = j = 0
2926
- li = list(li)
2927
- for size, multiplicity in sorted(p.items()):
2928
- for m in range(multiplicity):
2929
- j = i + size
2930
- rv.append(li[i: j])
2931
- i = j
2932
- yield rv
2933
- else:
2934
- raise ValueError(
2935
- 'ordered must be one of 00, 01, 10 or 11, not %s' % ordered)
2936
-
2937
-
2938
- def permute_signs(t):
2939
- """Return iterator in which the signs of non-zero elements
2940
- of t are permuted.
2941
-
2942
- Examples
2943
- ========
2944
-
2945
- >>> from sympy.utilities.iterables import permute_signs
2946
- >>> list(permute_signs((0, 1, 2)))
2947
- [(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2)]
2948
- """
2949
- for signs in product(*[(1, -1)]*(len(t) - t.count(0))):
2950
- signs = list(signs)
2951
- yield type(t)([i*signs.pop() if i else i for i in t])
2952
-
2953
-
2954
- def signed_permutations(t):
2955
- """Return iterator in which the signs of non-zero elements
2956
- of t and the order of the elements are permuted and all
2957
- returned values are unique.
2958
-
2959
- Examples
2960
- ========
2961
-
2962
- >>> from sympy.utilities.iterables import signed_permutations
2963
- >>> list(signed_permutations((0, 1, 2)))
2964
- [(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2), (0, 2, 1),
2965
- (0, -2, 1), (0, 2, -1), (0, -2, -1), (1, 0, 2), (-1, 0, 2),
2966
- (1, 0, -2), (-1, 0, -2), (1, 2, 0), (-1, 2, 0), (1, -2, 0),
2967
- (-1, -2, 0), (2, 0, 1), (-2, 0, 1), (2, 0, -1), (-2, 0, -1),
2968
- (2, 1, 0), (-2, 1, 0), (2, -1, 0), (-2, -1, 0)]
2969
- """
2970
- return (type(t)(i) for j in multiset_permutations(t)
2971
- for i in permute_signs(j))
2972
-
2973
-
2974
- def rotations(s, dir=1):
2975
- """Return a generator giving the items in s as list where
2976
- each subsequent list has the items rotated to the left (default)
2977
- or right (``dir=-1``) relative to the previous list.
2978
-
2979
- Examples
2980
- ========
2981
-
2982
- >>> from sympy import rotations
2983
- >>> list(rotations([1,2,3]))
2984
- [[1, 2, 3], [2, 3, 1], [3, 1, 2]]
2985
- >>> list(rotations([1,2,3], -1))
2986
- [[1, 2, 3], [3, 1, 2], [2, 3, 1]]
2987
- """
2988
- seq = list(s)
2989
- for i in range(len(seq)):
2990
- yield seq
2991
- seq = rotate_left(seq, dir)
2992
-
2993
-
2994
- def roundrobin(*iterables):
2995
- """roundrobin recipe taken from itertools documentation:
2996
- https://docs.python.org/3/library/itertools.html#itertools-recipes
2997
-
2998
- roundrobin('ABC', 'D', 'EF') --> A D E B F C
2999
-
3000
- Recipe credited to George Sakkis
3001
- """
3002
- nexts = cycle(iter(it).__next__ for it in iterables)
3003
-
3004
- pending = len(iterables)
3005
- while pending:
3006
- try:
3007
- for nxt in nexts:
3008
- yield nxt()
3009
- except StopIteration:
3010
- pending -= 1
3011
- nexts = cycle(islice(nexts, pending))
3012
-
3013
-
3014
-
3015
- class NotIterable:
3016
- """
3017
- Use this as mixin when creating a class which is not supposed to
3018
- return true when iterable() is called on its instances because
3019
- calling list() on the instance, for example, would result in
3020
- an infinite loop.
3021
- """
3022
- pass
3023
-
3024
-
3025
- def iterable(i, exclude=(str, dict, NotIterable)):
3026
- """
3027
- Return a boolean indicating whether ``i`` is SymPy iterable.
3028
- True also indicates that the iterator is finite, e.g. you can
3029
- call list(...) on the instance.
3030
-
3031
- When SymPy is working with iterables, it is almost always assuming
3032
- that the iterable is not a string or a mapping, so those are excluded
3033
- by default. If you want a pure Python definition, make exclude=None. To
3034
- exclude multiple items, pass them as a tuple.
3035
-
3036
- You can also set the _iterable attribute to True or False on your class,
3037
- which will override the checks here, including the exclude test.
3038
-
3039
- As a rule of thumb, some SymPy functions use this to check if they should
3040
- recursively map over an object. If an object is technically iterable in
3041
- the Python sense but does not desire this behavior (e.g., because its
3042
- iteration is not finite, or because iteration might induce an unwanted
3043
- computation), it should disable it by setting the _iterable attribute to False.
3044
-
3045
- See also: is_sequence
3046
-
3047
- Examples
3048
- ========
3049
-
3050
- >>> from sympy.utilities.iterables import iterable
3051
- >>> from sympy import Tuple
3052
- >>> things = [[1], (1,), set([1]), Tuple(1), (j for j in [1, 2]), {1:2}, '1', 1]
3053
- >>> for i in things:
3054
- ... print('%s %s' % (iterable(i), type(i)))
3055
- True <... 'list'>
3056
- True <... 'tuple'>
3057
- True <... 'set'>
3058
- True <class 'sympy.core.containers.Tuple'>
3059
- True <... 'generator'>
3060
- False <... 'dict'>
3061
- False <... 'str'>
3062
- False <... 'int'>
3063
-
3064
- >>> iterable({}, exclude=None)
3065
- True
3066
- >>> iterable({}, exclude=str)
3067
- True
3068
- >>> iterable("no", exclude=str)
3069
- False
3070
-
3071
- """
3072
- if hasattr(i, '_iterable'):
3073
- return i._iterable
3074
- try:
3075
- iter(i)
3076
- except TypeError:
3077
- return False
3078
- if exclude:
3079
- return not isinstance(i, exclude)
3080
- return True
3081
-
3082
-
3083
- def is_sequence(i, include=None):
3084
- """
3085
- Return a boolean indicating whether ``i`` is a sequence in the SymPy
3086
- sense. If anything that fails the test below should be included as
3087
- being a sequence for your application, set 'include' to that object's
3088
- type; multiple types should be passed as a tuple of types.
3089
-
3090
- Note: although generators can generate a sequence, they often need special
3091
- handling to make sure their elements are captured before the generator is
3092
- exhausted, so these are not included by default in the definition of a
3093
- sequence.
3094
-
3095
- See also: iterable
3096
-
3097
- Examples
3098
- ========
3099
-
3100
- >>> from sympy.utilities.iterables import is_sequence
3101
- >>> from types import GeneratorType
3102
- >>> is_sequence([])
3103
- True
3104
- >>> is_sequence(set())
3105
- False
3106
- >>> is_sequence('abc')
3107
- False
3108
- >>> is_sequence('abc', include=str)
3109
- True
3110
- >>> generator = (c for c in 'abc')
3111
- >>> is_sequence(generator)
3112
- False
3113
- >>> is_sequence(generator, include=(str, GeneratorType))
3114
- True
3115
-
3116
- """
3117
- return (hasattr(i, '__getitem__') and
3118
- iterable(i) or
3119
- bool(include) and
3120
- isinstance(i, include))
3121
-
3122
-
3123
- @deprecated(
3124
- """
3125
- Using postorder_traversal from the sympy.utilities.iterables submodule is
3126
- deprecated.
3127
-
3128
- Instead, use postorder_traversal from the top-level sympy namespace, like
3129
-
3130
- sympy.postorder_traversal
3131
- """,
3132
- deprecated_since_version="1.10",
3133
- active_deprecations_target="deprecated-traversal-functions-moved")
3134
- def postorder_traversal(node, keys=None):
3135
- from sympy.core.traversal import postorder_traversal as _postorder_traversal
3136
- return _postorder_traversal(node, keys=keys)
3137
-
3138
-
3139
- @deprecated(
3140
- """
3141
- Using interactive_traversal from the sympy.utilities.iterables submodule
3142
- is deprecated.
3143
-
3144
- Instead, use interactive_traversal from the top-level sympy namespace,
3145
- like
3146
-
3147
- sympy.interactive_traversal
3148
- """,
3149
- deprecated_since_version="1.10",
3150
- active_deprecations_target="deprecated-traversal-functions-moved")
3151
- def interactive_traversal(expr):
3152
- from sympy.interactive.traversal import interactive_traversal as _interactive_traversal
3153
- return _interactive_traversal(expr)
3154
-
3155
-
3156
- @deprecated(
3157
- """
3158
- Importing default_sort_key from sympy.utilities.iterables is deprecated.
3159
- Use from sympy import default_sort_key instead.
3160
- """,
3161
- deprecated_since_version="1.10",
3162
- active_deprecations_target="deprecated-sympy-core-compatibility",
3163
- )
3164
- def default_sort_key(*args, **kwargs):
3165
- from sympy import default_sort_key as _default_sort_key
3166
- return _default_sort_key(*args, **kwargs)
3167
-
3168
-
3169
- @deprecated(
3170
- """
3171
- Importing default_sort_key from sympy.utilities.iterables is deprecated.
3172
- Use from sympy import default_sort_key instead.
3173
- """,
3174
- deprecated_since_version="1.10",
3175
- active_deprecations_target="deprecated-sympy-core-compatibility",
3176
- )
3177
- def ordered(*args, **kwargs):
3178
- from sympy import ordered as _ordered
3179
- return _ordered(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/lambdify.py DELETED
@@ -1,1592 +0,0 @@
1
- """
2
- This module provides convenient functions to transform SymPy expressions to
3
- lambda functions which can be used to calculate numerical values very fast.
4
- """
5
-
6
- from __future__ import annotations
7
- from typing import Any
8
-
9
- import builtins
10
- import inspect
11
- import keyword
12
- import textwrap
13
- import linecache
14
- import weakref
15
-
16
- # Required despite static analysis claiming it is not used
17
- from sympy.external import import_module # noqa:F401
18
- from sympy.utilities.exceptions import sympy_deprecation_warning
19
- from sympy.utilities.decorator import doctest_depends_on
20
- from sympy.utilities.iterables import (is_sequence, iterable,
21
- NotIterable, flatten)
22
- from sympy.utilities.misc import filldedent
23
-
24
-
25
- __doctest_requires__ = {('lambdify',): ['numpy', 'tensorflow']}
26
-
27
-
28
- # Default namespaces, letting us define translations that can't be defined
29
- # by simple variable maps, like I => 1j
30
- MATH_DEFAULT: dict[str, Any] = {}
31
- CMATH_DEFAULT: dict[str,Any] = {}
32
- MPMATH_DEFAULT: dict[str, Any] = {}
33
- NUMPY_DEFAULT: dict[str, Any] = {"I": 1j}
34
- SCIPY_DEFAULT: dict[str, Any] = {"I": 1j}
35
- CUPY_DEFAULT: dict[str, Any] = {"I": 1j}
36
- JAX_DEFAULT: dict[str, Any] = {"I": 1j}
37
- TENSORFLOW_DEFAULT: dict[str, Any] = {}
38
- TORCH_DEFAULT: dict[str, Any] = {"I": 1j}
39
- SYMPY_DEFAULT: dict[str, Any] = {}
40
- NUMEXPR_DEFAULT: dict[str, Any] = {}
41
-
42
- # These are the namespaces the lambda functions will use.
43
- # These are separate from the names above because they are modified
44
- # throughout this file, whereas the defaults should remain unmodified.
45
-
46
- MATH = MATH_DEFAULT.copy()
47
- CMATH = CMATH_DEFAULT.copy()
48
- MPMATH = MPMATH_DEFAULT.copy()
49
- NUMPY = NUMPY_DEFAULT.copy()
50
- SCIPY = SCIPY_DEFAULT.copy()
51
- CUPY = CUPY_DEFAULT.copy()
52
- JAX = JAX_DEFAULT.copy()
53
- TENSORFLOW = TENSORFLOW_DEFAULT.copy()
54
- TORCH = TORCH_DEFAULT.copy()
55
- SYMPY = SYMPY_DEFAULT.copy()
56
- NUMEXPR = NUMEXPR_DEFAULT.copy()
57
-
58
-
59
- # Mappings between SymPy and other modules function names.
60
- MATH_TRANSLATIONS = {
61
- "ceiling": "ceil",
62
- "E": "e",
63
- "ln": "log",
64
- }
65
-
66
- CMATH_TRANSLATIONS: dict[str, str] = {}
67
-
68
- # NOTE: This dictionary is reused in Function._eval_evalf to allow subclasses
69
- # of Function to automatically evalf.
70
- MPMATH_TRANSLATIONS = {
71
- "Abs": "fabs",
72
- "elliptic_k": "ellipk",
73
- "elliptic_f": "ellipf",
74
- "elliptic_e": "ellipe",
75
- "elliptic_pi": "ellippi",
76
- "ceiling": "ceil",
77
- "chebyshevt": "chebyt",
78
- "chebyshevu": "chebyu",
79
- "assoc_legendre": "legenp",
80
- "E": "e",
81
- "I": "j",
82
- "ln": "log",
83
- #"lowergamma":"lower_gamma",
84
- "oo": "inf",
85
- #"uppergamma":"upper_gamma",
86
- "LambertW": "lambertw",
87
- "MutableDenseMatrix": "matrix",
88
- "ImmutableDenseMatrix": "matrix",
89
- "conjugate": "conj",
90
- "dirichlet_eta": "altzeta",
91
- "Ei": "ei",
92
- "Shi": "shi",
93
- "Chi": "chi",
94
- "Si": "si",
95
- "Ci": "ci",
96
- "RisingFactorial": "rf",
97
- "FallingFactorial": "ff",
98
- "betainc_regularized": "betainc",
99
- }
100
-
101
- NUMPY_TRANSLATIONS: dict[str, str] = {
102
- "Heaviside": "heaviside",
103
- }
104
- SCIPY_TRANSLATIONS: dict[str, str] = {
105
- "jn" : "spherical_jn",
106
- "yn" : "spherical_yn"
107
- }
108
- CUPY_TRANSLATIONS: dict[str, str] = {}
109
- JAX_TRANSLATIONS: dict[str, str] = {}
110
-
111
- TENSORFLOW_TRANSLATIONS: dict[str, str] = {}
112
- TORCH_TRANSLATIONS: dict[str, str] = {}
113
-
114
- NUMEXPR_TRANSLATIONS: dict[str, str] = {}
115
-
116
- # Available modules:
117
- MODULES = {
118
- "math": (MATH, MATH_DEFAULT, MATH_TRANSLATIONS, ("from math import *",)),
119
- "cmath": (CMATH, CMATH_DEFAULT, CMATH_TRANSLATIONS, ("import cmath; from cmath import *",)),
120
- "mpmath": (MPMATH, MPMATH_DEFAULT, MPMATH_TRANSLATIONS, ("from mpmath import *",)),
121
- "numpy": (NUMPY, NUMPY_DEFAULT, NUMPY_TRANSLATIONS, ("import numpy; from numpy import *; from numpy.linalg import *",)),
122
- "scipy": (SCIPY, SCIPY_DEFAULT, SCIPY_TRANSLATIONS, ("import scipy; import numpy; from scipy.special import *",)),
123
- "cupy": (CUPY, CUPY_DEFAULT, CUPY_TRANSLATIONS, ("import cupy",)),
124
- "jax": (JAX, JAX_DEFAULT, JAX_TRANSLATIONS, ("import jax",)),
125
- "tensorflow": (TENSORFLOW, TENSORFLOW_DEFAULT, TENSORFLOW_TRANSLATIONS, ("import tensorflow",)),
126
- "torch": (TORCH, TORCH_DEFAULT, TORCH_TRANSLATIONS, ("import torch",)),
127
- "sympy": (SYMPY, SYMPY_DEFAULT, {}, (
128
- "from sympy.functions import *",
129
- "from sympy.matrices import *",
130
- "from sympy import Integral, pi, oo, nan, zoo, E, I",)),
131
- "numexpr" : (NUMEXPR, NUMEXPR_DEFAULT, NUMEXPR_TRANSLATIONS,
132
- ("import_module('numexpr')", )),
133
- }
134
-
135
-
136
- def _import(module, reload=False):
137
- """
138
- Creates a global translation dictionary for module.
139
-
140
- The argument module has to be one of the following strings: "math","cmath"
141
- "mpmath", "numpy", "sympy", "tensorflow", "jax".
142
- These dictionaries map names of Python functions to their equivalent in
143
- other modules.
144
- """
145
- try:
146
- namespace, namespace_default, translations, import_commands = MODULES[
147
- module]
148
- except KeyError:
149
- raise NameError(
150
- "'%s' module cannot be used for lambdification" % module)
151
-
152
- # Clear namespace or exit
153
- if namespace != namespace_default:
154
- # The namespace was already generated, don't do it again if not forced.
155
- if reload:
156
- namespace.clear()
157
- namespace.update(namespace_default)
158
- else:
159
- return
160
-
161
- for import_command in import_commands:
162
- if import_command.startswith('import_module'):
163
- module = eval(import_command)
164
-
165
- if module is not None:
166
- namespace.update(module.__dict__)
167
- continue
168
- else:
169
- try:
170
- exec(import_command, {}, namespace)
171
- continue
172
- except ImportError:
173
- pass
174
-
175
- raise ImportError(
176
- "Cannot import '%s' with '%s' command" % (module, import_command))
177
-
178
- # Add translated names to namespace
179
- for sympyname, translation in translations.items():
180
- namespace[sympyname] = namespace[translation]
181
-
182
- # For computing the modulus of a SymPy expression we use the builtin abs
183
- # function, instead of the previously used fabs function for all
184
- # translation modules. This is because the fabs function in the math
185
- # module does not accept complex valued arguments. (see issue 9474). The
186
- # only exception, where we don't use the builtin abs function is the
187
- # mpmath translation module, because mpmath.fabs returns mpf objects in
188
- # contrast to abs().
189
- if 'Abs' not in namespace:
190
- namespace['Abs'] = abs
191
-
192
- # Used for dynamically generated filenames that are inserted into the
193
- # linecache.
194
- _lambdify_generated_counter = 1
195
-
196
-
197
- @doctest_depends_on(modules=('numpy', 'scipy', 'tensorflow',), python_version=(3,))
198
- def lambdify(args, expr, modules=None, printer=None, use_imps=True,
199
- dummify=False, cse=False, docstring_limit=1000):
200
- """Convert a SymPy expression into a function that allows for fast
201
- numeric evaluation.
202
-
203
- .. warning::
204
- This function uses ``exec``, and thus should not be used on
205
- unsanitized input.
206
-
207
- .. deprecated:: 1.7
208
- Passing a set for the *args* parameter is deprecated as sets are
209
- unordered. Use an ordered iterable such as a list or tuple.
210
-
211
- Explanation
212
- ===========
213
-
214
- For example, to convert the SymPy expression ``sin(x) + cos(x)`` to an
215
- equivalent NumPy function that numerically evaluates it:
216
-
217
- >>> from sympy import sin, cos, symbols, lambdify
218
- >>> import numpy as np
219
- >>> x = symbols('x')
220
- >>> expr = sin(x) + cos(x)
221
- >>> expr
222
- sin(x) + cos(x)
223
- >>> f = lambdify(x, expr, 'numpy')
224
- >>> a = np.array([1, 2])
225
- >>> f(a)
226
- [1.38177329 0.49315059]
227
-
228
- The primary purpose of this function is to provide a bridge from SymPy
229
- expressions to numerical libraries such as NumPy, SciPy, NumExpr, mpmath,
230
- and tensorflow. In general, SymPy functions do not work with objects from
231
- other libraries, such as NumPy arrays, and functions from numeric
232
- libraries like NumPy or mpmath do not work on SymPy expressions.
233
- ``lambdify`` bridges the two by converting a SymPy expression to an
234
- equivalent numeric function.
235
-
236
- The basic workflow with ``lambdify`` is to first create a SymPy expression
237
- representing whatever mathematical function you wish to evaluate. This
238
- should be done using only SymPy functions and expressions. Then, use
239
- ``lambdify`` to convert this to an equivalent function for numerical
240
- evaluation. For instance, above we created ``expr`` using the SymPy symbol
241
- ``x`` and SymPy functions ``sin`` and ``cos``, then converted it to an
242
- equivalent NumPy function ``f``, and called it on a NumPy array ``a``.
243
-
244
- Parameters
245
- ==========
246
-
247
- args : List[Symbol]
248
- A variable or a list of variables whose nesting represents the
249
- nesting of the arguments that will be passed to the function.
250
-
251
- Variables can be symbols, undefined functions, or matrix symbols.
252
-
253
- >>> from sympy import Eq
254
- >>> from sympy.abc import x, y, z
255
-
256
- The list of variables should match the structure of how the
257
- arguments will be passed to the function. Simply enclose the
258
- parameters as they will be passed in a list.
259
-
260
- To call a function like ``f(x)`` then ``[x]``
261
- should be the first argument to ``lambdify``; for this
262
- case a single ``x`` can also be used:
263
-
264
- >>> f = lambdify(x, x + 1)
265
- >>> f(1)
266
- 2
267
- >>> f = lambdify([x], x + 1)
268
- >>> f(1)
269
- 2
270
-
271
- To call a function like ``f(x, y)`` then ``[x, y]`` will
272
- be the first argument of the ``lambdify``:
273
-
274
- >>> f = lambdify([x, y], x + y)
275
- >>> f(1, 1)
276
- 2
277
-
278
- To call a function with a single 3-element tuple like
279
- ``f((x, y, z))`` then ``[(x, y, z)]`` will be the first
280
- argument of the ``lambdify``:
281
-
282
- >>> f = lambdify([(x, y, z)], Eq(z**2, x**2 + y**2))
283
- >>> f((3, 4, 5))
284
- True
285
-
286
- If two args will be passed and the first is a scalar but
287
- the second is a tuple with two arguments then the items
288
- in the list should match that structure:
289
-
290
- >>> f = lambdify([x, (y, z)], x + y + z)
291
- >>> f(1, (2, 3))
292
- 6
293
-
294
- expr : Expr
295
- An expression, list of expressions, or matrix to be evaluated.
296
-
297
- Lists may be nested.
298
- If the expression is a list, the output will also be a list.
299
-
300
- >>> f = lambdify(x, [x, [x + 1, x + 2]])
301
- >>> f(1)
302
- [1, [2, 3]]
303
-
304
- If it is a matrix, an array will be returned (for the NumPy module).
305
-
306
- >>> from sympy import Matrix
307
- >>> f = lambdify(x, Matrix([x, x + 1]))
308
- >>> f(1)
309
- [[1]
310
- [2]]
311
-
312
- Note that the argument order here (variables then expression) is used
313
- to emulate the Python ``lambda`` keyword. ``lambdify(x, expr)`` works
314
- (roughly) like ``lambda x: expr``
315
- (see :ref:`lambdify-how-it-works` below).
316
-
317
- modules : str, optional
318
- Specifies the numeric library to use.
319
-
320
- If not specified, *modules* defaults to:
321
-
322
- - ``["scipy", "numpy"]`` if SciPy is installed
323
- - ``["numpy"]`` if only NumPy is installed
324
- - ``["math","cmath", "mpmath", "sympy"]`` if neither is installed.
325
-
326
- That is, SymPy functions are replaced as far as possible by
327
- either ``scipy`` or ``numpy`` functions if available, and Python's
328
- standard library ``math`` and ``cmath``, or ``mpmath`` functions otherwise.
329
-
330
- *modules* can be one of the following types:
331
-
332
- - The strings ``"math"``, ``"cmath"``, ``"mpmath"``, ``"numpy"``, ``"numexpr"``,
333
- ``"scipy"``, ``"sympy"``, or ``"tensorflow"`` or ``"jax"``. This uses the
334
- corresponding printer and namespace mapping for that module.
335
- - A module (e.g., ``math``). This uses the global namespace of the
336
- module. If the module is one of the above known modules, it will
337
- also use the corresponding printer and namespace mapping
338
- (i.e., ``modules=numpy`` is equivalent to ``modules="numpy"``).
339
- - A dictionary that maps names of SymPy functions to arbitrary
340
- functions
341
- (e.g., ``{'sin': custom_sin}``).
342
- - A list that contains a mix of the arguments above, with higher
343
- priority given to entries appearing first
344
- (e.g., to use the NumPy module but override the ``sin`` function
345
- with a custom version, you can use
346
- ``[{'sin': custom_sin}, 'numpy']``).
347
-
348
- dummify : bool, optional
349
- Whether or not the variables in the provided expression that are not
350
- valid Python identifiers are substituted with dummy symbols.
351
-
352
- This allows for undefined functions like ``Function('f')(t)`` to be
353
- supplied as arguments. By default, the variables are only dummified
354
- if they are not valid Python identifiers.
355
-
356
- Set ``dummify=True`` to replace all arguments with dummy symbols
357
- (if ``args`` is not a string) - for example, to ensure that the
358
- arguments do not redefine any built-in names.
359
-
360
- cse : bool, or callable, optional
361
- Large expressions can be computed more efficiently when
362
- common subexpressions are identified and precomputed before
363
- being used multiple time. Finding the subexpressions will make
364
- creation of the 'lambdify' function slower, however.
365
-
366
- When ``True``, ``sympy.simplify.cse`` is used, otherwise (the default)
367
- the user may pass a function matching the ``cse`` signature.
368
-
369
- docstring_limit : int or None
370
- When lambdifying large expressions, a significant proportion of the time
371
- spent inside ``lambdify`` is spent producing a string representation of
372
- the expression for use in the automatically generated docstring of the
373
- returned function. For expressions containing hundreds or more nodes the
374
- resulting docstring often becomes so long and dense that it is difficult
375
- to read. To reduce the runtime of lambdify, the rendering of the full
376
- expression inside the docstring can be disabled.
377
-
378
- When ``None``, the full expression is rendered in the docstring. When
379
- ``0`` or a negative ``int``, an ellipsis is rendering in the docstring
380
- instead of the expression. When a strictly positive ``int``, if the
381
- number of nodes in the expression exceeds ``docstring_limit`` an
382
- ellipsis is rendered in the docstring, otherwise a string representation
383
- of the expression is rendered as normal. The default is ``1000``.
384
-
385
- Examples
386
- ========
387
-
388
- >>> from sympy.utilities.lambdify import implemented_function
389
- >>> from sympy import sqrt, sin, Matrix
390
- >>> from sympy import Function
391
- >>> from sympy.abc import w, x, y, z
392
-
393
- >>> f = lambdify(x, x**2)
394
- >>> f(2)
395
- 4
396
- >>> f = lambdify((x, y, z), [z, y, x])
397
- >>> f(1,2,3)
398
- [3, 2, 1]
399
- >>> f = lambdify(x, sqrt(x))
400
- >>> f(4)
401
- 2.0
402
- >>> f = lambdify((x, y), sin(x*y)**2)
403
- >>> f(0, 5)
404
- 0.0
405
- >>> row = lambdify((x, y), Matrix((x, x + y)).T, modules='sympy')
406
- >>> row(1, 2)
407
- Matrix([[1, 3]])
408
-
409
- ``lambdify`` can be used to translate SymPy expressions into mpmath
410
- functions. This may be preferable to using ``evalf`` (which uses mpmath on
411
- the backend) in some cases.
412
-
413
- >>> f = lambdify(x, sin(x), 'mpmath')
414
- >>> f(1)
415
- 0.8414709848078965
416
-
417
- Tuple arguments are handled and the lambdified function should
418
- be called with the same type of arguments as were used to create
419
- the function:
420
-
421
- >>> f = lambdify((x, (y, z)), x + y)
422
- >>> f(1, (2, 4))
423
- 3
424
-
425
- The ``flatten`` function can be used to always work with flattened
426
- arguments:
427
-
428
- >>> from sympy.utilities.iterables import flatten
429
- >>> args = w, (x, (y, z))
430
- >>> vals = 1, (2, (3, 4))
431
- >>> f = lambdify(flatten(args), w + x + y + z)
432
- >>> f(*flatten(vals))
433
- 10
434
-
435
- Functions present in ``expr`` can also carry their own numerical
436
- implementations, in a callable attached to the ``_imp_`` attribute. This
437
- can be used with undefined functions using the ``implemented_function``
438
- factory:
439
-
440
- >>> f = implemented_function(Function('f'), lambda x: x+1)
441
- >>> func = lambdify(x, f(x))
442
- >>> func(4)
443
- 5
444
-
445
- ``lambdify`` always prefers ``_imp_`` implementations to implementations
446
- in other namespaces, unless the ``use_imps`` input parameter is False.
447
-
448
- Usage with Tensorflow:
449
-
450
- >>> import tensorflow as tf
451
- >>> from sympy import Max, sin, lambdify
452
- >>> from sympy.abc import x
453
-
454
- >>> f = Max(x, sin(x))
455
- >>> func = lambdify(x, f, 'tensorflow')
456
-
457
- After tensorflow v2, eager execution is enabled by default.
458
- If you want to get the compatible result across tensorflow v1 and v2
459
- as same as this tutorial, run this line.
460
-
461
- >>> tf.compat.v1.enable_eager_execution()
462
-
463
- If you have eager execution enabled, you can get the result out
464
- immediately as you can use numpy.
465
-
466
- If you pass tensorflow objects, you may get an ``EagerTensor``
467
- object instead of value.
468
-
469
- >>> result = func(tf.constant(1.0))
470
- >>> print(result)
471
- tf.Tensor(1.0, shape=(), dtype=float32)
472
- >>> print(result.__class__)
473
- <class 'tensorflow.python.framework.ops.EagerTensor'>
474
-
475
- You can use ``.numpy()`` to get the numpy value of the tensor.
476
-
477
- >>> result.numpy()
478
- 1.0
479
-
480
- >>> var = tf.Variable(2.0)
481
- >>> result = func(var) # also works for tf.Variable and tf.Placeholder
482
- >>> result.numpy()
483
- 2.0
484
-
485
- And it works with any shape array.
486
-
487
- >>> tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])
488
- >>> result = func(tensor)
489
- >>> result.numpy()
490
- [[1. 2.]
491
- [3. 4.]]
492
-
493
- Notes
494
- =====
495
-
496
- - For functions involving large array calculations, numexpr can provide a
497
- significant speedup over numpy. Please note that the available functions
498
- for numexpr are more limited than numpy but can be expanded with
499
- ``implemented_function`` and user defined subclasses of Function. If
500
- specified, numexpr may be the only option in modules. The official list
501
- of numexpr functions can be found at:
502
- https://numexpr.readthedocs.io/en/latest/user_guide.html#supported-functions
503
-
504
- - In the above examples, the generated functions can accept scalar
505
- values or numpy arrays as arguments. However, in some cases
506
- the generated function relies on the input being a numpy array:
507
-
508
- >>> import numpy
509
- >>> from sympy import Piecewise
510
- >>> from sympy.testing.pytest import ignore_warnings
511
- >>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "numpy")
512
-
513
- >>> with ignore_warnings(RuntimeWarning):
514
- ... f(numpy.array([-1, 0, 1, 2]))
515
- [-1. 0. 1. 0.5]
516
-
517
- >>> f(0)
518
- Traceback (most recent call last):
519
- ...
520
- ZeroDivisionError: division by zero
521
-
522
- In such cases, the input should be wrapped in a numpy array:
523
-
524
- >>> with ignore_warnings(RuntimeWarning):
525
- ... float(f(numpy.array([0])))
526
- 0.0
527
-
528
- Or if numpy functionality is not required another module can be used:
529
-
530
- >>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "math")
531
- >>> f(0)
532
- 0
533
-
534
- .. _lambdify-how-it-works:
535
-
536
- How it works
537
- ============
538
-
539
- When using this function, it helps a great deal to have an idea of what it
540
- is doing. At its core, lambdify is nothing more than a namespace
541
- translation, on top of a special printer that makes some corner cases work
542
- properly.
543
-
544
- To understand lambdify, first we must properly understand how Python
545
- namespaces work. Say we had two files. One called ``sin_cos_sympy.py``,
546
- with
547
-
548
- .. code:: python
549
-
550
- # sin_cos_sympy.py
551
-
552
- from sympy.functions.elementary.trigonometric import (cos, sin)
553
-
554
- def sin_cos(x):
555
- return sin(x) + cos(x)
556
-
557
-
558
- and one called ``sin_cos_numpy.py`` with
559
-
560
- .. code:: python
561
-
562
- # sin_cos_numpy.py
563
-
564
- from numpy import sin, cos
565
-
566
- def sin_cos(x):
567
- return sin(x) + cos(x)
568
-
569
- The two files define an identical function ``sin_cos``. However, in the
570
- first file, ``sin`` and ``cos`` are defined as the SymPy ``sin`` and
571
- ``cos``. In the second, they are defined as the NumPy versions.
572
-
573
- If we were to import the first file and use the ``sin_cos`` function, we
574
- would get something like
575
-
576
- >>> from sin_cos_sympy import sin_cos # doctest: +SKIP
577
- >>> sin_cos(1) # doctest: +SKIP
578
- cos(1) + sin(1)
579
-
580
- On the other hand, if we imported ``sin_cos`` from the second file, we
581
- would get
582
-
583
- >>> from sin_cos_numpy import sin_cos # doctest: +SKIP
584
- >>> sin_cos(1) # doctest: +SKIP
585
- 1.38177329068
586
-
587
- In the first case we got a symbolic output, because it used the symbolic
588
- ``sin`` and ``cos`` functions from SymPy. In the second, we got a numeric
589
- result, because ``sin_cos`` used the numeric ``sin`` and ``cos`` functions
590
- from NumPy. But notice that the versions of ``sin`` and ``cos`` that were
591
- used was not inherent to the ``sin_cos`` function definition. Both
592
- ``sin_cos`` definitions are exactly the same. Rather, it was based on the
593
- names defined at the module where the ``sin_cos`` function was defined.
594
-
595
- The key point here is that when function in Python references a name that
596
- is not defined in the function, that name is looked up in the "global"
597
- namespace of the module where that function is defined.
598
-
599
- Now, in Python, we can emulate this behavior without actually writing a
600
- file to disk using the ``exec`` function. ``exec`` takes a string
601
- containing a block of Python code, and a dictionary that should contain
602
- the global variables of the module. It then executes the code "in" that
603
- dictionary, as if it were the module globals. The following is equivalent
604
- to the ``sin_cos`` defined in ``sin_cos_sympy.py``:
605
-
606
- >>> import sympy
607
- >>> module_dictionary = {'sin': sympy.sin, 'cos': sympy.cos}
608
- >>> exec('''
609
- ... def sin_cos(x):
610
- ... return sin(x) + cos(x)
611
- ... ''', module_dictionary)
612
- >>> sin_cos = module_dictionary['sin_cos']
613
- >>> sin_cos(1)
614
- cos(1) + sin(1)
615
-
616
- and similarly with ``sin_cos_numpy``:
617
-
618
- >>> import numpy
619
- >>> module_dictionary = {'sin': numpy.sin, 'cos': numpy.cos}
620
- >>> exec('''
621
- ... def sin_cos(x):
622
- ... return sin(x) + cos(x)
623
- ... ''', module_dictionary)
624
- >>> sin_cos = module_dictionary['sin_cos']
625
- >>> sin_cos(1)
626
- 1.38177329068
627
-
628
- So now we can get an idea of how ``lambdify`` works. The name "lambdify"
629
- comes from the fact that we can think of something like ``lambdify(x,
630
- sin(x) + cos(x), 'numpy')`` as ``lambda x: sin(x) + cos(x)``, where
631
- ``sin`` and ``cos`` come from the ``numpy`` namespace. This is also why
632
- the symbols argument is first in ``lambdify``, as opposed to most SymPy
633
- functions where it comes after the expression: to better mimic the
634
- ``lambda`` keyword.
635
-
636
- ``lambdify`` takes the input expression (like ``sin(x) + cos(x)``) and
637
-
638
- 1. Converts it to a string
639
- 2. Creates a module globals dictionary based on the modules that are
640
- passed in (by default, it uses the NumPy module)
641
- 3. Creates the string ``"def func({vars}): return {expr}"``, where ``{vars}`` is the
642
- list of variables separated by commas, and ``{expr}`` is the string
643
- created in step 1., then ``exec``s that string with the module globals
644
- namespace and returns ``func``.
645
-
646
- In fact, functions returned by ``lambdify`` support inspection. So you can
647
- see exactly how they are defined by using ``inspect.getsource``, or ``??`` if you
648
- are using IPython or the Jupyter notebook.
649
-
650
- >>> f = lambdify(x, sin(x) + cos(x))
651
- >>> import inspect
652
- >>> print(inspect.getsource(f))
653
- def _lambdifygenerated(x):
654
- return sin(x) + cos(x)
655
-
656
- This shows us the source code of the function, but not the namespace it
657
- was defined in. We can inspect that by looking at the ``__globals__``
658
- attribute of ``f``:
659
-
660
- >>> f.__globals__['sin']
661
- <ufunc 'sin'>
662
- >>> f.__globals__['cos']
663
- <ufunc 'cos'>
664
- >>> f.__globals__['sin'] is numpy.sin
665
- True
666
-
667
- This shows us that ``sin`` and ``cos`` in the namespace of ``f`` will be
668
- ``numpy.sin`` and ``numpy.cos``.
669
-
670
- Note that there are some convenience layers in each of these steps, but at
671
- the core, this is how ``lambdify`` works. Step 1 is done using the
672
- ``LambdaPrinter`` printers defined in the printing module (see
673
- :mod:`sympy.printing.lambdarepr`). This allows different SymPy expressions
674
- to define how they should be converted to a string for different modules.
675
- You can change which printer ``lambdify`` uses by passing a custom printer
676
- in to the ``printer`` argument.
677
-
678
- Step 2 is augmented by certain translations. There are default
679
- translations for each module, but you can provide your own by passing a
680
- list to the ``modules`` argument. For instance,
681
-
682
- >>> def mysin(x):
683
- ... print('taking the sin of', x)
684
- ... return numpy.sin(x)
685
- ...
686
- >>> f = lambdify(x, sin(x), [{'sin': mysin}, 'numpy'])
687
- >>> f(1)
688
- taking the sin of 1
689
- 0.8414709848078965
690
-
691
- The globals dictionary is generated from the list by merging the
692
- dictionary ``{'sin': mysin}`` and the module dictionary for NumPy. The
693
- merging is done so that earlier items take precedence, which is why
694
- ``mysin`` is used above instead of ``numpy.sin``.
695
-
696
- If you want to modify the way ``lambdify`` works for a given function, it
697
- is usually easiest to do so by modifying the globals dictionary as such.
698
- In more complicated cases, it may be necessary to create and pass in a
699
- custom printer.
700
-
701
- Finally, step 3 is augmented with certain convenience operations, such as
702
- the addition of a docstring.
703
-
704
- Understanding how ``lambdify`` works can make it easier to avoid certain
705
- gotchas when using it. For instance, a common mistake is to create a
706
- lambdified function for one module (say, NumPy), and pass it objects from
707
- another (say, a SymPy expression).
708
-
709
- For instance, say we create
710
-
711
- >>> from sympy.abc import x
712
- >>> f = lambdify(x, x + 1, 'numpy')
713
-
714
- Now if we pass in a NumPy array, we get that array plus 1
715
-
716
- >>> import numpy
717
- >>> a = numpy.array([1, 2])
718
- >>> f(a)
719
- [2 3]
720
-
721
- But what happens if you make the mistake of passing in a SymPy expression
722
- instead of a NumPy array:
723
-
724
- >>> f(x + 1)
725
- x + 2
726
-
727
- This worked, but it was only by accident. Now take a different lambdified
728
- function:
729
-
730
- >>> from sympy import sin
731
- >>> g = lambdify(x, x + sin(x), 'numpy')
732
-
733
- This works as expected on NumPy arrays:
734
-
735
- >>> g(a)
736
- [1.84147098 2.90929743]
737
-
738
- But if we try to pass in a SymPy expression, it fails
739
-
740
- >>> g(x + 1)
741
- Traceback (most recent call last):
742
- ...
743
- TypeError: loop of ufunc does not support argument 0 of type Add which has
744
- no callable sin method
745
-
746
- Now, let's look at what happened. The reason this fails is that ``g``
747
- calls ``numpy.sin`` on the input expression, and ``numpy.sin`` does not
748
- know how to operate on a SymPy object. **As a general rule, NumPy
749
- functions do not know how to operate on SymPy expressions, and SymPy
750
- functions do not know how to operate on NumPy arrays. This is why lambdify
751
- exists: to provide a bridge between SymPy and NumPy.**
752
-
753
- However, why is it that ``f`` did work? That's because ``f`` does not call
754
- any functions, it only adds 1. So the resulting function that is created,
755
- ``def _lambdifygenerated(x): return x + 1`` does not depend on the globals
756
- namespace it is defined in. Thus it works, but only by accident. A future
757
- version of ``lambdify`` may remove this behavior.
758
-
759
- Be aware that certain implementation details described here may change in
760
- future versions of SymPy. The API of passing in custom modules and
761
- printers will not change, but the details of how a lambda function is
762
- created may change. However, the basic idea will remain the same, and
763
- understanding it will be helpful to understanding the behavior of
764
- lambdify.
765
-
766
- **In general: you should create lambdified functions for one module (say,
767
- NumPy), and only pass it input types that are compatible with that module
768
- (say, NumPy arrays).** Remember that by default, if the ``module``
769
- argument is not provided, ``lambdify`` creates functions using the NumPy
770
- and SciPy namespaces.
771
- """
772
- from sympy.core.symbol import Symbol
773
- from sympy.core.expr import Expr
774
-
775
- # If the user hasn't specified any modules, use what is available.
776
- if modules is None:
777
- try:
778
- _import("scipy")
779
- except ImportError:
780
- try:
781
- _import("numpy")
782
- except ImportError:
783
- # Use either numpy (if available) or python.math where possible.
784
- # XXX: This leads to different behaviour on different systems and
785
- # might be the reason for irreproducible errors.
786
- modules = ["math", "mpmath", "sympy"]
787
- else:
788
- modules = ["numpy"]
789
- else:
790
- modules = ["numpy", "scipy"]
791
-
792
- # Get the needed namespaces.
793
- namespaces = []
794
- # First find any function implementations
795
- if use_imps:
796
- namespaces.append(_imp_namespace(expr))
797
- # Check for dict before iterating
798
- if isinstance(modules, (dict, str)) or not hasattr(modules, '__iter__'):
799
- namespaces.append(modules)
800
- else:
801
- # consistency check
802
- if _module_present('numexpr', modules) and len(modules) > 1:
803
- raise TypeError("numexpr must be the only item in 'modules'")
804
- namespaces += list(modules)
805
- # fill namespace with first having highest priority
806
- namespace = {}
807
- for m in namespaces[::-1]:
808
- buf = _get_namespace(m)
809
- namespace.update(buf)
810
-
811
- if hasattr(expr, "atoms"):
812
- #Try if you can extract symbols from the expression.
813
- #Move on if expr.atoms in not implemented.
814
- syms = expr.atoms(Symbol)
815
- for term in syms:
816
- namespace.update({str(term): term})
817
-
818
- if printer is None:
819
- if _module_present('mpmath', namespaces):
820
- from sympy.printing.pycode import MpmathPrinter as Printer # type: ignore
821
- elif _module_present('scipy', namespaces):
822
- from sympy.printing.numpy import SciPyPrinter as Printer # type: ignore
823
- elif _module_present('numpy', namespaces):
824
- from sympy.printing.numpy import NumPyPrinter as Printer # type: ignore
825
- elif _module_present('cupy', namespaces):
826
- from sympy.printing.numpy import CuPyPrinter as Printer # type: ignore
827
- elif _module_present('jax', namespaces):
828
- from sympy.printing.numpy import JaxPrinter as Printer # type: ignore
829
- elif _module_present('numexpr', namespaces):
830
- from sympy.printing.lambdarepr import NumExprPrinter as Printer # type: ignore
831
- elif _module_present('tensorflow', namespaces):
832
- from sympy.printing.tensorflow import TensorflowPrinter as Printer # type: ignore
833
- elif _module_present('torch', namespaces):
834
- from sympy.printing.pytorch import TorchPrinter as Printer # type: ignore
835
- elif _module_present('sympy', namespaces):
836
- from sympy.printing.pycode import SymPyPrinter as Printer # type: ignore
837
- elif _module_present('cmath', namespaces):
838
- from sympy.printing.pycode import CmathPrinter as Printer # type: ignore
839
- else:
840
- from sympy.printing.pycode import PythonCodePrinter as Printer # type: ignore
841
- user_functions = {}
842
- for m in namespaces[::-1]:
843
- if isinstance(m, dict):
844
- for k in m:
845
- user_functions[k] = k
846
- printer = Printer({'fully_qualified_modules': False, 'inline': True,
847
- 'allow_unknown_functions': True,
848
- 'user_functions': user_functions})
849
-
850
- if isinstance(args, set):
851
- sympy_deprecation_warning(
852
- """
853
- Passing the function arguments to lambdify() as a set is deprecated. This
854
- leads to unpredictable results since sets are unordered. Instead, use a list
855
- or tuple for the function arguments.
856
- """,
857
- deprecated_since_version="1.6.3",
858
- active_deprecations_target="deprecated-lambdify-arguments-set",
859
- )
860
-
861
- # Get the names of the args, for creating a docstring
862
- iterable_args = (args,) if isinstance(args, Expr) else args
863
- names = []
864
-
865
- # Grab the callers frame, for getting the names by inspection (if needed)
866
- callers_local_vars = inspect.currentframe().f_back.f_locals.items() # type: ignore
867
- for n, var in enumerate(iterable_args):
868
- if hasattr(var, 'name'):
869
- names.append(var.name)
870
- else:
871
- # It's an iterable. Try to get name by inspection of calling frame.
872
- name_list = [var_name for var_name, var_val in callers_local_vars
873
- if var_val is var]
874
- if len(name_list) == 1:
875
- names.append(name_list[0])
876
- else:
877
- # Cannot infer name with certainty. arg_# will have to do.
878
- names.append('arg_' + str(n))
879
-
880
- # Create the function definition code and execute it
881
- funcname = '_lambdifygenerated'
882
- if _module_present('tensorflow', namespaces):
883
- funcprinter = _TensorflowEvaluatorPrinter(printer, dummify)
884
- else:
885
- funcprinter = _EvaluatorPrinter(printer, dummify)
886
-
887
- if cse == True:
888
- from sympy.simplify.cse_main import cse as _cse
889
- cses, _expr = _cse(expr, list=False)
890
- elif callable(cse):
891
- cses, _expr = cse(expr)
892
- else:
893
- cses, _expr = (), expr
894
- funcstr = funcprinter.doprint(funcname, iterable_args, _expr, cses=cses)
895
-
896
- # Collect the module imports from the code printers.
897
- imp_mod_lines = []
898
- for mod, keys in (getattr(printer, 'module_imports', None) or {}).items():
899
- for k in keys:
900
- if k not in namespace:
901
- ln = "from %s import %s" % (mod, k)
902
- try:
903
- exec(ln, {}, namespace)
904
- except ImportError:
905
- # Tensorflow 2.0 has issues with importing a specific
906
- # function from its submodule.
907
- # https://github.com/tensorflow/tensorflow/issues/33022
908
- ln = "%s = %s.%s" % (k, mod, k)
909
- exec(ln, {}, namespace)
910
- imp_mod_lines.append(ln)
911
-
912
- # Provide lambda expression with builtins, and compatible implementation of range
913
- namespace.update({'builtins':builtins, 'range':range})
914
-
915
- funclocals = {}
916
- global _lambdify_generated_counter
917
- filename = '<lambdifygenerated-%s>' % _lambdify_generated_counter
918
- _lambdify_generated_counter += 1
919
- c = compile(funcstr, filename, 'exec')
920
- exec(c, namespace, funclocals)
921
- # mtime has to be None or else linecache.checkcache will remove it
922
- linecache.cache[filename] = (len(funcstr), None, funcstr.splitlines(True), filename) # type: ignore
923
-
924
- # Remove the entry from the linecache when the object is garbage collected
925
- def cleanup_linecache(filename):
926
- def _cleanup():
927
- if filename in linecache.cache:
928
- del linecache.cache[filename]
929
- return _cleanup
930
-
931
- func = funclocals[funcname]
932
-
933
- weakref.finalize(func, cleanup_linecache(filename))
934
-
935
- # Apply the docstring
936
- sig = "func({})".format(", ".join(str(i) for i in names))
937
- sig = textwrap.fill(sig, subsequent_indent=' '*8)
938
- if _too_large_for_docstring(expr, docstring_limit):
939
- expr_str = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
940
- src_str = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
941
- else:
942
- expr_str = str(expr)
943
- if len(expr_str) > 78:
944
- expr_str = textwrap.wrap(expr_str, 75)[0] + '...'
945
- src_str = funcstr
946
- func.__doc__ = (
947
- "Created with lambdify. Signature:\n\n"
948
- "{sig}\n\n"
949
- "Expression:\n\n"
950
- "{expr}\n\n"
951
- "Source code:\n\n"
952
- "{src}\n\n"
953
- "Imported modules:\n\n"
954
- "{imp_mods}"
955
- ).format(sig=sig, expr=expr_str, src=src_str, imp_mods='\n'.join(imp_mod_lines))
956
- return func
957
-
958
- def _module_present(modname, modlist):
959
- if modname in modlist:
960
- return True
961
- for m in modlist:
962
- if hasattr(m, '__name__') and m.__name__ == modname:
963
- return True
964
- return False
965
-
966
- def _get_namespace(m):
967
- """
968
- This is used by _lambdify to parse its arguments.
969
- """
970
- if isinstance(m, str):
971
- _import(m)
972
- return MODULES[m][0]
973
- elif isinstance(m, dict):
974
- return m
975
- elif hasattr(m, "__dict__"):
976
- return m.__dict__
977
- else:
978
- raise TypeError("Argument must be either a string, dict or module but it is: %s" % m)
979
-
980
-
981
- def _recursive_to_string(doprint, arg):
982
- """Functions in lambdify accept both SymPy types and non-SymPy types such as python
983
- lists and tuples. This method ensures that we only call the doprint method of the
984
- printer with SymPy types (so that the printer safely can use SymPy-methods)."""
985
- from sympy.matrices.matrixbase import MatrixBase
986
- from sympy.core.basic import Basic
987
-
988
- if isinstance(arg, (Basic, MatrixBase)):
989
- return doprint(arg)
990
- elif iterable(arg):
991
- if isinstance(arg, list):
992
- left, right = "[", "]"
993
- elif isinstance(arg, tuple):
994
- left, right = "(", ",)"
995
- if not arg:
996
- return "()"
997
- else:
998
- raise NotImplementedError("unhandled type: %s, %s" % (type(arg), arg))
999
- return left +', '.join(_recursive_to_string(doprint, e) for e in arg) + right
1000
- elif isinstance(arg, str):
1001
- return arg
1002
- else:
1003
- return doprint(arg)
1004
-
1005
-
1006
- def lambdastr(args, expr, printer=None, dummify=None):
1007
- """
1008
- Returns a string that can be evaluated to a lambda function.
1009
-
1010
- Examples
1011
- ========
1012
-
1013
- >>> from sympy.abc import x, y, z
1014
- >>> from sympy.utilities.lambdify import lambdastr
1015
- >>> lambdastr(x, x**2)
1016
- 'lambda x: (x**2)'
1017
- >>> lambdastr((x,y,z), [z,y,x])
1018
- 'lambda x,y,z: ([z, y, x])'
1019
-
1020
- Although tuples may not appear as arguments to lambda in Python 3,
1021
- lambdastr will create a lambda function that will unpack the original
1022
- arguments so that nested arguments can be handled:
1023
-
1024
- >>> lambdastr((x, (y, z)), x + y)
1025
- 'lambda _0,_1: (lambda x,y,z: (x + y))(_0,_1[0],_1[1])'
1026
- """
1027
- # Transforming everything to strings.
1028
- from sympy.matrices import DeferredVector
1029
- from sympy.core.basic import Basic
1030
- from sympy.core.function import (Derivative, Function)
1031
- from sympy.core.symbol import (Dummy, Symbol)
1032
- from sympy.core.sympify import sympify
1033
-
1034
- if printer is not None:
1035
- if inspect.isfunction(printer):
1036
- lambdarepr = printer
1037
- else:
1038
- if inspect.isclass(printer):
1039
- lambdarepr = lambda expr: printer().doprint(expr)
1040
- else:
1041
- lambdarepr = lambda expr: printer.doprint(expr)
1042
- else:
1043
- #XXX: This has to be done here because of circular imports
1044
- from sympy.printing.lambdarepr import lambdarepr
1045
-
1046
- def sub_args(args, dummies_dict):
1047
- if isinstance(args, str):
1048
- return args
1049
- elif isinstance(args, DeferredVector):
1050
- return str(args)
1051
- elif iterable(args):
1052
- dummies = flatten([sub_args(a, dummies_dict) for a in args])
1053
- return ",".join(str(a) for a in dummies)
1054
- else:
1055
- # replace these with Dummy symbols
1056
- if isinstance(args, (Function, Symbol, Derivative)):
1057
- dummies = Dummy()
1058
- dummies_dict.update({args : dummies})
1059
- return str(dummies)
1060
- else:
1061
- return str(args)
1062
-
1063
- def sub_expr(expr, dummies_dict):
1064
- expr = sympify(expr)
1065
- # dict/tuple are sympified to Basic
1066
- if isinstance(expr, Basic):
1067
- expr = expr.xreplace(dummies_dict)
1068
- # list is not sympified to Basic
1069
- elif isinstance(expr, list):
1070
- expr = [sub_expr(a, dummies_dict) for a in expr]
1071
- return expr
1072
-
1073
- # Transform args
1074
- def isiter(l):
1075
- return iterable(l, exclude=(str, DeferredVector, NotIterable))
1076
-
1077
- def flat_indexes(iterable):
1078
- n = 0
1079
-
1080
- for el in iterable:
1081
- if isiter(el):
1082
- for ndeep in flat_indexes(el):
1083
- yield (n,) + ndeep
1084
- else:
1085
- yield (n,)
1086
-
1087
- n += 1
1088
-
1089
- if dummify is None:
1090
- dummify = any(isinstance(a, Basic) and
1091
- a.atoms(Function, Derivative) for a in (
1092
- args if isiter(args) else [args]))
1093
-
1094
- if isiter(args) and any(isiter(i) for i in args):
1095
- dum_args = [str(Dummy(str(i))) for i in range(len(args))]
1096
-
1097
- indexed_args = ','.join([
1098
- dum_args[ind[0]] + ''.join(["[%s]" % k for k in ind[1:]])
1099
- for ind in flat_indexes(args)])
1100
-
1101
- lstr = lambdastr(flatten(args), expr, printer=printer, dummify=dummify)
1102
-
1103
- return 'lambda %s: (%s)(%s)' % (','.join(dum_args), lstr, indexed_args)
1104
-
1105
- dummies_dict = {}
1106
- if dummify:
1107
- args = sub_args(args, dummies_dict)
1108
- else:
1109
- if isinstance(args, str):
1110
- pass
1111
- elif iterable(args, exclude=DeferredVector):
1112
- args = ",".join(str(a) for a in args)
1113
-
1114
- # Transform expr
1115
- if dummify:
1116
- if isinstance(expr, str):
1117
- pass
1118
- else:
1119
- expr = sub_expr(expr, dummies_dict)
1120
- expr = _recursive_to_string(lambdarepr, expr)
1121
- return "lambda %s: (%s)" % (args, expr)
1122
-
1123
- class _EvaluatorPrinter:
1124
- def __init__(self, printer=None, dummify=False):
1125
- self._dummify = dummify
1126
-
1127
- #XXX: This has to be done here because of circular imports
1128
- from sympy.printing.lambdarepr import LambdaPrinter
1129
-
1130
- if printer is None:
1131
- printer = LambdaPrinter()
1132
-
1133
- if inspect.isfunction(printer):
1134
- self._exprrepr = printer
1135
- else:
1136
- if inspect.isclass(printer):
1137
- printer = printer()
1138
-
1139
- self._exprrepr = printer.doprint
1140
-
1141
- #if hasattr(printer, '_print_Symbol'):
1142
- # symbolrepr = printer._print_Symbol
1143
-
1144
- #if hasattr(printer, '_print_Dummy'):
1145
- # dummyrepr = printer._print_Dummy
1146
-
1147
- # Used to print the generated function arguments in a standard way
1148
- self._argrepr = LambdaPrinter().doprint
1149
-
1150
- def doprint(self, funcname, args, expr, *, cses=()):
1151
- """
1152
- Returns the function definition code as a string.
1153
- """
1154
- from sympy.core.symbol import Dummy
1155
-
1156
- funcbody = []
1157
-
1158
- if not iterable(args):
1159
- args = [args]
1160
-
1161
- if cses:
1162
- cses = list(cses)
1163
- subvars, subexprs = zip(*cses)
1164
- exprs = [expr] + list(subexprs)
1165
- argstrs, exprs = self._preprocess(args, exprs, cses=cses)
1166
- expr, subexprs = exprs[0], exprs[1:]
1167
- cses = zip(subvars, subexprs)
1168
- else:
1169
- argstrs, expr = self._preprocess(args, expr)
1170
-
1171
- # Generate argument unpacking and final argument list
1172
- funcargs = []
1173
- unpackings = []
1174
-
1175
- for argstr in argstrs:
1176
- if iterable(argstr):
1177
- funcargs.append(self._argrepr(Dummy()))
1178
- unpackings.extend(self._print_unpacking(argstr, funcargs[-1]))
1179
- else:
1180
- funcargs.append(argstr)
1181
-
1182
- funcsig = 'def {}({}):'.format(funcname, ', '.join(funcargs))
1183
-
1184
- # Wrap input arguments before unpacking
1185
- funcbody.extend(self._print_funcargwrapping(funcargs))
1186
-
1187
- funcbody.extend(unpackings)
1188
-
1189
- for s, e in cses:
1190
- if e is None:
1191
- funcbody.append('del {}'.format(self._exprrepr(s)))
1192
- else:
1193
- funcbody.append('{} = {}'.format(self._exprrepr(s), self._exprrepr(e)))
1194
-
1195
- # Subs may appear in expressions generated by .diff()
1196
- subs_assignments = []
1197
- expr = self._handle_Subs(expr, out=subs_assignments)
1198
- for lhs, rhs in subs_assignments:
1199
- funcbody.append('{} = {}'.format(self._exprrepr(lhs), self._exprrepr(rhs)))
1200
-
1201
- str_expr = _recursive_to_string(self._exprrepr, expr)
1202
-
1203
- if '\n' in str_expr:
1204
- str_expr = '({})'.format(str_expr)
1205
- funcbody.append('return {}'.format(str_expr))
1206
-
1207
- funclines = [funcsig]
1208
- funclines.extend([' ' + line for line in funcbody])
1209
-
1210
- return '\n'.join(funclines) + '\n'
1211
-
1212
- @classmethod
1213
- def _is_safe_ident(cls, ident):
1214
- return isinstance(ident, str) and ident.isidentifier() \
1215
- and not keyword.iskeyword(ident)
1216
-
1217
- def _preprocess(self, args, expr, cses=(), _dummies_dict=None):
1218
- """Preprocess args, expr to replace arguments that do not map
1219
- to valid Python identifiers.
1220
-
1221
- Returns string form of args, and updated expr.
1222
- """
1223
- from sympy.core.basic import Basic
1224
- from sympy.core.sorting import ordered
1225
- from sympy.core.function import (Derivative, Function)
1226
- from sympy.core.symbol import Dummy, uniquely_named_symbol
1227
- from sympy.matrices import DeferredVector
1228
- from sympy.core.expr import Expr
1229
-
1230
- # Args of type Dummy can cause name collisions with args
1231
- # of type Symbol. Force dummify of everything in this
1232
- # situation.
1233
- dummify = self._dummify or any(
1234
- isinstance(arg, Dummy) for arg in flatten(args))
1235
-
1236
- argstrs = [None]*len(args)
1237
- if _dummies_dict is None:
1238
- _dummies_dict = {}
1239
-
1240
- def update_dummies(arg, dummy):
1241
- _dummies_dict[arg] = dummy
1242
- for repl, sub in cses:
1243
- arg = arg.xreplace({sub: repl})
1244
- _dummies_dict[arg] = dummy
1245
-
1246
- for arg, i in reversed(list(ordered(zip(args, range(len(args)))))):
1247
- if iterable(arg):
1248
- s, expr = self._preprocess(arg, expr, cses=cses, _dummies_dict=_dummies_dict)
1249
- elif isinstance(arg, DeferredVector):
1250
- s = str(arg)
1251
- elif isinstance(arg, Basic) and arg.is_symbol:
1252
- s = str(arg)
1253
- if dummify or not self._is_safe_ident(s):
1254
- dummy = Dummy()
1255
- if isinstance(expr, Expr):
1256
- dummy = uniquely_named_symbol(
1257
- dummy.name, expr, modify=lambda s: '_' + s)
1258
- s = self._argrepr(dummy)
1259
- update_dummies(arg, dummy)
1260
- expr = self._subexpr(expr, _dummies_dict)
1261
- elif dummify or isinstance(arg, (Function, Derivative)):
1262
- dummy = Dummy()
1263
- s = self._argrepr(dummy)
1264
- update_dummies(arg, dummy)
1265
- expr = self._subexpr(expr, _dummies_dict)
1266
- else:
1267
- s = str(arg)
1268
- argstrs[i] = s
1269
- return argstrs, expr
1270
-
1271
- def _subexpr(self, expr, dummies_dict):
1272
- from sympy.matrices import DeferredVector
1273
- from sympy.core.sympify import sympify
1274
-
1275
- expr = sympify(expr)
1276
- xreplace = getattr(expr, 'xreplace', None)
1277
- if xreplace is not None:
1278
- expr = xreplace(dummies_dict)
1279
- else:
1280
- if isinstance(expr, DeferredVector):
1281
- pass
1282
- elif isinstance(expr, dict):
1283
- k = [self._subexpr(sympify(a), dummies_dict) for a in expr.keys()]
1284
- v = [self._subexpr(sympify(a), dummies_dict) for a in expr.values()]
1285
- expr = dict(zip(k, v))
1286
- elif isinstance(expr, tuple):
1287
- expr = tuple(self._subexpr(sympify(a), dummies_dict) for a in expr)
1288
- elif isinstance(expr, list):
1289
- expr = [self._subexpr(sympify(a), dummies_dict) for a in expr]
1290
- return expr
1291
-
1292
- def _print_funcargwrapping(self, args):
1293
- """Generate argument wrapping code.
1294
-
1295
- args is the argument list of the generated function (strings).
1296
-
1297
- Return value is a list of lines of code that will be inserted at
1298
- the beginning of the function definition.
1299
- """
1300
- return []
1301
-
1302
- def _print_unpacking(self, unpackto, arg):
1303
- """Generate argument unpacking code.
1304
-
1305
- arg is the function argument to be unpacked (a string), and
1306
- unpackto is a list or nested lists of the variable names (strings) to
1307
- unpack to.
1308
- """
1309
- def unpack_lhs(lvalues):
1310
- return '[{}]'.format(', '.join(
1311
- unpack_lhs(val) if iterable(val) else val for val in lvalues))
1312
-
1313
- return ['{} = {}'.format(unpack_lhs(unpackto), arg)]
1314
-
1315
- def _handle_Subs(self, expr, out):
1316
- """Any instance of Subs is extracted and returned as assignment pairs."""
1317
- from sympy.core.basic import Basic
1318
- from sympy.core.function import Subs
1319
- from sympy.core.symbol import Dummy
1320
- from sympy.matrices.matrixbase import MatrixBase
1321
-
1322
- def _replace(ex, variables, point):
1323
- safe = {}
1324
- for lhs, rhs in zip(variables, point):
1325
- dummy = Dummy()
1326
- safe[lhs] = dummy
1327
- out.append((dummy, rhs))
1328
- return ex.xreplace(safe)
1329
-
1330
- if isinstance(expr, (Basic, MatrixBase)):
1331
- expr = expr.replace(Subs, _replace)
1332
- elif iterable(expr):
1333
- expr = type(expr)([self._handle_Subs(e, out) for e in expr])
1334
- return expr
1335
-
1336
- class _TensorflowEvaluatorPrinter(_EvaluatorPrinter):
1337
- def _print_unpacking(self, lvalues, rvalue):
1338
- """Generate argument unpacking code.
1339
-
1340
- This method is used when the input value is not iterable,
1341
- but can be indexed (see issue #14655).
1342
- """
1343
-
1344
- def flat_indexes(elems):
1345
- n = 0
1346
-
1347
- for el in elems:
1348
- if iterable(el):
1349
- for ndeep in flat_indexes(el):
1350
- yield (n,) + ndeep
1351
- else:
1352
- yield (n,)
1353
-
1354
- n += 1
1355
-
1356
- indexed = ', '.join('{}[{}]'.format(rvalue, ']['.join(map(str, ind)))
1357
- for ind in flat_indexes(lvalues))
1358
-
1359
- return ['[{}] = [{}]'.format(', '.join(flatten(lvalues)), indexed)]
1360
-
1361
- def _imp_namespace(expr, namespace=None):
1362
- """ Return namespace dict with function implementations
1363
-
1364
- We need to search for functions in anything that can be thrown at
1365
- us - that is - anything that could be passed as ``expr``. Examples
1366
- include SymPy expressions, as well as tuples, lists and dicts that may
1367
- contain SymPy expressions.
1368
-
1369
- Parameters
1370
- ----------
1371
- expr : object
1372
- Something passed to lambdify, that will generate valid code from
1373
- ``str(expr)``.
1374
- namespace : None or mapping
1375
- Namespace to fill. None results in new empty dict
1376
-
1377
- Returns
1378
- -------
1379
- namespace : dict
1380
- dict with keys of implemented function names within ``expr`` and
1381
- corresponding values being the numerical implementation of
1382
- function
1383
-
1384
- Examples
1385
- ========
1386
-
1387
- >>> from sympy.abc import x
1388
- >>> from sympy.utilities.lambdify import implemented_function, _imp_namespace
1389
- >>> from sympy import Function
1390
- >>> f = implemented_function(Function('f'), lambda x: x+1)
1391
- >>> g = implemented_function(Function('g'), lambda x: x*10)
1392
- >>> namespace = _imp_namespace(f(g(x)))
1393
- >>> sorted(namespace.keys())
1394
- ['f', 'g']
1395
- """
1396
- # Delayed import to avoid circular imports
1397
- from sympy.core.function import FunctionClass
1398
- if namespace is None:
1399
- namespace = {}
1400
- # tuples, lists, dicts are valid expressions
1401
- if is_sequence(expr):
1402
- for arg in expr:
1403
- _imp_namespace(arg, namespace)
1404
- return namespace
1405
- elif isinstance(expr, dict):
1406
- for key, val in expr.items():
1407
- # functions can be in dictionary keys
1408
- _imp_namespace(key, namespace)
1409
- _imp_namespace(val, namespace)
1410
- return namespace
1411
- # SymPy expressions may be Functions themselves
1412
- func = getattr(expr, 'func', None)
1413
- if isinstance(func, FunctionClass):
1414
- imp = getattr(func, '_imp_', None)
1415
- if imp is not None:
1416
- name = expr.func.__name__
1417
- if name in namespace and namespace[name] != imp:
1418
- raise ValueError('We found more than one '
1419
- 'implementation with name '
1420
- '"%s"' % name)
1421
- namespace[name] = imp
1422
- # and / or they may take Functions as arguments
1423
- if hasattr(expr, 'args'):
1424
- for arg in expr.args:
1425
- _imp_namespace(arg, namespace)
1426
- return namespace
1427
-
1428
-
1429
- def implemented_function(symfunc, implementation):
1430
- """ Add numerical ``implementation`` to function ``symfunc``.
1431
-
1432
- ``symfunc`` can be an ``UndefinedFunction`` instance, or a name string.
1433
- In the latter case we create an ``UndefinedFunction`` instance with that
1434
- name.
1435
-
1436
- Be aware that this is a quick workaround, not a general method to create
1437
- special symbolic functions. If you want to create a symbolic function to be
1438
- used by all the machinery of SymPy you should subclass the ``Function``
1439
- class.
1440
-
1441
- Parameters
1442
- ----------
1443
- symfunc : ``str`` or ``UndefinedFunction`` instance
1444
- If ``str``, then create new ``UndefinedFunction`` with this as
1445
- name. If ``symfunc`` is an Undefined function, create a new function
1446
- with the same name and the implemented function attached.
1447
- implementation : callable
1448
- numerical implementation to be called by ``evalf()`` or ``lambdify``
1449
-
1450
- Returns
1451
- -------
1452
- afunc : sympy.FunctionClass instance
1453
- function with attached implementation
1454
-
1455
- Examples
1456
- ========
1457
-
1458
- >>> from sympy.abc import x
1459
- >>> from sympy.utilities.lambdify import implemented_function
1460
- >>> from sympy import lambdify
1461
- >>> f = implemented_function('f', lambda x: x+1)
1462
- >>> lam_f = lambdify(x, f(x))
1463
- >>> lam_f(4)
1464
- 5
1465
- """
1466
- # Delayed import to avoid circular imports
1467
- from sympy.core.function import UndefinedFunction
1468
- # if name, create function to hold implementation
1469
- kwargs = {}
1470
- if isinstance(symfunc, UndefinedFunction):
1471
- kwargs = symfunc._kwargs
1472
- symfunc = symfunc.__name__
1473
- if isinstance(symfunc, str):
1474
- # Keyword arguments to UndefinedFunction are added as attributes to
1475
- # the created class.
1476
- symfunc = UndefinedFunction(
1477
- symfunc, _imp_=staticmethod(implementation), **kwargs)
1478
- elif not isinstance(symfunc, UndefinedFunction):
1479
- raise ValueError(filldedent('''
1480
- symfunc should be either a string or
1481
- an UndefinedFunction instance.'''))
1482
- return symfunc
1483
-
1484
-
1485
- def _too_large_for_docstring(expr, limit):
1486
- """Decide whether an ``Expr`` is too large to be fully rendered in a
1487
- ``lambdify`` docstring.
1488
-
1489
- This is a fast alternative to ``count_ops``, which can become prohibitively
1490
- slow for large expressions, because in this instance we only care whether
1491
- ``limit`` is exceeded rather than counting the exact number of nodes in the
1492
- expression.
1493
-
1494
- Parameters
1495
- ==========
1496
- expr : ``Expr``, (nested) ``list`` of ``Expr``, or ``Matrix``
1497
- The same objects that can be passed to the ``expr`` argument of
1498
- ``lambdify``.
1499
- limit : ``int`` or ``None``
1500
- The threshold above which an expression contains too many nodes to be
1501
- usefully rendered in the docstring. If ``None`` then there is no limit.
1502
-
1503
- Returns
1504
- =======
1505
- bool
1506
- ``True`` if the number of nodes in the expression exceeds the limit,
1507
- ``False`` otherwise.
1508
-
1509
- Examples
1510
- ========
1511
-
1512
- >>> from sympy.abc import x, y, z
1513
- >>> from sympy.utilities.lambdify import _too_large_for_docstring
1514
- >>> expr = x
1515
- >>> _too_large_for_docstring(expr, None)
1516
- False
1517
- >>> _too_large_for_docstring(expr, 100)
1518
- False
1519
- >>> _too_large_for_docstring(expr, 1)
1520
- False
1521
- >>> _too_large_for_docstring(expr, 0)
1522
- True
1523
- >>> _too_large_for_docstring(expr, -1)
1524
- True
1525
-
1526
- Does this split it?
1527
-
1528
- >>> expr = [x, y, z]
1529
- >>> _too_large_for_docstring(expr, None)
1530
- False
1531
- >>> _too_large_for_docstring(expr, 100)
1532
- False
1533
- >>> _too_large_for_docstring(expr, 1)
1534
- True
1535
- >>> _too_large_for_docstring(expr, 0)
1536
- True
1537
- >>> _too_large_for_docstring(expr, -1)
1538
- True
1539
-
1540
- >>> expr = [x, [y], z, [[x+y], [x*y*z, [x+y+z]]]]
1541
- >>> _too_large_for_docstring(expr, None)
1542
- False
1543
- >>> _too_large_for_docstring(expr, 100)
1544
- False
1545
- >>> _too_large_for_docstring(expr, 1)
1546
- True
1547
- >>> _too_large_for_docstring(expr, 0)
1548
- True
1549
- >>> _too_large_for_docstring(expr, -1)
1550
- True
1551
-
1552
- >>> expr = ((x + y + z)**5).expand()
1553
- >>> _too_large_for_docstring(expr, None)
1554
- False
1555
- >>> _too_large_for_docstring(expr, 100)
1556
- True
1557
- >>> _too_large_for_docstring(expr, 1)
1558
- True
1559
- >>> _too_large_for_docstring(expr, 0)
1560
- True
1561
- >>> _too_large_for_docstring(expr, -1)
1562
- True
1563
-
1564
- >>> from sympy import Matrix
1565
- >>> expr = Matrix([[(x + y + z), ((x + y + z)**2).expand(),
1566
- ... ((x + y + z)**3).expand(), ((x + y + z)**4).expand()]])
1567
- >>> _too_large_for_docstring(expr, None)
1568
- False
1569
- >>> _too_large_for_docstring(expr, 1000)
1570
- False
1571
- >>> _too_large_for_docstring(expr, 100)
1572
- True
1573
- >>> _too_large_for_docstring(expr, 1)
1574
- True
1575
- >>> _too_large_for_docstring(expr, 0)
1576
- True
1577
- >>> _too_large_for_docstring(expr, -1)
1578
- True
1579
-
1580
- """
1581
- # Must be imported here to avoid a circular import error
1582
- from sympy.core.traversal import postorder_traversal
1583
-
1584
- if limit is None:
1585
- return False
1586
-
1587
- i = 0
1588
- for _ in postorder_traversal(expr):
1589
- i += 1
1590
- if i > limit:
1591
- return True
1592
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/magic.py DELETED
@@ -1,12 +0,0 @@
1
- """Functions that involve magic. """
2
-
3
- def pollute(names, objects):
4
- """Pollute the global namespace with symbols -> objects mapping. """
5
- from inspect import currentframe
6
- frame = currentframe().f_back.f_back
7
-
8
- try:
9
- for name, obj in zip(names, objects):
10
- frame.f_globals[name] = obj
11
- finally:
12
- del frame # break cyclic dependencies as stated in inspect docs
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/matchpy_connector.py DELETED
@@ -1,340 +0,0 @@
1
- """
2
- The objects in this module allow the usage of the MatchPy pattern matching
3
- library on SymPy expressions.
4
- """
5
- import re
6
- from typing import List, Callable, NamedTuple, Any, Dict
7
-
8
- from sympy.core.sympify import _sympify
9
- from sympy.external import import_module
10
- from sympy.functions import (log, sin, cos, tan, cot, csc, sec, erf, gamma, uppergamma)
11
- from sympy.functions.elementary.hyperbolic import acosh, asinh, atanh, acoth, acsch, asech, cosh, sinh, tanh, coth, sech, csch
12
- from sympy.functions.elementary.trigonometric import atan, acsc, asin, acot, acos, asec
13
- from sympy.functions.special.error_functions import fresnelc, fresnels, erfc, erfi, Ei
14
- from sympy.core.add import Add
15
- from sympy.core.basic import Basic
16
- from sympy.core.expr import Expr
17
- from sympy.core.mul import Mul
18
- from sympy.core.power import Pow
19
- from sympy.core.relational import (Equality, Unequality)
20
- from sympy.core.symbol import Symbol
21
- from sympy.functions.elementary.exponential import exp
22
- from sympy.integrals.integrals import Integral
23
- from sympy.printing.repr import srepr
24
- from sympy.utilities.decorator import doctest_depends_on
25
-
26
-
27
- matchpy = import_module("matchpy")
28
-
29
-
30
- __doctest_requires__ = {('*',): ['matchpy']}
31
-
32
-
33
- if matchpy:
34
- from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation
35
- from matchpy.expressions.functions import op_iter, create_operation_expression, op_len
36
-
37
- Operation.register(Integral)
38
- Operation.register(Pow)
39
- OneIdentityOperation.register(Pow)
40
-
41
- Operation.register(Add)
42
- OneIdentityOperation.register(Add)
43
- CommutativeOperation.register(Add)
44
- AssociativeOperation.register(Add)
45
-
46
- Operation.register(Mul)
47
- OneIdentityOperation.register(Mul)
48
- CommutativeOperation.register(Mul)
49
- AssociativeOperation.register(Mul)
50
-
51
- Operation.register(Equality)
52
- CommutativeOperation.register(Equality)
53
- Operation.register(Unequality)
54
- CommutativeOperation.register(Unequality)
55
-
56
- Operation.register(exp)
57
- Operation.register(log)
58
- Operation.register(gamma)
59
- Operation.register(uppergamma)
60
- Operation.register(fresnels)
61
- Operation.register(fresnelc)
62
- Operation.register(erf)
63
- Operation.register(Ei)
64
- Operation.register(erfc)
65
- Operation.register(erfi)
66
- Operation.register(sin)
67
- Operation.register(cos)
68
- Operation.register(tan)
69
- Operation.register(cot)
70
- Operation.register(csc)
71
- Operation.register(sec)
72
- Operation.register(sinh)
73
- Operation.register(cosh)
74
- Operation.register(tanh)
75
- Operation.register(coth)
76
- Operation.register(csch)
77
- Operation.register(sech)
78
- Operation.register(asin)
79
- Operation.register(acos)
80
- Operation.register(atan)
81
- Operation.register(acot)
82
- Operation.register(acsc)
83
- Operation.register(asec)
84
- Operation.register(asinh)
85
- Operation.register(acosh)
86
- Operation.register(atanh)
87
- Operation.register(acoth)
88
- Operation.register(acsch)
89
- Operation.register(asech)
90
-
91
- @op_iter.register(Integral) # type: ignore
92
- def _(operation):
93
- return iter((operation._args[0],) + operation._args[1])
94
-
95
- @op_iter.register(Basic) # type: ignore
96
- def _(operation):
97
- return iter(operation._args)
98
-
99
- @op_len.register(Integral) # type: ignore
100
- def _(operation):
101
- return 1 + len(operation._args[1])
102
-
103
- @op_len.register(Basic) # type: ignore
104
- def _(operation):
105
- return len(operation._args)
106
-
107
- @create_operation_expression.register(Basic)
108
- def sympy_op_factory(old_operation, new_operands, variable_name=True):
109
- return type(old_operation)(*new_operands)
110
-
111
-
112
- if matchpy:
113
- from matchpy import Wildcard
114
- else:
115
- class Wildcard: # type: ignore
116
- def __init__(self, min_length, fixed_size, variable_name, optional):
117
- self.min_count = min_length
118
- self.fixed_size = fixed_size
119
- self.variable_name = variable_name
120
- self.optional = optional
121
-
122
-
123
- @doctest_depends_on(modules=('matchpy',))
124
- class _WildAbstract(Wildcard, Symbol):
125
- min_length: int # abstract field required in subclasses
126
- fixed_size: bool # abstract field required in subclasses
127
-
128
- def __init__(self, variable_name=None, optional=None, **assumptions):
129
- min_length = self.min_length
130
- fixed_size = self.fixed_size
131
- if optional is not None:
132
- optional = _sympify(optional)
133
- Wildcard.__init__(self, min_length, fixed_size, str(variable_name), optional)
134
-
135
- def __getstate__(self):
136
- return {
137
- "min_length": self.min_length,
138
- "fixed_size": self.fixed_size,
139
- "min_count": self.min_count,
140
- "variable_name": self.variable_name,
141
- "optional": self.optional,
142
- }
143
-
144
- def __new__(cls, variable_name=None, optional=None, **assumptions):
145
- cls._sanitize(assumptions, cls)
146
- return _WildAbstract.__xnew__(cls, variable_name, optional, **assumptions)
147
-
148
- def __getnewargs__(self):
149
- return self.variable_name, self.optional
150
-
151
- @staticmethod
152
- def __xnew__(cls, variable_name=None, optional=None, **assumptions):
153
- obj = Symbol.__xnew__(cls, variable_name, **assumptions)
154
- return obj
155
-
156
- def _hashable_content(self):
157
- if self.optional:
158
- return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional)
159
- else:
160
- return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name)
161
-
162
- def __copy__(self) -> '_WildAbstract':
163
- return type(self)(variable_name=self.variable_name, optional=self.optional)
164
-
165
- def __repr__(self):
166
- return str(self)
167
-
168
- def __str__(self):
169
- return self.name
170
-
171
-
172
- @doctest_depends_on(modules=('matchpy',))
173
- class WildDot(_WildAbstract):
174
- min_length = 1
175
- fixed_size = True
176
-
177
-
178
- @doctest_depends_on(modules=('matchpy',))
179
- class WildPlus(_WildAbstract):
180
- min_length = 1
181
- fixed_size = False
182
-
183
-
184
- @doctest_depends_on(modules=('matchpy',))
185
- class WildStar(_WildAbstract):
186
- min_length = 0
187
- fixed_size = False
188
-
189
-
190
- def _get_srepr(expr):
191
- s = srepr(expr)
192
- s = re.sub(r"WildDot\('(\w+)'\)", r"\1", s)
193
- s = re.sub(r"WildPlus\('(\w+)'\)", r"*\1", s)
194
- s = re.sub(r"WildStar\('(\w+)'\)", r"*\1", s)
195
- return s
196
-
197
-
198
- class ReplacementInfo(NamedTuple):
199
- replacement: Any
200
- info: Any
201
-
202
-
203
- @doctest_depends_on(modules=('matchpy',))
204
- class Replacer:
205
- """
206
- Replacer object to perform multiple pattern matching and subexpression
207
- replacements in SymPy expressions.
208
-
209
- Examples
210
- ========
211
-
212
- Example to construct a simple first degree equation solver:
213
-
214
- >>> from sympy.utilities.matchpy_connector import WildDot, Replacer
215
- >>> from sympy import Equality, Symbol
216
- >>> x = Symbol("x")
217
- >>> a_ = WildDot("a_", optional=1)
218
- >>> b_ = WildDot("b_", optional=0)
219
-
220
- The lines above have defined two wildcards, ``a_`` and ``b_``, the
221
- coefficients of the equation `a x + b = 0`. The optional values specified
222
- indicate which expression to return in case no match is found, they are
223
- necessary in equations like `a x = 0` and `x + b = 0`.
224
-
225
- Create two constraints to make sure that ``a_`` and ``b_`` will not match
226
- any expression containing ``x``:
227
-
228
- >>> from matchpy import CustomConstraint
229
- >>> free_x_a = CustomConstraint(lambda a_: not a_.has(x))
230
- >>> free_x_b = CustomConstraint(lambda b_: not b_.has(x))
231
-
232
- Now create the rule replacer with the constraints:
233
-
234
- >>> replacer = Replacer(common_constraints=[free_x_a, free_x_b])
235
-
236
- Add the matching rule:
237
-
238
- >>> replacer.add(Equality(a_*x + b_, 0), -b_/a_)
239
-
240
- Let's try it:
241
-
242
- >>> replacer.replace(Equality(3*x + 4, 0))
243
- -4/3
244
-
245
- Notice that it will not match equations expressed with other patterns:
246
-
247
- >>> eq = Equality(3*x, 4)
248
- >>> replacer.replace(eq)
249
- Eq(3*x, 4)
250
-
251
- In order to extend the matching patterns, define another one (we also need
252
- to clear the cache, because the previous result has already been memorized
253
- and the pattern matcher will not iterate again if given the same expression)
254
-
255
- >>> replacer.add(Equality(a_*x, b_), b_/a_)
256
- >>> replacer._matcher.clear()
257
- >>> replacer.replace(eq)
258
- 4/3
259
- """
260
-
261
- def __init__(self, common_constraints: list = [], lambdify: bool = False, info: bool = False):
262
- self._matcher = matchpy.ManyToOneMatcher()
263
- self._common_constraint = common_constraints
264
- self._lambdify = lambdify
265
- self._info = info
266
- self._wildcards: Dict[str, Wildcard] = {}
267
-
268
- def _get_lambda(self, lambda_str: str) -> Callable[..., Expr]:
269
- exec("from sympy import *")
270
- return eval(lambda_str, locals())
271
-
272
- def _get_custom_constraint(self, constraint_expr: Expr, condition_template: str) -> Callable[..., Expr]:
273
- wilds = [x.name for x in constraint_expr.atoms(_WildAbstract)]
274
- lambdaargs = ', '.join(wilds)
275
- fullexpr = _get_srepr(constraint_expr)
276
- condition = condition_template.format(fullexpr)
277
- return matchpy.CustomConstraint(
278
- self._get_lambda(f"lambda {lambdaargs}: ({condition})"))
279
-
280
- def _get_custom_constraint_nonfalse(self, constraint_expr: Expr) -> Callable[..., Expr]:
281
- return self._get_custom_constraint(constraint_expr, "({}) != False")
282
-
283
- def _get_custom_constraint_true(self, constraint_expr: Expr) -> Callable[..., Expr]:
284
- return self._get_custom_constraint(constraint_expr, "({}) == True")
285
-
286
- def add(self, expr: Expr, replacement, conditions_true: List[Expr] = [],
287
- conditions_nonfalse: List[Expr] = [], info: Any = None) -> None:
288
- expr = _sympify(expr)
289
- replacement = _sympify(replacement)
290
- constraints = self._common_constraint[:]
291
- constraint_conditions_true = [
292
- self._get_custom_constraint_true(cond) for cond in conditions_true]
293
- constraint_conditions_nonfalse = [
294
- self._get_custom_constraint_nonfalse(cond) for cond in conditions_nonfalse]
295
- constraints.extend(constraint_conditions_true)
296
- constraints.extend(constraint_conditions_nonfalse)
297
- pattern = matchpy.Pattern(expr, *constraints)
298
- if self._lambdify:
299
- lambda_str = f"lambda {', '.join((x.name for x in expr.atoms(_WildAbstract)))}: {_get_srepr(replacement)}"
300
- lambda_expr = self._get_lambda(lambda_str)
301
- replacement = lambda_expr
302
- else:
303
- self._wildcards.update({str(i): i for i in expr.atoms(Wildcard)})
304
- if self._info:
305
- replacement = ReplacementInfo(replacement, info)
306
- self._matcher.add(pattern, replacement)
307
-
308
- def replace(self, expression, max_count: int = -1):
309
- # This method partly rewrites the .replace method of ManyToOneReplacer
310
- # in MatchPy.
311
- # License: https://github.com/HPAC/matchpy/blob/master/LICENSE
312
- infos = []
313
- replaced = True
314
- replace_count = 0
315
- while replaced and (max_count < 0 or replace_count < max_count):
316
- replaced = False
317
- for subexpr, pos in matchpy.preorder_iter_with_position(expression):
318
- try:
319
- replacement_data, subst = next(iter(self._matcher.match(subexpr)))
320
- if self._info:
321
- replacement = replacement_data.replacement
322
- infos.append(replacement_data.info)
323
- else:
324
- replacement = replacement_data
325
-
326
- if self._lambdify:
327
- result = replacement(**subst)
328
- else:
329
- result = replacement.xreplace({self._wildcards[k]: v for k, v in subst.items()})
330
-
331
- expression = matchpy.functions.replace(expression, pos, result)
332
- replaced = True
333
- break
334
- except StopIteration:
335
- pass
336
- replace_count += 1
337
- if self._info:
338
- return expression, infos
339
- else:
340
- return expression
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/__init__.py DELETED
@@ -1,122 +0,0 @@
1
- """Module with some functions for MathML, like transforming MathML
2
- content in MathML presentation.
3
-
4
- To use this module, you will need lxml.
5
- """
6
-
7
- from pathlib import Path
8
-
9
- from sympy.utilities.decorator import doctest_depends_on
10
-
11
-
12
- __doctest_requires__ = {('apply_xsl', 'c2p'): ['lxml']}
13
-
14
-
15
- def add_mathml_headers(s):
16
- return """<math xmlns:mml="http://www.w3.org/1998/Math/MathML"
17
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
18
- xsi:schemaLocation="http://www.w3.org/1998/Math/MathML
19
- http://www.w3.org/Math/XMLSchema/mathml2/mathml2.xsd">""" + s + "</math>"
20
-
21
-
22
- def _read_binary(pkgname, filename):
23
- import sys
24
-
25
- if sys.version_info >= (3, 10):
26
- # files was added in Python 3.9 but only seems to work here in 3.10+
27
- from importlib.resources import files
28
- return files(pkgname).joinpath(filename).read_bytes()
29
- else:
30
- # read_binary was deprecated in Python 3.11
31
- from importlib.resources import read_binary
32
- return read_binary(pkgname, filename)
33
-
34
-
35
- def _read_xsl(xsl):
36
- # Previously these values were allowed:
37
- if xsl == 'mathml/data/simple_mmlctop.xsl':
38
- xsl = 'simple_mmlctop.xsl'
39
- elif xsl == 'mathml/data/mmlctop.xsl':
40
- xsl = 'mmlctop.xsl'
41
- elif xsl == 'mathml/data/mmltex.xsl':
42
- xsl = 'mmltex.xsl'
43
-
44
- if xsl in ['simple_mmlctop.xsl', 'mmlctop.xsl', 'mmltex.xsl']:
45
- xslbytes = _read_binary('sympy.utilities.mathml.data', xsl)
46
- else:
47
- xslbytes = Path(xsl).read_bytes()
48
-
49
- return xslbytes
50
-
51
-
52
- @doctest_depends_on(modules=('lxml',))
53
- def apply_xsl(mml, xsl):
54
- """Apply a xsl to a MathML string.
55
-
56
- Parameters
57
- ==========
58
-
59
- mml
60
- A string with MathML code.
61
- xsl
62
- A string giving the name of an xsl (xml stylesheet) file which can be
63
- found in sympy/utilities/mathml/data. The following files are supplied
64
- with SymPy:
65
-
66
- - mmlctop.xsl
67
- - mmltex.xsl
68
- - simple_mmlctop.xsl
69
-
70
- Alternatively, a full path to an xsl file can be given.
71
-
72
- Examples
73
- ========
74
-
75
- >>> from sympy.utilities.mathml import apply_xsl
76
- >>> xsl = 'simple_mmlctop.xsl'
77
- >>> mml = '<apply> <plus/> <ci>a</ci> <ci>b</ci> </apply>'
78
- >>> res = apply_xsl(mml,xsl)
79
- >>> print(res)
80
- <?xml version="1.0"?>
81
- <mrow xmlns="http://www.w3.org/1998/Math/MathML">
82
- <mi>a</mi>
83
- <mo> + </mo>
84
- <mi>b</mi>
85
- </mrow>
86
- """
87
- from lxml import etree
88
-
89
- parser = etree.XMLParser(resolve_entities=False)
90
- ac = etree.XSLTAccessControl.DENY_ALL
91
-
92
- s = etree.XML(_read_xsl(xsl), parser=parser)
93
- transform = etree.XSLT(s, access_control=ac)
94
- doc = etree.XML(mml, parser=parser)
95
- result = transform(doc)
96
- s = str(result)
97
- return s
98
-
99
-
100
- @doctest_depends_on(modules=('lxml',))
101
- def c2p(mml, simple=False):
102
- """Transforms a document in MathML content (like the one that sympy produces)
103
- in one document in MathML presentation, more suitable for printing, and more
104
- widely accepted
105
-
106
- Examples
107
- ========
108
-
109
- >>> from sympy.utilities.mathml import c2p
110
- >>> mml = '<apply> <exp/> <cn>2</cn> </apply>'
111
- >>> c2p(mml,simple=True) != c2p(mml,simple=False)
112
- True
113
-
114
- """
115
-
116
- if not mml.startswith('<math'):
117
- mml = add_mathml_headers(mml)
118
-
119
- if simple:
120
- return apply_xsl(mml, 'mathml/data/simple_mmlctop.xsl')
121
-
122
- return apply_xsl(mml, 'mathml/data/mmlctop.xsl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/__init__.py DELETED
File without changes
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmlctop.xsl DELETED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmltex.xsl DELETED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl DELETED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/utilities/memoization.py DELETED
@@ -1,76 +0,0 @@
1
- from functools import wraps
2
-
3
-
4
- def recurrence_memo(initial):
5
- """
6
- Memo decorator for sequences defined by recurrence
7
-
8
- Examples
9
- ========
10
-
11
- >>> from sympy.utilities.memoization import recurrence_memo
12
- >>> @recurrence_memo([1]) # 0! = 1
13
- ... def factorial(n, prev):
14
- ... return n * prev[-1]
15
- >>> factorial(4)
16
- 24
17
- >>> factorial(3) # use cache values
18
- 6
19
- >>> factorial.cache_length() # cache length can be obtained
20
- 5
21
- >>> factorial.fetch_item(slice(2, 4))
22
- [2, 6]
23
-
24
- """
25
- cache = initial
26
-
27
- def decorator(f):
28
- @wraps(f)
29
- def g(n):
30
- L = len(cache)
31
- if n < L:
32
- return cache[n]
33
- for i in range(L, n + 1):
34
- cache.append(f(i, cache))
35
- return cache[-1]
36
- g.cache_length = lambda: len(cache)
37
- g.fetch_item = lambda x: cache[x]
38
- return g
39
- return decorator
40
-
41
-
42
- def assoc_recurrence_memo(base_seq):
43
- """
44
- Memo decorator for associated sequences defined by recurrence starting from base
45
-
46
- base_seq(n) -- callable to get base sequence elements
47
-
48
- XXX works only for Pn0 = base_seq(0) cases
49
- XXX works only for m <= n cases
50
- """
51
-
52
- cache = []
53
-
54
- def decorator(f):
55
- @wraps(f)
56
- def g(n, m):
57
- L = len(cache)
58
- if n < L:
59
- return cache[n][m]
60
-
61
- for i in range(L, n + 1):
62
- # get base sequence
63
- F_i0 = base_seq(i)
64
- F_i_cache = [F_i0]
65
- cache.append(F_i_cache)
66
-
67
- # XXX only works for m <= n cases
68
- # generate assoc sequence
69
- for j in range(1, i + 1):
70
- F_ij = f(i, j, cache)
71
- F_i_cache.append(F_ij)
72
-
73
- return cache[n][m]
74
-
75
- return g
76
- return decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/misc.py DELETED
@@ -1,564 +0,0 @@
1
- """Miscellaneous stuff that does not really fit anywhere else."""
2
-
3
- from __future__ import annotations
4
-
5
- import operator
6
- import sys
7
- import os
8
- import re as _re
9
- import struct
10
- from textwrap import fill, dedent
11
-
12
-
13
- class Undecidable(ValueError):
14
- # an error to be raised when a decision cannot be made definitively
15
- # where a definitive answer is needed
16
- pass
17
-
18
-
19
- def filldedent(s, w=70, **kwargs):
20
- """
21
- Strips leading and trailing empty lines from a copy of ``s``, then dedents,
22
- fills and returns it.
23
-
24
- Empty line stripping serves to deal with docstrings like this one that
25
- start with a newline after the initial triple quote, inserting an empty
26
- line at the beginning of the string.
27
-
28
- Additional keyword arguments will be passed to ``textwrap.fill()``.
29
-
30
- See Also
31
- ========
32
- strlines, rawlines
33
-
34
- """
35
- return '\n' + fill(dedent(str(s)).strip('\n'), width=w, **kwargs)
36
-
37
-
38
- def strlines(s, c=64, short=False):
39
- """Return a cut-and-pastable string that, when printed, is
40
- equivalent to the input. The lines will be surrounded by
41
- parentheses and no line will be longer than c (default 64)
42
- characters. If the line contains newlines characters, the
43
- `rawlines` result will be returned. If ``short`` is True
44
- (default is False) then if there is one line it will be
45
- returned without bounding parentheses.
46
-
47
- Examples
48
- ========
49
-
50
- >>> from sympy.utilities.misc import strlines
51
- >>> q = 'this is a long string that should be broken into shorter lines'
52
- >>> print(strlines(q, 40))
53
- (
54
- 'this is a long string that should be b'
55
- 'roken into shorter lines'
56
- )
57
- >>> q == (
58
- ... 'this is a long string that should be b'
59
- ... 'roken into shorter lines'
60
- ... )
61
- True
62
-
63
- See Also
64
- ========
65
- filldedent, rawlines
66
- """
67
- if not isinstance(s, str):
68
- raise ValueError('expecting string input')
69
- if '\n' in s:
70
- return rawlines(s)
71
- q = '"' if repr(s).startswith('"') else "'"
72
- q = (q,)*2
73
- if '\\' in s: # use r-string
74
- m = '(\nr%s%%s%s\n)' % q
75
- j = '%s\nr%s' % q
76
- c -= 3
77
- else:
78
- m = '(\n%s%%s%s\n)' % q
79
- j = '%s\n%s' % q
80
- c -= 2
81
- out = []
82
- while s:
83
- out.append(s[:c])
84
- s=s[c:]
85
- if short and len(out) == 1:
86
- return (m % out[0]).splitlines()[1] # strip bounding (\n...\n)
87
- return m % j.join(out)
88
-
89
-
90
- def rawlines(s):
91
- """Return a cut-and-pastable string that, when printed, is equivalent
92
- to the input. Use this when there is more than one line in the
93
- string. The string returned is formatted so it can be indented
94
- nicely within tests; in some cases it is wrapped in the dedent
95
- function which has to be imported from textwrap.
96
-
97
- Examples
98
- ========
99
-
100
- Note: because there are characters in the examples below that need
101
- to be escaped because they are themselves within a triple quoted
102
- docstring, expressions below look more complicated than they would
103
- be if they were printed in an interpreter window.
104
-
105
- >>> from sympy.utilities.misc import rawlines
106
- >>> from sympy import TableForm
107
- >>> s = str(TableForm([[1, 10]], headings=(None, ['a', 'bee'])))
108
- >>> print(rawlines(s))
109
- (
110
- 'a bee\\n'
111
- '-----\\n'
112
- '1 10 '
113
- )
114
- >>> print(rawlines('''this
115
- ... that'''))
116
- dedent('''\\
117
- this
118
- that''')
119
-
120
- >>> print(rawlines('''this
121
- ... that
122
- ... '''))
123
- dedent('''\\
124
- this
125
- that
126
- ''')
127
-
128
- >>> s = \"\"\"this
129
- ... is a triple '''
130
- ... \"\"\"
131
- >>> print(rawlines(s))
132
- dedent(\"\"\"\\
133
- this
134
- is a triple '''
135
- \"\"\")
136
-
137
- >>> print(rawlines('''this
138
- ... that
139
- ... '''))
140
- (
141
- 'this\\n'
142
- 'that\\n'
143
- ' '
144
- )
145
-
146
- See Also
147
- ========
148
- filldedent, strlines
149
- """
150
- lines = s.split('\n')
151
- if len(lines) == 1:
152
- return repr(lines[0])
153
- triple = ["'''" in s, '"""' in s]
154
- if any(li.endswith(' ') for li in lines) or '\\' in s or all(triple):
155
- rv = []
156
- # add on the newlines
157
- trailing = s.endswith('\n')
158
- last = len(lines) - 1
159
- for i, li in enumerate(lines):
160
- if i != last or trailing:
161
- rv.append(repr(li + '\n'))
162
- else:
163
- rv.append(repr(li))
164
- return '(\n %s\n)' % '\n '.join(rv)
165
- else:
166
- rv = '\n '.join(lines)
167
- if triple[0]:
168
- return 'dedent("""\\\n %s""")' % rv
169
- else:
170
- return "dedent('''\\\n %s''')" % rv
171
-
172
- ARCH = str(struct.calcsize('P') * 8) + "-bit"
173
-
174
-
175
- # XXX: PyPy does not support hash randomization
176
- HASH_RANDOMIZATION = getattr(sys.flags, 'hash_randomization', False)
177
-
178
- _debug_tmp: list[str] = []
179
- _debug_iter = 0
180
-
181
- def debug_decorator(func):
182
- """If SYMPY_DEBUG is True, it will print a nice execution tree with
183
- arguments and results of all decorated functions, else do nothing.
184
- """
185
- from sympy import SYMPY_DEBUG
186
-
187
- if not SYMPY_DEBUG:
188
- return func
189
-
190
- def maketree(f, *args, **kw):
191
- global _debug_tmp, _debug_iter
192
- oldtmp = _debug_tmp
193
- _debug_tmp = []
194
- _debug_iter += 1
195
-
196
- def tree(subtrees):
197
- def indent(s, variant=1):
198
- x = s.split("\n")
199
- r = "+-%s\n" % x[0]
200
- for a in x[1:]:
201
- if a == "":
202
- continue
203
- if variant == 1:
204
- r += "| %s\n" % a
205
- else:
206
- r += " %s\n" % a
207
- return r
208
- if len(subtrees) == 0:
209
- return ""
210
- f = []
211
- for a in subtrees[:-1]:
212
- f.append(indent(a))
213
- f.append(indent(subtrees[-1], 2))
214
- return ''.join(f)
215
-
216
- # If there is a bug and the algorithm enters an infinite loop, enable the
217
- # following lines. It will print the names and parameters of all major functions
218
- # that are called, *before* they are called
219
- #from functools import reduce
220
- #print("%s%s %s%s" % (_debug_iter, reduce(lambda x, y: x + y, \
221
- # map(lambda x: '-', range(1, 2 + _debug_iter))), f.__name__, args))
222
-
223
- r = f(*args, **kw)
224
-
225
- _debug_iter -= 1
226
- s = "%s%s = %s\n" % (f.__name__, args, r)
227
- if _debug_tmp != []:
228
- s += tree(_debug_tmp)
229
- _debug_tmp = oldtmp
230
- _debug_tmp.append(s)
231
- if _debug_iter == 0:
232
- print(_debug_tmp[0])
233
- _debug_tmp = []
234
- return r
235
-
236
- def decorated(*args, **kwargs):
237
- return maketree(func, *args, **kwargs)
238
-
239
- return decorated
240
-
241
-
242
- def debug(*args):
243
- """
244
- Print ``*args`` if SYMPY_DEBUG is True, else do nothing.
245
- """
246
- from sympy import SYMPY_DEBUG
247
- if SYMPY_DEBUG:
248
- print(*args, file=sys.stderr)
249
-
250
-
251
- def debugf(string, args):
252
- """
253
- Print ``string%args`` if SYMPY_DEBUG is True, else do nothing. This is
254
- intended for debug messages using formatted strings.
255
- """
256
- from sympy import SYMPY_DEBUG
257
- if SYMPY_DEBUG:
258
- print(string%args, file=sys.stderr)
259
-
260
-
261
- def find_executable(executable, path=None):
262
- """Try to find 'executable' in the directories listed in 'path' (a
263
- string listing directories separated by 'os.pathsep'; defaults to
264
- os.environ['PATH']). Returns the complete filename or None if not
265
- found
266
- """
267
- from .exceptions import sympy_deprecation_warning
268
- sympy_deprecation_warning(
269
- """
270
- sympy.utilities.misc.find_executable() is deprecated. Use the standard
271
- library shutil.which() function instead.
272
- """,
273
- deprecated_since_version="1.7",
274
- active_deprecations_target="deprecated-find-executable",
275
- )
276
- if path is None:
277
- path = os.environ['PATH']
278
- paths = path.split(os.pathsep)
279
- extlist = ['']
280
- if os.name == 'os2':
281
- (base, ext) = os.path.splitext(executable)
282
- # executable files on OS/2 can have an arbitrary extension, but
283
- # .exe is automatically appended if no dot is present in the name
284
- if not ext:
285
- executable = executable + ".exe"
286
- elif sys.platform == 'win32':
287
- pathext = os.environ['PATHEXT'].lower().split(os.pathsep)
288
- (base, ext) = os.path.splitext(executable)
289
- if ext.lower() not in pathext:
290
- extlist = pathext
291
- for ext in extlist:
292
- execname = executable + ext
293
- if os.path.isfile(execname):
294
- return execname
295
- else:
296
- for p in paths:
297
- f = os.path.join(p, execname)
298
- if os.path.isfile(f):
299
- return f
300
-
301
- return None
302
-
303
-
304
- def func_name(x, short=False):
305
- """Return function name of `x` (if defined) else the `type(x)`.
306
- If short is True and there is a shorter alias for the result,
307
- return the alias.
308
-
309
- Examples
310
- ========
311
-
312
- >>> from sympy.utilities.misc import func_name
313
- >>> from sympy import Matrix
314
- >>> from sympy.abc import x
315
- >>> func_name(Matrix.eye(3))
316
- 'MutableDenseMatrix'
317
- >>> func_name(x < 1)
318
- 'StrictLessThan'
319
- >>> func_name(x < 1, short=True)
320
- 'Lt'
321
- """
322
- alias = {
323
- 'GreaterThan': 'Ge',
324
- 'StrictGreaterThan': 'Gt',
325
- 'LessThan': 'Le',
326
- 'StrictLessThan': 'Lt',
327
- 'Equality': 'Eq',
328
- 'Unequality': 'Ne',
329
- }
330
- typ = type(x)
331
- if str(typ).startswith("<type '"):
332
- typ = str(typ).split("'")[1].split("'")[0]
333
- elif str(typ).startswith("<class '"):
334
- typ = str(typ).split("'")[1].split("'")[0]
335
- rv = getattr(getattr(x, 'func', x), '__name__', typ)
336
- if '.' in rv:
337
- rv = rv.split('.')[-1]
338
- if short:
339
- rv = alias.get(rv, rv)
340
- return rv
341
-
342
-
343
- def _replace(reps):
344
- """Return a function that can make the replacements, given in
345
- ``reps``, on a string. The replacements should be given as mapping.
346
-
347
- Examples
348
- ========
349
-
350
- >>> from sympy.utilities.misc import _replace
351
- >>> f = _replace(dict(foo='bar', d='t'))
352
- >>> f('food')
353
- 'bart'
354
- >>> f = _replace({})
355
- >>> f('food')
356
- 'food'
357
- """
358
- if not reps:
359
- return lambda x: x
360
- D = lambda match: reps[match.group(0)]
361
- pattern = _re.compile("|".join(
362
- [_re.escape(k) for k, v in reps.items()]), _re.MULTILINE)
363
- return lambda string: pattern.sub(D, string)
364
-
365
-
366
- def replace(string, *reps):
367
- """Return ``string`` with all keys in ``reps`` replaced with
368
- their corresponding values, longer strings first, irrespective
369
- of the order they are given. ``reps`` may be passed as tuples
370
- or a single mapping.
371
-
372
- Examples
373
- ========
374
-
375
- >>> from sympy.utilities.misc import replace
376
- >>> replace('foo', {'oo': 'ar', 'f': 'b'})
377
- 'bar'
378
- >>> replace("spamham sha", ("spam", "eggs"), ("sha","md5"))
379
- 'eggsham md5'
380
-
381
- There is no guarantee that a unique answer will be
382
- obtained if keys in a mapping overlap (i.e. are the same
383
- length and have some identical sequence at the
384
- beginning/end):
385
-
386
- >>> reps = [
387
- ... ('ab', 'x'),
388
- ... ('bc', 'y')]
389
- >>> replace('abc', *reps) in ('xc', 'ay')
390
- True
391
-
392
- References
393
- ==========
394
-
395
- .. [1] https://stackoverflow.com/questions/6116978/how-to-replace-multiple-substrings-of-a-string
396
- """
397
- if len(reps) == 1:
398
- kv = reps[0]
399
- if isinstance(kv, dict):
400
- reps = kv
401
- else:
402
- return string.replace(*kv)
403
- else:
404
- reps = dict(reps)
405
- return _replace(reps)(string)
406
-
407
-
408
- def translate(s, a, b=None, c=None):
409
- """Return ``s`` where characters have been replaced or deleted.
410
-
411
- SYNTAX
412
- ======
413
-
414
- translate(s, None, deletechars):
415
- all characters in ``deletechars`` are deleted
416
- translate(s, map [,deletechars]):
417
- all characters in ``deletechars`` (if provided) are deleted
418
- then the replacements defined by map are made; if the keys
419
- of map are strings then the longer ones are handled first.
420
- Multicharacter deletions should have a value of ''.
421
- translate(s, oldchars, newchars, deletechars)
422
- all characters in ``deletechars`` are deleted
423
- then each character in ``oldchars`` is replaced with the
424
- corresponding character in ``newchars``
425
-
426
- Examples
427
- ========
428
-
429
- >>> from sympy.utilities.misc import translate
430
- >>> abc = 'abc'
431
- >>> translate(abc, None, 'a')
432
- 'bc'
433
- >>> translate(abc, {'a': 'x'}, 'c')
434
- 'xb'
435
- >>> translate(abc, {'abc': 'x', 'a': 'y'})
436
- 'x'
437
-
438
- >>> translate('abcd', 'ac', 'AC', 'd')
439
- 'AbC'
440
-
441
- There is no guarantee that a unique answer will be
442
- obtained if keys in a mapping overlap are the same
443
- length and have some identical sequences at the
444
- beginning/end:
445
-
446
- >>> translate(abc, {'ab': 'x', 'bc': 'y'}) in ('xc', 'ay')
447
- True
448
- """
449
-
450
- mr = {}
451
- if a is None:
452
- if c is not None:
453
- raise ValueError('c should be None when a=None is passed, instead got %s' % c)
454
- if b is None:
455
- return s
456
- c = b
457
- a = b = ''
458
- else:
459
- if isinstance(a, dict):
460
- short = {}
461
- for k in list(a.keys()):
462
- if len(k) == 1 and len(a[k]) == 1:
463
- short[k] = a.pop(k)
464
- mr = a
465
- c = b
466
- if short:
467
- a, b = [''.join(i) for i in list(zip(*short.items()))]
468
- else:
469
- a = b = ''
470
- elif len(a) != len(b):
471
- raise ValueError('oldchars and newchars have different lengths')
472
-
473
- if c:
474
- val = str.maketrans('', '', c)
475
- s = s.translate(val)
476
- s = replace(s, mr)
477
- n = str.maketrans(a, b)
478
- return s.translate(n)
479
-
480
-
481
- def ordinal(num):
482
- """Return ordinal number string of num, e.g. 1 becomes 1st.
483
- """
484
- # modified from https://codereview.stackexchange.com/questions/41298/producing-ordinal-numbers
485
- n = as_int(num)
486
- k = abs(n) % 100
487
- if 11 <= k <= 13:
488
- suffix = 'th'
489
- elif k % 10 == 1:
490
- suffix = 'st'
491
- elif k % 10 == 2:
492
- suffix = 'nd'
493
- elif k % 10 == 3:
494
- suffix = 'rd'
495
- else:
496
- suffix = 'th'
497
- return str(n) + suffix
498
-
499
-
500
- def as_int(n, strict=True):
501
- """
502
- Convert the argument to a builtin integer.
503
-
504
- The return value is guaranteed to be equal to the input. ValueError is
505
- raised if the input has a non-integral value. When ``strict`` is True, this
506
- uses `__index__ <https://docs.python.org/3/reference/datamodel.html#object.__index__>`_
507
- and when it is False it uses ``int``.
508
-
509
-
510
- Examples
511
- ========
512
-
513
- >>> from sympy.utilities.misc import as_int
514
- >>> from sympy import sqrt, S
515
-
516
- The function is primarily concerned with sanitizing input for
517
- functions that need to work with builtin integers, so anything that
518
- is unambiguously an integer should be returned as an int:
519
-
520
- >>> as_int(S(3))
521
- 3
522
-
523
- Floats, being of limited precision, are not assumed to be exact and
524
- will raise an error unless the ``strict`` flag is False. This
525
- precision issue becomes apparent for large floating point numbers:
526
-
527
- >>> big = 1e23
528
- >>> type(big) is float
529
- True
530
- >>> big == int(big)
531
- True
532
- >>> as_int(big)
533
- Traceback (most recent call last):
534
- ...
535
- ValueError: ... is not an integer
536
- >>> as_int(big, strict=False)
537
- 99999999999999991611392
538
-
539
- Input that might be a complex representation of an integer value is
540
- also rejected by default:
541
-
542
- >>> one = sqrt(3 + 2*sqrt(2)) - sqrt(2)
543
- >>> int(one) == 1
544
- True
545
- >>> as_int(one)
546
- Traceback (most recent call last):
547
- ...
548
- ValueError: ... is not an integer
549
- """
550
- if strict:
551
- try:
552
- if isinstance(n, bool):
553
- raise TypeError
554
- return operator.index(n)
555
- except TypeError:
556
- raise ValueError('%s is not an integer' % (n,))
557
- else:
558
- try:
559
- result = int(n)
560
- except TypeError:
561
- raise ValueError('%s is not an integer' % (n,))
562
- if n - result:
563
- raise ValueError('%s is not an integer' % (n,))
564
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/pkgdata.py DELETED
@@ -1,33 +0,0 @@
1
- # This module is deprecated and will be removed.
2
-
3
- import sys
4
- import os
5
- from io import StringIO
6
-
7
- from sympy.utilities.decorator import deprecated
8
-
9
-
10
- @deprecated(
11
- """
12
- The sympy.utilities.pkgdata module and its get_resource function are
13
- deprecated. Use the stdlib importlib.resources module instead.
14
- """,
15
- deprecated_since_version="1.12",
16
- active_deprecations_target="pkgdata",
17
- )
18
- def get_resource(identifier, pkgname=__name__):
19
-
20
- mod = sys.modules[pkgname]
21
- fn = getattr(mod, '__file__', None)
22
- if fn is None:
23
- raise OSError("%r has no __file__!")
24
- path = os.path.join(os.path.dirname(fn), identifier)
25
- loader = getattr(mod, '__loader__', None)
26
- if loader is not None:
27
- try:
28
- data = loader.get_data(path)
29
- except (OSError, AttributeError):
30
- pass
31
- else:
32
- return StringIO(data.decode('utf-8'))
33
- return open(os.path.normpath(path), 'rb')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/pytest.py DELETED
@@ -1,12 +0,0 @@
1
- """
2
- .. deprecated:: 1.6
3
-
4
- sympy.utilities.pytest has been renamed to sympy.testing.pytest.
5
- """
6
- from sympy.utilities.exceptions import sympy_deprecation_warning
7
-
8
- sympy_deprecation_warning("The sympy.utilities.pytest submodule is deprecated. Use sympy.testing.pytest instead.",
9
- deprecated_since_version="1.6",
10
- active_deprecations_target="deprecated-sympy-utilities-submodules")
11
-
12
- from sympy.testing.pytest import * # noqa:F401,F403
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/randtest.py DELETED
@@ -1,12 +0,0 @@
1
- """
2
- .. deprecated:: 1.6
3
-
4
- sympy.utilities.randtest has been renamed to sympy.core.random.
5
- """
6
- from sympy.utilities.exceptions import sympy_deprecation_warning
7
-
8
- sympy_deprecation_warning("The sympy.utilities.randtest submodule is deprecated. Use sympy.core.random instead.",
9
- deprecated_since_version="1.6",
10
- active_deprecations_target="deprecated-sympy-utilities-submodules")
11
-
12
- from sympy.core.random import * # noqa:F401,F403
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/runtests.py DELETED
@@ -1,13 +0,0 @@
1
- """
2
- .. deprecated:: 1.6
3
-
4
- sympy.utilities.runtests has been renamed to sympy.testing.runtests.
5
- """
6
-
7
- from sympy.utilities.exceptions import sympy_deprecation_warning
8
-
9
- sympy_deprecation_warning("The sympy.utilities.runtests submodule is deprecated. Use sympy.testing.runtests instead.",
10
- deprecated_since_version="1.6",
11
- active_deprecations_target="deprecated-sympy-utilities-submodules")
12
-
13
- from sympy.testing.runtests import * # noqa: F401,F403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/source.py DELETED
@@ -1,40 +0,0 @@
1
- """
2
- This module adds several functions for interactive source code inspection.
3
- """
4
-
5
-
6
- def get_class(lookup_view):
7
- """
8
- Convert a string version of a class name to the object.
9
-
10
- For example, get_class('sympy.core.Basic') will return
11
- class Basic located in module sympy.core
12
- """
13
- if isinstance(lookup_view, str):
14
- mod_name, func_name = get_mod_func(lookup_view)
15
- if func_name != '':
16
- lookup_view = getattr(
17
- __import__(mod_name, {}, {}, ['*']), func_name)
18
- if not callable(lookup_view):
19
- raise AttributeError(
20
- "'%s.%s' is not a callable." % (mod_name, func_name))
21
- return lookup_view
22
-
23
-
24
- def get_mod_func(callback):
25
- """
26
- splits the string path to a class into a string path to the module
27
- and the name of the class.
28
-
29
- Examples
30
- ========
31
-
32
- >>> from sympy.utilities.source import get_mod_func
33
- >>> get_mod_func('sympy.core.basic.Basic')
34
- ('sympy.core.basic', 'Basic')
35
-
36
- """
37
- dot = callback.rfind('.')
38
- if dot == -1:
39
- return callback, ''
40
- return callback[:dot], callback[dot + 1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/__init__.py DELETED
File without changes
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_autowrap.py DELETED
@@ -1,467 +0,0 @@
1
- # Tests that require installed backends go into
2
- # sympy/test_external/test_autowrap
3
-
4
- import os
5
- import tempfile
6
- import shutil
7
- from io import StringIO
8
- from pathlib import Path
9
-
10
- from sympy.core import symbols, Eq
11
- from sympy.utilities.autowrap import (autowrap, binary_function,
12
- CythonCodeWrapper, UfuncifyCodeWrapper, CodeWrapper)
13
- from sympy.utilities.codegen import (
14
- CCodeGen, C99CodeGen, CodeGenArgumentListError, make_routine
15
- )
16
- from sympy.testing.pytest import raises
17
- from sympy.testing.tmpfiles import TmpFileManager
18
-
19
-
20
- def get_string(dump_fn, routines, prefix="file", **kwargs):
21
- """Wrapper for dump_fn. dump_fn writes its results to a stream object and
22
- this wrapper returns the contents of that stream as a string. This
23
- auxiliary function is used by many tests below.
24
-
25
- The header and the empty lines are not generator to facilitate the
26
- testing of the output.
27
- """
28
- output = StringIO()
29
- dump_fn(routines, output, prefix, **kwargs)
30
- source = output.getvalue()
31
- output.close()
32
- return source
33
-
34
-
35
- def test_cython_wrapper_scalar_function():
36
- x, y, z = symbols('x,y,z')
37
- expr = (x + y)*z
38
- routine = make_routine("test", expr)
39
- code_gen = CythonCodeWrapper(CCodeGen())
40
- source = get_string(code_gen.dump_pyx, [routine])
41
-
42
- expected = (
43
- "cdef extern from 'file.h':\n"
44
- " double test(double x, double y, double z)\n"
45
- "\n"
46
- "def test_c(double x, double y, double z):\n"
47
- "\n"
48
- " return test(x, y, z)")
49
- assert source == expected
50
-
51
-
52
- def test_cython_wrapper_outarg():
53
- from sympy.core.relational import Equality
54
- x, y, z = symbols('x,y,z')
55
- code_gen = CythonCodeWrapper(C99CodeGen())
56
-
57
- routine = make_routine("test", Equality(z, x + y))
58
- source = get_string(code_gen.dump_pyx, [routine])
59
- expected = (
60
- "cdef extern from 'file.h':\n"
61
- " void test(double x, double y, double *z)\n"
62
- "\n"
63
- "def test_c(double x, double y):\n"
64
- "\n"
65
- " cdef double z = 0\n"
66
- " test(x, y, &z)\n"
67
- " return z")
68
- assert source == expected
69
-
70
-
71
- def test_cython_wrapper_inoutarg():
72
- from sympy.core.relational import Equality
73
- x, y, z = symbols('x,y,z')
74
- code_gen = CythonCodeWrapper(C99CodeGen())
75
- routine = make_routine("test", Equality(z, x + y + z))
76
- source = get_string(code_gen.dump_pyx, [routine])
77
- expected = (
78
- "cdef extern from 'file.h':\n"
79
- " void test(double x, double y, double *z)\n"
80
- "\n"
81
- "def test_c(double x, double y, double z):\n"
82
- "\n"
83
- " test(x, y, &z)\n"
84
- " return z")
85
- assert source == expected
86
-
87
-
88
- def test_cython_wrapper_compile_flags():
89
- from sympy.core.relational import Equality
90
- x, y, z = symbols('x,y,z')
91
- routine = make_routine("test", Equality(z, x + y))
92
-
93
- code_gen = CythonCodeWrapper(CCodeGen())
94
-
95
- expected = """\
96
- from setuptools import setup
97
- from setuptools import Extension
98
- from Cython.Build import cythonize
99
- cy_opts = {'compiler_directives': {'language_level': '3'}}
100
-
101
- ext_mods = [Extension(
102
- 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
103
- include_dirs=[],
104
- library_dirs=[],
105
- libraries=[],
106
- extra_compile_args=['-std=c99'],
107
- extra_link_args=[]
108
- )]
109
- setup(ext_modules=cythonize(ext_mods, **cy_opts))
110
- """ % {'num': CodeWrapper._module_counter}
111
-
112
- temp_dir = tempfile.mkdtemp()
113
- TmpFileManager.tmp_folder(temp_dir)
114
- setup_file_path = os.path.join(temp_dir, 'setup.py')
115
-
116
- code_gen._prepare_files(routine, build_dir=temp_dir)
117
- setup_text = Path(setup_file_path).read_text()
118
- assert setup_text == expected
119
-
120
- code_gen = CythonCodeWrapper(CCodeGen(),
121
- include_dirs=['/usr/local/include', '/opt/booger/include'],
122
- library_dirs=['/user/local/lib'],
123
- libraries=['thelib', 'nilib'],
124
- extra_compile_args=['-slow-math'],
125
- extra_link_args=['-lswamp', '-ltrident'],
126
- cythonize_options={'compiler_directives': {'boundscheck': False}}
127
- )
128
- expected = """\
129
- from setuptools import setup
130
- from setuptools import Extension
131
- from Cython.Build import cythonize
132
- cy_opts = {'compiler_directives': {'boundscheck': False}}
133
-
134
- ext_mods = [Extension(
135
- 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
136
- include_dirs=['/usr/local/include', '/opt/booger/include'],
137
- library_dirs=['/user/local/lib'],
138
- libraries=['thelib', 'nilib'],
139
- extra_compile_args=['-slow-math', '-std=c99'],
140
- extra_link_args=['-lswamp', '-ltrident']
141
- )]
142
- setup(ext_modules=cythonize(ext_mods, **cy_opts))
143
- """ % {'num': CodeWrapper._module_counter}
144
-
145
- code_gen._prepare_files(routine, build_dir=temp_dir)
146
- setup_text = Path(setup_file_path).read_text()
147
- assert setup_text == expected
148
-
149
- expected = """\
150
- from setuptools import setup
151
- from setuptools import Extension
152
- from Cython.Build import cythonize
153
- cy_opts = {'compiler_directives': {'boundscheck': False}}
154
- import numpy as np
155
-
156
- ext_mods = [Extension(
157
- 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
158
- include_dirs=['/usr/local/include', '/opt/booger/include', np.get_include()],
159
- library_dirs=['/user/local/lib'],
160
- libraries=['thelib', 'nilib'],
161
- extra_compile_args=['-slow-math', '-std=c99'],
162
- extra_link_args=['-lswamp', '-ltrident']
163
- )]
164
- setup(ext_modules=cythonize(ext_mods, **cy_opts))
165
- """ % {'num': CodeWrapper._module_counter}
166
-
167
- code_gen._need_numpy = True
168
- code_gen._prepare_files(routine, build_dir=temp_dir)
169
- setup_text = Path(setup_file_path).read_text()
170
- assert setup_text == expected
171
-
172
- TmpFileManager.cleanup()
173
-
174
- def test_cython_wrapper_unique_dummyvars():
175
- from sympy.core.relational import Equality
176
- from sympy.core.symbol import Dummy
177
- x, y, z = Dummy('x'), Dummy('y'), Dummy('z')
178
- x_id, y_id, z_id = [str(d.dummy_index) for d in [x, y, z]]
179
- expr = Equality(z, x + y)
180
- routine = make_routine("test", expr)
181
- code_gen = CythonCodeWrapper(CCodeGen())
182
- source = get_string(code_gen.dump_pyx, [routine])
183
- expected_template = (
184
- "cdef extern from 'file.h':\n"
185
- " void test(double x_{x_id}, double y_{y_id}, double *z_{z_id})\n"
186
- "\n"
187
- "def test_c(double x_{x_id}, double y_{y_id}):\n"
188
- "\n"
189
- " cdef double z_{z_id} = 0\n"
190
- " test(x_{x_id}, y_{y_id}, &z_{z_id})\n"
191
- " return z_{z_id}")
192
- expected = expected_template.format(x_id=x_id, y_id=y_id, z_id=z_id)
193
- assert source == expected
194
-
195
- def test_autowrap_dummy():
196
- x, y, z = symbols('x y z')
197
-
198
- # Uses DummyWrapper to test that codegen works as expected
199
-
200
- f = autowrap(x + y, backend='dummy')
201
- assert f() == str(x + y)
202
- assert f.args == "x, y"
203
- assert f.returns == "nameless"
204
- f = autowrap(Eq(z, x + y), backend='dummy')
205
- assert f() == str(x + y)
206
- assert f.args == "x, y"
207
- assert f.returns == "z"
208
- f = autowrap(Eq(z, x + y + z), backend='dummy')
209
- assert f() == str(x + y + z)
210
- assert f.args == "x, y, z"
211
- assert f.returns == "z"
212
-
213
-
214
- def test_autowrap_args():
215
- x, y, z = symbols('x y z')
216
-
217
- raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y),
218
- backend='dummy', args=[x]))
219
- f = autowrap(Eq(z, x + y), backend='dummy', args=[y, x])
220
- assert f() == str(x + y)
221
- assert f.args == "y, x"
222
- assert f.returns == "z"
223
-
224
- raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y + z),
225
- backend='dummy', args=[x, y]))
226
- f = autowrap(Eq(z, x + y + z), backend='dummy', args=[y, x, z])
227
- assert f() == str(x + y + z)
228
- assert f.args == "y, x, z"
229
- assert f.returns == "z"
230
-
231
- f = autowrap(Eq(z, x + y + z), backend='dummy', args=(y, x, z))
232
- assert f() == str(x + y + z)
233
- assert f.args == "y, x, z"
234
- assert f.returns == "z"
235
-
236
- def test_autowrap_store_files():
237
- x, y = symbols('x y')
238
- tmp = tempfile.mkdtemp()
239
- TmpFileManager.tmp_folder(tmp)
240
-
241
- f = autowrap(x + y, backend='dummy', tempdir=tmp)
242
- assert f() == str(x + y)
243
- assert os.access(tmp, os.F_OK)
244
-
245
- TmpFileManager.cleanup()
246
-
247
- def test_autowrap_store_files_issue_gh12939():
248
- x, y = symbols('x y')
249
- tmp = './tmp'
250
- saved_cwd = os.getcwd()
251
- temp_cwd = tempfile.mkdtemp()
252
- try:
253
- os.chdir(temp_cwd)
254
- f = autowrap(x + y, backend='dummy', tempdir=tmp)
255
- assert f() == str(x + y)
256
- assert os.access(tmp, os.F_OK)
257
- finally:
258
- os.chdir(saved_cwd)
259
- shutil.rmtree(temp_cwd)
260
-
261
-
262
- def test_binary_function():
263
- x, y = symbols('x y')
264
- f = binary_function('f', x + y, backend='dummy')
265
- assert f._imp_() == str(x + y)
266
-
267
-
268
- def test_ufuncify_source():
269
- x, y, z = symbols('x,y,z')
270
- code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
271
- routine = make_routine("test", x + y + z)
272
- source = get_string(code_wrapper.dump_c, [routine])
273
- expected = """\
274
- #include "Python.h"
275
- #include "math.h"
276
- #include "numpy/ndarraytypes.h"
277
- #include "numpy/ufuncobject.h"
278
- #include "numpy/halffloat.h"
279
- #include "file.h"
280
-
281
- static PyMethodDef wrapper_module_%(num)sMethods[] = {
282
- {NULL, NULL, 0, NULL}
283
- };
284
-
285
- #ifdef NPY_1_19_API_VERSION
286
- static void test_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
287
- #else
288
- static void test_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
289
- #endif
290
- {
291
- npy_intp i;
292
- npy_intp n = dimensions[0];
293
- char *in0 = args[0];
294
- char *in1 = args[1];
295
- char *in2 = args[2];
296
- char *out0 = args[3];
297
- npy_intp in0_step = steps[0];
298
- npy_intp in1_step = steps[1];
299
- npy_intp in2_step = steps[2];
300
- npy_intp out0_step = steps[3];
301
- for (i = 0; i < n; i++) {
302
- *((double *)out0) = test(*(double *)in0, *(double *)in1, *(double *)in2);
303
- in0 += in0_step;
304
- in1 += in1_step;
305
- in2 += in2_step;
306
- out0 += out0_step;
307
- }
308
- }
309
- PyUFuncGenericFunction test_funcs[1] = {&test_ufunc};
310
- static char test_types[4] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
311
- static void *test_data[1] = {NULL};
312
-
313
- #if PY_VERSION_HEX >= 0x03000000
314
- static struct PyModuleDef moduledef = {
315
- PyModuleDef_HEAD_INIT,
316
- "wrapper_module_%(num)s",
317
- NULL,
318
- -1,
319
- wrapper_module_%(num)sMethods,
320
- NULL,
321
- NULL,
322
- NULL,
323
- NULL
324
- };
325
-
326
- PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
327
- {
328
- PyObject *m, *d;
329
- PyObject *ufunc0;
330
- m = PyModule_Create(&moduledef);
331
- if (!m) {
332
- return NULL;
333
- }
334
- import_array();
335
- import_umath();
336
- d = PyModule_GetDict(m);
337
- ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
338
- PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
339
- PyDict_SetItemString(d, "test", ufunc0);
340
- Py_DECREF(ufunc0);
341
- return m;
342
- }
343
- #else
344
- PyMODINIT_FUNC initwrapper_module_%(num)s(void)
345
- {
346
- PyObject *m, *d;
347
- PyObject *ufunc0;
348
- m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
349
- if (m == NULL) {
350
- return;
351
- }
352
- import_array();
353
- import_umath();
354
- d = PyModule_GetDict(m);
355
- ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
356
- PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
357
- PyDict_SetItemString(d, "test", ufunc0);
358
- Py_DECREF(ufunc0);
359
- }
360
- #endif""" % {'num': CodeWrapper._module_counter}
361
- assert source == expected
362
-
363
-
364
- def test_ufuncify_source_multioutput():
365
- x, y, z = symbols('x,y,z')
366
- var_symbols = (x, y, z)
367
- expr = x + y**3 + 10*z**2
368
- code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
369
- routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))]
370
- source = get_string(code_wrapper.dump_c, routines, funcname='multitest')
371
- expected = """\
372
- #include "Python.h"
373
- #include "math.h"
374
- #include "numpy/ndarraytypes.h"
375
- #include "numpy/ufuncobject.h"
376
- #include "numpy/halffloat.h"
377
- #include "file.h"
378
-
379
- static PyMethodDef wrapper_module_%(num)sMethods[] = {
380
- {NULL, NULL, 0, NULL}
381
- };
382
-
383
- #ifdef NPY_1_19_API_VERSION
384
- static void multitest_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
385
- #else
386
- static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
387
- #endif
388
- {
389
- npy_intp i;
390
- npy_intp n = dimensions[0];
391
- char *in0 = args[0];
392
- char *in1 = args[1];
393
- char *in2 = args[2];
394
- char *out0 = args[3];
395
- char *out1 = args[4];
396
- char *out2 = args[5];
397
- npy_intp in0_step = steps[0];
398
- npy_intp in1_step = steps[1];
399
- npy_intp in2_step = steps[2];
400
- npy_intp out0_step = steps[3];
401
- npy_intp out1_step = steps[4];
402
- npy_intp out2_step = steps[5];
403
- for (i = 0; i < n; i++) {
404
- *((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2);
405
- *((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2);
406
- *((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2);
407
- in0 += in0_step;
408
- in1 += in1_step;
409
- in2 += in2_step;
410
- out0 += out0_step;
411
- out1 += out1_step;
412
- out2 += out2_step;
413
- }
414
- }
415
- PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc};
416
- static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
417
- static void *multitest_data[1] = {NULL};
418
-
419
- #if PY_VERSION_HEX >= 0x03000000
420
- static struct PyModuleDef moduledef = {
421
- PyModuleDef_HEAD_INIT,
422
- "wrapper_module_%(num)s",
423
- NULL,
424
- -1,
425
- wrapper_module_%(num)sMethods,
426
- NULL,
427
- NULL,
428
- NULL,
429
- NULL
430
- };
431
-
432
- PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
433
- {
434
- PyObject *m, *d;
435
- PyObject *ufunc0;
436
- m = PyModule_Create(&moduledef);
437
- if (!m) {
438
- return NULL;
439
- }
440
- import_array();
441
- import_umath();
442
- d = PyModule_GetDict(m);
443
- ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
444
- PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
445
- PyDict_SetItemString(d, "multitest", ufunc0);
446
- Py_DECREF(ufunc0);
447
- return m;
448
- }
449
- #else
450
- PyMODINIT_FUNC initwrapper_module_%(num)s(void)
451
- {
452
- PyObject *m, *d;
453
- PyObject *ufunc0;
454
- m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
455
- if (m == NULL) {
456
- return;
457
- }
458
- import_array();
459
- import_umath();
460
- d = PyModule_GetDict(m);
461
- ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
462
- PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
463
- PyDict_SetItemString(d, "multitest", ufunc0);
464
- Py_DECREF(ufunc0);
465
- }
466
- #endif""" % {'num': CodeWrapper._module_counter}
467
- assert source == expected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen.py DELETED
@@ -1,1632 +0,0 @@
1
- from io import StringIO
2
-
3
- from sympy.core import symbols, Eq, pi, Catalan, Lambda, Dummy
4
- from sympy.core.relational import Equality
5
- from sympy.core.symbol import Symbol
6
- from sympy.functions.special.error_functions import erf
7
- from sympy.integrals.integrals import Integral
8
- from sympy.matrices import Matrix, MatrixSymbol
9
- from sympy.utilities.codegen import (
10
- codegen, make_routine, CCodeGen, C89CodeGen, C99CodeGen, InputArgument,
11
- CodeGenError, FCodeGen, CodeGenArgumentListError, OutputArgument,
12
- InOutArgument)
13
- from sympy.testing.pytest import raises
14
- from sympy.utilities.lambdify import implemented_function
15
-
16
- #FIXME: Fails due to circular import in with core
17
- # from sympy import codegen
18
-
19
-
20
- def get_string(dump_fn, routines, prefix="file", header=False, empty=False):
21
- """Wrapper for dump_fn. dump_fn writes its results to a stream object and
22
- this wrapper returns the contents of that stream as a string. This
23
- auxiliary function is used by many tests below.
24
-
25
- The header and the empty lines are not generated to facilitate the
26
- testing of the output.
27
- """
28
- output = StringIO()
29
- dump_fn(routines, output, prefix, header, empty)
30
- source = output.getvalue()
31
- output.close()
32
- return source
33
-
34
-
35
- def test_Routine_argument_order():
36
- a, x, y, z = symbols('a x y z')
37
- expr = (x + y)*z
38
- raises(CodeGenArgumentListError, lambda: make_routine("test", expr,
39
- argument_sequence=[z, x]))
40
- raises(CodeGenArgumentListError, lambda: make_routine("test", Eq(a,
41
- expr), argument_sequence=[z, x, y]))
42
- r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
43
- assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
44
- assert [ type(arg) for arg in r.arguments ] == [
45
- InputArgument, InputArgument, OutputArgument, InputArgument ]
46
- r = make_routine('test', Eq(z, expr), argument_sequence=[z, x, y])
47
- assert [ type(arg) for arg in r.arguments ] == [
48
- InOutArgument, InputArgument, InputArgument ]
49
-
50
- from sympy.tensor import IndexedBase, Idx
51
- A, B = map(IndexedBase, ['A', 'B'])
52
- m = symbols('m', integer=True)
53
- i = Idx('i', m)
54
- r = make_routine('test', Eq(A[i], B[i]), argument_sequence=[B, A, m])
55
- assert [ arg.name for arg in r.arguments ] == [B.label, A.label, m]
56
-
57
- expr = Integral(x*y*z, (x, 1, 2), (y, 1, 3))
58
- r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
59
- assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
60
-
61
-
62
- def test_empty_c_code():
63
- code_gen = C89CodeGen()
64
- source = get_string(code_gen.dump_c, [])
65
- assert source == "#include \"file.h\"\n#include <math.h>\n"
66
-
67
-
68
- def test_empty_c_code_with_comment():
69
- code_gen = C89CodeGen()
70
- source = get_string(code_gen.dump_c, [], header=True)
71
- assert source[:82] == (
72
- "/******************************************************************************\n *"
73
- )
74
- # " Code generated with SymPy 0.7.2-git "
75
- assert source[158:] == ( "*\n"
76
- " * *\n"
77
- " * See http://www.sympy.org/ for more information. *\n"
78
- " * *\n"
79
- " * This file is part of 'project' *\n"
80
- " ******************************************************************************/\n"
81
- "#include \"file.h\"\n"
82
- "#include <math.h>\n"
83
- )
84
-
85
-
86
- def test_empty_c_header():
87
- code_gen = C99CodeGen()
88
- source = get_string(code_gen.dump_h, [])
89
- assert source == "#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n#endif\n"
90
-
91
-
92
- def test_simple_c_code():
93
- x, y, z = symbols('x,y,z')
94
- expr = (x + y)*z
95
- routine = make_routine("test", expr)
96
- code_gen = C89CodeGen()
97
- source = get_string(code_gen.dump_c, [routine])
98
- expected = (
99
- "#include \"file.h\"\n"
100
- "#include <math.h>\n"
101
- "double test(double x, double y, double z) {\n"
102
- " double test_result;\n"
103
- " test_result = z*(x + y);\n"
104
- " return test_result;\n"
105
- "}\n"
106
- )
107
- assert source == expected
108
-
109
-
110
- def test_c_code_reserved_words():
111
- x, y, z = symbols('if, typedef, while')
112
- expr = (x + y) * z
113
- routine = make_routine("test", expr)
114
- code_gen = C99CodeGen()
115
- source = get_string(code_gen.dump_c, [routine])
116
- expected = (
117
- "#include \"file.h\"\n"
118
- "#include <math.h>\n"
119
- "double test(double if_, double typedef_, double while_) {\n"
120
- " double test_result;\n"
121
- " test_result = while_*(if_ + typedef_);\n"
122
- " return test_result;\n"
123
- "}\n"
124
- )
125
- assert source == expected
126
-
127
-
128
- def test_numbersymbol_c_code():
129
- routine = make_routine("test", pi**Catalan)
130
- code_gen = C89CodeGen()
131
- source = get_string(code_gen.dump_c, [routine])
132
- expected = (
133
- "#include \"file.h\"\n"
134
- "#include <math.h>\n"
135
- "double test() {\n"
136
- " double test_result;\n"
137
- " double const Catalan = %s;\n"
138
- " test_result = pow(M_PI, Catalan);\n"
139
- " return test_result;\n"
140
- "}\n"
141
- ) % Catalan.evalf(17)
142
- assert source == expected
143
-
144
-
145
- def test_c_code_argument_order():
146
- x, y, z = symbols('x,y,z')
147
- expr = x + y
148
- routine = make_routine("test", expr, argument_sequence=[z, x, y])
149
- code_gen = C89CodeGen()
150
- source = get_string(code_gen.dump_c, [routine])
151
- expected = (
152
- "#include \"file.h\"\n"
153
- "#include <math.h>\n"
154
- "double test(double z, double x, double y) {\n"
155
- " double test_result;\n"
156
- " test_result = x + y;\n"
157
- " return test_result;\n"
158
- "}\n"
159
- )
160
- assert source == expected
161
-
162
-
163
- def test_simple_c_header():
164
- x, y, z = symbols('x,y,z')
165
- expr = (x + y)*z
166
- routine = make_routine("test", expr)
167
- code_gen = C89CodeGen()
168
- source = get_string(code_gen.dump_h, [routine])
169
- expected = (
170
- "#ifndef PROJECT__FILE__H\n"
171
- "#define PROJECT__FILE__H\n"
172
- "double test(double x, double y, double z);\n"
173
- "#endif\n"
174
- )
175
- assert source == expected
176
-
177
-
178
- def test_simple_c_codegen():
179
- x, y, z = symbols('x,y,z')
180
- expr = (x + y)*z
181
- expected = [
182
- ("file.c",
183
- "#include \"file.h\"\n"
184
- "#include <math.h>\n"
185
- "double test(double x, double y, double z) {\n"
186
- " double test_result;\n"
187
- " test_result = z*(x + y);\n"
188
- " return test_result;\n"
189
- "}\n"),
190
- ("file.h",
191
- "#ifndef PROJECT__FILE__H\n"
192
- "#define PROJECT__FILE__H\n"
193
- "double test(double x, double y, double z);\n"
194
- "#endif\n")
195
- ]
196
- result = codegen(("test", expr), "C", "file", header=False, empty=False)
197
- assert result == expected
198
-
199
-
200
- def test_multiple_results_c():
201
- x, y, z = symbols('x,y,z')
202
- expr1 = (x + y)*z
203
- expr2 = (x - y)*z
204
- routine = make_routine(
205
- "test",
206
- [expr1, expr2]
207
- )
208
- code_gen = C99CodeGen()
209
- raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
210
-
211
-
212
- def test_no_results_c():
213
- raises(ValueError, lambda: make_routine("test", []))
214
-
215
-
216
- def test_ansi_math1_codegen():
217
- # not included: log10
218
- from sympy.functions.elementary.complexes import Abs
219
- from sympy.functions.elementary.exponential import log
220
- from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
221
- from sympy.functions.elementary.integers import (ceiling, floor)
222
- from sympy.functions.elementary.miscellaneous import sqrt
223
- from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
224
- x = symbols('x')
225
- name_expr = [
226
- ("test_fabs", Abs(x)),
227
- ("test_acos", acos(x)),
228
- ("test_asin", asin(x)),
229
- ("test_atan", atan(x)),
230
- ("test_ceil", ceiling(x)),
231
- ("test_cos", cos(x)),
232
- ("test_cosh", cosh(x)),
233
- ("test_floor", floor(x)),
234
- ("test_log", log(x)),
235
- ("test_ln", log(x)),
236
- ("test_sin", sin(x)),
237
- ("test_sinh", sinh(x)),
238
- ("test_sqrt", sqrt(x)),
239
- ("test_tan", tan(x)),
240
- ("test_tanh", tanh(x)),
241
- ]
242
- result = codegen(name_expr, "C89", "file", header=False, empty=False)
243
- assert result[0][0] == "file.c"
244
- assert result[0][1] == (
245
- '#include "file.h"\n#include <math.h>\n'
246
- 'double test_fabs(double x) {\n double test_fabs_result;\n test_fabs_result = fabs(x);\n return test_fabs_result;\n}\n'
247
- 'double test_acos(double x) {\n double test_acos_result;\n test_acos_result = acos(x);\n return test_acos_result;\n}\n'
248
- 'double test_asin(double x) {\n double test_asin_result;\n test_asin_result = asin(x);\n return test_asin_result;\n}\n'
249
- 'double test_atan(double x) {\n double test_atan_result;\n test_atan_result = atan(x);\n return test_atan_result;\n}\n'
250
- 'double test_ceil(double x) {\n double test_ceil_result;\n test_ceil_result = ceil(x);\n return test_ceil_result;\n}\n'
251
- 'double test_cos(double x) {\n double test_cos_result;\n test_cos_result = cos(x);\n return test_cos_result;\n}\n'
252
- 'double test_cosh(double x) {\n double test_cosh_result;\n test_cosh_result = cosh(x);\n return test_cosh_result;\n}\n'
253
- 'double test_floor(double x) {\n double test_floor_result;\n test_floor_result = floor(x);\n return test_floor_result;\n}\n'
254
- 'double test_log(double x) {\n double test_log_result;\n test_log_result = log(x);\n return test_log_result;\n}\n'
255
- 'double test_ln(double x) {\n double test_ln_result;\n test_ln_result = log(x);\n return test_ln_result;\n}\n'
256
- 'double test_sin(double x) {\n double test_sin_result;\n test_sin_result = sin(x);\n return test_sin_result;\n}\n'
257
- 'double test_sinh(double x) {\n double test_sinh_result;\n test_sinh_result = sinh(x);\n return test_sinh_result;\n}\n'
258
- 'double test_sqrt(double x) {\n double test_sqrt_result;\n test_sqrt_result = sqrt(x);\n return test_sqrt_result;\n}\n'
259
- 'double test_tan(double x) {\n double test_tan_result;\n test_tan_result = tan(x);\n return test_tan_result;\n}\n'
260
- 'double test_tanh(double x) {\n double test_tanh_result;\n test_tanh_result = tanh(x);\n return test_tanh_result;\n}\n'
261
- )
262
- assert result[1][0] == "file.h"
263
- assert result[1][1] == (
264
- '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
265
- 'double test_fabs(double x);\ndouble test_acos(double x);\n'
266
- 'double test_asin(double x);\ndouble test_atan(double x);\n'
267
- 'double test_ceil(double x);\ndouble test_cos(double x);\n'
268
- 'double test_cosh(double x);\ndouble test_floor(double x);\n'
269
- 'double test_log(double x);\ndouble test_ln(double x);\n'
270
- 'double test_sin(double x);\ndouble test_sinh(double x);\n'
271
- 'double test_sqrt(double x);\ndouble test_tan(double x);\n'
272
- 'double test_tanh(double x);\n#endif\n'
273
- )
274
-
275
-
276
- def test_ansi_math2_codegen():
277
- # not included: frexp, ldexp, modf, fmod
278
- from sympy.functions.elementary.trigonometric import atan2
279
- x, y = symbols('x,y')
280
- name_expr = [
281
- ("test_atan2", atan2(x, y)),
282
- ("test_pow", x**y),
283
- ]
284
- result = codegen(name_expr, "C89", "file", header=False, empty=False)
285
- assert result[0][0] == "file.c"
286
- assert result[0][1] == (
287
- '#include "file.h"\n#include <math.h>\n'
288
- 'double test_atan2(double x, double y) {\n double test_atan2_result;\n test_atan2_result = atan2(x, y);\n return test_atan2_result;\n}\n'
289
- 'double test_pow(double x, double y) {\n double test_pow_result;\n test_pow_result = pow(x, y);\n return test_pow_result;\n}\n'
290
- )
291
- assert result[1][0] == "file.h"
292
- assert result[1][1] == (
293
- '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
294
- 'double test_atan2(double x, double y);\n'
295
- 'double test_pow(double x, double y);\n'
296
- '#endif\n'
297
- )
298
-
299
-
300
- def test_complicated_codegen():
301
- from sympy.functions.elementary.trigonometric import (cos, sin, tan)
302
- x, y, z = symbols('x,y,z')
303
- name_expr = [
304
- ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
305
- ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
306
- ]
307
- result = codegen(name_expr, "C89", "file", header=False, empty=False)
308
- assert result[0][0] == "file.c"
309
- assert result[0][1] == (
310
- '#include "file.h"\n#include <math.h>\n'
311
- 'double test1(double x, double y, double z) {\n'
312
- ' double test1_result;\n'
313
- ' test1_result = '
314
- 'pow(sin(x), 7) + '
315
- '7*pow(sin(x), 6)*cos(y) + '
316
- '7*pow(sin(x), 6)*tan(z) + '
317
- '21*pow(sin(x), 5)*pow(cos(y), 2) + '
318
- '42*pow(sin(x), 5)*cos(y)*tan(z) + '
319
- '21*pow(sin(x), 5)*pow(tan(z), 2) + '
320
- '35*pow(sin(x), 4)*pow(cos(y), 3) + '
321
- '105*pow(sin(x), 4)*pow(cos(y), 2)*tan(z) + '
322
- '105*pow(sin(x), 4)*cos(y)*pow(tan(z), 2) + '
323
- '35*pow(sin(x), 4)*pow(tan(z), 3) + '
324
- '35*pow(sin(x), 3)*pow(cos(y), 4) + '
325
- '140*pow(sin(x), 3)*pow(cos(y), 3)*tan(z) + '
326
- '210*pow(sin(x), 3)*pow(cos(y), 2)*pow(tan(z), 2) + '
327
- '140*pow(sin(x), 3)*cos(y)*pow(tan(z), 3) + '
328
- '35*pow(sin(x), 3)*pow(tan(z), 4) + '
329
- '21*pow(sin(x), 2)*pow(cos(y), 5) + '
330
- '105*pow(sin(x), 2)*pow(cos(y), 4)*tan(z) + '
331
- '210*pow(sin(x), 2)*pow(cos(y), 3)*pow(tan(z), 2) + '
332
- '210*pow(sin(x), 2)*pow(cos(y), 2)*pow(tan(z), 3) + '
333
- '105*pow(sin(x), 2)*cos(y)*pow(tan(z), 4) + '
334
- '21*pow(sin(x), 2)*pow(tan(z), 5) + '
335
- '7*sin(x)*pow(cos(y), 6) + '
336
- '42*sin(x)*pow(cos(y), 5)*tan(z) + '
337
- '105*sin(x)*pow(cos(y), 4)*pow(tan(z), 2) + '
338
- '140*sin(x)*pow(cos(y), 3)*pow(tan(z), 3) + '
339
- '105*sin(x)*pow(cos(y), 2)*pow(tan(z), 4) + '
340
- '42*sin(x)*cos(y)*pow(tan(z), 5) + '
341
- '7*sin(x)*pow(tan(z), 6) + '
342
- 'pow(cos(y), 7) + '
343
- '7*pow(cos(y), 6)*tan(z) + '
344
- '21*pow(cos(y), 5)*pow(tan(z), 2) + '
345
- '35*pow(cos(y), 4)*pow(tan(z), 3) + '
346
- '35*pow(cos(y), 3)*pow(tan(z), 4) + '
347
- '21*pow(cos(y), 2)*pow(tan(z), 5) + '
348
- '7*cos(y)*pow(tan(z), 6) + '
349
- 'pow(tan(z), 7);\n'
350
- ' return test1_result;\n'
351
- '}\n'
352
- 'double test2(double x, double y, double z) {\n'
353
- ' double test2_result;\n'
354
- ' test2_result = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n'
355
- ' return test2_result;\n'
356
- '}\n'
357
- )
358
- assert result[1][0] == "file.h"
359
- assert result[1][1] == (
360
- '#ifndef PROJECT__FILE__H\n'
361
- '#define PROJECT__FILE__H\n'
362
- 'double test1(double x, double y, double z);\n'
363
- 'double test2(double x, double y, double z);\n'
364
- '#endif\n'
365
- )
366
-
367
-
368
- def test_loops_c():
369
- from sympy.tensor import IndexedBase, Idx
370
- from sympy.core.symbol import symbols
371
- n, m = symbols('n m', integer=True)
372
- A = IndexedBase('A')
373
- x = IndexedBase('x')
374
- y = IndexedBase('y')
375
- i = Idx('i', m)
376
- j = Idx('j', n)
377
-
378
- (f1, code), (f2, interface) = codegen(
379
- ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
380
-
381
- assert f1 == 'file.c'
382
- expected = (
383
- '#include "file.h"\n'
384
- '#include <math.h>\n'
385
- 'void matrix_vector(double *A, int m, int n, double *x, double *y) {\n'
386
- ' for (int i=0; i<m; i++){\n'
387
- ' y[i] = 0;\n'
388
- ' }\n'
389
- ' for (int i=0; i<m; i++){\n'
390
- ' for (int j=0; j<n; j++){\n'
391
- ' y[i] = %(rhs)s + y[i];\n'
392
- ' }\n'
393
- ' }\n'
394
- '}\n'
395
- )
396
-
397
- assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*n + j)} or
398
- code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*n)} or
399
- code == expected % {'rhs': 'x[j]*A[%s]' % (i*n + j)} or
400
- code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*n)})
401
- assert f2 == 'file.h'
402
- assert interface == (
403
- '#ifndef PROJECT__FILE__H\n'
404
- '#define PROJECT__FILE__H\n'
405
- 'void matrix_vector(double *A, int m, int n, double *x, double *y);\n'
406
- '#endif\n'
407
- )
408
-
409
-
410
- def test_dummy_loops_c():
411
- from sympy.tensor import IndexedBase, Idx
412
- i, m = symbols('i m', integer=True, cls=Dummy)
413
- x = IndexedBase('x')
414
- y = IndexedBase('y')
415
- i = Idx(i, m)
416
- expected = (
417
- '#include "file.h"\n'
418
- '#include <math.h>\n'
419
- 'void test_dummies(int m_%(mno)i, double *x, double *y) {\n'
420
- ' for (int i_%(ino)i=0; i_%(ino)i<m_%(mno)i; i_%(ino)i++){\n'
421
- ' y[i_%(ino)i] = x[i_%(ino)i];\n'
422
- ' }\n'
423
- '}\n'
424
- ) % {'ino': i.label.dummy_index, 'mno': m.dummy_index}
425
- r = make_routine('test_dummies', Eq(y[i], x[i]))
426
- c89 = C89CodeGen()
427
- c99 = C99CodeGen()
428
- code = get_string(c99.dump_c, [r])
429
- assert code == expected
430
- with raises(NotImplementedError):
431
- get_string(c89.dump_c, [r])
432
-
433
- def test_partial_loops_c():
434
- # check that loop boundaries are determined by Idx, and array strides
435
- # determined by shape of IndexedBase object.
436
- from sympy.tensor import IndexedBase, Idx
437
- from sympy.core.symbol import symbols
438
- n, m, o, p = symbols('n m o p', integer=True)
439
- A = IndexedBase('A', shape=(m, p))
440
- x = IndexedBase('x')
441
- y = IndexedBase('y')
442
- i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
443
- j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
444
-
445
- (f1, code), (f2, interface) = codegen(
446
- ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
447
-
448
- assert f1 == 'file.c'
449
- expected = (
450
- '#include "file.h"\n'
451
- '#include <math.h>\n'
452
- 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y) {\n'
453
- ' for (int i=o; i<%(upperi)s; i++){\n'
454
- ' y[i] = 0;\n'
455
- ' }\n'
456
- ' for (int i=o; i<%(upperi)s; i++){\n'
457
- ' for (int j=0; j<n; j++){\n'
458
- ' y[i] = %(rhs)s + y[i];\n'
459
- ' }\n'
460
- ' }\n'
461
- '}\n'
462
- ) % {'upperi': m - 4, 'rhs': '%(rhs)s'}
463
-
464
- assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*p + j)} or
465
- code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*p)} or
466
- code == expected % {'rhs': 'x[j]*A[%s]' % (i*p + j)} or
467
- code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*p)})
468
- assert f2 == 'file.h'
469
- assert interface == (
470
- '#ifndef PROJECT__FILE__H\n'
471
- '#define PROJECT__FILE__H\n'
472
- 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y);\n'
473
- '#endif\n'
474
- )
475
-
476
-
477
- def test_output_arg_c():
478
- from sympy.core.relational import Equality
479
- from sympy.functions.elementary.trigonometric import (cos, sin)
480
- x, y, z = symbols("x,y,z")
481
- r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
482
- c = C89CodeGen()
483
- result = c.write([r], "test", header=False, empty=False)
484
- assert result[0][0] == "test.c"
485
- expected = (
486
- '#include "test.h"\n'
487
- '#include <math.h>\n'
488
- 'double foo(double x, double *y) {\n'
489
- ' (*y) = sin(x);\n'
490
- ' double foo_result;\n'
491
- ' foo_result = cos(x);\n'
492
- ' return foo_result;\n'
493
- '}\n'
494
- )
495
- assert result[0][1] == expected
496
-
497
-
498
- def test_output_arg_c_reserved_words():
499
- from sympy.core.relational import Equality
500
- from sympy.functions.elementary.trigonometric import (cos, sin)
501
- x, y, z = symbols("if, while, z")
502
- r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
503
- c = C89CodeGen()
504
- result = c.write([r], "test", header=False, empty=False)
505
- assert result[0][0] == "test.c"
506
- expected = (
507
- '#include "test.h"\n'
508
- '#include <math.h>\n'
509
- 'double foo(double if_, double *while_) {\n'
510
- ' (*while_) = sin(if_);\n'
511
- ' double foo_result;\n'
512
- ' foo_result = cos(if_);\n'
513
- ' return foo_result;\n'
514
- '}\n'
515
- )
516
- assert result[0][1] == expected
517
-
518
-
519
- def test_multidim_c_argument_cse():
520
- A_sym = MatrixSymbol('A', 3, 3)
521
- b_sym = MatrixSymbol('b', 3, 1)
522
- A = Matrix(A_sym)
523
- b = Matrix(b_sym)
524
- c = A*b
525
- cgen = CCodeGen(project="test", cse=True)
526
- r = cgen.routine("c", c)
527
- r.arguments[-1].result_var = "out"
528
- r.arguments[-1]._name = "out"
529
- code = get_string(cgen.dump_c, [r], prefix="test")
530
- expected = (
531
- '#include "test.h"\n'
532
- "#include <math.h>\n"
533
- "void c(double *A, double *b, double *out) {\n"
534
- " out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n"
535
- " out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n"
536
- " out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n"
537
- "}\n"
538
- )
539
- assert code == expected
540
-
541
-
542
- def test_ccode_results_named_ordered():
543
- x, y, z = symbols('x,y,z')
544
- B, C = symbols('B,C')
545
- A = MatrixSymbol('A', 1, 3)
546
- expr1 = Equality(A, Matrix([[1, 2, x]]))
547
- expr2 = Equality(C, (x + y)*z)
548
- expr3 = Equality(B, 2*x)
549
- name_expr = ("test", [expr1, expr2, expr3])
550
- expected = (
551
- '#include "test.h"\n'
552
- '#include <math.h>\n'
553
- 'void test(double x, double *C, double z, double y, double *A, double *B) {\n'
554
- ' (*C) = z*(x + y);\n'
555
- ' A[0] = 1;\n'
556
- ' A[1] = 2;\n'
557
- ' A[2] = x;\n'
558
- ' (*B) = 2*x;\n'
559
- '}\n'
560
- )
561
-
562
- result = codegen(name_expr, "c", "test", header=False, empty=False,
563
- argument_sequence=(x, C, z, y, A, B))
564
- source = result[0][1]
565
- assert source == expected
566
-
567
-
568
- def test_ccode_matrixsymbol_slice():
569
- A = MatrixSymbol('A', 5, 3)
570
- B = MatrixSymbol('B', 1, 3)
571
- C = MatrixSymbol('C', 1, 3)
572
- D = MatrixSymbol('D', 5, 1)
573
- name_expr = ("test", [Equality(B, A[0, :]),
574
- Equality(C, A[1, :]),
575
- Equality(D, A[:, 2])])
576
- result = codegen(name_expr, "c99", "test", header=False, empty=False)
577
- source = result[0][1]
578
- expected = (
579
- '#include "test.h"\n'
580
- '#include <math.h>\n'
581
- 'void test(double *A, double *B, double *C, double *D) {\n'
582
- ' B[0] = A[0];\n'
583
- ' B[1] = A[1];\n'
584
- ' B[2] = A[2];\n'
585
- ' C[0] = A[3];\n'
586
- ' C[1] = A[4];\n'
587
- ' C[2] = A[5];\n'
588
- ' D[0] = A[2];\n'
589
- ' D[1] = A[5];\n'
590
- ' D[2] = A[8];\n'
591
- ' D[3] = A[11];\n'
592
- ' D[4] = A[14];\n'
593
- '}\n'
594
- )
595
- assert source == expected
596
-
597
- def test_ccode_cse():
598
- a, b, c, d = symbols('a b c d')
599
- e = MatrixSymbol('e', 3, 1)
600
- name_expr = ("test", [Equality(e, Matrix([[a*b], [a*b + c*d], [a*b*c*d]]))])
601
- generator = CCodeGen(cse=True)
602
- result = codegen(name_expr, code_gen=generator, header=False, empty=False)
603
- source = result[0][1]
604
- expected = (
605
- '#include "test.h"\n'
606
- '#include <math.h>\n'
607
- 'void test(double a, double b, double c, double d, double *e) {\n'
608
- ' const double x0 = a*b;\n'
609
- ' const double x1 = c*d;\n'
610
- ' e[0] = x0;\n'
611
- ' e[1] = x0 + x1;\n'
612
- ' e[2] = x0*x1;\n'
613
- '}\n'
614
- )
615
- assert source == expected
616
-
617
- def test_ccode_unused_array_arg():
618
- x = MatrixSymbol('x', 2, 1)
619
- # x does not appear in output
620
- name_expr = ("test", 1.0)
621
- generator = CCodeGen()
622
- result = codegen(name_expr, code_gen=generator, header=False, empty=False, argument_sequence=(x,))
623
- source = result[0][1]
624
- # note: x should appear as (double *)
625
- expected = (
626
- '#include "test.h"\n'
627
- '#include <math.h>\n'
628
- 'double test(double *x) {\n'
629
- ' double test_result;\n'
630
- ' test_result = 1.0;\n'
631
- ' return test_result;\n'
632
- '}\n'
633
- )
634
- assert source == expected
635
-
636
- def test_ccode_unused_array_arg_func():
637
- # issue 16689
638
- X = MatrixSymbol('X',3,1)
639
- Y = MatrixSymbol('Y',3,1)
640
- z = symbols('z',integer = True)
641
- name_expr = ('testBug', X[0] + X[1])
642
- result = codegen(name_expr, language='C', header=False, empty=False, argument_sequence=(X, Y, z))
643
- source = result[0][1]
644
- expected = (
645
- '#include "testBug.h"\n'
646
- '#include <math.h>\n'
647
- 'double testBug(double *X, double *Y, int z) {\n'
648
- ' double testBug_result;\n'
649
- ' testBug_result = X[0] + X[1];\n'
650
- ' return testBug_result;\n'
651
- '}\n'
652
- )
653
- assert source == expected
654
-
655
- def test_empty_f_code():
656
- code_gen = FCodeGen()
657
- source = get_string(code_gen.dump_f95, [])
658
- assert source == ""
659
-
660
-
661
- def test_empty_f_code_with_header():
662
- code_gen = FCodeGen()
663
- source = get_string(code_gen.dump_f95, [], header=True)
664
- assert source[:82] == (
665
- "!******************************************************************************\n!*"
666
- )
667
- # " Code generated with SymPy 0.7.2-git "
668
- assert source[158:] == ( "*\n"
669
- "!* *\n"
670
- "!* See http://www.sympy.org/ for more information. *\n"
671
- "!* *\n"
672
- "!* This file is part of 'project' *\n"
673
- "!******************************************************************************\n"
674
- )
675
-
676
-
677
- def test_empty_f_header():
678
- code_gen = FCodeGen()
679
- source = get_string(code_gen.dump_h, [])
680
- assert source == ""
681
-
682
-
683
- def test_simple_f_code():
684
- x, y, z = symbols('x,y,z')
685
- expr = (x + y)*z
686
- routine = make_routine("test", expr)
687
- code_gen = FCodeGen()
688
- source = get_string(code_gen.dump_f95, [routine])
689
- expected = (
690
- "REAL*8 function test(x, y, z)\n"
691
- "implicit none\n"
692
- "REAL*8, intent(in) :: x\n"
693
- "REAL*8, intent(in) :: y\n"
694
- "REAL*8, intent(in) :: z\n"
695
- "test = z*(x + y)\n"
696
- "end function\n"
697
- )
698
- assert source == expected
699
-
700
-
701
- def test_numbersymbol_f_code():
702
- routine = make_routine("test", pi**Catalan)
703
- code_gen = FCodeGen()
704
- source = get_string(code_gen.dump_f95, [routine])
705
- expected = (
706
- "REAL*8 function test()\n"
707
- "implicit none\n"
708
- "REAL*8, parameter :: Catalan = %sd0\n"
709
- "REAL*8, parameter :: pi = %sd0\n"
710
- "test = pi**Catalan\n"
711
- "end function\n"
712
- ) % (Catalan.evalf(17), pi.evalf(17))
713
- assert source == expected
714
-
715
- def test_erf_f_code():
716
- x = symbols('x')
717
- routine = make_routine("test", erf(x) - erf(-2 * x))
718
- code_gen = FCodeGen()
719
- source = get_string(code_gen.dump_f95, [routine])
720
- expected = (
721
- "REAL*8 function test(x)\n"
722
- "implicit none\n"
723
- "REAL*8, intent(in) :: x\n"
724
- "test = erf(x) + erf(2.0d0*x)\n"
725
- "end function\n"
726
- )
727
- assert source == expected, source
728
-
729
- def test_f_code_argument_order():
730
- x, y, z = symbols('x,y,z')
731
- expr = x + y
732
- routine = make_routine("test", expr, argument_sequence=[z, x, y])
733
- code_gen = FCodeGen()
734
- source = get_string(code_gen.dump_f95, [routine])
735
- expected = (
736
- "REAL*8 function test(z, x, y)\n"
737
- "implicit none\n"
738
- "REAL*8, intent(in) :: z\n"
739
- "REAL*8, intent(in) :: x\n"
740
- "REAL*8, intent(in) :: y\n"
741
- "test = x + y\n"
742
- "end function\n"
743
- )
744
- assert source == expected
745
-
746
-
747
- def test_simple_f_header():
748
- x, y, z = symbols('x,y,z')
749
- expr = (x + y)*z
750
- routine = make_routine("test", expr)
751
- code_gen = FCodeGen()
752
- source = get_string(code_gen.dump_h, [routine])
753
- expected = (
754
- "interface\n"
755
- "REAL*8 function test(x, y, z)\n"
756
- "implicit none\n"
757
- "REAL*8, intent(in) :: x\n"
758
- "REAL*8, intent(in) :: y\n"
759
- "REAL*8, intent(in) :: z\n"
760
- "end function\n"
761
- "end interface\n"
762
- )
763
- assert source == expected
764
-
765
-
766
- def test_simple_f_codegen():
767
- x, y, z = symbols('x,y,z')
768
- expr = (x + y)*z
769
- result = codegen(
770
- ("test", expr), "F95", "file", header=False, empty=False)
771
- expected = [
772
- ("file.f90",
773
- "REAL*8 function test(x, y, z)\n"
774
- "implicit none\n"
775
- "REAL*8, intent(in) :: x\n"
776
- "REAL*8, intent(in) :: y\n"
777
- "REAL*8, intent(in) :: z\n"
778
- "test = z*(x + y)\n"
779
- "end function\n"),
780
- ("file.h",
781
- "interface\n"
782
- "REAL*8 function test(x, y, z)\n"
783
- "implicit none\n"
784
- "REAL*8, intent(in) :: x\n"
785
- "REAL*8, intent(in) :: y\n"
786
- "REAL*8, intent(in) :: z\n"
787
- "end function\n"
788
- "end interface\n")
789
- ]
790
- assert result == expected
791
-
792
-
793
- def test_multiple_results_f():
794
- x, y, z = symbols('x,y,z')
795
- expr1 = (x + y)*z
796
- expr2 = (x - y)*z
797
- routine = make_routine(
798
- "test",
799
- [expr1, expr2]
800
- )
801
- code_gen = FCodeGen()
802
- raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
803
-
804
-
805
- def test_no_results_f():
806
- raises(ValueError, lambda: make_routine("test", []))
807
-
808
-
809
- def test_intrinsic_math_codegen():
810
- # not included: log10
811
- from sympy.functions.elementary.complexes import Abs
812
- from sympy.functions.elementary.exponential import log
813
- from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
814
- from sympy.functions.elementary.miscellaneous import sqrt
815
- from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
816
- x = symbols('x')
817
- name_expr = [
818
- ("test_abs", Abs(x)),
819
- ("test_acos", acos(x)),
820
- ("test_asin", asin(x)),
821
- ("test_atan", atan(x)),
822
- ("test_cos", cos(x)),
823
- ("test_cosh", cosh(x)),
824
- ("test_log", log(x)),
825
- ("test_ln", log(x)),
826
- ("test_sin", sin(x)),
827
- ("test_sinh", sinh(x)),
828
- ("test_sqrt", sqrt(x)),
829
- ("test_tan", tan(x)),
830
- ("test_tanh", tanh(x)),
831
- ]
832
- result = codegen(name_expr, "F95", "file", header=False, empty=False)
833
- assert result[0][0] == "file.f90"
834
- expected = (
835
- 'REAL*8 function test_abs(x)\n'
836
- 'implicit none\n'
837
- 'REAL*8, intent(in) :: x\n'
838
- 'test_abs = abs(x)\n'
839
- 'end function\n'
840
- 'REAL*8 function test_acos(x)\n'
841
- 'implicit none\n'
842
- 'REAL*8, intent(in) :: x\n'
843
- 'test_acos = acos(x)\n'
844
- 'end function\n'
845
- 'REAL*8 function test_asin(x)\n'
846
- 'implicit none\n'
847
- 'REAL*8, intent(in) :: x\n'
848
- 'test_asin = asin(x)\n'
849
- 'end function\n'
850
- 'REAL*8 function test_atan(x)\n'
851
- 'implicit none\n'
852
- 'REAL*8, intent(in) :: x\n'
853
- 'test_atan = atan(x)\n'
854
- 'end function\n'
855
- 'REAL*8 function test_cos(x)\n'
856
- 'implicit none\n'
857
- 'REAL*8, intent(in) :: x\n'
858
- 'test_cos = cos(x)\n'
859
- 'end function\n'
860
- 'REAL*8 function test_cosh(x)\n'
861
- 'implicit none\n'
862
- 'REAL*8, intent(in) :: x\n'
863
- 'test_cosh = cosh(x)\n'
864
- 'end function\n'
865
- 'REAL*8 function test_log(x)\n'
866
- 'implicit none\n'
867
- 'REAL*8, intent(in) :: x\n'
868
- 'test_log = log(x)\n'
869
- 'end function\n'
870
- 'REAL*8 function test_ln(x)\n'
871
- 'implicit none\n'
872
- 'REAL*8, intent(in) :: x\n'
873
- 'test_ln = log(x)\n'
874
- 'end function\n'
875
- 'REAL*8 function test_sin(x)\n'
876
- 'implicit none\n'
877
- 'REAL*8, intent(in) :: x\n'
878
- 'test_sin = sin(x)\n'
879
- 'end function\n'
880
- 'REAL*8 function test_sinh(x)\n'
881
- 'implicit none\n'
882
- 'REAL*8, intent(in) :: x\n'
883
- 'test_sinh = sinh(x)\n'
884
- 'end function\n'
885
- 'REAL*8 function test_sqrt(x)\n'
886
- 'implicit none\n'
887
- 'REAL*8, intent(in) :: x\n'
888
- 'test_sqrt = sqrt(x)\n'
889
- 'end function\n'
890
- 'REAL*8 function test_tan(x)\n'
891
- 'implicit none\n'
892
- 'REAL*8, intent(in) :: x\n'
893
- 'test_tan = tan(x)\n'
894
- 'end function\n'
895
- 'REAL*8 function test_tanh(x)\n'
896
- 'implicit none\n'
897
- 'REAL*8, intent(in) :: x\n'
898
- 'test_tanh = tanh(x)\n'
899
- 'end function\n'
900
- )
901
- assert result[0][1] == expected
902
-
903
- assert result[1][0] == "file.h"
904
- expected = (
905
- 'interface\n'
906
- 'REAL*8 function test_abs(x)\n'
907
- 'implicit none\n'
908
- 'REAL*8, intent(in) :: x\n'
909
- 'end function\n'
910
- 'end interface\n'
911
- 'interface\n'
912
- 'REAL*8 function test_acos(x)\n'
913
- 'implicit none\n'
914
- 'REAL*8, intent(in) :: x\n'
915
- 'end function\n'
916
- 'end interface\n'
917
- 'interface\n'
918
- 'REAL*8 function test_asin(x)\n'
919
- 'implicit none\n'
920
- 'REAL*8, intent(in) :: x\n'
921
- 'end function\n'
922
- 'end interface\n'
923
- 'interface\n'
924
- 'REAL*8 function test_atan(x)\n'
925
- 'implicit none\n'
926
- 'REAL*8, intent(in) :: x\n'
927
- 'end function\n'
928
- 'end interface\n'
929
- 'interface\n'
930
- 'REAL*8 function test_cos(x)\n'
931
- 'implicit none\n'
932
- 'REAL*8, intent(in) :: x\n'
933
- 'end function\n'
934
- 'end interface\n'
935
- 'interface\n'
936
- 'REAL*8 function test_cosh(x)\n'
937
- 'implicit none\n'
938
- 'REAL*8, intent(in) :: x\n'
939
- 'end function\n'
940
- 'end interface\n'
941
- 'interface\n'
942
- 'REAL*8 function test_log(x)\n'
943
- 'implicit none\n'
944
- 'REAL*8, intent(in) :: x\n'
945
- 'end function\n'
946
- 'end interface\n'
947
- 'interface\n'
948
- 'REAL*8 function test_ln(x)\n'
949
- 'implicit none\n'
950
- 'REAL*8, intent(in) :: x\n'
951
- 'end function\n'
952
- 'end interface\n'
953
- 'interface\n'
954
- 'REAL*8 function test_sin(x)\n'
955
- 'implicit none\n'
956
- 'REAL*8, intent(in) :: x\n'
957
- 'end function\n'
958
- 'end interface\n'
959
- 'interface\n'
960
- 'REAL*8 function test_sinh(x)\n'
961
- 'implicit none\n'
962
- 'REAL*8, intent(in) :: x\n'
963
- 'end function\n'
964
- 'end interface\n'
965
- 'interface\n'
966
- 'REAL*8 function test_sqrt(x)\n'
967
- 'implicit none\n'
968
- 'REAL*8, intent(in) :: x\n'
969
- 'end function\n'
970
- 'end interface\n'
971
- 'interface\n'
972
- 'REAL*8 function test_tan(x)\n'
973
- 'implicit none\n'
974
- 'REAL*8, intent(in) :: x\n'
975
- 'end function\n'
976
- 'end interface\n'
977
- 'interface\n'
978
- 'REAL*8 function test_tanh(x)\n'
979
- 'implicit none\n'
980
- 'REAL*8, intent(in) :: x\n'
981
- 'end function\n'
982
- 'end interface\n'
983
- )
984
- assert result[1][1] == expected
985
-
986
-
987
- def test_intrinsic_math2_codegen():
988
- # not included: frexp, ldexp, modf, fmod
989
- from sympy.functions.elementary.trigonometric import atan2
990
- x, y = symbols('x,y')
991
- name_expr = [
992
- ("test_atan2", atan2(x, y)),
993
- ("test_pow", x**y),
994
- ]
995
- result = codegen(name_expr, "F95", "file", header=False, empty=False)
996
- assert result[0][0] == "file.f90"
997
- expected = (
998
- 'REAL*8 function test_atan2(x, y)\n'
999
- 'implicit none\n'
1000
- 'REAL*8, intent(in) :: x\n'
1001
- 'REAL*8, intent(in) :: y\n'
1002
- 'test_atan2 = atan2(x, y)\n'
1003
- 'end function\n'
1004
- 'REAL*8 function test_pow(x, y)\n'
1005
- 'implicit none\n'
1006
- 'REAL*8, intent(in) :: x\n'
1007
- 'REAL*8, intent(in) :: y\n'
1008
- 'test_pow = x**y\n'
1009
- 'end function\n'
1010
- )
1011
- assert result[0][1] == expected
1012
-
1013
- assert result[1][0] == "file.h"
1014
- expected = (
1015
- 'interface\n'
1016
- 'REAL*8 function test_atan2(x, y)\n'
1017
- 'implicit none\n'
1018
- 'REAL*8, intent(in) :: x\n'
1019
- 'REAL*8, intent(in) :: y\n'
1020
- 'end function\n'
1021
- 'end interface\n'
1022
- 'interface\n'
1023
- 'REAL*8 function test_pow(x, y)\n'
1024
- 'implicit none\n'
1025
- 'REAL*8, intent(in) :: x\n'
1026
- 'REAL*8, intent(in) :: y\n'
1027
- 'end function\n'
1028
- 'end interface\n'
1029
- )
1030
- assert result[1][1] == expected
1031
-
1032
-
1033
- def test_complicated_codegen_f95():
1034
- from sympy.functions.elementary.trigonometric import (cos, sin, tan)
1035
- x, y, z = symbols('x,y,z')
1036
- name_expr = [
1037
- ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
1038
- ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
1039
- ]
1040
- result = codegen(name_expr, "F95", "file", header=False, empty=False)
1041
- assert result[0][0] == "file.f90"
1042
- expected = (
1043
- 'REAL*8 function test1(x, y, z)\n'
1044
- 'implicit none\n'
1045
- 'REAL*8, intent(in) :: x\n'
1046
- 'REAL*8, intent(in) :: y\n'
1047
- 'REAL*8, intent(in) :: z\n'
1048
- 'test1 = sin(x)**7 + 7*sin(x)**6*cos(y) + 7*sin(x)**6*tan(z) + 21*sin(x) &\n'
1049
- ' **5*cos(y)**2 + 42*sin(x)**5*cos(y)*tan(z) + 21*sin(x)**5*tan(z) &\n'
1050
- ' **2 + 35*sin(x)**4*cos(y)**3 + 105*sin(x)**4*cos(y)**2*tan(z) + &\n'
1051
- ' 105*sin(x)**4*cos(y)*tan(z)**2 + 35*sin(x)**4*tan(z)**3 + 35*sin( &\n'
1052
- ' x)**3*cos(y)**4 + 140*sin(x)**3*cos(y)**3*tan(z) + 210*sin(x)**3* &\n'
1053
- ' cos(y)**2*tan(z)**2 + 140*sin(x)**3*cos(y)*tan(z)**3 + 35*sin(x) &\n'
1054
- ' **3*tan(z)**4 + 21*sin(x)**2*cos(y)**5 + 105*sin(x)**2*cos(y)**4* &\n'
1055
- ' tan(z) + 210*sin(x)**2*cos(y)**3*tan(z)**2 + 210*sin(x)**2*cos(y) &\n'
1056
- ' **2*tan(z)**3 + 105*sin(x)**2*cos(y)*tan(z)**4 + 21*sin(x)**2*tan &\n'
1057
- ' (z)**5 + 7*sin(x)*cos(y)**6 + 42*sin(x)*cos(y)**5*tan(z) + 105* &\n'
1058
- ' sin(x)*cos(y)**4*tan(z)**2 + 140*sin(x)*cos(y)**3*tan(z)**3 + 105 &\n'
1059
- ' *sin(x)*cos(y)**2*tan(z)**4 + 42*sin(x)*cos(y)*tan(z)**5 + 7*sin( &\n'
1060
- ' x)*tan(z)**6 + cos(y)**7 + 7*cos(y)**6*tan(z) + 21*cos(y)**5*tan( &\n'
1061
- ' z)**2 + 35*cos(y)**4*tan(z)**3 + 35*cos(y)**3*tan(z)**4 + 21*cos( &\n'
1062
- ' y)**2*tan(z)**5 + 7*cos(y)*tan(z)**6 + tan(z)**7\n'
1063
- 'end function\n'
1064
- 'REAL*8 function test2(x, y, z)\n'
1065
- 'implicit none\n'
1066
- 'REAL*8, intent(in) :: x\n'
1067
- 'REAL*8, intent(in) :: y\n'
1068
- 'REAL*8, intent(in) :: z\n'
1069
- 'test2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n'
1070
- 'end function\n'
1071
- )
1072
- assert result[0][1] == expected
1073
- assert result[1][0] == "file.h"
1074
- expected = (
1075
- 'interface\n'
1076
- 'REAL*8 function test1(x, y, z)\n'
1077
- 'implicit none\n'
1078
- 'REAL*8, intent(in) :: x\n'
1079
- 'REAL*8, intent(in) :: y\n'
1080
- 'REAL*8, intent(in) :: z\n'
1081
- 'end function\n'
1082
- 'end interface\n'
1083
- 'interface\n'
1084
- 'REAL*8 function test2(x, y, z)\n'
1085
- 'implicit none\n'
1086
- 'REAL*8, intent(in) :: x\n'
1087
- 'REAL*8, intent(in) :: y\n'
1088
- 'REAL*8, intent(in) :: z\n'
1089
- 'end function\n'
1090
- 'end interface\n'
1091
- )
1092
- assert result[1][1] == expected
1093
-
1094
-
1095
- def test_loops():
1096
- from sympy.tensor import IndexedBase, Idx
1097
- from sympy.core.symbol import symbols
1098
-
1099
- n, m = symbols('n,m', integer=True)
1100
- A, x, y = map(IndexedBase, 'Axy')
1101
- i = Idx('i', m)
1102
- j = Idx('j', n)
1103
-
1104
- (f1, code), (f2, interface) = codegen(
1105
- ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
1106
-
1107
- assert f1 == 'file.f90'
1108
- expected = (
1109
- 'subroutine matrix_vector(A, m, n, x, y)\n'
1110
- 'implicit none\n'
1111
- 'INTEGER*4, intent(in) :: m\n'
1112
- 'INTEGER*4, intent(in) :: n\n'
1113
- 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1114
- 'REAL*8, intent(in), dimension(1:n) :: x\n'
1115
- 'REAL*8, intent(out), dimension(1:m) :: y\n'
1116
- 'INTEGER*4 :: i\n'
1117
- 'INTEGER*4 :: j\n'
1118
- 'do i = 1, m\n'
1119
- ' y(i) = 0\n'
1120
- 'end do\n'
1121
- 'do i = 1, m\n'
1122
- ' do j = 1, n\n'
1123
- ' y(i) = %(rhs)s + y(i)\n'
1124
- ' end do\n'
1125
- 'end do\n'
1126
- 'end subroutine\n'
1127
- )
1128
-
1129
- assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
1130
- code == expected % {'rhs': 'x(j)*A(i, j)'}
1131
- assert f2 == 'file.h'
1132
- assert interface == (
1133
- 'interface\n'
1134
- 'subroutine matrix_vector(A, m, n, x, y)\n'
1135
- 'implicit none\n'
1136
- 'INTEGER*4, intent(in) :: m\n'
1137
- 'INTEGER*4, intent(in) :: n\n'
1138
- 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1139
- 'REAL*8, intent(in), dimension(1:n) :: x\n'
1140
- 'REAL*8, intent(out), dimension(1:m) :: y\n'
1141
- 'end subroutine\n'
1142
- 'end interface\n'
1143
- )
1144
-
1145
-
1146
- def test_dummy_loops_f95():
1147
- from sympy.tensor import IndexedBase, Idx
1148
- i, m = symbols('i m', integer=True, cls=Dummy)
1149
- x = IndexedBase('x')
1150
- y = IndexedBase('y')
1151
- i = Idx(i, m)
1152
- expected = (
1153
- 'subroutine test_dummies(m_%(mcount)i, x, y)\n'
1154
- 'implicit none\n'
1155
- 'INTEGER*4, intent(in) :: m_%(mcount)i\n'
1156
- 'REAL*8, intent(in), dimension(1:m_%(mcount)i) :: x\n'
1157
- 'REAL*8, intent(out), dimension(1:m_%(mcount)i) :: y\n'
1158
- 'INTEGER*4 :: i_%(icount)i\n'
1159
- 'do i_%(icount)i = 1, m_%(mcount)i\n'
1160
- ' y(i_%(icount)i) = x(i_%(icount)i)\n'
1161
- 'end do\n'
1162
- 'end subroutine\n'
1163
- ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
1164
- r = make_routine('test_dummies', Eq(y[i], x[i]))
1165
- c = FCodeGen()
1166
- code = get_string(c.dump_f95, [r])
1167
- assert code == expected
1168
-
1169
-
1170
- def test_loops_InOut():
1171
- from sympy.tensor import IndexedBase, Idx
1172
- from sympy.core.symbol import symbols
1173
-
1174
- i, j, n, m = symbols('i,j,n,m', integer=True)
1175
- A, x, y = symbols('A,x,y')
1176
- A = IndexedBase(A)[Idx(i, m), Idx(j, n)]
1177
- x = IndexedBase(x)[Idx(j, n)]
1178
- y = IndexedBase(y)[Idx(i, m)]
1179
-
1180
- (f1, code), (f2, interface) = codegen(
1181
- ('matrix_vector', Eq(y, y + A*x)), "F95", "file", header=False, empty=False)
1182
-
1183
- assert f1 == 'file.f90'
1184
- expected = (
1185
- 'subroutine matrix_vector(A, m, n, x, y)\n'
1186
- 'implicit none\n'
1187
- 'INTEGER*4, intent(in) :: m\n'
1188
- 'INTEGER*4, intent(in) :: n\n'
1189
- 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1190
- 'REAL*8, intent(in), dimension(1:n) :: x\n'
1191
- 'REAL*8, intent(inout), dimension(1:m) :: y\n'
1192
- 'INTEGER*4 :: i\n'
1193
- 'INTEGER*4 :: j\n'
1194
- 'do i = 1, m\n'
1195
- ' do j = 1, n\n'
1196
- ' y(i) = %(rhs)s + y(i)\n'
1197
- ' end do\n'
1198
- 'end do\n'
1199
- 'end subroutine\n'
1200
- )
1201
-
1202
- assert (code == expected % {'rhs': 'A(i, j)*x(j)'} or
1203
- code == expected % {'rhs': 'x(j)*A(i, j)'})
1204
- assert f2 == 'file.h'
1205
- assert interface == (
1206
- 'interface\n'
1207
- 'subroutine matrix_vector(A, m, n, x, y)\n'
1208
- 'implicit none\n'
1209
- 'INTEGER*4, intent(in) :: m\n'
1210
- 'INTEGER*4, intent(in) :: n\n'
1211
- 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1212
- 'REAL*8, intent(in), dimension(1:n) :: x\n'
1213
- 'REAL*8, intent(inout), dimension(1:m) :: y\n'
1214
- 'end subroutine\n'
1215
- 'end interface\n'
1216
- )
1217
-
1218
-
1219
- def test_partial_loops_f():
1220
- # check that loop boundaries are determined by Idx, and array strides
1221
- # determined by shape of IndexedBase object.
1222
- from sympy.tensor import IndexedBase, Idx
1223
- from sympy.core.symbol import symbols
1224
- n, m, o, p = symbols('n m o p', integer=True)
1225
- A = IndexedBase('A', shape=(m, p))
1226
- x = IndexedBase('x')
1227
- y = IndexedBase('y')
1228
- i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
1229
- j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
1230
-
1231
- (f1, code), (f2, interface) = codegen(
1232
- ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
1233
-
1234
- expected = (
1235
- 'subroutine matrix_vector(A, m, n, o, p, x, y)\n'
1236
- 'implicit none\n'
1237
- 'INTEGER*4, intent(in) :: m\n'
1238
- 'INTEGER*4, intent(in) :: n\n'
1239
- 'INTEGER*4, intent(in) :: o\n'
1240
- 'INTEGER*4, intent(in) :: p\n'
1241
- 'REAL*8, intent(in), dimension(1:m, 1:p) :: A\n'
1242
- 'REAL*8, intent(in), dimension(1:n) :: x\n'
1243
- 'REAL*8, intent(out), dimension(1:%(iup-ilow)s) :: y\n'
1244
- 'INTEGER*4 :: i\n'
1245
- 'INTEGER*4 :: j\n'
1246
- 'do i = %(ilow)s, %(iup)s\n'
1247
- ' y(i) = 0\n'
1248
- 'end do\n'
1249
- 'do i = %(ilow)s, %(iup)s\n'
1250
- ' do j = 1, n\n'
1251
- ' y(i) = %(rhs)s + y(i)\n'
1252
- ' end do\n'
1253
- 'end do\n'
1254
- 'end subroutine\n'
1255
- ) % {
1256
- 'rhs': '%(rhs)s',
1257
- 'iup': str(m - 4),
1258
- 'ilow': str(1 + o),
1259
- 'iup-ilow': str(m - 4 - o)
1260
- }
1261
-
1262
- assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
1263
- code == expected % {'rhs': 'x(j)*A(i, j)'}
1264
-
1265
-
1266
- def test_output_arg_f():
1267
- from sympy.core.relational import Equality
1268
- from sympy.functions.elementary.trigonometric import (cos, sin)
1269
- x, y, z = symbols("x,y,z")
1270
- r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
1271
- c = FCodeGen()
1272
- result = c.write([r], "test", header=False, empty=False)
1273
- assert result[0][0] == "test.f90"
1274
- assert result[0][1] == (
1275
- 'REAL*8 function foo(x, y)\n'
1276
- 'implicit none\n'
1277
- 'REAL*8, intent(in) :: x\n'
1278
- 'REAL*8, intent(out) :: y\n'
1279
- 'y = sin(x)\n'
1280
- 'foo = cos(x)\n'
1281
- 'end function\n'
1282
- )
1283
-
1284
-
1285
- def test_inline_function():
1286
- from sympy.tensor import IndexedBase, Idx
1287
- from sympy.core.symbol import symbols
1288
- n, m = symbols('n m', integer=True)
1289
- A, x, y = map(IndexedBase, 'Axy')
1290
- i = Idx('i', m)
1291
- p = FCodeGen()
1292
- func = implemented_function('func', Lambda(n, n*(n + 1)))
1293
- routine = make_routine('test_inline', Eq(y[i], func(x[i])))
1294
- code = get_string(p.dump_f95, [routine])
1295
- expected = (
1296
- 'subroutine test_inline(m, x, y)\n'
1297
- 'implicit none\n'
1298
- 'INTEGER*4, intent(in) :: m\n'
1299
- 'REAL*8, intent(in), dimension(1:m) :: x\n'
1300
- 'REAL*8, intent(out), dimension(1:m) :: y\n'
1301
- 'INTEGER*4 :: i\n'
1302
- 'do i = 1, m\n'
1303
- ' y(i) = %s*%s\n'
1304
- 'end do\n'
1305
- 'end subroutine\n'
1306
- )
1307
- args = ('x(i)', '(x(i) + 1)')
1308
- assert code == expected % args or\
1309
- code == expected % args[::-1]
1310
-
1311
-
1312
- def test_f_code_call_signature_wrap():
1313
- # Issue #7934
1314
- x = symbols('x:20')
1315
- expr = 0
1316
- for sym in x:
1317
- expr += sym
1318
- routine = make_routine("test", expr)
1319
- code_gen = FCodeGen()
1320
- source = get_string(code_gen.dump_f95, [routine])
1321
- expected = """\
1322
- REAL*8 function test(x0, x1, x10, x11, x12, x13, x14, x15, x16, x17, x18, &
1323
- x19, x2, x3, x4, x5, x6, x7, x8, x9)
1324
- implicit none
1325
- REAL*8, intent(in) :: x0
1326
- REAL*8, intent(in) :: x1
1327
- REAL*8, intent(in) :: x10
1328
- REAL*8, intent(in) :: x11
1329
- REAL*8, intent(in) :: x12
1330
- REAL*8, intent(in) :: x13
1331
- REAL*8, intent(in) :: x14
1332
- REAL*8, intent(in) :: x15
1333
- REAL*8, intent(in) :: x16
1334
- REAL*8, intent(in) :: x17
1335
- REAL*8, intent(in) :: x18
1336
- REAL*8, intent(in) :: x19
1337
- REAL*8, intent(in) :: x2
1338
- REAL*8, intent(in) :: x3
1339
- REAL*8, intent(in) :: x4
1340
- REAL*8, intent(in) :: x5
1341
- REAL*8, intent(in) :: x6
1342
- REAL*8, intent(in) :: x7
1343
- REAL*8, intent(in) :: x8
1344
- REAL*8, intent(in) :: x9
1345
- test = x0 + x1 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + &
1346
- x19 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
1347
- end function
1348
- """
1349
- assert source == expected
1350
-
1351
-
1352
- def test_check_case():
1353
- x, X = symbols('x,X')
1354
- raises(CodeGenError, lambda: codegen(('test', x*X), 'f95', 'prefix'))
1355
-
1356
-
1357
- def test_check_case_false_positive():
1358
- # The upper case/lower case exception should not be triggered by SymPy
1359
- # objects that differ only because of assumptions. (It may be useful to
1360
- # have a check for that as well, but here we only want to test against
1361
- # false positives with respect to case checking.)
1362
- x1 = symbols('x')
1363
- x2 = symbols('x', my_assumption=True)
1364
- try:
1365
- codegen(('test', x1*x2), 'f95', 'prefix')
1366
- except CodeGenError as e:
1367
- if e.args[0].startswith("Fortran ignores case."):
1368
- raise AssertionError("This exception should not be raised!")
1369
-
1370
-
1371
- def test_c_fortran_omit_routine_name():
1372
- x, y = symbols("x,y")
1373
- name_expr = [("foo", 2*x)]
1374
- result = codegen(name_expr, "F95", header=False, empty=False)
1375
- expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
1376
- assert result[0][1] == expresult[0][1]
1377
-
1378
- name_expr = ("foo", x*y)
1379
- result = codegen(name_expr, "F95", header=False, empty=False)
1380
- expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
1381
- assert result[0][1] == expresult[0][1]
1382
-
1383
- name_expr = ("foo", Matrix([[x, y], [x+y, x-y]]))
1384
- result = codegen(name_expr, "C89", header=False, empty=False)
1385
- expresult = codegen(name_expr, "C89", "foo", header=False, empty=False)
1386
- assert result[0][1] == expresult[0][1]
1387
-
1388
-
1389
- def test_fcode_matrix_output():
1390
- x, y, z = symbols('x,y,z')
1391
- e1 = x + y
1392
- e2 = Matrix([[x, y], [z, 16]])
1393
- name_expr = ("test", (e1, e2))
1394
- result = codegen(name_expr, "f95", "test", header=False, empty=False)
1395
- source = result[0][1]
1396
- expected = (
1397
- "REAL*8 function test(x, y, z, out_%(hash)s)\n"
1398
- "implicit none\n"
1399
- "REAL*8, intent(in) :: x\n"
1400
- "REAL*8, intent(in) :: y\n"
1401
- "REAL*8, intent(in) :: z\n"
1402
- "REAL*8, intent(out), dimension(1:2, 1:2) :: out_%(hash)s\n"
1403
- "out_%(hash)s(1, 1) = x\n"
1404
- "out_%(hash)s(2, 1) = z\n"
1405
- "out_%(hash)s(1, 2) = y\n"
1406
- "out_%(hash)s(2, 2) = 16\n"
1407
- "test = x + y\n"
1408
- "end function\n"
1409
- )
1410
- # look for the magic number
1411
- a = source.splitlines()[5]
1412
- b = a.split('_')
1413
- out = b[1]
1414
- expected = expected % {'hash': out}
1415
- assert source == expected
1416
-
1417
-
1418
- def test_fcode_results_named_ordered():
1419
- x, y, z = symbols('x,y,z')
1420
- B, C = symbols('B,C')
1421
- A = MatrixSymbol('A', 1, 3)
1422
- expr1 = Equality(A, Matrix([[1, 2, x]]))
1423
- expr2 = Equality(C, (x + y)*z)
1424
- expr3 = Equality(B, 2*x)
1425
- name_expr = ("test", [expr1, expr2, expr3])
1426
- result = codegen(name_expr, "f95", "test", header=False, empty=False,
1427
- argument_sequence=(x, z, y, C, A, B))
1428
- source = result[0][1]
1429
- expected = (
1430
- "subroutine test(x, z, y, C, A, B)\n"
1431
- "implicit none\n"
1432
- "REAL*8, intent(in) :: x\n"
1433
- "REAL*8, intent(in) :: z\n"
1434
- "REAL*8, intent(in) :: y\n"
1435
- "REAL*8, intent(out) :: C\n"
1436
- "REAL*8, intent(out) :: B\n"
1437
- "REAL*8, intent(out), dimension(1:1, 1:3) :: A\n"
1438
- "C = z*(x + y)\n"
1439
- "A(1, 1) = 1\n"
1440
- "A(1, 2) = 2\n"
1441
- "A(1, 3) = x\n"
1442
- "B = 2*x\n"
1443
- "end subroutine\n"
1444
- )
1445
- assert source == expected
1446
-
1447
-
1448
- def test_fcode_matrixsymbol_slice():
1449
- A = MatrixSymbol('A', 2, 3)
1450
- B = MatrixSymbol('B', 1, 3)
1451
- C = MatrixSymbol('C', 1, 3)
1452
- D = MatrixSymbol('D', 2, 1)
1453
- name_expr = ("test", [Equality(B, A[0, :]),
1454
- Equality(C, A[1, :]),
1455
- Equality(D, A[:, 2])])
1456
- result = codegen(name_expr, "f95", "test", header=False, empty=False)
1457
- source = result[0][1]
1458
- expected = (
1459
- "subroutine test(A, B, C, D)\n"
1460
- "implicit none\n"
1461
- "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
1462
- "REAL*8, intent(out), dimension(1:1, 1:3) :: B\n"
1463
- "REAL*8, intent(out), dimension(1:1, 1:3) :: C\n"
1464
- "REAL*8, intent(out), dimension(1:2, 1:1) :: D\n"
1465
- "B(1, 1) = A(1, 1)\n"
1466
- "B(1, 2) = A(1, 2)\n"
1467
- "B(1, 3) = A(1, 3)\n"
1468
- "C(1, 1) = A(2, 1)\n"
1469
- "C(1, 2) = A(2, 2)\n"
1470
- "C(1, 3) = A(2, 3)\n"
1471
- "D(1, 1) = A(1, 3)\n"
1472
- "D(2, 1) = A(2, 3)\n"
1473
- "end subroutine\n"
1474
- )
1475
- assert source == expected
1476
-
1477
-
1478
- def test_fcode_matrixsymbol_slice_autoname():
1479
- # see issue #8093
1480
- A = MatrixSymbol('A', 2, 3)
1481
- name_expr = ("test", A[:, 1])
1482
- result = codegen(name_expr, "f95", "test", header=False, empty=False)
1483
- source = result[0][1]
1484
- expected = (
1485
- "subroutine test(A, out_%(hash)s)\n"
1486
- "implicit none\n"
1487
- "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
1488
- "REAL*8, intent(out), dimension(1:2, 1:1) :: out_%(hash)s\n"
1489
- "out_%(hash)s(1, 1) = A(1, 2)\n"
1490
- "out_%(hash)s(2, 1) = A(2, 2)\n"
1491
- "end subroutine\n"
1492
- )
1493
- # look for the magic number
1494
- a = source.splitlines()[3]
1495
- b = a.split('_')
1496
- out = b[1]
1497
- expected = expected % {'hash': out}
1498
- assert source == expected
1499
-
1500
-
1501
- def test_global_vars():
1502
- x, y, z, t = symbols("x y z t")
1503
- result = codegen(('f', x*y), "F95", header=False, empty=False,
1504
- global_vars=(y,))
1505
- source = result[0][1]
1506
- expected = (
1507
- "REAL*8 function f(x)\n"
1508
- "implicit none\n"
1509
- "REAL*8, intent(in) :: x\n"
1510
- "f = x*y\n"
1511
- "end function\n"
1512
- )
1513
- assert source == expected
1514
-
1515
- expected = (
1516
- '#include "f.h"\n'
1517
- '#include <math.h>\n'
1518
- 'double f(double x, double y) {\n'
1519
- ' double f_result;\n'
1520
- ' f_result = x*y + z;\n'
1521
- ' return f_result;\n'
1522
- '}\n'
1523
- )
1524
- result = codegen(('f', x*y+z), "C", header=False, empty=False,
1525
- global_vars=(z, t))
1526
- source = result[0][1]
1527
- assert source == expected
1528
-
1529
- def test_custom_codegen():
1530
- from sympy.printing.c import C99CodePrinter
1531
- from sympy.functions.elementary.exponential import exp
1532
-
1533
- printer = C99CodePrinter(settings={'user_functions': {'exp': 'fastexp'}})
1534
-
1535
- x, y = symbols('x y')
1536
- expr = exp(x + y)
1537
-
1538
- # replace math.h with a different header
1539
- gen = C99CodeGen(printer=printer,
1540
- preprocessor_statements=['#include "fastexp.h"'])
1541
-
1542
- expected = (
1543
- '#include "expr.h"\n'
1544
- '#include "fastexp.h"\n'
1545
- 'double expr(double x, double y) {\n'
1546
- ' double expr_result;\n'
1547
- ' expr_result = fastexp(x + y);\n'
1548
- ' return expr_result;\n'
1549
- '}\n'
1550
- )
1551
-
1552
- result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
1553
- source = result[0][1]
1554
- assert source == expected
1555
-
1556
- # use both math.h and an external header
1557
- gen = C99CodeGen(printer=printer)
1558
- gen.preprocessor_statements.append('#include "fastexp.h"')
1559
-
1560
- expected = (
1561
- '#include "expr.h"\n'
1562
- '#include <math.h>\n'
1563
- '#include "fastexp.h"\n'
1564
- 'double expr(double x, double y) {\n'
1565
- ' double expr_result;\n'
1566
- ' expr_result = fastexp(x + y);\n'
1567
- ' return expr_result;\n'
1568
- '}\n'
1569
- )
1570
-
1571
- result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
1572
- source = result[0][1]
1573
- assert source == expected
1574
-
1575
- def test_c_with_printer():
1576
- # issue 13586
1577
- from sympy.printing.c import C99CodePrinter
1578
- class CustomPrinter(C99CodePrinter):
1579
- def _print_Pow(self, expr):
1580
- return "fastpow({}, {})".format(self._print(expr.base),
1581
- self._print(expr.exp))
1582
-
1583
- x = symbols('x')
1584
- expr = x**3
1585
- expected =[
1586
- ("file.c",
1587
- "#include \"file.h\"\n"
1588
- "#include <math.h>\n"
1589
- "double test(double x) {\n"
1590
- " double test_result;\n"
1591
- " test_result = fastpow(x, 3);\n"
1592
- " return test_result;\n"
1593
- "}\n"),
1594
- ("file.h",
1595
- "#ifndef PROJECT__FILE__H\n"
1596
- "#define PROJECT__FILE__H\n"
1597
- "double test(double x);\n"
1598
- "#endif\n")
1599
- ]
1600
- result = codegen(("test", expr), "C","file", header=False, empty=False, printer = CustomPrinter())
1601
- assert result == expected
1602
-
1603
-
1604
- def test_fcode_complex():
1605
- import sympy.utilities.codegen
1606
- sympy.utilities.codegen.COMPLEX_ALLOWED = True
1607
- x = Symbol('x', real=True)
1608
- y = Symbol('y',real=True)
1609
- result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
1610
- source = (result[0][1])
1611
- expected = (
1612
- "REAL*8 function test(x, y)\n"
1613
- "implicit none\n"
1614
- "REAL*8, intent(in) :: x\n"
1615
- "REAL*8, intent(in) :: y\n"
1616
- "test = x + y\n"
1617
- "end function\n")
1618
- assert source == expected
1619
- x = Symbol('x')
1620
- y = Symbol('y',real=True)
1621
- result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
1622
- source = (result[0][1])
1623
- expected = (
1624
- "COMPLEX*16 function test(x, y)\n"
1625
- "implicit none\n"
1626
- "COMPLEX*16, intent(in) :: x\n"
1627
- "REAL*8, intent(in) :: y\n"
1628
- "test = x + y\n"
1629
- "end function\n"
1630
- )
1631
- assert source==expected
1632
- sympy.utilities.codegen.COMPLEX_ALLOWED = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_julia.py DELETED
@@ -1,620 +0,0 @@
1
- from io import StringIO
2
-
3
- from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
4
- from sympy.core.relational import Equality
5
- from sympy.functions.elementary.piecewise import Piecewise
6
- from sympy.matrices import Matrix, MatrixSymbol
7
- from sympy.utilities.codegen import JuliaCodeGen, codegen, make_routine
8
- from sympy.testing.pytest import XFAIL
9
- import sympy
10
-
11
-
12
- x, y, z = symbols('x,y,z')
13
-
14
-
15
- def test_empty_jl_code():
16
- code_gen = JuliaCodeGen()
17
- output = StringIO()
18
- code_gen.dump_jl([], output, "file", header=False, empty=False)
19
- source = output.getvalue()
20
- assert source == ""
21
-
22
-
23
- def test_jl_simple_code():
24
- name_expr = ("test", (x + y)*z)
25
- result, = codegen(name_expr, "Julia", header=False, empty=False)
26
- assert result[0] == "test.jl"
27
- source = result[1]
28
- expected = (
29
- "function test(x, y, z)\n"
30
- " out1 = z .* (x + y)\n"
31
- " return out1\n"
32
- "end\n"
33
- )
34
- assert source == expected
35
-
36
-
37
- def test_jl_simple_code_with_header():
38
- name_expr = ("test", (x + y)*z)
39
- result, = codegen(name_expr, "Julia", header=True, empty=False)
40
- assert result[0] == "test.jl"
41
- source = result[1]
42
- expected = (
43
- "# Code generated with SymPy " + sympy.__version__ + "\n"
44
- "#\n"
45
- "# See http://www.sympy.org/ for more information.\n"
46
- "#\n"
47
- "# This file is part of 'project'\n"
48
- "function test(x, y, z)\n"
49
- " out1 = z .* (x + y)\n"
50
- " return out1\n"
51
- "end\n"
52
- )
53
- assert source == expected
54
-
55
-
56
- def test_jl_simple_code_nameout():
57
- expr = Equality(z, (x + y))
58
- name_expr = ("test", expr)
59
- result, = codegen(name_expr, "Julia", header=False, empty=False)
60
- source = result[1]
61
- expected = (
62
- "function test(x, y)\n"
63
- " z = x + y\n"
64
- " return z\n"
65
- "end\n"
66
- )
67
- assert source == expected
68
-
69
-
70
- def test_jl_numbersymbol():
71
- name_expr = ("test", pi**Catalan)
72
- result, = codegen(name_expr, "Julia", header=False, empty=False)
73
- source = result[1]
74
- expected = (
75
- "function test()\n"
76
- " out1 = pi ^ catalan\n"
77
- " return out1\n"
78
- "end\n"
79
- )
80
- assert source == expected
81
-
82
-
83
- @XFAIL
84
- def test_jl_numbersymbol_no_inline():
85
- # FIXME: how to pass inline=False to the JuliaCodePrinter?
86
- name_expr = ("test", [pi**Catalan, EulerGamma])
87
- result, = codegen(name_expr, "Julia", header=False,
88
- empty=False, inline=False)
89
- source = result[1]
90
- expected = (
91
- "function test()\n"
92
- " Catalan = 0.915965594177219\n"
93
- " EulerGamma = 0.5772156649015329\n"
94
- " out1 = pi ^ Catalan\n"
95
- " out2 = EulerGamma\n"
96
- " return out1, out2\n"
97
- "end\n"
98
- )
99
- assert source == expected
100
-
101
-
102
- def test_jl_code_argument_order():
103
- expr = x + y
104
- routine = make_routine("test", expr, argument_sequence=[z, x, y], language="julia")
105
- code_gen = JuliaCodeGen()
106
- output = StringIO()
107
- code_gen.dump_jl([routine], output, "test", header=False, empty=False)
108
- source = output.getvalue()
109
- expected = (
110
- "function test(z, x, y)\n"
111
- " out1 = x + y\n"
112
- " return out1\n"
113
- "end\n"
114
- )
115
- assert source == expected
116
-
117
-
118
- def test_multiple_results_m():
119
- # Here the output order is the input order
120
- expr1 = (x + y)*z
121
- expr2 = (x - y)*z
122
- name_expr = ("test", [expr1, expr2])
123
- result, = codegen(name_expr, "Julia", header=False, empty=False)
124
- source = result[1]
125
- expected = (
126
- "function test(x, y, z)\n"
127
- " out1 = z .* (x + y)\n"
128
- " out2 = z .* (x - y)\n"
129
- " return out1, out2\n"
130
- "end\n"
131
- )
132
- assert source == expected
133
-
134
-
135
- def test_results_named_unordered():
136
- # Here output order is based on name_expr
137
- A, B, C = symbols('A,B,C')
138
- expr1 = Equality(C, (x + y)*z)
139
- expr2 = Equality(A, (x - y)*z)
140
- expr3 = Equality(B, 2*x)
141
- name_expr = ("test", [expr1, expr2, expr3])
142
- result, = codegen(name_expr, "Julia", header=False, empty=False)
143
- source = result[1]
144
- expected = (
145
- "function test(x, y, z)\n"
146
- " C = z .* (x + y)\n"
147
- " A = z .* (x - y)\n"
148
- " B = 2 * x\n"
149
- " return C, A, B\n"
150
- "end\n"
151
- )
152
- assert source == expected
153
-
154
-
155
- def test_results_named_ordered():
156
- A, B, C = symbols('A,B,C')
157
- expr1 = Equality(C, (x + y)*z)
158
- expr2 = Equality(A, (x - y)*z)
159
- expr3 = Equality(B, 2*x)
160
- name_expr = ("test", [expr1, expr2, expr3])
161
- result = codegen(name_expr, "Julia", header=False, empty=False,
162
- argument_sequence=(x, z, y))
163
- assert result[0][0] == "test.jl"
164
- source = result[0][1]
165
- expected = (
166
- "function test(x, z, y)\n"
167
- " C = z .* (x + y)\n"
168
- " A = z .* (x - y)\n"
169
- " B = 2 * x\n"
170
- " return C, A, B\n"
171
- "end\n"
172
- )
173
- assert source == expected
174
-
175
-
176
- def test_complicated_jl_codegen():
177
- from sympy.functions.elementary.trigonometric import (cos, sin, tan)
178
- name_expr = ("testlong",
179
- [ ((sin(x) + cos(y) + tan(z))**3).expand(),
180
- cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
181
- ])
182
- result = codegen(name_expr, "Julia", header=False, empty=False)
183
- assert result[0][0] == "testlong.jl"
184
- source = result[0][1]
185
- expected = (
186
- "function testlong(x, y, z)\n"
187
- " out1 = sin(x) .^ 3 + 3 * sin(x) .^ 2 .* cos(y) + 3 * sin(x) .^ 2 .* tan(z)"
188
- " + 3 * sin(x) .* cos(y) .^ 2 + 6 * sin(x) .* cos(y) .* tan(z) + 3 * sin(x) .* tan(z) .^ 2"
189
- " + cos(y) .^ 3 + 3 * cos(y) .^ 2 .* tan(z) + 3 * cos(y) .* tan(z) .^ 2 + tan(z) .^ 3\n"
190
- " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n"
191
- " return out1, out2\n"
192
- "end\n"
193
- )
194
- assert source == expected
195
-
196
-
197
- def test_jl_output_arg_mixed_unordered():
198
- # named outputs are alphabetical, unnamed output appear in the given order
199
- from sympy.functions.elementary.trigonometric import (cos, sin)
200
- a = symbols("a")
201
- name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
202
- result, = codegen(name_expr, "Julia", header=False, empty=False)
203
- assert result[0] == "foo.jl"
204
- source = result[1]
205
- expected = (
206
- 'function foo(x)\n'
207
- ' out1 = cos(2 * x)\n'
208
- ' y = sin(x)\n'
209
- ' out3 = cos(x)\n'
210
- ' a = sin(2 * x)\n'
211
- ' return out1, y, out3, a\n'
212
- 'end\n'
213
- )
214
- assert source == expected
215
-
216
-
217
- def test_jl_piecewise_():
218
- pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
219
- name_expr = ("pwtest", pw)
220
- result, = codegen(name_expr, "Julia", header=False, empty=False)
221
- source = result[1]
222
- expected = (
223
- "function pwtest(x)\n"
224
- " out1 = ((x < -1) ? (0) :\n"
225
- " (x <= 1) ? (x .^ 2) :\n"
226
- " (x > 1) ? (2 - x) : (1))\n"
227
- " return out1\n"
228
- "end\n"
229
- )
230
- assert source == expected
231
-
232
-
233
- @XFAIL
234
- def test_jl_piecewise_no_inline():
235
- # FIXME: how to pass inline=False to the JuliaCodePrinter?
236
- pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
237
- name_expr = ("pwtest", pw)
238
- result, = codegen(name_expr, "Julia", header=False, empty=False,
239
- inline=False)
240
- source = result[1]
241
- expected = (
242
- "function pwtest(x)\n"
243
- " if (x < -1)\n"
244
- " out1 = 0\n"
245
- " elseif (x <= 1)\n"
246
- " out1 = x .^ 2\n"
247
- " elseif (x > 1)\n"
248
- " out1 = -x + 2\n"
249
- " else\n"
250
- " out1 = 1\n"
251
- " end\n"
252
- " return out1\n"
253
- "end\n"
254
- )
255
- assert source == expected
256
-
257
-
258
- def test_jl_multifcns_per_file():
259
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
260
- result = codegen(name_expr, "Julia", header=False, empty=False)
261
- assert result[0][0] == "foo.jl"
262
- source = result[0][1]
263
- expected = (
264
- "function foo(x, y)\n"
265
- " out1 = 2 * x\n"
266
- " out2 = 3 * y\n"
267
- " return out1, out2\n"
268
- "end\n"
269
- "function bar(y)\n"
270
- " out1 = y .^ 2\n"
271
- " out2 = 4 * y\n"
272
- " return out1, out2\n"
273
- "end\n"
274
- )
275
- assert source == expected
276
-
277
-
278
- def test_jl_multifcns_per_file_w_header():
279
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
280
- result = codegen(name_expr, "Julia", header=True, empty=False)
281
- assert result[0][0] == "foo.jl"
282
- source = result[0][1]
283
- expected = (
284
- "# Code generated with SymPy " + sympy.__version__ + "\n"
285
- "#\n"
286
- "# See http://www.sympy.org/ for more information.\n"
287
- "#\n"
288
- "# This file is part of 'project'\n"
289
- "function foo(x, y)\n"
290
- " out1 = 2 * x\n"
291
- " out2 = 3 * y\n"
292
- " return out1, out2\n"
293
- "end\n"
294
- "function bar(y)\n"
295
- " out1 = y .^ 2\n"
296
- " out2 = 4 * y\n"
297
- " return out1, out2\n"
298
- "end\n"
299
- )
300
- assert source == expected
301
-
302
-
303
- def test_jl_filename_match_prefix():
304
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
305
- result, = codegen(name_expr, "Julia", prefix="baz", header=False,
306
- empty=False)
307
- assert result[0] == "baz.jl"
308
-
309
-
310
- def test_jl_matrix_named():
311
- e2 = Matrix([[x, 2*y, pi*z]])
312
- name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
313
- result = codegen(name_expr, "Julia", header=False, empty=False)
314
- assert result[0][0] == "test.jl"
315
- source = result[0][1]
316
- expected = (
317
- "function test(x, y, z)\n"
318
- " myout1 = [x 2 * y pi * z]\n"
319
- " return myout1\n"
320
- "end\n"
321
- )
322
- assert source == expected
323
-
324
-
325
- def test_jl_matrix_named_matsym():
326
- myout1 = MatrixSymbol('myout1', 1, 3)
327
- e2 = Matrix([[x, 2*y, pi*z]])
328
- name_expr = ("test", Equality(myout1, e2, evaluate=False))
329
- result, = codegen(name_expr, "Julia", header=False, empty=False)
330
- source = result[1]
331
- expected = (
332
- "function test(x, y, z)\n"
333
- " myout1 = [x 2 * y pi * z]\n"
334
- " return myout1\n"
335
- "end\n"
336
- )
337
- assert source == expected
338
-
339
-
340
- def test_jl_matrix_output_autoname():
341
- expr = Matrix([[x, x+y, 3]])
342
- name_expr = ("test", expr)
343
- result, = codegen(name_expr, "Julia", header=False, empty=False)
344
- source = result[1]
345
- expected = (
346
- "function test(x, y)\n"
347
- " out1 = [x x + y 3]\n"
348
- " return out1\n"
349
- "end\n"
350
- )
351
- assert source == expected
352
-
353
-
354
- def test_jl_matrix_output_autoname_2():
355
- e1 = (x + y)
356
- e2 = Matrix([[2*x, 2*y, 2*z]])
357
- e3 = Matrix([[x], [y], [z]])
358
- e4 = Matrix([[x, y], [z, 16]])
359
- name_expr = ("test", (e1, e2, e3, e4))
360
- result, = codegen(name_expr, "Julia", header=False, empty=False)
361
- source = result[1]
362
- expected = (
363
- "function test(x, y, z)\n"
364
- " out1 = x + y\n"
365
- " out2 = [2 * x 2 * y 2 * z]\n"
366
- " out3 = [x, y, z]\n"
367
- " out4 = [x y;\n"
368
- " z 16]\n"
369
- " return out1, out2, out3, out4\n"
370
- "end\n"
371
- )
372
- assert source == expected
373
-
374
-
375
- def test_jl_results_matrix_named_ordered():
376
- B, C = symbols('B,C')
377
- A = MatrixSymbol('A', 1, 3)
378
- expr1 = Equality(C, (x + y)*z)
379
- expr2 = Equality(A, Matrix([[1, 2, x]]))
380
- expr3 = Equality(B, 2*x)
381
- name_expr = ("test", [expr1, expr2, expr3])
382
- result, = codegen(name_expr, "Julia", header=False, empty=False,
383
- argument_sequence=(x, z, y))
384
- source = result[1]
385
- expected = (
386
- "function test(x, z, y)\n"
387
- " C = z .* (x + y)\n"
388
- " A = [1 2 x]\n"
389
- " B = 2 * x\n"
390
- " return C, A, B\n"
391
- "end\n"
392
- )
393
- assert source == expected
394
-
395
-
396
- def test_jl_matrixsymbol_slice():
397
- A = MatrixSymbol('A', 2, 3)
398
- B = MatrixSymbol('B', 1, 3)
399
- C = MatrixSymbol('C', 1, 3)
400
- D = MatrixSymbol('D', 2, 1)
401
- name_expr = ("test", [Equality(B, A[0, :]),
402
- Equality(C, A[1, :]),
403
- Equality(D, A[:, 2])])
404
- result, = codegen(name_expr, "Julia", header=False, empty=False)
405
- source = result[1]
406
- expected = (
407
- "function test(A)\n"
408
- " B = A[1,:]\n"
409
- " C = A[2,:]\n"
410
- " D = A[:,3]\n"
411
- " return B, C, D\n"
412
- "end\n"
413
- )
414
- assert source == expected
415
-
416
-
417
- def test_jl_matrixsymbol_slice2():
418
- A = MatrixSymbol('A', 3, 4)
419
- B = MatrixSymbol('B', 2, 2)
420
- C = MatrixSymbol('C', 2, 2)
421
- name_expr = ("test", [Equality(B, A[0:2, 0:2]),
422
- Equality(C, A[0:2, 1:3])])
423
- result, = codegen(name_expr, "Julia", header=False, empty=False)
424
- source = result[1]
425
- expected = (
426
- "function test(A)\n"
427
- " B = A[1:2,1:2]\n"
428
- " C = A[1:2,2:3]\n"
429
- " return B, C\n"
430
- "end\n"
431
- )
432
- assert source == expected
433
-
434
-
435
- def test_jl_matrixsymbol_slice3():
436
- A = MatrixSymbol('A', 8, 7)
437
- B = MatrixSymbol('B', 2, 2)
438
- C = MatrixSymbol('C', 4, 2)
439
- name_expr = ("test", [Equality(B, A[6:, 1::3]),
440
- Equality(C, A[::2, ::3])])
441
- result, = codegen(name_expr, "Julia", header=False, empty=False)
442
- source = result[1]
443
- expected = (
444
- "function test(A)\n"
445
- " B = A[7:end,2:3:end]\n"
446
- " C = A[1:2:end,1:3:end]\n"
447
- " return B, C\n"
448
- "end\n"
449
- )
450
- assert source == expected
451
-
452
-
453
- def test_jl_matrixsymbol_slice_autoname():
454
- A = MatrixSymbol('A', 2, 3)
455
- B = MatrixSymbol('B', 1, 3)
456
- name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
457
- result, = codegen(name_expr, "Julia", header=False, empty=False)
458
- source = result[1]
459
- expected = (
460
- "function test(A)\n"
461
- " B = A[1,:]\n"
462
- " out2 = A[2,:]\n"
463
- " out3 = A[:,1]\n"
464
- " out4 = A[:,2]\n"
465
- " return B, out2, out3, out4\n"
466
- "end\n"
467
- )
468
- assert source == expected
469
-
470
-
471
- def test_jl_loops():
472
- # Note: an Julia programmer would probably vectorize this across one or
473
- # more dimensions. Also, size(A) would be used rather than passing in m
474
- # and n. Perhaps users would expect us to vectorize automatically here?
475
- # Or is it possible to represent such things using IndexedBase?
476
- from sympy.tensor import IndexedBase, Idx
477
- from sympy.core.symbol import symbols
478
- n, m = symbols('n m', integer=True)
479
- A = IndexedBase('A')
480
- x = IndexedBase('x')
481
- y = IndexedBase('y')
482
- i = Idx('i', m)
483
- j = Idx('j', n)
484
- result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Julia",
485
- header=False, empty=False)
486
- source = result[1]
487
- expected = (
488
- 'function mat_vec_mult(y, A, m, n, x)\n'
489
- ' for i = 1:m\n'
490
- ' y[i] = 0\n'
491
- ' end\n'
492
- ' for i = 1:m\n'
493
- ' for j = 1:n\n'
494
- ' y[i] = %(rhs)s + y[i]\n'
495
- ' end\n'
496
- ' end\n'
497
- ' return y\n'
498
- 'end\n'
499
- )
500
- assert (source == expected % {'rhs': 'A[%s,%s] .* x[j]' % (i, j)} or
501
- source == expected % {'rhs': 'x[j] .* A[%s,%s]' % (i, j)})
502
-
503
-
504
- def test_jl_tensor_loops_multiple_contractions():
505
- # see comments in previous test about vectorizing
506
- from sympy.tensor import IndexedBase, Idx
507
- from sympy.core.symbol import symbols
508
- n, m, o, p = symbols('n m o p', integer=True)
509
- A = IndexedBase('A')
510
- B = IndexedBase('B')
511
- y = IndexedBase('y')
512
- i = Idx('i', m)
513
- j = Idx('j', n)
514
- k = Idx('k', o)
515
- l = Idx('l', p)
516
- result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
517
- "Julia", header=False, empty=False)
518
- source = result[1]
519
- expected = (
520
- 'function tensorthing(y, A, B, m, n, o, p)\n'
521
- ' for i = 1:m\n'
522
- ' y[i] = 0\n'
523
- ' end\n'
524
- ' for i = 1:m\n'
525
- ' for j = 1:n\n'
526
- ' for k = 1:o\n'
527
- ' for l = 1:p\n'
528
- ' y[i] = A[i,j,k,l] .* B[j,k,l] + y[i]\n'
529
- ' end\n'
530
- ' end\n'
531
- ' end\n'
532
- ' end\n'
533
- ' return y\n'
534
- 'end\n'
535
- )
536
- assert source == expected
537
-
538
-
539
- def test_jl_InOutArgument():
540
- expr = Equality(x, x**2)
541
- name_expr = ("mysqr", expr)
542
- result, = codegen(name_expr, "Julia", header=False, empty=False)
543
- source = result[1]
544
- expected = (
545
- "function mysqr(x)\n"
546
- " x = x .^ 2\n"
547
- " return x\n"
548
- "end\n"
549
- )
550
- assert source == expected
551
-
552
-
553
- def test_jl_InOutArgument_order():
554
- # can specify the order as (x, y)
555
- expr = Equality(x, x**2 + y)
556
- name_expr = ("test", expr)
557
- result, = codegen(name_expr, "Julia", header=False,
558
- empty=False, argument_sequence=(x,y))
559
- source = result[1]
560
- expected = (
561
- "function test(x, y)\n"
562
- " x = x .^ 2 + y\n"
563
- " return x\n"
564
- "end\n"
565
- )
566
- assert source == expected
567
- # make sure it gives (x, y) not (y, x)
568
- expr = Equality(x, x**2 + y)
569
- name_expr = ("test", expr)
570
- result, = codegen(name_expr, "Julia", header=False, empty=False)
571
- source = result[1]
572
- expected = (
573
- "function test(x, y)\n"
574
- " x = x .^ 2 + y\n"
575
- " return x\n"
576
- "end\n"
577
- )
578
- assert source == expected
579
-
580
-
581
- def test_jl_not_supported():
582
- f = Function('f')
583
- name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
584
- result, = codegen(name_expr, "Julia", header=False, empty=False)
585
- source = result[1]
586
- expected = (
587
- "function test(x)\n"
588
- " # unsupported: Derivative(f(x), x)\n"
589
- " # unsupported: zoo\n"
590
- " out1 = Derivative(f(x), x)\n"
591
- " out2 = zoo\n"
592
- " return out1, out2\n"
593
- "end\n"
594
- )
595
- assert source == expected
596
-
597
-
598
- def test_global_vars_octave():
599
- x, y, z, t = symbols("x y z t")
600
- result = codegen(('f', x*y), "Julia", header=False, empty=False,
601
- global_vars=(y,))
602
- source = result[0][1]
603
- expected = (
604
- "function f(x)\n"
605
- " out1 = x .* y\n"
606
- " return out1\n"
607
- "end\n"
608
- )
609
- assert source == expected
610
-
611
- result = codegen(('f', x*y+z), "Julia", header=False, empty=False,
612
- argument_sequence=(x, y), global_vars=(z, t))
613
- source = result[0][1]
614
- expected = (
615
- "function f(x, y)\n"
616
- " out1 = x .* y + z\n"
617
- " return out1\n"
618
- "end\n"
619
- )
620
- assert source == expected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_octave.py DELETED
@@ -1,589 +0,0 @@
1
- from io import StringIO
2
-
3
- from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
4
- from sympy.core.relational import Equality
5
- from sympy.functions.elementary.piecewise import Piecewise
6
- from sympy.matrices import Matrix, MatrixSymbol
7
- from sympy.utilities.codegen import OctaveCodeGen, codegen, make_routine
8
- from sympy.testing.pytest import raises
9
- from sympy.testing.pytest import XFAIL
10
- import sympy
11
-
12
-
13
- x, y, z = symbols('x,y,z')
14
-
15
-
16
- def test_empty_m_code():
17
- code_gen = OctaveCodeGen()
18
- output = StringIO()
19
- code_gen.dump_m([], output, "file", header=False, empty=False)
20
- source = output.getvalue()
21
- assert source == ""
22
-
23
-
24
- def test_m_simple_code():
25
- name_expr = ("test", (x + y)*z)
26
- result, = codegen(name_expr, "Octave", header=False, empty=False)
27
- assert result[0] == "test.m"
28
- source = result[1]
29
- expected = (
30
- "function out1 = test(x, y, z)\n"
31
- " out1 = z.*(x + y);\n"
32
- "end\n"
33
- )
34
- assert source == expected
35
-
36
-
37
- def test_m_simple_code_with_header():
38
- name_expr = ("test", (x + y)*z)
39
- result, = codegen(name_expr, "Octave", header=True, empty=False)
40
- assert result[0] == "test.m"
41
- source = result[1]
42
- expected = (
43
- "function out1 = test(x, y, z)\n"
44
- " %TEST Autogenerated by SymPy\n"
45
- " % Code generated with SymPy " + sympy.__version__ + "\n"
46
- " %\n"
47
- " % See http://www.sympy.org/ for more information.\n"
48
- " %\n"
49
- " % This file is part of 'project'\n"
50
- " out1 = z.*(x + y);\n"
51
- "end\n"
52
- )
53
- assert source == expected
54
-
55
-
56
- def test_m_simple_code_nameout():
57
- expr = Equality(z, (x + y))
58
- name_expr = ("test", expr)
59
- result, = codegen(name_expr, "Octave", header=False, empty=False)
60
- source = result[1]
61
- expected = (
62
- "function z = test(x, y)\n"
63
- " z = x + y;\n"
64
- "end\n"
65
- )
66
- assert source == expected
67
-
68
-
69
- def test_m_numbersymbol():
70
- name_expr = ("test", pi**Catalan)
71
- result, = codegen(name_expr, "Octave", header=False, empty=False)
72
- source = result[1]
73
- expected = (
74
- "function out1 = test()\n"
75
- " out1 = pi^%s;\n"
76
- "end\n"
77
- ) % Catalan.evalf(17)
78
- assert source == expected
79
-
80
-
81
- @XFAIL
82
- def test_m_numbersymbol_no_inline():
83
- # FIXME: how to pass inline=False to the OctaveCodePrinter?
84
- name_expr = ("test", [pi**Catalan, EulerGamma])
85
- result, = codegen(name_expr, "Octave", header=False,
86
- empty=False, inline=False)
87
- source = result[1]
88
- expected = (
89
- "function [out1, out2] = test()\n"
90
- " Catalan = 0.915965594177219; % constant\n"
91
- " EulerGamma = 0.5772156649015329; % constant\n"
92
- " out1 = pi^Catalan;\n"
93
- " out2 = EulerGamma;\n"
94
- "end\n"
95
- )
96
- assert source == expected
97
-
98
-
99
- def test_m_code_argument_order():
100
- expr = x + y
101
- routine = make_routine("test", expr, argument_sequence=[z, x, y], language="octave")
102
- code_gen = OctaveCodeGen()
103
- output = StringIO()
104
- code_gen.dump_m([routine], output, "test", header=False, empty=False)
105
- source = output.getvalue()
106
- expected = (
107
- "function out1 = test(z, x, y)\n"
108
- " out1 = x + y;\n"
109
- "end\n"
110
- )
111
- assert source == expected
112
-
113
-
114
- def test_multiple_results_m():
115
- # Here the output order is the input order
116
- expr1 = (x + y)*z
117
- expr2 = (x - y)*z
118
- name_expr = ("test", [expr1, expr2])
119
- result, = codegen(name_expr, "Octave", header=False, empty=False)
120
- source = result[1]
121
- expected = (
122
- "function [out1, out2] = test(x, y, z)\n"
123
- " out1 = z.*(x + y);\n"
124
- " out2 = z.*(x - y);\n"
125
- "end\n"
126
- )
127
- assert source == expected
128
-
129
-
130
- def test_results_named_unordered():
131
- # Here output order is based on name_expr
132
- A, B, C = symbols('A,B,C')
133
- expr1 = Equality(C, (x + y)*z)
134
- expr2 = Equality(A, (x - y)*z)
135
- expr3 = Equality(B, 2*x)
136
- name_expr = ("test", [expr1, expr2, expr3])
137
- result, = codegen(name_expr, "Octave", header=False, empty=False)
138
- source = result[1]
139
- expected = (
140
- "function [C, A, B] = test(x, y, z)\n"
141
- " C = z.*(x + y);\n"
142
- " A = z.*(x - y);\n"
143
- " B = 2*x;\n"
144
- "end\n"
145
- )
146
- assert source == expected
147
-
148
-
149
- def test_results_named_ordered():
150
- A, B, C = symbols('A,B,C')
151
- expr1 = Equality(C, (x + y)*z)
152
- expr2 = Equality(A, (x - y)*z)
153
- expr3 = Equality(B, 2*x)
154
- name_expr = ("test", [expr1, expr2, expr3])
155
- result = codegen(name_expr, "Octave", header=False, empty=False,
156
- argument_sequence=(x, z, y))
157
- assert result[0][0] == "test.m"
158
- source = result[0][1]
159
- expected = (
160
- "function [C, A, B] = test(x, z, y)\n"
161
- " C = z.*(x + y);\n"
162
- " A = z.*(x - y);\n"
163
- " B = 2*x;\n"
164
- "end\n"
165
- )
166
- assert source == expected
167
-
168
-
169
- def test_complicated_m_codegen():
170
- from sympy.functions.elementary.trigonometric import (cos, sin, tan)
171
- name_expr = ("testlong",
172
- [ ((sin(x) + cos(y) + tan(z))**3).expand(),
173
- cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
174
- ])
175
- result = codegen(name_expr, "Octave", header=False, empty=False)
176
- assert result[0][0] == "testlong.m"
177
- source = result[0][1]
178
- expected = (
179
- "function [out1, out2] = testlong(x, y, z)\n"
180
- " out1 = sin(x).^3 + 3*sin(x).^2.*cos(y) + 3*sin(x).^2.*tan(z)"
181
- " + 3*sin(x).*cos(y).^2 + 6*sin(x).*cos(y).*tan(z) + 3*sin(x).*tan(z).^2"
182
- " + cos(y).^3 + 3*cos(y).^2.*tan(z) + 3*cos(y).*tan(z).^2 + tan(z).^3;\n"
183
- " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n"
184
- "end\n"
185
- )
186
- assert source == expected
187
-
188
-
189
- def test_m_output_arg_mixed_unordered():
190
- # named outputs are alphabetical, unnamed output appear in the given order
191
- from sympy.functions.elementary.trigonometric import (cos, sin)
192
- a = symbols("a")
193
- name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
194
- result, = codegen(name_expr, "Octave", header=False, empty=False)
195
- assert result[0] == "foo.m"
196
- source = result[1]
197
- expected = (
198
- 'function [out1, y, out3, a] = foo(x)\n'
199
- ' out1 = cos(2*x);\n'
200
- ' y = sin(x);\n'
201
- ' out3 = cos(x);\n'
202
- ' a = sin(2*x);\n'
203
- 'end\n'
204
- )
205
- assert source == expected
206
-
207
-
208
- def test_m_piecewise_():
209
- pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
210
- name_expr = ("pwtest", pw)
211
- result, = codegen(name_expr, "Octave", header=False, empty=False)
212
- source = result[1]
213
- expected = (
214
- "function out1 = pwtest(x)\n"
215
- " out1 = ((x < -1).*(0) + (~(x < -1)).*( ...\n"
216
- " (x <= 1).*(x.^2) + (~(x <= 1)).*( ...\n"
217
- " (x > 1).*(2 - x) + (~(x > 1)).*(1))));\n"
218
- "end\n"
219
- )
220
- assert source == expected
221
-
222
-
223
- @XFAIL
224
- def test_m_piecewise_no_inline():
225
- # FIXME: how to pass inline=False to the OctaveCodePrinter?
226
- pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
227
- name_expr = ("pwtest", pw)
228
- result, = codegen(name_expr, "Octave", header=False, empty=False,
229
- inline=False)
230
- source = result[1]
231
- expected = (
232
- "function out1 = pwtest(x)\n"
233
- " if (x < -1)\n"
234
- " out1 = 0;\n"
235
- " elseif (x <= 1)\n"
236
- " out1 = x.^2;\n"
237
- " elseif (x > 1)\n"
238
- " out1 = -x + 2;\n"
239
- " else\n"
240
- " out1 = 1;\n"
241
- " end\n"
242
- "end\n"
243
- )
244
- assert source == expected
245
-
246
-
247
- def test_m_multifcns_per_file():
248
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
249
- result = codegen(name_expr, "Octave", header=False, empty=False)
250
- assert result[0][0] == "foo.m"
251
- source = result[0][1]
252
- expected = (
253
- "function [out1, out2] = foo(x, y)\n"
254
- " out1 = 2*x;\n"
255
- " out2 = 3*y;\n"
256
- "end\n"
257
- "function [out1, out2] = bar(y)\n"
258
- " out1 = y.^2;\n"
259
- " out2 = 4*y;\n"
260
- "end\n"
261
- )
262
- assert source == expected
263
-
264
-
265
- def test_m_multifcns_per_file_w_header():
266
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
267
- result = codegen(name_expr, "Octave", header=True, empty=False)
268
- assert result[0][0] == "foo.m"
269
- source = result[0][1]
270
- expected = (
271
- "function [out1, out2] = foo(x, y)\n"
272
- " %FOO Autogenerated by SymPy\n"
273
- " % Code generated with SymPy " + sympy.__version__ + "\n"
274
- " %\n"
275
- " % See http://www.sympy.org/ for more information.\n"
276
- " %\n"
277
- " % This file is part of 'project'\n"
278
- " out1 = 2*x;\n"
279
- " out2 = 3*y;\n"
280
- "end\n"
281
- "function [out1, out2] = bar(y)\n"
282
- " out1 = y.^2;\n"
283
- " out2 = 4*y;\n"
284
- "end\n"
285
- )
286
- assert source == expected
287
-
288
-
289
- def test_m_filename_match_first_fcn():
290
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
291
- raises(ValueError, lambda: codegen(name_expr,
292
- "Octave", prefix="bar", header=False, empty=False))
293
-
294
-
295
- def test_m_matrix_named():
296
- e2 = Matrix([[x, 2*y, pi*z]])
297
- name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
298
- result = codegen(name_expr, "Octave", header=False, empty=False)
299
- assert result[0][0] == "test.m"
300
- source = result[0][1]
301
- expected = (
302
- "function myout1 = test(x, y, z)\n"
303
- " myout1 = [x 2*y pi*z];\n"
304
- "end\n"
305
- )
306
- assert source == expected
307
-
308
-
309
- def test_m_matrix_named_matsym():
310
- myout1 = MatrixSymbol('myout1', 1, 3)
311
- e2 = Matrix([[x, 2*y, pi*z]])
312
- name_expr = ("test", Equality(myout1, e2, evaluate=False))
313
- result, = codegen(name_expr, "Octave", header=False, empty=False)
314
- source = result[1]
315
- expected = (
316
- "function myout1 = test(x, y, z)\n"
317
- " myout1 = [x 2*y pi*z];\n"
318
- "end\n"
319
- )
320
- assert source == expected
321
-
322
-
323
- def test_m_matrix_output_autoname():
324
- expr = Matrix([[x, x+y, 3]])
325
- name_expr = ("test", expr)
326
- result, = codegen(name_expr, "Octave", header=False, empty=False)
327
- source = result[1]
328
- expected = (
329
- "function out1 = test(x, y)\n"
330
- " out1 = [x x + y 3];\n"
331
- "end\n"
332
- )
333
- assert source == expected
334
-
335
-
336
- def test_m_matrix_output_autoname_2():
337
- e1 = (x + y)
338
- e2 = Matrix([[2*x, 2*y, 2*z]])
339
- e3 = Matrix([[x], [y], [z]])
340
- e4 = Matrix([[x, y], [z, 16]])
341
- name_expr = ("test", (e1, e2, e3, e4))
342
- result, = codegen(name_expr, "Octave", header=False, empty=False)
343
- source = result[1]
344
- expected = (
345
- "function [out1, out2, out3, out4] = test(x, y, z)\n"
346
- " out1 = x + y;\n"
347
- " out2 = [2*x 2*y 2*z];\n"
348
- " out3 = [x; y; z];\n"
349
- " out4 = [x y; z 16];\n"
350
- "end\n"
351
- )
352
- assert source == expected
353
-
354
-
355
- def test_m_results_matrix_named_ordered():
356
- B, C = symbols('B,C')
357
- A = MatrixSymbol('A', 1, 3)
358
- expr1 = Equality(C, (x + y)*z)
359
- expr2 = Equality(A, Matrix([[1, 2, x]]))
360
- expr3 = Equality(B, 2*x)
361
- name_expr = ("test", [expr1, expr2, expr3])
362
- result, = codegen(name_expr, "Octave", header=False, empty=False,
363
- argument_sequence=(x, z, y))
364
- source = result[1]
365
- expected = (
366
- "function [C, A, B] = test(x, z, y)\n"
367
- " C = z.*(x + y);\n"
368
- " A = [1 2 x];\n"
369
- " B = 2*x;\n"
370
- "end\n"
371
- )
372
- assert source == expected
373
-
374
-
375
- def test_m_matrixsymbol_slice():
376
- A = MatrixSymbol('A', 2, 3)
377
- B = MatrixSymbol('B', 1, 3)
378
- C = MatrixSymbol('C', 1, 3)
379
- D = MatrixSymbol('D', 2, 1)
380
- name_expr = ("test", [Equality(B, A[0, :]),
381
- Equality(C, A[1, :]),
382
- Equality(D, A[:, 2])])
383
- result, = codegen(name_expr, "Octave", header=False, empty=False)
384
- source = result[1]
385
- expected = (
386
- "function [B, C, D] = test(A)\n"
387
- " B = A(1, :);\n"
388
- " C = A(2, :);\n"
389
- " D = A(:, 3);\n"
390
- "end\n"
391
- )
392
- assert source == expected
393
-
394
-
395
- def test_m_matrixsymbol_slice2():
396
- A = MatrixSymbol('A', 3, 4)
397
- B = MatrixSymbol('B', 2, 2)
398
- C = MatrixSymbol('C', 2, 2)
399
- name_expr = ("test", [Equality(B, A[0:2, 0:2]),
400
- Equality(C, A[0:2, 1:3])])
401
- result, = codegen(name_expr, "Octave", header=False, empty=False)
402
- source = result[1]
403
- expected = (
404
- "function [B, C] = test(A)\n"
405
- " B = A(1:2, 1:2);\n"
406
- " C = A(1:2, 2:3);\n"
407
- "end\n"
408
- )
409
- assert source == expected
410
-
411
-
412
- def test_m_matrixsymbol_slice3():
413
- A = MatrixSymbol('A', 8, 7)
414
- B = MatrixSymbol('B', 2, 2)
415
- C = MatrixSymbol('C', 4, 2)
416
- name_expr = ("test", [Equality(B, A[6:, 1::3]),
417
- Equality(C, A[::2, ::3])])
418
- result, = codegen(name_expr, "Octave", header=False, empty=False)
419
- source = result[1]
420
- expected = (
421
- "function [B, C] = test(A)\n"
422
- " B = A(7:end, 2:3:end);\n"
423
- " C = A(1:2:end, 1:3:end);\n"
424
- "end\n"
425
- )
426
- assert source == expected
427
-
428
-
429
- def test_m_matrixsymbol_slice_autoname():
430
- A = MatrixSymbol('A', 2, 3)
431
- B = MatrixSymbol('B', 1, 3)
432
- name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
433
- result, = codegen(name_expr, "Octave", header=False, empty=False)
434
- source = result[1]
435
- expected = (
436
- "function [B, out2, out3, out4] = test(A)\n"
437
- " B = A(1, :);\n"
438
- " out2 = A(2, :);\n"
439
- " out3 = A(:, 1);\n"
440
- " out4 = A(:, 2);\n"
441
- "end\n"
442
- )
443
- assert source == expected
444
-
445
-
446
- def test_m_loops():
447
- # Note: an Octave programmer would probably vectorize this across one or
448
- # more dimensions. Also, size(A) would be used rather than passing in m
449
- # and n. Perhaps users would expect us to vectorize automatically here?
450
- # Or is it possible to represent such things using IndexedBase?
451
- from sympy.tensor import IndexedBase, Idx
452
- from sympy.core.symbol import symbols
453
- n, m = symbols('n m', integer=True)
454
- A = IndexedBase('A')
455
- x = IndexedBase('x')
456
- y = IndexedBase('y')
457
- i = Idx('i', m)
458
- j = Idx('j', n)
459
- result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Octave",
460
- header=False, empty=False)
461
- source = result[1]
462
- expected = (
463
- 'function y = mat_vec_mult(A, m, n, x)\n'
464
- ' for i = 1:m\n'
465
- ' y(i) = 0;\n'
466
- ' end\n'
467
- ' for i = 1:m\n'
468
- ' for j = 1:n\n'
469
- ' y(i) = %(rhs)s + y(i);\n'
470
- ' end\n'
471
- ' end\n'
472
- 'end\n'
473
- )
474
- assert (source == expected % {'rhs': 'A(%s, %s).*x(j)' % (i, j)} or
475
- source == expected % {'rhs': 'x(j).*A(%s, %s)' % (i, j)})
476
-
477
-
478
- def test_m_tensor_loops_multiple_contractions():
479
- # see comments in previous test about vectorizing
480
- from sympy.tensor import IndexedBase, Idx
481
- from sympy.core.symbol import symbols
482
- n, m, o, p = symbols('n m o p', integer=True)
483
- A = IndexedBase('A')
484
- B = IndexedBase('B')
485
- y = IndexedBase('y')
486
- i = Idx('i', m)
487
- j = Idx('j', n)
488
- k = Idx('k', o)
489
- l = Idx('l', p)
490
- result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
491
- "Octave", header=False, empty=False)
492
- source = result[1]
493
- expected = (
494
- 'function y = tensorthing(A, B, m, n, o, p)\n'
495
- ' for i = 1:m\n'
496
- ' y(i) = 0;\n'
497
- ' end\n'
498
- ' for i = 1:m\n'
499
- ' for j = 1:n\n'
500
- ' for k = 1:o\n'
501
- ' for l = 1:p\n'
502
- ' y(i) = A(i, j, k, l).*B(j, k, l) + y(i);\n'
503
- ' end\n'
504
- ' end\n'
505
- ' end\n'
506
- ' end\n'
507
- 'end\n'
508
- )
509
- assert source == expected
510
-
511
-
512
- def test_m_InOutArgument():
513
- expr = Equality(x, x**2)
514
- name_expr = ("mysqr", expr)
515
- result, = codegen(name_expr, "Octave", header=False, empty=False)
516
- source = result[1]
517
- expected = (
518
- "function x = mysqr(x)\n"
519
- " x = x.^2;\n"
520
- "end\n"
521
- )
522
- assert source == expected
523
-
524
-
525
- def test_m_InOutArgument_order():
526
- # can specify the order as (x, y)
527
- expr = Equality(x, x**2 + y)
528
- name_expr = ("test", expr)
529
- result, = codegen(name_expr, "Octave", header=False,
530
- empty=False, argument_sequence=(x,y))
531
- source = result[1]
532
- expected = (
533
- "function x = test(x, y)\n"
534
- " x = x.^2 + y;\n"
535
- "end\n"
536
- )
537
- assert source == expected
538
- # make sure it gives (x, y) not (y, x)
539
- expr = Equality(x, x**2 + y)
540
- name_expr = ("test", expr)
541
- result, = codegen(name_expr, "Octave", header=False, empty=False)
542
- source = result[1]
543
- expected = (
544
- "function x = test(x, y)\n"
545
- " x = x.^2 + y;\n"
546
- "end\n"
547
- )
548
- assert source == expected
549
-
550
-
551
- def test_m_not_supported():
552
- f = Function('f')
553
- name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
554
- result, = codegen(name_expr, "Octave", header=False, empty=False)
555
- source = result[1]
556
- expected = (
557
- "function [out1, out2] = test(x)\n"
558
- " % unsupported: Derivative(f(x), x)\n"
559
- " % unsupported: zoo\n"
560
- " out1 = Derivative(f(x), x);\n"
561
- " out2 = zoo;\n"
562
- "end\n"
563
- )
564
- assert source == expected
565
-
566
-
567
- def test_global_vars_octave():
568
- x, y, z, t = symbols("x y z t")
569
- result = codegen(('f', x*y), "Octave", header=False, empty=False,
570
- global_vars=(y,))
571
- source = result[0][1]
572
- expected = (
573
- "function out1 = f(x)\n"
574
- " global y\n"
575
- " out1 = x.*y;\n"
576
- "end\n"
577
- )
578
- assert source == expected
579
-
580
- result = codegen(('f', x*y+z), "Octave", header=False, empty=False,
581
- argument_sequence=(x, y), global_vars=(z, t))
582
- source = result[0][1]
583
- expected = (
584
- "function out1 = f(x, y)\n"
585
- " global t z\n"
586
- " out1 = x.*y + z;\n"
587
- "end\n"
588
- )
589
- assert source == expected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_rust.py DELETED
@@ -1,401 +0,0 @@
1
- from io import StringIO
2
-
3
- from sympy.core import S, symbols, pi, Catalan, EulerGamma, Function
4
- from sympy.core.relational import Equality
5
- from sympy.functions.elementary.piecewise import Piecewise
6
- from sympy.utilities.codegen import RustCodeGen, codegen, make_routine
7
- from sympy.testing.pytest import XFAIL
8
- import sympy
9
-
10
-
11
- x, y, z = symbols('x,y,z')
12
-
13
-
14
- def test_empty_rust_code():
15
- code_gen = RustCodeGen()
16
- output = StringIO()
17
- code_gen.dump_rs([], output, "file", header=False, empty=False)
18
- source = output.getvalue()
19
- assert source == ""
20
-
21
-
22
- def test_simple_rust_code():
23
- name_expr = ("test", (x + y)*z)
24
- result, = codegen(name_expr, "Rust", header=False, empty=False)
25
- assert result[0] == "test.rs"
26
- source = result[1]
27
- expected = (
28
- "fn test(x: f64, y: f64, z: f64) -> f64 {\n"
29
- " let out1 = z*(x + y);\n"
30
- " out1\n"
31
- "}\n"
32
- )
33
- assert source == expected
34
-
35
-
36
- def test_simple_code_with_header():
37
- name_expr = ("test", (x + y)*z)
38
- result, = codegen(name_expr, "Rust", header=True, empty=False)
39
- assert result[0] == "test.rs"
40
- source = result[1]
41
- version_str = "Code generated with SymPy %s" % sympy.__version__
42
- version_line = version_str.center(76).rstrip()
43
- expected = (
44
- "/*\n"
45
- " *%(version_line)s\n"
46
- " *\n"
47
- " * See http://www.sympy.org/ for more information.\n"
48
- " *\n"
49
- " * This file is part of 'project'\n"
50
- " */\n"
51
- "fn test(x: f64, y: f64, z: f64) -> f64 {\n"
52
- " let out1 = z*(x + y);\n"
53
- " out1\n"
54
- "}\n"
55
- ) % {'version_line': version_line}
56
- assert source == expected
57
-
58
-
59
- def test_simple_code_nameout():
60
- expr = Equality(z, (x + y))
61
- name_expr = ("test", expr)
62
- result, = codegen(name_expr, "Rust", header=False, empty=False)
63
- source = result[1]
64
- expected = (
65
- "fn test(x: f64, y: f64) -> f64 {\n"
66
- " let z = x + y;\n"
67
- " z\n"
68
- "}\n"
69
- )
70
- assert source == expected
71
-
72
-
73
- def test_numbersymbol():
74
- name_expr = ("test", pi**Catalan)
75
- result, = codegen(name_expr, "Rust", header=False, empty=False)
76
- source = result[1]
77
- expected = (
78
- "fn test() -> f64 {\n"
79
- " const Catalan: f64 = %s;\n"
80
- " let out1 = PI.powf(Catalan);\n"
81
- " out1\n"
82
- "}\n"
83
- ) % Catalan.evalf(17)
84
- assert source == expected
85
-
86
-
87
- @XFAIL
88
- def test_numbersymbol_inline():
89
- # FIXME: how to pass inline to the RustCodePrinter?
90
- name_expr = ("test", [pi**Catalan, EulerGamma])
91
- result, = codegen(name_expr, "Rust", header=False,
92
- empty=False, inline=True)
93
- source = result[1]
94
- expected = (
95
- "fn test() -> (f64, f64) {\n"
96
- " const Catalan: f64 = %s;\n"
97
- " const EulerGamma: f64 = %s;\n"
98
- " let out1 = PI.powf(Catalan);\n"
99
- " let out2 = EulerGamma);\n"
100
- " (out1, out2)\n"
101
- "}\n"
102
- ) % (Catalan.evalf(17), EulerGamma.evalf(17))
103
- assert source == expected
104
-
105
-
106
- def test_argument_order():
107
- expr = x + y
108
- routine = make_routine("test", expr, argument_sequence=[z, x, y], language="rust")
109
- code_gen = RustCodeGen()
110
- output = StringIO()
111
- code_gen.dump_rs([routine], output, "test", header=False, empty=False)
112
- source = output.getvalue()
113
- expected = (
114
- "fn test(z: f64, x: f64, y: f64) -> f64 {\n"
115
- " let out1 = x + y;\n"
116
- " out1\n"
117
- "}\n"
118
- )
119
- assert source == expected
120
-
121
-
122
- def test_multiple_results_rust():
123
- # Here the output order is the input order
124
- expr1 = (x + y)*z
125
- expr2 = (x - y)*z
126
- name_expr = ("test", [expr1, expr2])
127
- result, = codegen(name_expr, "Rust", header=False, empty=False)
128
- source = result[1]
129
- expected = (
130
- "fn test(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
131
- " let out1 = z*(x + y);\n"
132
- " let out2 = z*(x - y);\n"
133
- " (out1, out2)\n"
134
- "}\n"
135
- )
136
- assert source == expected
137
-
138
-
139
- def test_results_named_unordered():
140
- # Here output order is based on name_expr
141
- A, B, C = symbols('A,B,C')
142
- expr1 = Equality(C, (x + y)*z)
143
- expr2 = Equality(A, (x - y)*z)
144
- expr3 = Equality(B, 2*x)
145
- name_expr = ("test", [expr1, expr2, expr3])
146
- result, = codegen(name_expr, "Rust", header=False, empty=False)
147
- source = result[1]
148
- expected = (
149
- "fn test(x: f64, y: f64, z: f64) -> (f64, f64, f64) {\n"
150
- " let C = z*(x + y);\n"
151
- " let A = z*(x - y);\n"
152
- " let B = 2*x;\n"
153
- " (C, A, B)\n"
154
- "}\n"
155
- )
156
- assert source == expected
157
-
158
-
159
- def test_results_named_ordered():
160
- A, B, C = symbols('A,B,C')
161
- expr1 = Equality(C, (x + y)*z)
162
- expr2 = Equality(A, (x - y)*z)
163
- expr3 = Equality(B, 2*x)
164
- name_expr = ("test", [expr1, expr2, expr3])
165
- result = codegen(name_expr, "Rust", header=False, empty=False,
166
- argument_sequence=(x, z, y))
167
- assert result[0][0] == "test.rs"
168
- source = result[0][1]
169
- expected = (
170
- "fn test(x: f64, z: f64, y: f64) -> (f64, f64, f64) {\n"
171
- " let C = z*(x + y);\n"
172
- " let A = z*(x - y);\n"
173
- " let B = 2*x;\n"
174
- " (C, A, B)\n"
175
- "}\n"
176
- )
177
- assert source == expected
178
-
179
-
180
- def test_complicated_rs_codegen():
181
- from sympy.functions.elementary.trigonometric import (cos, sin, tan)
182
- name_expr = ("testlong",
183
- [ ((sin(x) + cos(y) + tan(z))**3).expand(),
184
- cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
185
- ])
186
- result = codegen(name_expr, "Rust", header=False, empty=False)
187
- assert result[0][0] == "testlong.rs"
188
- source = result[0][1]
189
- expected = (
190
- "fn testlong(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
191
- " let out1 = x.sin().powi(3) + 3*x.sin().powi(2)*y.cos()"
192
- " + 3*x.sin().powi(2)*z.tan() + 3*x.sin()*y.cos().powi(2)"
193
- " + 6*x.sin()*y.cos()*z.tan() + 3*x.sin()*z.tan().powi(2)"
194
- " + y.cos().powi(3) + 3*y.cos().powi(2)*z.tan()"
195
- " + 3*y.cos()*z.tan().powi(2) + z.tan().powi(3);\n"
196
- " let out2 = (x + y + z).cos().cos().cos().cos()"
197
- ".cos().cos().cos().cos();\n"
198
- " (out1, out2)\n"
199
- "}\n"
200
- )
201
- assert source == expected
202
-
203
-
204
- def test_output_arg_mixed_unordered():
205
- # named outputs are alphabetical, unnamed output appear in the given order
206
- from sympy.functions.elementary.trigonometric import (cos, sin)
207
- a = symbols("a")
208
- name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
209
- result, = codegen(name_expr, "Rust", header=False, empty=False)
210
- assert result[0] == "foo.rs"
211
- source = result[1]
212
- expected = (
213
- "fn foo(x: f64) -> (f64, f64, f64, f64) {\n"
214
- " let out1 = (2*x).cos();\n"
215
- " let y = x.sin();\n"
216
- " let out3 = x.cos();\n"
217
- " let a = (2*x).sin();\n"
218
- " (out1, y, out3, a)\n"
219
- "}\n"
220
- )
221
- assert source == expected
222
-
223
-
224
- def test_piecewise_():
225
- pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
226
- name_expr = ("pwtest", pw)
227
- result, = codegen(name_expr, "Rust", header=False, empty=False)
228
- source = result[1]
229
- expected = (
230
- "fn pwtest(x: f64) -> f64 {\n"
231
- " let out1 = if (x < -1.0) {\n"
232
- " 0\n"
233
- " } else if (x <= 1.0) {\n"
234
- " x.powi(2)\n"
235
- " } else if (x > 1.0) {\n"
236
- " 2 - x\n"
237
- " } else {\n"
238
- " 1\n"
239
- " };\n"
240
- " out1\n"
241
- "}\n"
242
- )
243
- assert source == expected
244
-
245
-
246
- @XFAIL
247
- def test_piecewise_inline():
248
- # FIXME: how to pass inline to the RustCodePrinter?
249
- pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
250
- name_expr = ("pwtest", pw)
251
- result, = codegen(name_expr, "Rust", header=False, empty=False,
252
- inline=True)
253
- source = result[1]
254
- expected = (
255
- "fn pwtest(x: f64) -> f64 {\n"
256
- " let out1 = if (x < -1) { 0 } else if (x <= 1) { x.powi(2) }"
257
- " else if (x > 1) { -x + 2 } else { 1 };\n"
258
- " out1\n"
259
- "}\n"
260
- )
261
- assert source == expected
262
-
263
-
264
- def test_multifcns_per_file():
265
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
266
- result = codegen(name_expr, "Rust", header=False, empty=False)
267
- assert result[0][0] == "foo.rs"
268
- source = result[0][1]
269
- expected = (
270
- "fn foo(x: f64, y: f64) -> (f64, f64) {\n"
271
- " let out1 = 2*x;\n"
272
- " let out2 = 3*y;\n"
273
- " (out1, out2)\n"
274
- "}\n"
275
- "fn bar(y: f64) -> (f64, f64) {\n"
276
- " let out1 = y.powi(2);\n"
277
- " let out2 = 4*y;\n"
278
- " (out1, out2)\n"
279
- "}\n"
280
- )
281
- assert source == expected
282
-
283
-
284
- def test_multifcns_per_file_w_header():
285
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
286
- result = codegen(name_expr, "Rust", header=True, empty=False)
287
- assert result[0][0] == "foo.rs"
288
- source = result[0][1]
289
- version_str = "Code generated with SymPy %s" % sympy.__version__
290
- version_line = version_str.center(76).rstrip()
291
- expected = (
292
- "/*\n"
293
- " *%(version_line)s\n"
294
- " *\n"
295
- " * See http://www.sympy.org/ for more information.\n"
296
- " *\n"
297
- " * This file is part of 'project'\n"
298
- " */\n"
299
- "fn foo(x: f64, y: f64) -> (f64, f64) {\n"
300
- " let out1 = 2*x;\n"
301
- " let out2 = 3*y;\n"
302
- " (out1, out2)\n"
303
- "}\n"
304
- "fn bar(y: f64) -> (f64, f64) {\n"
305
- " let out1 = y.powi(2);\n"
306
- " let out2 = 4*y;\n"
307
- " (out1, out2)\n"
308
- "}\n"
309
- ) % {'version_line': version_line}
310
- assert source == expected
311
-
312
-
313
- def test_filename_match_prefix():
314
- name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
315
- result, = codegen(name_expr, "Rust", prefix="baz", header=False,
316
- empty=False)
317
- assert result[0] == "baz.rs"
318
-
319
-
320
- def test_InOutArgument():
321
- expr = Equality(x, x**2)
322
- name_expr = ("mysqr", expr)
323
- result, = codegen(name_expr, "Rust", header=False, empty=False)
324
- source = result[1]
325
- expected = (
326
- "fn mysqr(x: f64) -> f64 {\n"
327
- " let x = x.powi(2);\n"
328
- " x\n"
329
- "}\n"
330
- )
331
- assert source == expected
332
-
333
-
334
- def test_InOutArgument_order():
335
- # can specify the order as (x, y)
336
- expr = Equality(x, x**2 + y)
337
- name_expr = ("test", expr)
338
- result, = codegen(name_expr, "Rust", header=False,
339
- empty=False, argument_sequence=(x,y))
340
- source = result[1]
341
- expected = (
342
- "fn test(x: f64, y: f64) -> f64 {\n"
343
- " let x = x.powi(2) + y;\n"
344
- " x\n"
345
- "}\n"
346
- )
347
- assert source == expected
348
- # make sure it gives (x, y) not (y, x)
349
- expr = Equality(x, x**2 + y)
350
- name_expr = ("test", expr)
351
- result, = codegen(name_expr, "Rust", header=False, empty=False)
352
- source = result[1]
353
- expected = (
354
- "fn test(x: f64, y: f64) -> f64 {\n"
355
- " let x = x.powi(2) + y;\n"
356
- " x\n"
357
- "}\n"
358
- )
359
- assert source == expected
360
-
361
-
362
- def test_not_supported():
363
- f = Function('f')
364
- name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
365
- result, = codegen(name_expr, "Rust", header=False, empty=False)
366
- source = result[1]
367
- expected = (
368
- "fn test(x: f64) -> (f64, f64) {\n"
369
- " // unsupported: Derivative(f(x), x)\n"
370
- " // unsupported: zoo\n"
371
- " let out1 = Derivative(f(x), x);\n"
372
- " let out2 = zoo;\n"
373
- " (out1, out2)\n"
374
- "}\n"
375
- )
376
- assert source == expected
377
-
378
-
379
- def test_global_vars_rust():
380
- x, y, z, t = symbols("x y z t")
381
- result = codegen(('f', x*y), "Rust", header=False, empty=False,
382
- global_vars=(y,))
383
- source = result[0][1]
384
- expected = (
385
- "fn f(x: f64) -> f64 {\n"
386
- " let out1 = x*y;\n"
387
- " out1\n"
388
- "}\n"
389
- )
390
- assert source == expected
391
-
392
- result = codegen(('f', x*y+z), "Rust", header=False, empty=False,
393
- argument_sequence=(x, y), global_vars=(z, t))
394
- source = result[0][1]
395
- expected = (
396
- "fn f(x: f64, y: f64) -> f64 {\n"
397
- " let out1 = x*y + z;\n"
398
- " out1\n"
399
- "}\n"
400
- )
401
- assert source == expected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_decorator.py DELETED
@@ -1,129 +0,0 @@
1
- from functools import wraps
2
-
3
- from sympy.utilities.decorator import threaded, xthreaded, memoize_property, deprecated
4
- from sympy.testing.pytest import warns_deprecated_sympy
5
-
6
- from sympy.core.basic import Basic
7
- from sympy.core.relational import Eq
8
- from sympy.matrices.dense import Matrix
9
-
10
- from sympy.abc import x, y
11
-
12
-
13
- def test_threaded():
14
- @threaded
15
- def function(expr, *args):
16
- return 2*expr + sum(args)
17
-
18
- assert function(Matrix([[x, y], [1, x]]), 1, 2) == \
19
- Matrix([[2*x + 3, 2*y + 3], [5, 2*x + 3]])
20
-
21
- assert function(Eq(x, y), 1, 2) == Eq(2*x + 3, 2*y + 3)
22
-
23
- assert function([x, y], 1, 2) == [2*x + 3, 2*y + 3]
24
- assert function((x, y), 1, 2) == (2*x + 3, 2*y + 3)
25
-
26
- assert function({x, y}, 1, 2) == {2*x + 3, 2*y + 3}
27
-
28
- @threaded
29
- def function(expr, n):
30
- return expr**n
31
-
32
- assert function(x + y, 2) == x**2 + y**2
33
- assert function(x, 2) == x**2
34
-
35
-
36
- def test_xthreaded():
37
- @xthreaded
38
- def function(expr, n):
39
- return expr**n
40
-
41
- assert function(x + y, 2) == (x + y)**2
42
-
43
-
44
- def test_wraps():
45
- def my_func(x):
46
- """My function. """
47
-
48
- my_func.is_my_func = True
49
-
50
- new_my_func = threaded(my_func)
51
- new_my_func = wraps(my_func)(new_my_func)
52
-
53
- assert new_my_func.__name__ == 'my_func'
54
- assert new_my_func.__doc__ == 'My function. '
55
- assert hasattr(new_my_func, 'is_my_func')
56
- assert new_my_func.is_my_func is True
57
-
58
-
59
- def test_memoize_property():
60
- class TestMemoize(Basic):
61
- @memoize_property
62
- def prop(self):
63
- return Basic()
64
-
65
- member = TestMemoize()
66
- obj1 = member.prop
67
- obj2 = member.prop
68
- assert obj1 is obj2
69
-
70
- def test_deprecated():
71
- @deprecated('deprecated_function is deprecated',
72
- deprecated_since_version='1.10',
73
- # This is the target at the top of the file, which will never
74
- # go away.
75
- active_deprecations_target='active-deprecations')
76
- def deprecated_function(x):
77
- return x
78
-
79
- with warns_deprecated_sympy():
80
- assert deprecated_function(1) == 1
81
-
82
- @deprecated('deprecated_class is deprecated',
83
- deprecated_since_version='1.10',
84
- active_deprecations_target='active-deprecations')
85
- class deprecated_class:
86
- pass
87
-
88
- with warns_deprecated_sympy():
89
- assert isinstance(deprecated_class(), deprecated_class)
90
-
91
- # Ensure the class decorator works even when the class never returns
92
- # itself
93
- @deprecated('deprecated_class_new is deprecated',
94
- deprecated_since_version='1.10',
95
- active_deprecations_target='active-deprecations')
96
- class deprecated_class_new:
97
- def __new__(cls, arg):
98
- return arg
99
-
100
- with warns_deprecated_sympy():
101
- assert deprecated_class_new(1) == 1
102
-
103
- @deprecated('deprecated_class_init is deprecated',
104
- deprecated_since_version='1.10',
105
- active_deprecations_target='active-deprecations')
106
- class deprecated_class_init:
107
- def __init__(self, arg):
108
- self.arg = 1
109
-
110
- with warns_deprecated_sympy():
111
- assert deprecated_class_init(1).arg == 1
112
-
113
- @deprecated('deprecated_class_new_init is deprecated',
114
- deprecated_since_version='1.10',
115
- active_deprecations_target='active-deprecations')
116
- class deprecated_class_new_init:
117
- def __new__(cls, arg):
118
- if arg == 0:
119
- return arg
120
- return object.__new__(cls)
121
-
122
- def __init__(self, arg):
123
- self.arg = 1
124
-
125
- with warns_deprecated_sympy():
126
- assert deprecated_class_new_init(0) == 0
127
-
128
- with warns_deprecated_sympy():
129
- assert deprecated_class_new_init(1).arg == 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_deprecated.py DELETED
@@ -1,13 +0,0 @@
1
- from sympy.testing.pytest import warns_deprecated_sympy
2
-
3
- # See https://github.com/sympy/sympy/pull/18095
4
-
5
- def test_deprecated_utilities():
6
- with warns_deprecated_sympy():
7
- import sympy.utilities.pytest # noqa:F401
8
- with warns_deprecated_sympy():
9
- import sympy.utilities.runtests # noqa:F401
10
- with warns_deprecated_sympy():
11
- import sympy.utilities.randtest # noqa:F401
12
- with warns_deprecated_sympy():
13
- import sympy.utilities.tmpfiles # noqa:F401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_enumerative.py DELETED
@@ -1,179 +0,0 @@
1
- import string
2
- from itertools import zip_longest
3
-
4
- from sympy.utilities.enumerative import (
5
- list_visitor,
6
- MultisetPartitionTraverser,
7
- multiset_partitions_taocp
8
- )
9
- from sympy.utilities.iterables import _set_partitions
10
-
11
- # first some functions only useful as test scaffolding - these provide
12
- # straightforward, but slow reference implementations against which to
13
- # compare the real versions, and also a comparison to verify that
14
- # different versions are giving identical results.
15
-
16
- def part_range_filter(partition_iterator, lb, ub):
17
- """
18
- Filters (on the number of parts) a multiset partition enumeration
19
-
20
- Arguments
21
- =========
22
-
23
- lb, and ub are a range (in the Python slice sense) on the lpart
24
- variable returned from a multiset partition enumeration. Recall
25
- that lpart is 0-based (it points to the topmost part on the part
26
- stack), so if you want to return parts of sizes 2,3,4,5 you would
27
- use lb=1 and ub=5.
28
- """
29
- for state in partition_iterator:
30
- f, lpart, pstack = state
31
- if lpart >= lb and lpart < ub:
32
- yield state
33
-
34
- def multiset_partitions_baseline(multiplicities, components):
35
- """Enumerates partitions of a multiset
36
-
37
- Parameters
38
- ==========
39
-
40
- multiplicities
41
- list of integer multiplicities of the components of the multiset.
42
-
43
- components
44
- the components (elements) themselves
45
-
46
- Returns
47
- =======
48
-
49
- Set of partitions. Each partition is tuple of parts, and each
50
- part is a tuple of components (with repeats to indicate
51
- multiplicity)
52
-
53
- Notes
54
- =====
55
-
56
- Multiset partitions can be created as equivalence classes of set
57
- partitions, and this function does just that. This approach is
58
- slow and memory intensive compared to the more advanced algorithms
59
- available, but the code is simple and easy to understand. Hence
60
- this routine is strictly for testing -- to provide a
61
- straightforward baseline against which to regress the production
62
- versions. (This code is a simplified version of an earlier
63
- production implementation.)
64
- """
65
-
66
- canon = [] # list of components with repeats
67
- for ct, elem in zip(multiplicities, components):
68
- canon.extend([elem]*ct)
69
-
70
- # accumulate the multiset partitions in a set to eliminate dups
71
- cache = set()
72
- n = len(canon)
73
- for nc, q in _set_partitions(n):
74
- rv = [[] for i in range(nc)]
75
- for i in range(n):
76
- rv[q[i]].append(canon[i])
77
- canonical = tuple(
78
- sorted([tuple(p) for p in rv]))
79
- cache.add(canonical)
80
- return cache
81
-
82
-
83
- def compare_multiset_w_baseline(multiplicities):
84
- """
85
- Enumerates the partitions of multiset with AOCP algorithm and
86
- baseline implementation, and compare the results.
87
-
88
- """
89
- letters = string.ascii_lowercase
90
- bl_partitions = multiset_partitions_baseline(multiplicities, letters)
91
-
92
- # The partitions returned by the different algorithms may have
93
- # their parts in different orders. Also, they generate partitions
94
- # in different orders. Hence the sorting, and set comparison.
95
-
96
- aocp_partitions = set()
97
- for state in multiset_partitions_taocp(multiplicities):
98
- p1 = tuple(sorted(
99
- [tuple(p) for p in list_visitor(state, letters)]))
100
- aocp_partitions.add(p1)
101
-
102
- assert bl_partitions == aocp_partitions
103
-
104
- def compare_multiset_states(s1, s2):
105
- """compare for equality two instances of multiset partition states
106
-
107
- This is useful for comparing different versions of the algorithm
108
- to verify correctness."""
109
- # Comparison is physical, the only use of semantics is to ignore
110
- # trash off the top of the stack.
111
- f1, lpart1, pstack1 = s1
112
- f2, lpart2, pstack2 = s2
113
-
114
- if (lpart1 == lpart2) and (f1[0:lpart1+1] == f2[0:lpart2+1]):
115
- if pstack1[0:f1[lpart1+1]] == pstack2[0:f2[lpart2+1]]:
116
- return True
117
- return False
118
-
119
- def test_multiset_partitions_taocp():
120
- """Compares the output of multiset_partitions_taocp with a baseline
121
- (set partition based) implementation."""
122
-
123
- # Test cases should not be too large, since the baseline
124
- # implementation is fairly slow.
125
- multiplicities = [2,2]
126
- compare_multiset_w_baseline(multiplicities)
127
-
128
- multiplicities = [4,3,1]
129
- compare_multiset_w_baseline(multiplicities)
130
-
131
- def test_multiset_partitions_versions():
132
- """Compares Knuth-based versions of multiset_partitions"""
133
- multiplicities = [5,2,2,1]
134
- m = MultisetPartitionTraverser()
135
- for s1, s2 in zip_longest(m.enum_all(multiplicities),
136
- multiset_partitions_taocp(multiplicities)):
137
- assert compare_multiset_states(s1, s2)
138
-
139
- def subrange_exercise(mult, lb, ub):
140
- """Compare filter-based and more optimized subrange implementations
141
-
142
- Helper for tests, called with both small and larger multisets.
143
- """
144
- m = MultisetPartitionTraverser()
145
- assert m.count_partitions(mult) == \
146
- m.count_partitions_slow(mult)
147
-
148
- # Note - multiple traversals from the same
149
- # MultisetPartitionTraverser object cannot execute at the same
150
- # time, hence make several instances here.
151
- ma = MultisetPartitionTraverser()
152
- mc = MultisetPartitionTraverser()
153
- md = MultisetPartitionTraverser()
154
-
155
- # Several paths to compute just the size two partitions
156
- a_it = ma.enum_range(mult, lb, ub)
157
- b_it = part_range_filter(multiset_partitions_taocp(mult), lb, ub)
158
- c_it = part_range_filter(mc.enum_small(mult, ub), lb, sum(mult))
159
- d_it = part_range_filter(md.enum_large(mult, lb), 0, ub)
160
-
161
- for sa, sb, sc, sd in zip_longest(a_it, b_it, c_it, d_it):
162
- assert compare_multiset_states(sa, sb)
163
- assert compare_multiset_states(sa, sc)
164
- assert compare_multiset_states(sa, sd)
165
-
166
- def test_subrange():
167
- # Quick, but doesn't hit some of the corner cases
168
- mult = [4,4,2,1] # mississippi
169
- lb = 1
170
- ub = 2
171
- subrange_exercise(mult, lb, ub)
172
-
173
-
174
- def test_subrange_large():
175
- # takes a second or so, depending on cpu, Python version, etc.
176
- mult = [6,3,2,1]
177
- lb = 4
178
- ub = 7
179
- subrange_exercise(mult, lb, ub)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_exceptions.py DELETED
@@ -1,12 +0,0 @@
1
- from sympy.testing.pytest import raises
2
- from sympy.utilities.exceptions import sympy_deprecation_warning
3
-
4
- # Only test exceptions here because the other cases are tested in the
5
- # warns_deprecated_sympy tests
6
- def test_sympy_deprecation_warning():
7
- raises(TypeError, lambda: sympy_deprecation_warning('test',
8
- deprecated_since_version=1.10,
9
- active_deprecations_target='active-deprecations'))
10
-
11
- raises(ValueError, lambda: sympy_deprecation_warning('test',
12
- deprecated_since_version="1.10", active_deprecations_target='(active-deprecations)='))
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_iterables.py DELETED
@@ -1,945 +0,0 @@
1
- from textwrap import dedent
2
- from itertools import islice, product
3
-
4
- from sympy.core.basic import Basic
5
- from sympy.core.numbers import Integer
6
- from sympy.core.sorting import ordered
7
- from sympy.core.symbol import (Dummy, symbols)
8
- from sympy.functions.combinatorial.factorials import factorial
9
- from sympy.matrices.dense import Matrix
10
- from sympy.combinatorics import RGS_enum, RGS_unrank, Permutation
11
- from sympy.utilities.iterables import (
12
- _partition, _set_partitions, binary_partitions, bracelets, capture,
13
- cartes, common_prefix, common_suffix, connected_components, dict_merge,
14
- filter_symbols, flatten, generate_bell, generate_derangements,
15
- generate_involutions, generate_oriented_forest, group, has_dups, ibin,
16
- iproduct, kbins, minlex, multiset, multiset_combinations,
17
- multiset_partitions, multiset_permutations, necklaces, numbered_symbols,
18
- partitions, permutations, postfixes,
19
- prefixes, reshape, rotate_left, rotate_right, runs, sift,
20
- strongly_connected_components, subsets, take, topological_sort, unflatten,
21
- uniq, variations, ordered_partitions, rotations, is_palindromic, iterable,
22
- NotIterable, multiset_derangements, signed_permutations,
23
- sequence_partitions, sequence_partitions_empty)
24
- from sympy.utilities.enumerative import (
25
- factoring_visitor, multiset_partitions_taocp )
26
-
27
- from sympy.core.singleton import S
28
- from sympy.testing.pytest import raises, warns_deprecated_sympy
29
-
30
- w, x, y, z = symbols('w,x,y,z')
31
-
32
-
33
- def test_deprecated_iterables():
34
- from sympy.utilities.iterables import default_sort_key, ordered
35
- with warns_deprecated_sympy():
36
- assert list(ordered([y, x])) == [x, y]
37
- with warns_deprecated_sympy():
38
- assert sorted([y, x], key=default_sort_key) == [x, y]
39
-
40
-
41
- def test_is_palindromic():
42
- assert is_palindromic('')
43
- assert is_palindromic('x')
44
- assert is_palindromic('xx')
45
- assert is_palindromic('xyx')
46
- assert not is_palindromic('xy')
47
- assert not is_palindromic('xyzx')
48
- assert is_palindromic('xxyzzyx', 1)
49
- assert not is_palindromic('xxyzzyx', 2)
50
- assert is_palindromic('xxyzzyx', 2, -1)
51
- assert is_palindromic('xxyzzyx', 2, 6)
52
- assert is_palindromic('xxyzyx', 1)
53
- assert not is_palindromic('xxyzyx', 2)
54
- assert is_palindromic('xxyzyx', 2, 2 + 3)
55
-
56
-
57
- def test_flatten():
58
- assert flatten((1, (1,))) == [1, 1]
59
- assert flatten((x, (x,))) == [x, x]
60
-
61
- ls = [[(-2, -1), (1, 2)], [(0, 0)]]
62
-
63
- assert flatten(ls, levels=0) == ls
64
- assert flatten(ls, levels=1) == [(-2, -1), (1, 2), (0, 0)]
65
- assert flatten(ls, levels=2) == [-2, -1, 1, 2, 0, 0]
66
- assert flatten(ls, levels=3) == [-2, -1, 1, 2, 0, 0]
67
-
68
- raises(ValueError, lambda: flatten(ls, levels=-1))
69
-
70
- class MyOp(Basic):
71
- pass
72
-
73
- assert flatten([MyOp(x, y), z]) == [MyOp(x, y), z]
74
- assert flatten([MyOp(x, y), z], cls=MyOp) == [x, y, z]
75
-
76
- assert flatten({1, 11, 2}) == list({1, 11, 2})
77
-
78
-
79
- def test_iproduct():
80
- assert list(iproduct()) == [()]
81
- assert list(iproduct([])) == []
82
- assert list(iproduct([1,2,3])) == [(1,),(2,),(3,)]
83
- assert sorted(iproduct([1, 2], [3, 4, 5])) == [
84
- (1,3),(1,4),(1,5),(2,3),(2,4),(2,5)]
85
- assert sorted(iproduct([0,1],[0,1],[0,1])) == [
86
- (0,0,0),(0,0,1),(0,1,0),(0,1,1),(1,0,0),(1,0,1),(1,1,0),(1,1,1)]
87
- assert iterable(iproduct(S.Integers)) is True
88
- assert iterable(iproduct(S.Integers, S.Integers)) is True
89
- assert (3,) in iproduct(S.Integers)
90
- assert (4, 5) in iproduct(S.Integers, S.Integers)
91
- assert (1, 2, 3) in iproduct(S.Integers, S.Integers, S.Integers)
92
- triples = set(islice(iproduct(S.Integers, S.Integers, S.Integers), 1000))
93
- for n1, n2, n3 in triples:
94
- assert isinstance(n1, Integer)
95
- assert isinstance(n2, Integer)
96
- assert isinstance(n3, Integer)
97
- for t in set(product(*([range(-2, 3)]*3))):
98
- assert t in iproduct(S.Integers, S.Integers, S.Integers)
99
-
100
-
101
- def test_group():
102
- assert group([]) == []
103
- assert group([], multiple=False) == []
104
-
105
- assert group([1]) == [[1]]
106
- assert group([1], multiple=False) == [(1, 1)]
107
-
108
- assert group([1, 1]) == [[1, 1]]
109
- assert group([1, 1], multiple=False) == [(1, 2)]
110
-
111
- assert group([1, 1, 1]) == [[1, 1, 1]]
112
- assert group([1, 1, 1], multiple=False) == [(1, 3)]
113
-
114
- assert group([1, 2, 1]) == [[1], [2], [1]]
115
- assert group([1, 2, 1], multiple=False) == [(1, 1), (2, 1), (1, 1)]
116
-
117
- assert group([1, 1, 2, 2, 2, 1, 3, 3]) == [[1, 1], [2, 2, 2], [1], [3, 3]]
118
- assert group([1, 1, 2, 2, 2, 1, 3, 3], multiple=False) == [(1, 2),
119
- (2, 3), (1, 1), (3, 2)]
120
-
121
-
122
- def test_subsets():
123
- # combinations
124
- assert list(subsets([1, 2, 3], 0)) == [()]
125
- assert list(subsets([1, 2, 3], 1)) == [(1,), (2,), (3,)]
126
- assert list(subsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)]
127
- assert list(subsets([1, 2, 3], 3)) == [(1, 2, 3)]
128
- l = list(range(4))
129
- assert list(subsets(l, 0, repetition=True)) == [()]
130
- assert list(subsets(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
131
- assert list(subsets(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
132
- (0, 3), (1, 1), (1, 2),
133
- (1, 3), (2, 2), (2, 3),
134
- (3, 3)]
135
- assert list(subsets(l, 3, repetition=True)) == [(0, 0, 0), (0, 0, 1),
136
- (0, 0, 2), (0, 0, 3),
137
- (0, 1, 1), (0, 1, 2),
138
- (0, 1, 3), (0, 2, 2),
139
- (0, 2, 3), (0, 3, 3),
140
- (1, 1, 1), (1, 1, 2),
141
- (1, 1, 3), (1, 2, 2),
142
- (1, 2, 3), (1, 3, 3),
143
- (2, 2, 2), (2, 2, 3),
144
- (2, 3, 3), (3, 3, 3)]
145
- assert len(list(subsets(l, 4, repetition=True))) == 35
146
-
147
- assert list(subsets(l[:2], 3, repetition=False)) == []
148
- assert list(subsets(l[:2], 3, repetition=True)) == [(0, 0, 0),
149
- (0, 0, 1),
150
- (0, 1, 1),
151
- (1, 1, 1)]
152
- assert list(subsets([1, 2], repetition=True)) == \
153
- [(), (1,), (2,), (1, 1), (1, 2), (2, 2)]
154
- assert list(subsets([1, 2], repetition=False)) == \
155
- [(), (1,), (2,), (1, 2)]
156
- assert list(subsets([1, 2, 3], 2)) == \
157
- [(1, 2), (1, 3), (2, 3)]
158
- assert list(subsets([1, 2, 3], 2, repetition=True)) == \
159
- [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
160
-
161
-
162
- def test_variations():
163
- # permutations
164
- l = list(range(4))
165
- assert list(variations(l, 0, repetition=False)) == [()]
166
- assert list(variations(l, 1, repetition=False)) == [(0,), (1,), (2,), (3,)]
167
- assert list(variations(l, 2, repetition=False)) == [(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)]
168
- assert list(variations(l, 3, repetition=False)) == [(0, 1, 2), (0, 1, 3), (0, 2, 1), (0, 2, 3), (0, 3, 1), (0, 3, 2), (1, 0, 2), (1, 0, 3), (1, 2, 0), (1, 2, 3), (1, 3, 0), (1, 3, 2), (2, 0, 1), (2, 0, 3), (2, 1, 0), (2, 1, 3), (2, 3, 0), (2, 3, 1), (3, 0, 1), (3, 0, 2), (3, 1, 0), (3, 1, 2), (3, 2, 0), (3, 2, 1)]
169
- assert list(variations(l, 0, repetition=True)) == [()]
170
- assert list(variations(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
171
- assert list(variations(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
172
- (0, 3), (1, 0), (1, 1),
173
- (1, 2), (1, 3), (2, 0),
174
- (2, 1), (2, 2), (2, 3),
175
- (3, 0), (3, 1), (3, 2),
176
- (3, 3)]
177
- assert len(list(variations(l, 3, repetition=True))) == 64
178
- assert len(list(variations(l, 4, repetition=True))) == 256
179
- assert list(variations(l[:2], 3, repetition=False)) == []
180
- assert list(variations(l[:2], 3, repetition=True)) == [
181
- (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1),
182
- (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)
183
- ]
184
-
185
-
186
- def test_cartes():
187
- assert list(cartes([1, 2], [3, 4, 5])) == \
188
- [(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5)]
189
- assert list(cartes()) == [()]
190
- assert list(cartes('a')) == [('a',)]
191
- assert list(cartes('a', repeat=2)) == [('a', 'a')]
192
- assert list(cartes(list(range(2)))) == [(0,), (1,)]
193
-
194
-
195
- def test_filter_symbols():
196
- s = numbered_symbols()
197
- filtered = filter_symbols(s, symbols("x0 x2 x3"))
198
- assert take(filtered, 3) == list(symbols("x1 x4 x5"))
199
-
200
-
201
- def test_numbered_symbols():
202
- s = numbered_symbols(cls=Dummy)
203
- assert isinstance(next(s), Dummy)
204
- assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
205
- symbols('C2')
206
-
207
-
208
- def test_sift():
209
- assert sift(list(range(5)), lambda _: _ % 2) == {1: [1, 3], 0: [0, 2, 4]}
210
- assert sift([x, y], lambda _: _.has(x)) == {False: [y], True: [x]}
211
- assert sift([S.One], lambda _: _.has(x)) == {False: [1]}
212
- assert sift([0, 1, 2, 3], lambda x: x % 2, binary=True) == (
213
- [1, 3], [0, 2])
214
- assert sift([0, 1, 2, 3], lambda x: x % 3 == 1, binary=True) == (
215
- [1], [0, 2, 3])
216
- raises(ValueError, lambda:
217
- sift([0, 1, 2, 3], lambda x: x % 3, binary=True))
218
-
219
-
220
- def test_take():
221
- X = numbered_symbols()
222
-
223
- assert take(X, 5) == list(symbols('x0:5'))
224
- assert take(X, 5) == list(symbols('x5:10'))
225
-
226
- assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
227
-
228
-
229
- def test_dict_merge():
230
- assert dict_merge({}, {1: x, y: z}) == {1: x, y: z}
231
- assert dict_merge({1: x, y: z}, {}) == {1: x, y: z}
232
-
233
- assert dict_merge({2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
234
- assert dict_merge({1: x, y: z}, {2: z}) == {1: x, 2: z, y: z}
235
-
236
- assert dict_merge({1: y, 2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
237
- assert dict_merge({1: x, y: z}, {1: y, 2: z}) == {1: y, 2: z, y: z}
238
-
239
-
240
- def test_prefixes():
241
- assert list(prefixes([])) == []
242
- assert list(prefixes([1])) == [[1]]
243
- assert list(prefixes([1, 2])) == [[1], [1, 2]]
244
-
245
- assert list(prefixes([1, 2, 3, 4, 5])) == \
246
- [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]
247
-
248
-
249
- def test_postfixes():
250
- assert list(postfixes([])) == []
251
- assert list(postfixes([1])) == [[1]]
252
- assert list(postfixes([1, 2])) == [[2], [1, 2]]
253
-
254
- assert list(postfixes([1, 2, 3, 4, 5])) == \
255
- [[5], [4, 5], [3, 4, 5], [2, 3, 4, 5], [1, 2, 3, 4, 5]]
256
-
257
-
258
- def test_topological_sort():
259
- V = [2, 3, 5, 7, 8, 9, 10, 11]
260
- E = [(7, 11), (7, 8), (5, 11),
261
- (3, 8), (3, 10), (11, 2),
262
- (11, 9), (11, 10), (8, 9)]
263
-
264
- assert topological_sort((V, E)) == [3, 5, 7, 8, 11, 2, 9, 10]
265
- assert topological_sort((V, E), key=lambda v: -v) == \
266
- [7, 5, 11, 3, 10, 8, 9, 2]
267
-
268
- raises(ValueError, lambda: topological_sort((V, E + [(10, 7)])))
269
-
270
-
271
- def test_strongly_connected_components():
272
- assert strongly_connected_components(([], [])) == []
273
- assert strongly_connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
274
-
275
- V = [1, 2, 3]
276
- E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
277
- assert strongly_connected_components((V, E)) == [[1, 2, 3]]
278
-
279
- V = [1, 2, 3, 4]
280
- E = [(1, 2), (2, 3), (3, 2), (3, 4)]
281
- assert strongly_connected_components((V, E)) == [[4], [2, 3], [1]]
282
-
283
- V = [1, 2, 3, 4]
284
- E = [(1, 2), (2, 1), (3, 4), (4, 3)]
285
- assert strongly_connected_components((V, E)) == [[1, 2], [3, 4]]
286
-
287
-
288
- def test_connected_components():
289
- assert connected_components(([], [])) == []
290
- assert connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
291
-
292
- V = [1, 2, 3]
293
- E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
294
- assert connected_components((V, E)) == [[1, 2, 3]]
295
-
296
- V = [1, 2, 3, 4]
297
- E = [(1, 2), (2, 3), (3, 2), (3, 4)]
298
- assert connected_components((V, E)) == [[1, 2, 3, 4]]
299
-
300
- V = [1, 2, 3, 4]
301
- E = [(1, 2), (3, 4)]
302
- assert connected_components((V, E)) == [[1, 2], [3, 4]]
303
-
304
-
305
- def test_rotate():
306
- A = [0, 1, 2, 3, 4]
307
-
308
- assert rotate_left(A, 2) == [2, 3, 4, 0, 1]
309
- assert rotate_right(A, 1) == [4, 0, 1, 2, 3]
310
- A = []
311
- B = rotate_right(A, 1)
312
- assert B == []
313
- B.append(1)
314
- assert A == []
315
- B = rotate_left(A, 1)
316
- assert B == []
317
- B.append(1)
318
- assert A == []
319
-
320
-
321
- def test_multiset_partitions():
322
- A = [0, 1, 2, 3, 4]
323
-
324
- assert list(multiset_partitions(A, 5)) == [[[0], [1], [2], [3], [4]]]
325
- assert len(list(multiset_partitions(A, 4))) == 10
326
- assert len(list(multiset_partitions(A, 3))) == 25
327
-
328
- assert list(multiset_partitions([1, 1, 1, 2, 2], 2)) == [
329
- [[1, 1, 1, 2], [2]], [[1, 1, 1], [2, 2]], [[1, 1, 2, 2], [1]],
330
- [[1, 1, 2], [1, 2]], [[1, 1], [1, 2, 2]]]
331
-
332
- assert list(multiset_partitions([1, 1, 2, 2], 2)) == [
333
- [[1, 1, 2], [2]], [[1, 1], [2, 2]], [[1, 2, 2], [1]],
334
- [[1, 2], [1, 2]]]
335
-
336
- assert list(multiset_partitions([1, 2, 3, 4], 2)) == [
337
- [[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]],
338
- [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]],
339
- [[1], [2, 3, 4]]]
340
-
341
- assert list(multiset_partitions([1, 2, 2], 2)) == [
342
- [[1, 2], [2]], [[1], [2, 2]]]
343
-
344
- assert list(multiset_partitions(3)) == [
345
- [[0, 1, 2]], [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]],
346
- [[0], [1], [2]]]
347
- assert list(multiset_partitions(3, 2)) == [
348
- [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]]]
349
- assert list(multiset_partitions([1] * 3, 2)) == [[[1], [1, 1]]]
350
- assert list(multiset_partitions([1] * 3)) == [
351
- [[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]]
352
- a = [3, 2, 1]
353
- assert list(multiset_partitions(a)) == \
354
- list(multiset_partitions(sorted(a)))
355
- assert list(multiset_partitions(a, 5)) == []
356
- assert list(multiset_partitions(a, 1)) == [[[1, 2, 3]]]
357
- assert list(multiset_partitions(a + [4], 5)) == []
358
- assert list(multiset_partitions(a + [4], 1)) == [[[1, 2, 3, 4]]]
359
- assert list(multiset_partitions(2, 5)) == []
360
- assert list(multiset_partitions(2, 1)) == [[[0, 1]]]
361
- assert list(multiset_partitions('a')) == [[['a']]]
362
- assert list(multiset_partitions('a', 2)) == []
363
- assert list(multiset_partitions('ab')) == [[['a', 'b']], [['a'], ['b']]]
364
- assert list(multiset_partitions('ab', 1)) == [[['a', 'b']]]
365
- assert list(multiset_partitions('aaa', 1)) == [['aaa']]
366
- assert list(multiset_partitions([1, 1], 1)) == [[[1, 1]]]
367
- ans = [('mpsyy',), ('mpsy', 'y'), ('mps', 'yy'), ('mps', 'y', 'y'),
368
- ('mpyy', 's'), ('mpy', 'sy'), ('mpy', 's', 'y'), ('mp', 'syy'),
369
- ('mp', 'sy', 'y'), ('mp', 's', 'yy'), ('mp', 's', 'y', 'y'),
370
- ('msyy', 'p'), ('msy', 'py'), ('msy', 'p', 'y'), ('ms', 'pyy'),
371
- ('ms', 'py', 'y'), ('ms', 'p', 'yy'), ('ms', 'p', 'y', 'y'),
372
- ('myy', 'ps'), ('myy', 'p', 's'), ('my', 'psy'), ('my', 'ps', 'y'),
373
- ('my', 'py', 's'), ('my', 'p', 'sy'), ('my', 'p', 's', 'y'),
374
- ('m', 'psyy'), ('m', 'psy', 'y'), ('m', 'ps', 'yy'),
375
- ('m', 'ps', 'y', 'y'), ('m', 'pyy', 's'), ('m', 'py', 'sy'),
376
- ('m', 'py', 's', 'y'), ('m', 'p', 'syy'),
377
- ('m', 'p', 'sy', 'y'), ('m', 'p', 's', 'yy'),
378
- ('m', 'p', 's', 'y', 'y')]
379
- assert [tuple("".join(part) for part in p)
380
- for p in multiset_partitions('sympy')] == ans
381
- factorings = [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3],
382
- [6, 2, 2], [2, 2, 2, 3]]
383
- assert [factoring_visitor(p, [2,3]) for
384
- p in multiset_partitions_taocp([3, 1])] == factorings
385
-
386
-
387
- def test_multiset_combinations():
388
- ans = ['iii', 'iim', 'iip', 'iis', 'imp', 'ims', 'ipp', 'ips',
389
- 'iss', 'mpp', 'mps', 'mss', 'pps', 'pss', 'sss']
390
- assert [''.join(i) for i in
391
- list(multiset_combinations('mississippi', 3))] == ans
392
- M = multiset('mississippi')
393
- assert [''.join(i) for i in
394
- list(multiset_combinations(M, 3))] == ans
395
- assert [''.join(i) for i in multiset_combinations(M, 30)] == []
396
- assert list(multiset_combinations([[1], [2, 3]], 2)) == [[[1], [2, 3]]]
397
- assert len(list(multiset_combinations('a', 3))) == 0
398
- assert len(list(multiset_combinations('a', 0))) == 1
399
- assert list(multiset_combinations('abc', 1)) == [['a'], ['b'], ['c']]
400
- raises(ValueError, lambda: list(multiset_combinations({0: 3, 1: -1}, 2)))
401
-
402
-
403
- def test_multiset_permutations():
404
- ans = ['abby', 'abyb', 'aybb', 'baby', 'bayb', 'bbay', 'bbya', 'byab',
405
- 'byba', 'yabb', 'ybab', 'ybba']
406
- assert [''.join(i) for i in multiset_permutations('baby')] == ans
407
- assert [''.join(i) for i in multiset_permutations(multiset('baby'))] == ans
408
- assert list(multiset_permutations([0, 0, 0], 2)) == [[0, 0]]
409
- assert list(multiset_permutations([0, 2, 1], 2)) == [
410
- [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]
411
- assert len(list(multiset_permutations('a', 0))) == 1
412
- assert len(list(multiset_permutations('a', 3))) == 0
413
- for nul in ([], {}, ''):
414
- assert list(multiset_permutations(nul)) == [[]]
415
- assert list(multiset_permutations(nul, 0)) == [[]]
416
- # impossible requests give no result
417
- assert list(multiset_permutations(nul, 1)) == []
418
- assert list(multiset_permutations(nul, -1)) == []
419
-
420
- def test():
421
- for i in range(1, 7):
422
- print(i)
423
- for p in multiset_permutations([0, 0, 1, 0, 1], i):
424
- print(p)
425
- assert capture(lambda: test()) == dedent('''\
426
- 1
427
- [0]
428
- [1]
429
- 2
430
- [0, 0]
431
- [0, 1]
432
- [1, 0]
433
- [1, 1]
434
- 3
435
- [0, 0, 0]
436
- [0, 0, 1]
437
- [0, 1, 0]
438
- [0, 1, 1]
439
- [1, 0, 0]
440
- [1, 0, 1]
441
- [1, 1, 0]
442
- 4
443
- [0, 0, 0, 1]
444
- [0, 0, 1, 0]
445
- [0, 0, 1, 1]
446
- [0, 1, 0, 0]
447
- [0, 1, 0, 1]
448
- [0, 1, 1, 0]
449
- [1, 0, 0, 0]
450
- [1, 0, 0, 1]
451
- [1, 0, 1, 0]
452
- [1, 1, 0, 0]
453
- 5
454
- [0, 0, 0, 1, 1]
455
- [0, 0, 1, 0, 1]
456
- [0, 0, 1, 1, 0]
457
- [0, 1, 0, 0, 1]
458
- [0, 1, 0, 1, 0]
459
- [0, 1, 1, 0, 0]
460
- [1, 0, 0, 0, 1]
461
- [1, 0, 0, 1, 0]
462
- [1, 0, 1, 0, 0]
463
- [1, 1, 0, 0, 0]
464
- 6\n''')
465
- raises(ValueError, lambda: list(multiset_permutations({0: 3, 1: -1})))
466
-
467
-
468
- def test_partitions():
469
- ans = [[{}], [(0, {})]]
470
- for i in range(2):
471
- assert list(partitions(0, size=i)) == ans[i]
472
- assert list(partitions(1, 0, size=i)) == ans[i]
473
- assert list(partitions(6, 2, 2, size=i)) == ans[i]
474
- assert list(partitions(6, 2, None, size=i)) != ans[i]
475
- assert list(partitions(6, None, 2, size=i)) != ans[i]
476
- assert list(partitions(6, 2, 0, size=i)) == ans[i]
477
-
478
- assert list(partitions(6, k=2)) == [
479
- {2: 3}, {1: 2, 2: 2}, {1: 4, 2: 1}, {1: 6}]
480
-
481
- assert list(partitions(6, k=3)) == [
482
- {3: 2}, {1: 1, 2: 1, 3: 1}, {1: 3, 3: 1}, {2: 3}, {1: 2, 2: 2},
483
- {1: 4, 2: 1}, {1: 6}]
484
-
485
- assert list(partitions(8, k=4, m=3)) == [
486
- {4: 2}, {1: 1, 3: 1, 4: 1}, {2: 2, 4: 1}, {2: 1, 3: 2}] == [
487
- i for i in partitions(8, k=4, m=3) if all(k <= 4 for k in i)
488
- and sum(i.values()) <=3]
489
-
490
- assert list(partitions(S(3), m=2)) == [
491
- {3: 1}, {1: 1, 2: 1}]
492
-
493
- assert list(partitions(4, k=3)) == [
494
- {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] == [
495
- i for i in partitions(4) if all(k <= 3 for k in i)]
496
-
497
-
498
- # Consistency check on output of _partitions and RGS_unrank.
499
- # This provides a sanity test on both routines. Also verifies that
500
- # the total number of partitions is the same in each case.
501
- # (from pkrathmann2)
502
-
503
- for n in range(2, 6):
504
- i = 0
505
- for m, q in _set_partitions(n):
506
- assert q == RGS_unrank(i, n)
507
- i += 1
508
- assert i == RGS_enum(n)
509
-
510
-
511
- def test_binary_partitions():
512
- assert [i[:] for i in binary_partitions(10)] == [[8, 2], [8, 1, 1],
513
- [4, 4, 2], [4, 4, 1, 1], [4, 2, 2, 2], [4, 2, 2, 1, 1],
514
- [4, 2, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2],
515
- [2, 2, 2, 2, 1, 1], [2, 2, 2, 1, 1, 1, 1], [2, 2, 1, 1, 1, 1, 1, 1],
516
- [2, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
517
-
518
- assert len([j[:] for j in binary_partitions(16)]) == 36
519
-
520
-
521
- def test_bell_perm():
522
- assert [len(set(generate_bell(i))) for i in range(1, 7)] == [
523
- factorial(i) for i in range(1, 7)]
524
- assert list(generate_bell(3)) == [
525
- (0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)]
526
- # generate_bell and trotterjohnson are advertised to return the same
527
- # permutations; this is not technically necessary so this test could
528
- # be removed
529
- for n in range(1, 5):
530
- p = Permutation(range(n))
531
- b = generate_bell(n)
532
- for bi in b:
533
- assert bi == tuple(p.array_form)
534
- p = p.next_trotterjohnson()
535
- raises(ValueError, lambda: list(generate_bell(0))) # XXX is this consistent with other permutation algorithms?
536
-
537
-
538
- def test_involutions():
539
- lengths = [1, 2, 4, 10, 26, 76]
540
- for n, N in enumerate(lengths):
541
- i = list(generate_involutions(n + 1))
542
- assert len(i) == N
543
- assert len({Permutation(j)**2 for j in i}) == 1
544
-
545
-
546
- def test_derangements():
547
- assert len(list(generate_derangements(list(range(6))))) == 265
548
- assert ''.join(''.join(i) for i in generate_derangements('abcde')) == (
549
- 'badecbaecdbcaedbcdeabceadbdaecbdeacbdecabeacdbedacbedcacabedcadebcaebd'
550
- 'cdaebcdbeacdeabcdebaceabdcebadcedabcedbadabecdaebcdaecbdcaebdcbeadceab'
551
- 'dcebadeabcdeacbdebacdebcaeabcdeadbceadcbecabdecbadecdabecdbaedabcedacb'
552
- 'edbacedbca')
553
- assert list(generate_derangements([0, 1, 2, 3])) == [
554
- [1, 0, 3, 2], [1, 2, 3, 0], [1, 3, 0, 2], [2, 0, 3, 1],
555
- [2, 3, 0, 1], [2, 3, 1, 0], [3, 0, 1, 2], [3, 2, 0, 1], [3, 2, 1, 0]]
556
- assert list(generate_derangements([0, 1, 2, 2])) == [
557
- [2, 2, 0, 1], [2, 2, 1, 0]]
558
- assert list(generate_derangements('ba')) == [list('ab')]
559
- # multiset_derangements
560
- D = multiset_derangements
561
- assert list(D('abb')) == []
562
- assert [''.join(i) for i in D('ab')] == ['ba']
563
- assert [''.join(i) for i in D('abc')] == ['bca', 'cab']
564
- assert [''.join(i) for i in D('aabb')] == ['bbaa']
565
- assert [''.join(i) for i in D('aabbcccc')] == [
566
- 'ccccaabb', 'ccccabab', 'ccccabba', 'ccccbaab', 'ccccbaba',
567
- 'ccccbbaa']
568
- assert [''.join(i) for i in D('aabbccc')] == [
569
- 'cccabba', 'cccabab', 'cccaabb', 'ccacbba', 'ccacbab',
570
- 'ccacabb', 'cbccbaa', 'cbccaba', 'cbccaab', 'bcccbaa',
571
- 'bcccaba', 'bcccaab']
572
- assert [''.join(i) for i in D('books')] == ['kbsoo', 'ksboo',
573
- 'sbkoo', 'skboo', 'oksbo', 'oskbo', 'okbso', 'obkso', 'oskob',
574
- 'oksob', 'osbok', 'obsok']
575
- assert list(generate_derangements([[3], [2], [2], [1]])) == [
576
- [[2], [1], [3], [2]], [[2], [3], [1], [2]]]
577
-
578
-
579
- def test_necklaces():
580
- def count(n, k, f):
581
- return len(list(necklaces(n, k, f)))
582
- m = []
583
- for i in range(1, 8):
584
- m.append((
585
- i, count(i, 2, 0), count(i, 2, 1), count(i, 3, 1)))
586
- assert Matrix(m) == Matrix([
587
- [1, 2, 2, 3],
588
- [2, 3, 3, 6],
589
- [3, 4, 4, 10],
590
- [4, 6, 6, 21],
591
- [5, 8, 8, 39],
592
- [6, 14, 13, 92],
593
- [7, 20, 18, 198]])
594
-
595
-
596
- def test_bracelets():
597
- bc = list(bracelets(2, 4))
598
- assert Matrix(bc) == Matrix([
599
- [0, 0],
600
- [0, 1],
601
- [0, 2],
602
- [0, 3],
603
- [1, 1],
604
- [1, 2],
605
- [1, 3],
606
- [2, 2],
607
- [2, 3],
608
- [3, 3]
609
- ])
610
- bc = list(bracelets(4, 2))
611
- assert Matrix(bc) == Matrix([
612
- [0, 0, 0, 0],
613
- [0, 0, 0, 1],
614
- [0, 0, 1, 1],
615
- [0, 1, 0, 1],
616
- [0, 1, 1, 1],
617
- [1, 1, 1, 1]
618
- ])
619
-
620
-
621
- def test_generate_oriented_forest():
622
- assert list(generate_oriented_forest(5)) == [[0, 1, 2, 3, 4],
623
- [0, 1, 2, 3, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 0],
624
- [0, 1, 2, 2, 2], [0, 1, 2, 2, 1], [0, 1, 2, 2, 0], [0, 1, 2, 1, 2],
625
- [0, 1, 2, 1, 1], [0, 1, 2, 1, 0], [0, 1, 2, 0, 1], [0, 1, 2, 0, 0],
626
- [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1], [0, 1, 1, 0, 0],
627
- [0, 1, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]]
628
- assert len(list(generate_oriented_forest(10))) == 1842
629
-
630
-
631
- def test_unflatten():
632
- r = list(range(10))
633
- assert unflatten(r) == list(zip(r[::2], r[1::2]))
634
- assert unflatten(r, 5) == [tuple(r[:5]), tuple(r[5:])]
635
- raises(ValueError, lambda: unflatten(list(range(10)), 3))
636
- raises(ValueError, lambda: unflatten(list(range(10)), -2))
637
-
638
-
639
- def test_common_prefix_suffix():
640
- assert common_prefix([], [1]) == []
641
- assert common_prefix(list(range(3))) == [0, 1, 2]
642
- assert common_prefix(list(range(3)), list(range(4))) == [0, 1, 2]
643
- assert common_prefix([1, 2, 3], [1, 2, 5]) == [1, 2]
644
- assert common_prefix([1, 2, 3], [1, 3, 5]) == [1]
645
-
646
- assert common_suffix([], [1]) == []
647
- assert common_suffix(list(range(3))) == [0, 1, 2]
648
- assert common_suffix(list(range(3)), list(range(3))) == [0, 1, 2]
649
- assert common_suffix(list(range(3)), list(range(4))) == []
650
- assert common_suffix([1, 2, 3], [9, 2, 3]) == [2, 3]
651
- assert common_suffix([1, 2, 3], [9, 7, 3]) == [3]
652
-
653
-
654
- def test_minlex():
655
- assert minlex([1, 2, 0]) == (0, 1, 2)
656
- assert minlex((1, 2, 0)) == (0, 1, 2)
657
- assert minlex((1, 0, 2)) == (0, 2, 1)
658
- assert minlex((1, 0, 2), directed=False) == (0, 1, 2)
659
- assert minlex('aba') == 'aab'
660
- assert minlex(('bb', 'aaa', 'c', 'a'), key=len) == ('c', 'a', 'bb', 'aaa')
661
-
662
-
663
- def test_ordered():
664
- assert list(ordered((x, y), hash, default=False)) in [[x, y], [y, x]]
665
- assert list(ordered((x, y), hash, default=False)) == \
666
- list(ordered((y, x), hash, default=False))
667
- assert list(ordered((x, y))) == [x, y]
668
-
669
- seq, keys = [[[1, 2, 1], [0, 3, 1], [1, 1, 3], [2], [1]],
670
- (lambda x: len(x), lambda x: sum(x))]
671
- assert list(ordered(seq, keys, default=False, warn=False)) == \
672
- [[1], [2], [1, 2, 1], [0, 3, 1], [1, 1, 3]]
673
- raises(ValueError, lambda:
674
- list(ordered(seq, keys, default=False, warn=True)))
675
-
676
-
677
- def test_runs():
678
- assert runs([]) == []
679
- assert runs([1]) == [[1]]
680
- assert runs([1, 1]) == [[1], [1]]
681
- assert runs([1, 1, 2]) == [[1], [1, 2]]
682
- assert runs([1, 2, 1]) == [[1, 2], [1]]
683
- assert runs([2, 1, 1]) == [[2], [1], [1]]
684
- from operator import lt
685
- assert runs([2, 1, 1], lt) == [[2, 1], [1]]
686
-
687
-
688
- def test_reshape():
689
- seq = list(range(1, 9))
690
- assert reshape(seq, [4]) == \
691
- [[1, 2, 3, 4], [5, 6, 7, 8]]
692
- assert reshape(seq, (4,)) == \
693
- [(1, 2, 3, 4), (5, 6, 7, 8)]
694
- assert reshape(seq, (2, 2)) == \
695
- [(1, 2, 3, 4), (5, 6, 7, 8)]
696
- assert reshape(seq, (2, [2])) == \
697
- [(1, 2, [3, 4]), (5, 6, [7, 8])]
698
- assert reshape(seq, ((2,), [2])) == \
699
- [((1, 2), [3, 4]), ((5, 6), [7, 8])]
700
- assert reshape(seq, (1, [2], 1)) == \
701
- [(1, [2, 3], 4), (5, [6, 7], 8)]
702
- assert reshape(tuple(seq), ([[1], 1, (2,)],)) == \
703
- (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],))
704
- assert reshape(tuple(seq), ([1], 1, (2,))) == \
705
- (([1], 2, (3, 4)), ([5], 6, (7, 8)))
706
- assert reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) == \
707
- [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]]
708
- raises(ValueError, lambda: reshape([0, 1], [-1]))
709
- raises(ValueError, lambda: reshape([0, 1], [3]))
710
-
711
-
712
- def test_uniq():
713
- assert list(uniq(p for p in partitions(4))) == \
714
- [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}]
715
- assert list(uniq(x % 2 for x in range(5))) == [0, 1]
716
- assert list(uniq('a')) == ['a']
717
- assert list(uniq('ababc')) == list('abc')
718
- assert list(uniq([[1], [2, 1], [1]])) == [[1], [2, 1]]
719
- assert list(uniq(permutations(i for i in [[1], 2, 2]))) == \
720
- [([1], 2, 2), (2, [1], 2), (2, 2, [1])]
721
- assert list(uniq([2, 3, 2, 4, [2], [1], [2], [3], [1]])) == \
722
- [2, 3, 4, [2], [1], [3]]
723
- f = [1]
724
- raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
725
- f = [[1]]
726
- raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
727
-
728
-
729
- def test_kbins():
730
- assert len(list(kbins('1123', 2, ordered=1))) == 24
731
- assert len(list(kbins('1123', 2, ordered=11))) == 36
732
- assert len(list(kbins('1123', 2, ordered=10))) == 10
733
- assert len(list(kbins('1123', 2, ordered=0))) == 5
734
- assert len(list(kbins('1123', 2, ordered=None))) == 3
735
-
736
- def test1():
737
- for orderedval in [None, 0, 1, 10, 11]:
738
- print('ordered =', orderedval)
739
- for p in kbins([0, 0, 1], 2, ordered=orderedval):
740
- print(' ', p)
741
- assert capture(lambda : test1()) == dedent('''\
742
- ordered = None
743
- [[0], [0, 1]]
744
- [[0, 0], [1]]
745
- ordered = 0
746
- [[0, 0], [1]]
747
- [[0, 1], [0]]
748
- ordered = 1
749
- [[0], [0, 1]]
750
- [[0], [1, 0]]
751
- [[1], [0, 0]]
752
- ordered = 10
753
- [[0, 0], [1]]
754
- [[1], [0, 0]]
755
- [[0, 1], [0]]
756
- [[0], [0, 1]]
757
- ordered = 11
758
- [[0], [0, 1]]
759
- [[0, 0], [1]]
760
- [[0], [1, 0]]
761
- [[0, 1], [0]]
762
- [[1], [0, 0]]
763
- [[1, 0], [0]]\n''')
764
-
765
- def test2():
766
- for orderedval in [None, 0, 1, 10, 11]:
767
- print('ordered =', orderedval)
768
- for p in kbins(list(range(3)), 2, ordered=orderedval):
769
- print(' ', p)
770
- assert capture(lambda : test2()) == dedent('''\
771
- ordered = None
772
- [[0], [1, 2]]
773
- [[0, 1], [2]]
774
- ordered = 0
775
- [[0, 1], [2]]
776
- [[0, 2], [1]]
777
- [[0], [1, 2]]
778
- ordered = 1
779
- [[0], [1, 2]]
780
- [[0], [2, 1]]
781
- [[1], [0, 2]]
782
- [[1], [2, 0]]
783
- [[2], [0, 1]]
784
- [[2], [1, 0]]
785
- ordered = 10
786
- [[0, 1], [2]]
787
- [[2], [0, 1]]
788
- [[0, 2], [1]]
789
- [[1], [0, 2]]
790
- [[0], [1, 2]]
791
- [[1, 2], [0]]
792
- ordered = 11
793
- [[0], [1, 2]]
794
- [[0, 1], [2]]
795
- [[0], [2, 1]]
796
- [[0, 2], [1]]
797
- [[1], [0, 2]]
798
- [[1, 0], [2]]
799
- [[1], [2, 0]]
800
- [[1, 2], [0]]
801
- [[2], [0, 1]]
802
- [[2, 0], [1]]
803
- [[2], [1, 0]]
804
- [[2, 1], [0]]\n''')
805
-
806
-
807
- def test_has_dups():
808
- assert has_dups(set()) is False
809
- assert has_dups(list(range(3))) is False
810
- assert has_dups([1, 2, 1]) is True
811
- assert has_dups([[1], [1]]) is True
812
- assert has_dups([[1], [2]]) is False
813
-
814
-
815
- def test__partition():
816
- assert _partition('abcde', [1, 0, 1, 2, 0]) == [
817
- ['b', 'e'], ['a', 'c'], ['d']]
818
- assert _partition('abcde', [1, 0, 1, 2, 0], 3) == [
819
- ['b', 'e'], ['a', 'c'], ['d']]
820
- output = (3, [1, 0, 1, 2, 0])
821
- assert _partition('abcde', *output) == [['b', 'e'], ['a', 'c'], ['d']]
822
-
823
-
824
- def test_ordered_partitions():
825
- from sympy.functions.combinatorial.numbers import nT
826
- f = ordered_partitions
827
- assert list(f(0, 1)) == [[]]
828
- assert list(f(1, 0)) == [[]]
829
- for i in range(1, 7):
830
- for j in [None] + list(range(1, i)):
831
- assert (
832
- sum(1 for p in f(i, j, 1)) ==
833
- sum(1 for p in f(i, j, 0)) ==
834
- nT(i, j))
835
-
836
-
837
- def test_rotations():
838
- assert list(rotations('ab')) == [['a', 'b'], ['b', 'a']]
839
- assert list(rotations(range(3))) == [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
840
- assert list(rotations(range(3), dir=-1)) == [[0, 1, 2], [2, 0, 1], [1, 2, 0]]
841
-
842
-
843
- def test_ibin():
844
- assert ibin(3) == [1, 1]
845
- assert ibin(3, 3) == [0, 1, 1]
846
- assert ibin(3, str=True) == '11'
847
- assert ibin(3, 3, str=True) == '011'
848
- assert list(ibin(2, 'all')) == [(0, 0), (0, 1), (1, 0), (1, 1)]
849
- assert list(ibin(2, '', str=True)) == ['00', '01', '10', '11']
850
- raises(ValueError, lambda: ibin(-.5))
851
- raises(ValueError, lambda: ibin(2, 1))
852
-
853
-
854
- def test_iterable():
855
- assert iterable(0) is False
856
- assert iterable(1) is False
857
- assert iterable(None) is False
858
-
859
- class Test1(NotIterable):
860
- pass
861
-
862
- assert iterable(Test1()) is False
863
-
864
- class Test2(NotIterable):
865
- _iterable = True
866
-
867
- assert iterable(Test2()) is True
868
-
869
- class Test3:
870
- pass
871
-
872
- assert iterable(Test3()) is False
873
-
874
- class Test4:
875
- _iterable = True
876
-
877
- assert iterable(Test4()) is True
878
-
879
- class Test5:
880
- def __iter__(self):
881
- yield 1
882
-
883
- assert iterable(Test5()) is True
884
-
885
- class Test6(Test5):
886
- _iterable = False
887
-
888
- assert iterable(Test6()) is False
889
-
890
-
891
- def test_sequence_partitions():
892
- assert list(sequence_partitions([1], 1)) == [[[1]]]
893
- assert list(sequence_partitions([1, 2], 1)) == [[[1, 2]]]
894
- assert list(sequence_partitions([1, 2], 2)) == [[[1], [2]]]
895
- assert list(sequence_partitions([1, 2, 3], 1)) == [[[1, 2, 3]]]
896
- assert list(sequence_partitions([1, 2, 3], 2)) == \
897
- [[[1], [2, 3]], [[1, 2], [3]]]
898
- assert list(sequence_partitions([1, 2, 3], 3)) == [[[1], [2], [3]]]
899
-
900
- # Exceptional cases
901
- assert list(sequence_partitions([], 0)) == []
902
- assert list(sequence_partitions([], 1)) == []
903
- assert list(sequence_partitions([1, 2], 0)) == []
904
- assert list(sequence_partitions([1, 2], 3)) == []
905
-
906
-
907
- def test_sequence_partitions_empty():
908
- assert list(sequence_partitions_empty([], 1)) == [[[]]]
909
- assert list(sequence_partitions_empty([], 2)) == [[[], []]]
910
- assert list(sequence_partitions_empty([], 3)) == [[[], [], []]]
911
- assert list(sequence_partitions_empty([1], 1)) == [[[1]]]
912
- assert list(sequence_partitions_empty([1], 2)) == [[[], [1]], [[1], []]]
913
- assert list(sequence_partitions_empty([1], 3)) == \
914
- [[[], [], [1]], [[], [1], []], [[1], [], []]]
915
- assert list(sequence_partitions_empty([1, 2], 1)) == [[[1, 2]]]
916
- assert list(sequence_partitions_empty([1, 2], 2)) == \
917
- [[[], [1, 2]], [[1], [2]], [[1, 2], []]]
918
- assert list(sequence_partitions_empty([1, 2], 3)) == [
919
- [[], [], [1, 2]], [[], [1], [2]], [[], [1, 2], []],
920
- [[1], [], [2]], [[1], [2], []], [[1, 2], [], []]
921
- ]
922
- assert list(sequence_partitions_empty([1, 2, 3], 1)) == [[[1, 2, 3]]]
923
- assert list(sequence_partitions_empty([1, 2, 3], 2)) == \
924
- [[[], [1, 2, 3]], [[1], [2, 3]], [[1, 2], [3]], [[1, 2, 3], []]]
925
- assert list(sequence_partitions_empty([1, 2, 3], 3)) == [
926
- [[], [], [1, 2, 3]], [[], [1], [2, 3]],
927
- [[], [1, 2], [3]], [[], [1, 2, 3], []],
928
- [[1], [], [2, 3]], [[1], [2], [3]],
929
- [[1], [2, 3], []], [[1, 2], [], [3]],
930
- [[1, 2], [3], []], [[1, 2, 3], [], []]
931
- ]
932
-
933
- # Exceptional cases
934
- assert list(sequence_partitions([], 0)) == []
935
- assert list(sequence_partitions([1], 0)) == []
936
- assert list(sequence_partitions([1, 2], 0)) == []
937
-
938
-
939
- def test_signed_permutations():
940
- ans = [(0, 1, 1), (0, -1, 1), (0, 1, -1), (0, -1, -1),
941
- (1, 0, 1), (-1, 0, 1), (1, 0, -1), (-1, 0, -1),
942
- (1, 1, 0), (-1, 1, 0), (1, -1, 0), (-1, -1, 0)]
943
- assert list(signed_permutations((0, 1, 1))) == ans
944
- assert list(signed_permutations((1, 0, 1))) == ans
945
- assert list(signed_permutations((1, 1, 0))) == ans
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_lambdify.py DELETED
@@ -1,2263 +0,0 @@
1
- from itertools import product
2
- import math
3
- import inspect
4
- import linecache
5
- import gc
6
-
7
- import mpmath
8
- import cmath
9
-
10
- from sympy.testing.pytest import raises, warns_deprecated_sympy
11
- from sympy.concrete.summations import Sum
12
- from sympy.core.function import (Function, Lambda, diff)
13
- from sympy.core.numbers import (E, Float, I, Rational, all_close, oo, pi)
14
- from sympy.core.relational import Eq
15
- from sympy.core.singleton import S
16
- from sympy.core.symbol import (Dummy, symbols)
17
- from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
18
- from sympy.functions.combinatorial.numbers import bernoulli, harmonic
19
- from sympy.functions.elementary.complexes import Abs, sign
20
- from sympy.functions.elementary.exponential import exp, log
21
- from sympy.functions.elementary.hyperbolic import asinh,acosh,atanh
22
- from sympy.functions.elementary.integers import floor
23
- from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt)
24
- from sympy.functions.elementary.piecewise import Piecewise
25
- from sympy.functions.elementary.trigonometric import (asin, acos, atan, cos, cot, sin,
26
- sinc, tan)
27
- from sympy.functions import sinh,cosh,tanh
28
- from sympy.functions.special.bessel import (besseli, besselj, besselk, bessely, jn, yn)
29
- from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized)
30
- from sympy.functions.special.delta_functions import (Heaviside)
31
- from sympy.functions.special.error_functions import (Ei, erf, erfc, fresnelc, fresnels, Si, Ci)
32
- from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, polygamma)
33
- from sympy.functions.special.zeta_functions import zeta
34
- from sympy.integrals.integrals import Integral
35
- from sympy.logic.boolalg import (And, false, ITE, Not, Or, true)
36
- from sympy.matrices.expressions.dotproduct import DotProduct
37
- from sympy.simplify.cse_main import cse
38
- from sympy.tensor.array import derive_by_array, Array
39
- from sympy.tensor.array.expressions import ArraySymbol
40
- from sympy.tensor.indexed import IndexedBase, Idx
41
- from sympy.utilities.lambdify import lambdify
42
- from sympy.utilities.iterables import numbered_symbols
43
- from sympy.vector import CoordSys3D
44
- from sympy.core.expr import UnevaluatedExpr
45
- from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, log10, hypot, isnan, isinf
46
- from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, amin, amax, minimum, maximum
47
- from sympy.codegen.scipy_nodes import cosm1, powm1
48
- from sympy.functions.elementary.complexes import re, im, arg
49
- from sympy.functions.special.polynomials import \
50
- chebyshevt, chebyshevu, legendre, hermite, laguerre, gegenbauer, \
51
- assoc_legendre, assoc_laguerre, jacobi
52
- from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
53
- from sympy.printing.codeprinter import PrintMethodNotImplementedError
54
- from sympy.printing.lambdarepr import LambdaPrinter
55
- from sympy.printing.numpy import NumPyPrinter
56
- from sympy.utilities.lambdify import implemented_function, lambdastr
57
- from sympy.testing.pytest import skip
58
- from sympy.utilities.decorator import conserve_mpmath_dps
59
- from sympy.utilities.exceptions import ignore_warnings
60
- from sympy.external import import_module
61
- from sympy.functions.special.gamma_functions import uppergamma, lowergamma
62
-
63
-
64
- import sympy
65
-
66
-
67
- MutableDenseMatrix = Matrix
68
-
69
- numpy = import_module('numpy')
70
- scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']})
71
- numexpr = import_module('numexpr')
72
- tensorflow = import_module('tensorflow')
73
- cupy = import_module('cupy')
74
- jax = import_module('jax')
75
- numba = import_module('numba')
76
-
77
- if tensorflow:
78
- # Hide Tensorflow warnings
79
- import os
80
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
81
-
82
- w, x, y, z = symbols('w,x,y,z')
83
-
84
- #================== Test different arguments =======================
85
-
86
-
87
- def test_no_args():
88
- f = lambdify([], 1)
89
- raises(TypeError, lambda: f(-1))
90
- assert f() == 1
91
-
92
-
93
- def test_single_arg():
94
- f = lambdify(x, 2*x)
95
- assert f(1) == 2
96
-
97
-
98
- def test_list_args():
99
- f = lambdify([x, y], x + y)
100
- assert f(1, 2) == 3
101
-
102
-
103
- def test_nested_args():
104
- f1 = lambdify([[w, x]], [w, x])
105
- assert f1([91, 2]) == [91, 2]
106
- raises(TypeError, lambda: f1(1, 2))
107
-
108
- f2 = lambdify([(w, x), (y, z)], [w, x, y, z])
109
- assert f2((18, 12), (73, 4)) == [18, 12, 73, 4]
110
- raises(TypeError, lambda: f2(3, 4))
111
-
112
- f3 = lambdify([w, [[[x]], y], z], [w, x, y, z])
113
- assert f3(10, [[[52]], 31], 44) == [10, 52, 31, 44]
114
-
115
-
116
- def test_str_args():
117
- f = lambdify('x,y,z', 'z,y,x')
118
- assert f(3, 2, 1) == (1, 2, 3)
119
- assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
120
- # make sure correct number of args required
121
- raises(TypeError, lambda: f(0))
122
-
123
-
124
- def test_own_namespace_1():
125
- myfunc = lambda x: 1
126
- f = lambdify(x, sin(x), {"sin": myfunc})
127
- assert f(0.1) == 1
128
- assert f(100) == 1
129
-
130
-
131
- def test_own_namespace_2():
132
- def myfunc(x):
133
- return 1
134
- f = lambdify(x, sin(x), {'sin': myfunc})
135
- assert f(0.1) == 1
136
- assert f(100) == 1
137
-
138
-
139
- def test_own_module():
140
- f = lambdify(x, sin(x), math)
141
- assert f(0) == 0.0
142
-
143
- p, q, r = symbols("p q r", real=True)
144
- ae = abs(exp(p+UnevaluatedExpr(q+r)))
145
- f = lambdify([p, q, r], [ae, ae], modules=math)
146
- results = f(1.0, 1e18, -1e18)
147
- refvals = [math.exp(1.0)]*2
148
- for res, ref in zip(results, refvals):
149
- assert abs((res-ref)/ref) < 1e-15
150
-
151
-
152
- def test_bad_args():
153
- # no vargs given
154
- raises(TypeError, lambda: lambdify(1))
155
- # same with vector exprs
156
- raises(TypeError, lambda: lambdify([1, 2]))
157
-
158
-
159
- def test_atoms():
160
- # Non-Symbol atoms should not be pulled out from the expression namespace
161
- f = lambdify(x, pi + x, {"pi": 3.14})
162
- assert f(0) == 3.14
163
- f = lambdify(x, I + x, {"I": 1j})
164
- assert f(1) == 1 + 1j
165
-
166
- #================== Test different modules =========================
167
-
168
- # high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted
169
-
170
-
171
- @conserve_mpmath_dps
172
- def test_sympy_lambda():
173
- mpmath.mp.dps = 50
174
- sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
175
- f = lambdify(x, sin(x), "sympy")
176
- assert f(x) == sin(x)
177
- prec = 1e-15
178
- assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec
179
- # arctan is in numpy module and should not be available
180
- # The arctan below gives NameError. What is this supposed to test?
181
- # raises(NameError, lambda: lambdify(x, arctan(x), "sympy"))
182
-
183
-
184
- @conserve_mpmath_dps
185
- def test_math_lambda():
186
- mpmath.mp.dps = 50
187
- sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
188
- f = lambdify(x, sin(x), "math")
189
- prec = 1e-15
190
- assert -prec < f(0.2) - sin02 < prec
191
- raises(TypeError, lambda: f(x))
192
- # if this succeeds, it can't be a Python math function
193
-
194
-
195
- @conserve_mpmath_dps
196
- def test_mpmath_lambda():
197
- mpmath.mp.dps = 50
198
- sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
199
- f = lambdify(x, sin(x), "mpmath")
200
- prec = 1e-49 # mpmath precision is around 50 decimal places
201
- assert -prec < f(mpmath.mpf("0.2")) - sin02 < prec
202
- raises(TypeError, lambda: f(x))
203
- # if this succeeds, it can't be a mpmath function
204
-
205
- ref2 = (mpmath.mpf("1e-30")
206
- - mpmath.mpf("1e-45")/2
207
- + 5*mpmath.mpf("1e-60")/6
208
- - 3*mpmath.mpf("1e-75")/4
209
- + 33*mpmath.mpf("1e-90")/40
210
- )
211
- f2a = lambdify((x, y), x**y - 1, "mpmath")
212
- f2b = lambdify((x, y), powm1(x, y), "mpmath")
213
- f2c = lambdify((x,), expm1(x*log1p(x)), "mpmath")
214
- ans2a = f2a(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
215
- ans2b = f2b(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
216
- ans2c = f2c(mpmath.mpf("1e-15"))
217
- assert abs(ans2a - ref2) < 1e-51
218
- assert abs(ans2b - ref2) < 1e-67
219
- assert abs(ans2c - ref2) < 1e-80
220
-
221
-
222
- @conserve_mpmath_dps
223
- def test_number_precision():
224
- mpmath.mp.dps = 50
225
- sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
226
- f = lambdify(x, sin02, "mpmath")
227
- prec = 1e-49 # mpmath precision is around 50 decimal places
228
- assert -prec < f(0) - sin02 < prec
229
-
230
- @conserve_mpmath_dps
231
- def test_mpmath_precision():
232
- mpmath.mp.dps = 100
233
- assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100))
234
-
235
- #================== Test Translations ==============================
236
- # We can only check if all translated functions are valid. It has to be checked
237
- # by hand if they are complete.
238
-
239
-
240
- def test_math_transl():
241
- from sympy.utilities.lambdify import MATH_TRANSLATIONS
242
- for sym, mat in MATH_TRANSLATIONS.items():
243
- assert sym in sympy.__dict__
244
- assert mat in math.__dict__
245
-
246
-
247
- def test_mpmath_transl():
248
- from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
249
- for sym, mat in MPMATH_TRANSLATIONS.items():
250
- assert sym in sympy.__dict__ or sym == 'Matrix'
251
- assert mat in mpmath.__dict__
252
-
253
-
254
- def test_numpy_transl():
255
- if not numpy:
256
- skip("numpy not installed.")
257
-
258
- from sympy.utilities.lambdify import NUMPY_TRANSLATIONS
259
- for sym, nump in NUMPY_TRANSLATIONS.items():
260
- assert sym in sympy.__dict__
261
- assert nump in numpy.__dict__
262
-
263
-
264
- def test_scipy_transl():
265
- if not scipy:
266
- skip("scipy not installed.")
267
-
268
- from sympy.utilities.lambdify import SCIPY_TRANSLATIONS
269
- for sym, scip in SCIPY_TRANSLATIONS.items():
270
- assert sym in sympy.__dict__
271
- assert scip in scipy.__dict__ or scip in scipy.special.__dict__
272
-
273
-
274
- def test_numpy_translation_abs():
275
- if not numpy:
276
- skip("numpy not installed.")
277
-
278
- f = lambdify(x, Abs(x), "numpy")
279
- assert f(-1) == 1
280
- assert f(1) == 1
281
-
282
-
283
- def test_numexpr_printer():
284
- if not numexpr:
285
- skip("numexpr not installed.")
286
-
287
- # if translation/printing is done incorrectly then evaluating
288
- # a lambdified numexpr expression will throw an exception
289
- from sympy.printing.lambdarepr import NumExprPrinter
290
-
291
- blacklist = ('where', 'complex', 'contains')
292
- arg_tuple = (x, y, z) # some functions take more than one argument
293
- for sym in NumExprPrinter._numexpr_functions.keys():
294
- if sym in blacklist:
295
- continue
296
- ssym = S(sym)
297
- if hasattr(ssym, '_nargs'):
298
- nargs = ssym._nargs[0]
299
- else:
300
- nargs = 1
301
- args = arg_tuple[:nargs]
302
- f = lambdify(args, ssym(*args), modules='numexpr')
303
- assert f(*(1, )*nargs) is not None
304
-
305
-
306
- def test_cmath_sqrt():
307
- f = lambdify(x, sqrt(x), "cmath")
308
- assert f(0) == 0
309
- assert f(1) == 1
310
- assert f(4) == 2
311
- assert abs(f(2) - 1.414) < 0.001
312
- assert f(-1) == 1j
313
- assert f(-4) == 2j
314
-
315
-
316
- def test_cmath_log():
317
- f = lambdify(x, log(x), "cmath")
318
- assert abs(f(1) - 0) < 1e-15
319
- assert abs(f(cmath.e) - 1) < 1e-15
320
- assert abs(f(-1) - cmath.log(-1)) < 1e-15
321
-
322
-
323
- def test_cmath_sinh():
324
- f = lambdify(x, sinh(x), "cmath")
325
- assert abs(f(0) - cmath.sinh(0)) < 1e-15
326
- assert abs(f(pi) - cmath.sinh(pi)) < 1e-15
327
- assert abs(f(-pi) - cmath.sinh(-pi)) < 1e-15
328
- assert abs(f(1j) - cmath.sinh(1j)) < 1e-15
329
-
330
-
331
- def test_cmath_cosh():
332
- f = lambdify(x, cosh(x), "cmath")
333
- assert abs(f(0) - cmath.cosh(0)) < 1e-15
334
- assert abs(f(pi) - cmath.cosh(pi)) < 1e-15
335
- assert abs(f(-pi) - cmath.cosh(-pi)) < 1e-15
336
- assert abs(f(1j) - cmath.cosh(1j)) < 1e-15
337
-
338
-
339
- def test_cmath_tanh():
340
- f = lambdify(x, tanh(x), "cmath")
341
- assert abs(f(0) - cmath.tanh(0)) < 1e-15
342
- assert abs(f(pi) - cmath.tanh(pi)) < 1e-15
343
- assert abs(f(-pi) - cmath.tanh(-pi)) < 1e-15
344
- assert abs(f(1j) - cmath.tanh(1j)) < 1e-15
345
-
346
-
347
- def test_cmath_sin():
348
- f = lambdify(x, sin(x), "cmath")
349
- assert abs(f(0) - cmath.sin(0)) < 1e-15
350
- assert abs(f(pi) - cmath.sin(pi)) < 1e-15
351
- assert abs(f(-pi) - cmath.sin(-pi)) < 1e-15
352
- assert abs(f(1j) - cmath.sin(1j)) < 1e-15
353
-
354
-
355
- def test_cmath_cos():
356
- f = lambdify(x, cos(x), "cmath")
357
- assert abs(f(0) - cmath.cos(0)) < 1e-15
358
- assert abs(f(pi) - cmath.cos(pi)) < 1e-15
359
- assert abs(f(-pi) - cmath.cos(-pi)) < 1e-15
360
- assert abs(f(1j) - cmath.cos(1j)) < 1e-15
361
-
362
-
363
- def test_cmath_tan():
364
- f = lambdify(x, tan(x), "cmath")
365
- assert abs(f(0) - cmath.tan(0)) < 1e-15
366
- assert abs(f(1j) - cmath.tan(1j)) < 1e-15
367
-
368
-
369
- def test_cmath_asin():
370
- f = lambdify(x, asin(x), "cmath")
371
- assert abs(f(0) - cmath.asin(0)) < 1e-15
372
- assert abs(f(1) - cmath.asin(1)) < 1e-15
373
- assert abs(f(-1) - cmath.asin(-1)) < 1e-15
374
- assert abs(f(2) - cmath.asin(2)) < 1e-15
375
- assert abs(f(1j) - cmath.asin(1j)) < 1e-15
376
-
377
-
378
- def test_cmath_acos():
379
- f = lambdify(x, acos(x), "cmath")
380
- assert abs(f(1) - cmath.acos(1)) < 1e-15
381
- assert abs(f(-1) - cmath.acos(-1)) < 1e-15
382
- assert abs(f(2) - cmath.acos(2)) < 1e-15
383
- assert abs(f(1j) - cmath.acos(1j)) < 1e-15
384
-
385
-
386
- def test_cmath_atan():
387
- f = lambdify(x, atan(x), "cmath")
388
- assert abs(f(0) - cmath.atan(0)) < 1e-15
389
- assert abs(f(1) - cmath.atan(1)) < 1e-15
390
- assert abs(f(-1) - cmath.atan(-1)) < 1e-15
391
- assert abs(f(2) - cmath.atan(2)) < 1e-15
392
- assert abs(f(2j) - cmath.atan(2j)) < 1e-15
393
-
394
-
395
- def test_cmath_asinh():
396
- f = lambdify(x, asinh(x), "cmath")
397
- assert abs(f(0) - cmath.asinh(0)) < 1e-15
398
- assert abs(f(1) - cmath.asinh(1)) < 1e-15
399
- assert abs(f(-1) - cmath.asinh(-1)) < 1e-15
400
- assert abs(f(2) - cmath.asinh(2)) < 1e-15
401
- assert abs(f(2j) - cmath.asinh(2j)) < 1e-15
402
-
403
-
404
- def test_cmath_acosh():
405
- f = lambdify(x, acosh(x), "cmath")
406
- assert abs(f(1) - cmath.acosh(1)) < 1e-15
407
- assert abs(f(2) - cmath.acosh(2)) < 1e-15
408
- assert abs(f(-1) - cmath.acosh(-1)) < 1e-15
409
- assert abs(f(2j) - cmath.acosh(2j)) < 1e-15
410
-
411
-
412
- def test_cmath_atanh():
413
- f = lambdify(x, atanh(x), "cmath")
414
- assert abs(f(0) - cmath.atanh(0)) < 1e-15
415
- assert abs(f(0.5) - cmath.atanh(0.5)) < 1e-15
416
- assert abs(f(-0.5) - cmath.atanh(-0.5)) < 1e-15
417
- assert abs(f(2) - cmath.atanh(2)) < 1e-15
418
- assert abs(f(-2) - cmath.atanh(-2)) < 1e-15
419
- assert abs(f(2j) - cmath.atanh(2j)) < 1e-15
420
-
421
-
422
- def test_cmath_complex_identities():
423
- # Define symbol
424
- z = symbols('z')
425
-
426
- # Trigonometric identity using re(z) and im(z)
427
- expr = cos(z) - cos(re(z)) * cosh(im(z)) + I * sin(re(z)) * sinh(im(z))
428
- func = lambdify([z], expr, modules=["cmath", "math"])
429
- hpi = math.pi / 2
430
- assert abs(func(hpi + 1j * hpi)) < 4e-16
431
-
432
- # Euler's Formula: e^(i*z) = cos(z) + i*sin(z)
433
- func = lambdify([z], exp(I * z) - (cos(z) + I * sin(z)), modules=["cmath", "math"])
434
- assert abs(func(hpi)) < 4e-16
435
-
436
- # Exponential Identity: e^z = e^(Re(z)) * (cos(Im(z)) + i*sin(Im(z)))
437
- func_exp = lambdify([z], exp(z) - exp(re(z)) * (cos(im(z)) + I * sin(im(z))),
438
- modules=["cmath", "math"])
439
- assert abs(func_exp(hpi + 1j * hpi)) < 4e-16
440
-
441
- # Complex Cosine Identity: cos(z) = cos(Re(z)) * cosh(Im(z)) - i*sin(Re(z)) * sinh(Im(z))
442
- func_cos = lambdify([z], cos(z) - (cos(re(z)) * cosh(im(z)) - I * sin(re(z)) * sinh(im(z))),
443
- modules=["cmath", "math"])
444
- assert abs(func_cos(hpi + 1j * hpi)) < 4e-16
445
-
446
- # Complex Sine Identity: sin(z) = sin(Re(z)) * cosh(Im(z)) + i*cos(Re(z)) * sinh(Im(z))
447
- func_sin = lambdify([z], sin(z) - (sin(re(z)) * cosh(im(z)) + I * cos(re(z)) * sinh(im(z))),
448
- modules=["cmath", "math"])
449
- assert abs(func_sin(hpi + 1j * hpi)) < 4e-16
450
-
451
- # Complex Hyperbolic Cosine Identity: cosh(z) = cosh(Re(z)) * cos(Im(z)) + i*sinh(Re(z)) * sin(Im(z))
452
- func_cosh_1 = lambdify([z], cosh(z) - (cosh(re(z)) * cos(im(z)) + I * sinh(re(z)) * sin(im(z))),
453
- modules=["cmath", "math"])
454
- assert abs(func_cosh_1(hpi + 1j * hpi)) < 4e-16
455
-
456
- # Complex Hyperbolic Sine Identity: sinh(z) = sinh(Re(z)) * cos(Im(z)) + i*cosh(Re(z)) * sin(Im(z))
457
- func_sinh = lambdify([z], sinh(z) - (sinh(re(z)) * cos(im(z)) + I * cosh(re(z)) * sin(im(z))),
458
- modules=["cmath", "math"])
459
- assert abs(func_sinh(hpi + 1j * hpi)) < 4e-16
460
-
461
- # cosh(z) = (e^z + e^(-z)) / 2
462
- func_cosh_2 = lambdify([z], cosh(z) - (exp(z) + exp(-z)) / 2, modules=["cmath", "math"])
463
- assert abs(func_cosh_2(hpi)) < 4e-16
464
-
465
- # Additional expressions testing log and exp with real and imaginary parts
466
- expr1 = log(re(z)) + log(im(z)) - log(re(z) * im(z))
467
- expr2 = exp(re(z)) * exp(im(z) * I) - exp(z)
468
- expr3 = log(exp(re(z))) - re(z)
469
- expr4 = exp(log(re(z))) - re(z)
470
- expr5 = log(exp(re(z) + im(z))) - (re(z) + im(z))
471
- expr6 = exp(log(re(z) + im(z))) - (re(z) + im(z))
472
- func1 = lambdify([z], expr1, modules=["cmath", "math"])
473
- func2 = lambdify([z], expr2, modules=["cmath", "math"])
474
- func3 = lambdify([z], expr3, modules=["cmath", "math"])
475
- func4 = lambdify([z], expr4, modules=["cmath", "math"])
476
- func5 = lambdify([z], expr5, modules=["cmath", "math"])
477
- func6 = lambdify([z], expr6, modules=["cmath", "math"])
478
- test_value = 3 + 4j
479
- assert abs(func1(test_value)) < 4e-16
480
- assert abs(func2(test_value)) < 4e-16
481
- assert abs(func3(test_value)) < 4e-16
482
- assert abs(func4(test_value)) < 4e-16
483
- assert abs(func5(test_value)) < 4e-16
484
- assert abs(func6(test_value)) < 4e-16
485
-
486
-
487
- def test_issue_9334():
488
- if not numexpr:
489
- skip("numexpr not installed.")
490
- if not numpy:
491
- skip("numpy not installed.")
492
- expr = S('b*a - sqrt(a**2)')
493
- a, b = sorted(expr.free_symbols, key=lambda s: s.name)
494
- func_numexpr = lambdify((a,b), expr, modules=[numexpr], dummify=False)
495
- foo, bar = numpy.random.random((2, 4))
496
- func_numexpr(foo, bar)
497
-
498
-
499
- def test_issue_12984():
500
- if not numexpr:
501
- skip("numexpr not installed.")
502
- func_numexpr = lambdify((x,y,z), Piecewise((y, x >= 0), (z, x > -1)), numexpr)
503
- with ignore_warnings(RuntimeWarning):
504
- assert func_numexpr(1, 24, 42) == 24
505
- assert str(func_numexpr(-1, 24, 42)) == 'nan'
506
-
507
-
508
- def test_empty_modules():
509
- x, y = symbols('x y')
510
- expr = -(x % y)
511
-
512
- no_modules = lambdify([x, y], expr)
513
- empty_modules = lambdify([x, y], expr, modules=[])
514
- assert no_modules(3, 7) == empty_modules(3, 7)
515
- assert no_modules(3, 7) == -3
516
-
517
-
518
- def test_exponentiation():
519
- f = lambdify(x, x**2)
520
- assert f(-1) == 1
521
- assert f(0) == 0
522
- assert f(1) == 1
523
- assert f(-2) == 4
524
- assert f(2) == 4
525
- assert f(2.5) == 6.25
526
-
527
-
528
- def test_sqrt():
529
- f = lambdify(x, sqrt(x))
530
- assert f(0) == 0.0
531
- assert f(1) == 1.0
532
- assert f(4) == 2.0
533
- assert abs(f(2) - 1.414) < 0.001
534
- assert f(6.25) == 2.5
535
-
536
-
537
- def test_trig():
538
- f = lambdify([x], [cos(x), sin(x)], 'math')
539
- d = f(pi)
540
- prec = 1e-11
541
- assert -prec < d[0] + 1 < prec
542
- assert -prec < d[1] < prec
543
- d = f(3.14159)
544
- prec = 1e-5
545
- assert -prec < d[0] + 1 < prec
546
- assert -prec < d[1] < prec
547
-
548
-
549
- def test_integral():
550
- if numpy and not scipy:
551
- skip("scipy not installed.")
552
- f = Lambda(x, exp(-x**2))
553
- l = lambdify(y, Integral(f(x), (x, y, oo)))
554
- d = l(-oo)
555
- assert 1.77245385 < d < 1.772453851
556
-
557
-
558
- def test_double_integral():
559
- if numpy and not scipy:
560
- skip("scipy not installed.")
561
- # example from http://mpmath.org/doc/current/calculus/integration.html
562
- i = Integral(1/(1 - x**2*y**2), (x, 0, 1), (y, 0, z))
563
- l = lambdify([z], i)
564
- d = l(1)
565
- assert 1.23370055 < d < 1.233700551
566
-
567
- def test_spherical_bessel():
568
- if numpy and not scipy:
569
- skip("scipy not installed.")
570
- test_point = 4.2 #randomly selected
571
- x = symbols("x")
572
- jtest = jn(2, x)
573
- assert abs(lambdify(x,jtest)(test_point) -
574
- jtest.subs(x,test_point).evalf()) < 1e-8
575
- ytest = yn(2, x)
576
- assert abs(lambdify(x,ytest)(test_point) -
577
- ytest.subs(x,test_point).evalf()) < 1e-8
578
-
579
-
580
- #================== Test vectors ===================================
581
-
582
-
583
- def test_vector_simple():
584
- f = lambdify((x, y, z), (z, y, x))
585
- assert f(3, 2, 1) == (1, 2, 3)
586
- assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
587
- # make sure correct number of args required
588
- raises(TypeError, lambda: f(0))
589
-
590
-
591
- def test_vector_discontinuous():
592
- f = lambdify(x, (-1/x, 1/x))
593
- raises(ZeroDivisionError, lambda: f(0))
594
- assert f(1) == (-1.0, 1.0)
595
- assert f(2) == (-0.5, 0.5)
596
- assert f(-2) == (0.5, -0.5)
597
-
598
-
599
- def test_trig_symbolic():
600
- f = lambdify([x], [cos(x), sin(x)], 'math')
601
- d = f(pi)
602
- assert abs(d[0] + 1) < 0.0001
603
- assert abs(d[1] - 0) < 0.0001
604
-
605
-
606
- def test_trig_float():
607
- f = lambdify([x], [cos(x), sin(x)])
608
- d = f(3.14159)
609
- assert abs(d[0] + 1) < 0.0001
610
- assert abs(d[1] - 0) < 0.0001
611
-
612
-
613
- def test_docs():
614
- f = lambdify(x, x**2)
615
- assert f(2) == 4
616
- f = lambdify([x, y, z], [z, y, x])
617
- assert f(1, 2, 3) == [3, 2, 1]
618
- f = lambdify(x, sqrt(x))
619
- assert f(4) == 2.0
620
- f = lambdify((x, y), sin(x*y)**2)
621
- assert f(0, 5) == 0
622
-
623
-
624
- def test_math():
625
- f = lambdify((x, y), sin(x), modules="math")
626
- assert f(0, 5) == 0
627
-
628
-
629
- def test_sin():
630
- f = lambdify(x, sin(x)**2)
631
- assert isinstance(f(2), float)
632
- f = lambdify(x, sin(x)**2, modules="math")
633
- assert isinstance(f(2), float)
634
-
635
-
636
- def test_matrix():
637
- A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
638
- sol = Matrix([[1, 2], [sin(3) + 4, 1]])
639
- f = lambdify((x, y, z), A, modules="sympy")
640
- assert f(1, 2, 3) == sol
641
- f = lambdify((x, y, z), (A, [A]), modules="sympy")
642
- assert f(1, 2, 3) == (sol, [sol])
643
- J = Matrix((x, x + y)).jacobian((x, y))
644
- v = Matrix((x, y))
645
- sol = Matrix([[1, 0], [1, 1]])
646
- assert lambdify(v, J, modules='sympy')(1, 2) == sol
647
- assert lambdify(v.T, J, modules='sympy')(1, 2) == sol
648
-
649
-
650
- def test_numpy_matrix():
651
- if not numpy:
652
- skip("numpy not installed.")
653
- A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
654
- sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
655
- #Lambdify array first, to ensure return to array as default
656
- f = lambdify((x, y, z), A, ['numpy'])
657
- numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
658
- #Check that the types are arrays and matrices
659
- assert isinstance(f(1, 2, 3), numpy.ndarray)
660
-
661
- # gh-15071
662
- class dot(Function):
663
- pass
664
- x_dot_mtx = dot(x, Matrix([[2], [1], [0]]))
665
- f_dot1 = lambdify(x, x_dot_mtx)
666
- inp = numpy.zeros((17, 3))
667
- assert numpy.all(f_dot1(inp) == 0)
668
-
669
- strict_kw = {"allow_unknown_functions": False, "inline": True, "fully_qualified_modules": False}
670
- p2 = NumPyPrinter(dict(user_functions={'dot': 'dot'}, **strict_kw))
671
- f_dot2 = lambdify(x, x_dot_mtx, printer=p2)
672
- assert numpy.all(f_dot2(inp) == 0)
673
-
674
- p3 = NumPyPrinter(strict_kw)
675
- # The line below should probably fail upon construction (before calling with "(inp)"):
676
- raises(Exception, lambda: lambdify(x, x_dot_mtx, printer=p3)(inp))
677
-
678
-
679
- def test_numpy_transpose():
680
- if not numpy:
681
- skip("numpy not installed.")
682
- A = Matrix([[1, x], [0, 1]])
683
- f = lambdify((x), A.T, modules="numpy")
684
- numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]]))
685
-
686
-
687
- def test_numpy_dotproduct():
688
- if not numpy:
689
- skip("numpy not installed")
690
- A = Matrix([x, y, z])
691
- f1 = lambdify([x, y, z], DotProduct(A, A), modules='numpy')
692
- f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
693
- f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='numpy')
694
- f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
695
-
696
- assert f1(1, 2, 3) == \
697
- f2(1, 2, 3) == \
698
- f3(1, 2, 3) == \
699
- f4(1, 2, 3) == \
700
- numpy.array([14])
701
-
702
-
703
- def test_numpy_inverse():
704
- if not numpy:
705
- skip("numpy not installed.")
706
- A = Matrix([[1, x], [0, 1]])
707
- f = lambdify((x), A**-1, modules="numpy")
708
- numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]]))
709
-
710
-
711
- def test_numpy_old_matrix():
712
- if not numpy:
713
- skip("numpy not installed.")
714
- A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
715
- sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
716
- f = lambdify((x, y, z), A, [{'ImmutableDenseMatrix': numpy.matrix}, 'numpy'])
717
- with ignore_warnings(PendingDeprecationWarning):
718
- numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
719
- assert isinstance(f(1, 2, 3), numpy.matrix)
720
-
721
-
722
- def test_scipy_sparse_matrix():
723
- if not scipy:
724
- skip("scipy not installed.")
725
- A = SparseMatrix([[x, 0], [0, y]])
726
- f = lambdify((x, y), A, modules="scipy")
727
- B = f(1, 2)
728
- assert isinstance(B, scipy.sparse.coo_matrix)
729
-
730
-
731
- def test_python_div_zero_issue_11306():
732
- if not numpy:
733
- skip("numpy not installed.")
734
- p = Piecewise((1 / x, y < -1), (x, y < 1), (1 / x, True))
735
- f = lambdify([x, y], p, modules='numpy')
736
- with numpy.errstate(divide='ignore'):
737
- assert float(f(numpy.array(0), numpy.array(0.5))) == 0
738
- assert float(f(numpy.array(0), numpy.array(1))) == float('inf')
739
-
740
-
741
- def test_issue9474():
742
- mods = [None, 'math']
743
- if numpy:
744
- mods.append('numpy')
745
- if mpmath:
746
- mods.append('mpmath')
747
- for mod in mods:
748
- f = lambdify(x, S.One/x, modules=mod)
749
- assert f(2) == 0.5
750
- f = lambdify(x, floor(S.One/x), modules=mod)
751
- assert f(2) == 0
752
-
753
- for absfunc, modules in product([Abs, abs], mods):
754
- f = lambdify(x, absfunc(x), modules=modules)
755
- assert f(-1) == 1
756
- assert f(1) == 1
757
- assert f(3+4j) == 5
758
-
759
-
760
- def test_issue_9871():
761
- if not numexpr:
762
- skip("numexpr not installed.")
763
- if not numpy:
764
- skip("numpy not installed.")
765
-
766
- r = sqrt(x**2 + y**2)
767
- expr = diff(1/r, x)
768
-
769
- xn = yn = numpy.linspace(1, 10, 16)
770
- # expr(xn, xn) = -xn/(sqrt(2)*xn)^3
771
- fv_exact = -numpy.sqrt(2.)**-3 * xn**-2
772
-
773
- fv_numpy = lambdify((x, y), expr, modules='numpy')(xn, yn)
774
- fv_numexpr = lambdify((x, y), expr, modules='numexpr')(xn, yn)
775
- numpy.testing.assert_allclose(fv_numpy, fv_exact, rtol=1e-10)
776
- numpy.testing.assert_allclose(fv_numexpr, fv_exact, rtol=1e-10)
777
-
778
-
779
- def test_numpy_piecewise():
780
- if not numpy:
781
- skip("numpy not installed.")
782
- pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True))
783
- f = lambdify(x, pieces, modules="numpy")
784
- numpy.testing.assert_array_equal(f(numpy.arange(10)),
785
- numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81]))
786
- # If we evaluate somewhere all conditions are False, we should get back NaN
787
- nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0)))
788
- numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])),
789
- numpy.array([1, numpy.nan, 1]))
790
-
791
-
792
- def test_numpy_logical_ops():
793
- if not numpy:
794
- skip("numpy not installed.")
795
- and_func = lambdify((x, y), And(x, y), modules="numpy")
796
- and_func_3 = lambdify((x, y, z), And(x, y, z), modules="numpy")
797
- or_func = lambdify((x, y), Or(x, y), modules="numpy")
798
- or_func_3 = lambdify((x, y, z), Or(x, y, z), modules="numpy")
799
- not_func = lambdify((x), Not(x), modules="numpy")
800
- arr1 = numpy.array([True, True])
801
- arr2 = numpy.array([False, True])
802
- arr3 = numpy.array([True, False])
803
- numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True]))
804
- numpy.testing.assert_array_equal(and_func_3(arr1, arr2, arr3), numpy.array([False, False]))
805
- numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True]))
806
- numpy.testing.assert_array_equal(or_func_3(arr1, arr2, arr3), numpy.array([True, True]))
807
- numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False]))
808
-
809
-
810
- def test_numpy_matmul():
811
- if not numpy:
812
- skip("numpy not installed.")
813
- xmat = Matrix([[x, y], [z, 1+z]])
814
- ymat = Matrix([[x**2], [Abs(x)]])
815
- mat_func = lambdify((x, y, z), xmat*ymat, modules="numpy")
816
- numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]]))
817
- numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]]))
818
- # Multiple matrices chained together in multiplication
819
- f = lambdify((x, y, z), xmat*xmat*xmat, modules="numpy")
820
- numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25],
821
- [159, 251]]))
822
-
823
-
824
- def test_numpy_numexpr():
825
- if not numpy:
826
- skip("numpy not installed.")
827
- if not numexpr:
828
- skip("numexpr not installed.")
829
- a, b, c = numpy.random.randn(3, 128, 128)
830
- # ensure that numpy and numexpr return same value for complicated expression
831
- expr = sin(x) + cos(y) + tan(z)**2 + Abs(z-y)*acos(sin(y*z)) + \
832
- Abs(y-z)*acosh(2+exp(y-x))- sqrt(x**2+I*y**2)
833
- npfunc = lambdify((x, y, z), expr, modules='numpy')
834
- nefunc = lambdify((x, y, z), expr, modules='numexpr')
835
- assert numpy.allclose(npfunc(a, b, c), nefunc(a, b, c))
836
-
837
-
838
- def test_numexpr_userfunctions():
839
- if not numpy:
840
- skip("numpy not installed.")
841
- if not numexpr:
842
- skip("numexpr not installed.")
843
- a, b = numpy.random.randn(2, 10)
844
- uf = type('uf', (Function, ),
845
- {'eval' : classmethod(lambda x, y : y**2+1)})
846
- func = lambdify(x, 1-uf(x), modules='numexpr')
847
- assert numpy.allclose(func(a), -(a**2))
848
-
849
- uf = implemented_function(Function('uf'), lambda x, y : 2*x*y+1)
850
- func = lambdify((x, y), uf(x, y), modules='numexpr')
851
- assert numpy.allclose(func(a, b), 2*a*b+1)
852
-
853
-
854
- def test_tensorflow_basic_math():
855
- if not tensorflow:
856
- skip("tensorflow not installed.")
857
- expr = Max(sin(x), Abs(1/(x+2)))
858
- func = lambdify(x, expr, modules="tensorflow")
859
-
860
- with tensorflow.compat.v1.Session() as s:
861
- a = tensorflow.constant(0, dtype=tensorflow.float32)
862
- assert func(a).eval(session=s) == 0.5
863
-
864
-
865
- def test_tensorflow_placeholders():
866
- if not tensorflow:
867
- skip("tensorflow not installed.")
868
- expr = Max(sin(x), Abs(1/(x+2)))
869
- func = lambdify(x, expr, modules="tensorflow")
870
-
871
- with tensorflow.compat.v1.Session() as s:
872
- a = tensorflow.compat.v1.placeholder(dtype=tensorflow.float32)
873
- assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
874
-
875
-
876
- def test_tensorflow_variables():
877
- if not tensorflow:
878
- skip("tensorflow not installed.")
879
- expr = Max(sin(x), Abs(1/(x+2)))
880
- func = lambdify(x, expr, modules="tensorflow")
881
-
882
- with tensorflow.compat.v1.Session() as s:
883
- a = tensorflow.Variable(0, dtype=tensorflow.float32)
884
- s.run(a.initializer)
885
- assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
886
-
887
-
888
- def test_tensorflow_logical_operations():
889
- if not tensorflow:
890
- skip("tensorflow not installed.")
891
- expr = Not(And(Or(x, y), y))
892
- func = lambdify([x, y], expr, modules="tensorflow")
893
-
894
- with tensorflow.compat.v1.Session() as s:
895
- assert func(False, True).eval(session=s) == False
896
-
897
-
898
- def test_tensorflow_piecewise():
899
- if not tensorflow:
900
- skip("tensorflow not installed.")
901
- expr = Piecewise((0, Eq(x,0)), (-1, x < 0), (1, x > 0))
902
- func = lambdify(x, expr, modules="tensorflow")
903
-
904
- with tensorflow.compat.v1.Session() as s:
905
- assert func(-1).eval(session=s) == -1
906
- assert func(0).eval(session=s) == 0
907
- assert func(1).eval(session=s) == 1
908
-
909
-
910
- def test_tensorflow_multi_max():
911
- if not tensorflow:
912
- skip("tensorflow not installed.")
913
- expr = Max(x, -x, x**2)
914
- func = lambdify(x, expr, modules="tensorflow")
915
-
916
- with tensorflow.compat.v1.Session() as s:
917
- assert func(-2).eval(session=s) == 4
918
-
919
-
920
- def test_tensorflow_multi_min():
921
- if not tensorflow:
922
- skip("tensorflow not installed.")
923
- expr = Min(x, -x, x**2)
924
- func = lambdify(x, expr, modules="tensorflow")
925
-
926
- with tensorflow.compat.v1.Session() as s:
927
- assert func(-2).eval(session=s) == -2
928
-
929
-
930
- def test_tensorflow_relational():
931
- if not tensorflow:
932
- skip("tensorflow not installed.")
933
- expr = x >= 0
934
- func = lambdify(x, expr, modules="tensorflow")
935
-
936
- with tensorflow.compat.v1.Session() as s:
937
- assert func(1).eval(session=s) == True
938
-
939
-
940
- def test_tensorflow_complexes():
941
- if not tensorflow:
942
- skip("tensorflow not installed")
943
-
944
- func1 = lambdify(x, re(x), modules="tensorflow")
945
- func2 = lambdify(x, im(x), modules="tensorflow")
946
- func3 = lambdify(x, Abs(x), modules="tensorflow")
947
- func4 = lambdify(x, arg(x), modules="tensorflow")
948
-
949
- with tensorflow.compat.v1.Session() as s:
950
- # For versions before
951
- # https://github.com/tensorflow/tensorflow/issues/30029
952
- # resolved, using Python numeric types may not work
953
- a = tensorflow.constant(1+2j)
954
- assert func1(a).eval(session=s) == 1
955
- assert func2(a).eval(session=s) == 2
956
-
957
- tensorflow_result = func3(a).eval(session=s)
958
- sympy_result = Abs(1 + 2j).evalf()
959
- assert abs(tensorflow_result-sympy_result) < 10**-6
960
-
961
- tensorflow_result = func4(a).eval(session=s)
962
- sympy_result = arg(1 + 2j).evalf()
963
- assert abs(tensorflow_result-sympy_result) < 10**-6
964
-
965
-
966
- def test_tensorflow_array_arg():
967
- # Test for issue 14655 (tensorflow part)
968
- if not tensorflow:
969
- skip("tensorflow not installed.")
970
-
971
- f = lambdify([[x, y]], x*x + y, 'tensorflow')
972
-
973
- with tensorflow.compat.v1.Session() as s:
974
- fcall = f(tensorflow.constant([2.0, 1.0]))
975
- assert fcall.eval(session=s) == 5.0
976
-
977
-
978
- #================== Test symbolic ==================================
979
-
980
-
981
- def test_sym_single_arg():
982
- f = lambdify(x, x * y)
983
- assert f(z) == z * y
984
-
985
-
986
- def test_sym_list_args():
987
- f = lambdify([x, y], x + y + z)
988
- assert f(1, 2) == 3 + z
989
-
990
-
991
- def test_sym_integral():
992
- f = Lambda(x, exp(-x**2))
993
- l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules="sympy")
994
- assert l(y) == Integral(exp(-y**2), (y, -oo, oo))
995
- assert l(y).doit() == sqrt(pi)
996
-
997
-
998
- def test_namespace_order():
999
- # lambdify had a bug, such that module dictionaries or cached module
1000
- # dictionaries would pull earlier namespaces into themselves.
1001
- # Because the module dictionaries form the namespace of the
1002
- # generated lambda, this meant that the behavior of a previously
1003
- # generated lambda function could change as a result of later calls
1004
- # to lambdify.
1005
- n1 = {'f': lambda x: 'first f'}
1006
- n2 = {'f': lambda x: 'second f',
1007
- 'g': lambda x: 'function g'}
1008
- f = sympy.Function('f')
1009
- g = sympy.Function('g')
1010
- if1 = lambdify(x, f(x), modules=(n1, "sympy"))
1011
- assert if1(1) == 'first f'
1012
- if2 = lambdify(x, g(x), modules=(n2, "sympy"))
1013
- # previously gave 'second f'
1014
- assert if1(1) == 'first f'
1015
-
1016
- assert if2(1) == 'function g'
1017
-
1018
-
1019
- def test_imps():
1020
- # Here we check if the default returned functions are anonymous - in
1021
- # the sense that we can have more than one function with the same name
1022
- f = implemented_function('f', lambda x: 2*x)
1023
- g = implemented_function('f', lambda x: math.sqrt(x))
1024
- l1 = lambdify(x, f(x))
1025
- l2 = lambdify(x, g(x))
1026
- assert str(f(x)) == str(g(x))
1027
- assert l1(3) == 6
1028
- assert l2(3) == math.sqrt(3)
1029
- # check that we can pass in a Function as input
1030
- func = sympy.Function('myfunc')
1031
- assert not hasattr(func, '_imp_')
1032
- my_f = implemented_function(func, lambda x: 2*x)
1033
- assert hasattr(my_f, '_imp_')
1034
- # Error for functions with same name and different implementation
1035
- f2 = implemented_function("f", lambda x: x + 101)
1036
- raises(ValueError, lambda: lambdify(x, f(f2(x))))
1037
-
1038
-
1039
- def test_imps_errors():
1040
- # Test errors that implemented functions can return, and still be able to
1041
- # form expressions.
1042
- # See: https://github.com/sympy/sympy/issues/10810
1043
- #
1044
- # XXX: Removed AttributeError here. This test was added due to issue 10810
1045
- # but that issue was about ValueError. It doesn't seem reasonable to
1046
- # "support" catching AttributeError in the same context...
1047
- for val, error_class in product((0, 0., 2, 2.0), (TypeError, ValueError)):
1048
-
1049
- def myfunc(a):
1050
- if a == 0:
1051
- raise error_class
1052
- return 1
1053
-
1054
- f = implemented_function('f', myfunc)
1055
- expr = f(val)
1056
- assert expr == f(val)
1057
-
1058
-
1059
- def test_imps_wrong_args():
1060
- raises(ValueError, lambda: implemented_function(sin, lambda x: x))
1061
-
1062
-
1063
- def test_lambdify_imps():
1064
- # Test lambdify with implemented functions
1065
- # first test basic (sympy) lambdify
1066
- f = sympy.cos
1067
- assert lambdify(x, f(x))(0) == 1
1068
- assert lambdify(x, 1 + f(x))(0) == 2
1069
- assert lambdify((x, y), y + f(x))(0, 1) == 2
1070
- # make an implemented function and test
1071
- f = implemented_function("f", lambda x: x + 100)
1072
- assert lambdify(x, f(x))(0) == 100
1073
- assert lambdify(x, 1 + f(x))(0) == 101
1074
- assert lambdify((x, y), y + f(x))(0, 1) == 101
1075
- # Can also handle tuples, lists, dicts as expressions
1076
- lam = lambdify(x, (f(x), x))
1077
- assert lam(3) == (103, 3)
1078
- lam = lambdify(x, [f(x), x])
1079
- assert lam(3) == [103, 3]
1080
- lam = lambdify(x, [f(x), (f(x), x)])
1081
- assert lam(3) == [103, (103, 3)]
1082
- lam = lambdify(x, {f(x): x})
1083
- assert lam(3) == {103: 3}
1084
- lam = lambdify(x, {f(x): x})
1085
- assert lam(3) == {103: 3}
1086
- lam = lambdify(x, {x: f(x)})
1087
- assert lam(3) == {3: 103}
1088
- # Check that imp preferred to other namespaces by default
1089
- d = {'f': lambda x: x + 99}
1090
- lam = lambdify(x, f(x), d)
1091
- assert lam(3) == 103
1092
- # Unless flag passed
1093
- lam = lambdify(x, f(x), d, use_imps=False)
1094
- assert lam(3) == 102
1095
-
1096
-
1097
- def test_dummification():
1098
- t = symbols('t')
1099
- F = Function('F')
1100
- G = Function('G')
1101
- #"\alpha" is not a valid Python variable name
1102
- #lambdify should sub in a dummy for it, and return
1103
- #without a syntax error
1104
- alpha = symbols(r'\alpha')
1105
- some_expr = 2 * F(t)**2 / G(t)
1106
- lam = lambdify((F(t), G(t)), some_expr)
1107
- assert lam(3, 9) == 2
1108
- lam = lambdify(sin(t), 2 * sin(t)**2)
1109
- assert lam(F(t)) == 2 * F(t)**2
1110
- #Test that \alpha was properly dummified
1111
- lam = lambdify((alpha, t), 2*alpha + t)
1112
- assert lam(2, 1) == 5
1113
- raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5))
1114
- raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5))
1115
- raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5))
1116
-
1117
-
1118
- def test_lambdify__arguments_with_invalid_python_identifiers():
1119
- # see sympy/sympy#26690
1120
- N = CoordSys3D('N')
1121
- xn, yn, zn = N.base_scalars()
1122
- expr = xn + yn
1123
- f = lambdify([xn, yn], expr)
1124
- res = f(0.2, 0.3)
1125
- ref = 0.2 + 0.3
1126
- assert abs(res-ref) < 1e-15
1127
-
1128
-
1129
- def test_curly_matrix_symbol():
1130
- # Issue #15009
1131
- curlyv = sympy.MatrixSymbol("{v}", 2, 1)
1132
- lam = lambdify(curlyv, curlyv)
1133
- assert lam(1)==1
1134
- lam = lambdify(curlyv, curlyv, dummify=True)
1135
- assert lam(1)==1
1136
-
1137
-
1138
- def test_python_keywords():
1139
- # Test for issue 7452. The automatic dummification should ensure use of
1140
- # Python reserved keywords as symbol names will create valid lambda
1141
- # functions. This is an additional regression test.
1142
- python_if = symbols('if')
1143
- expr = python_if / 2
1144
- f = lambdify(python_if, expr)
1145
- assert f(4.0) == 2.0
1146
-
1147
-
1148
- def test_lambdify_docstring():
1149
- func = lambdify((w, x, y, z), w + x + y + z)
1150
- ref = (
1151
- "Created with lambdify. Signature:\n\n"
1152
- "func(w, x, y, z)\n\n"
1153
- "Expression:\n\n"
1154
- "w + x + y + z"
1155
- ).splitlines()
1156
- assert func.__doc__.splitlines()[:len(ref)] == ref
1157
- syms = symbols('a1:26')
1158
- func = lambdify(syms, sum(syms))
1159
- ref = (
1160
- "Created with lambdify. Signature:\n\n"
1161
- "func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n"
1162
- " a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n"
1163
- "Expression:\n\n"
1164
- "a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +..."
1165
- ).splitlines()
1166
- assert func.__doc__.splitlines()[:len(ref)] == ref
1167
-
1168
-
1169
- def test_lambdify_linecache():
1170
- func = lambdify(x, x + 1)
1171
- source = 'def _lambdifygenerated(x):\n return x + 1\n'
1172
- assert inspect.getsource(func) == source
1173
- filename = inspect.getsourcefile(func)
1174
- assert filename.startswith('<lambdifygenerated-')
1175
- assert filename in linecache.cache
1176
- assert linecache.cache[filename] == (len(source), None, source.splitlines(True), filename)
1177
- del func
1178
- gc.collect()
1179
- assert filename not in linecache.cache
1180
-
1181
- #================== Test special printers ==========================
1182
-
1183
-
1184
- def test_special_printers():
1185
- from sympy.printing.lambdarepr import IntervalPrinter
1186
-
1187
- def intervalrepr(expr):
1188
- return IntervalPrinter().doprint(expr)
1189
-
1190
- expr = sqrt(sqrt(2) + sqrt(3)) + S.Half
1191
-
1192
- func0 = lambdify((), expr, modules="mpmath", printer=intervalrepr)
1193
- func1 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter)
1194
- func2 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter())
1195
-
1196
- mpi = type(mpmath.mpi(1, 2))
1197
-
1198
- assert isinstance(func0(), mpi)
1199
- assert isinstance(func1(), mpi)
1200
- assert isinstance(func2(), mpi)
1201
-
1202
- # To check Is lambdify loggamma works for mpmath or not
1203
- exp1 = lambdify(x, loggamma(x), 'mpmath')(5)
1204
- exp2 = lambdify(x, loggamma(x), 'mpmath')(1.8)
1205
- exp3 = lambdify(x, loggamma(x), 'mpmath')(15)
1206
- exp_ls = [exp1, exp2, exp3]
1207
-
1208
- sol1 = mpmath.loggamma(5)
1209
- sol2 = mpmath.loggamma(1.8)
1210
- sol3 = mpmath.loggamma(15)
1211
- sol_ls = [sol1, sol2, sol3]
1212
-
1213
- assert exp_ls == sol_ls
1214
-
1215
-
1216
- def test_true_false():
1217
- # We want exact is comparison here, not just ==
1218
- assert lambdify([], true)() is True
1219
- assert lambdify([], false)() is False
1220
-
1221
-
1222
- def test_issue_2790():
1223
- assert lambdify((x, (y, z)), x + y)(1, (2, 4)) == 3
1224
- assert lambdify((x, (y, (w, z))), w + x + y + z)(1, (2, (3, 4))) == 10
1225
- assert lambdify(x, x + 1, dummify=False)(1) == 2
1226
-
1227
-
1228
- def test_issue_12092():
1229
- f = implemented_function('f', lambda x: x**2)
1230
- assert f(f(2)).evalf() == Float(16)
1231
-
1232
-
1233
- def test_issue_14911():
1234
- class Variable(sympy.Symbol):
1235
- def _sympystr(self, printer):
1236
- return printer.doprint(self.name)
1237
-
1238
- _lambdacode = _sympystr
1239
- _numpycode = _sympystr
1240
-
1241
- x = Variable('x')
1242
- y = 2 * x
1243
- code = LambdaPrinter().doprint(y)
1244
- assert code.replace(' ', '') == '2*x'
1245
-
1246
-
1247
- def test_ITE():
1248
- assert lambdify((x, y, z), ITE(x, y, z))(True, 5, 3) == 5
1249
- assert lambdify((x, y, z), ITE(x, y, z))(False, 5, 3) == 3
1250
-
1251
-
1252
- def test_Min_Max():
1253
- # see gh-10375
1254
- assert lambdify((x, y, z), Min(x, y, z))(1, 2, 3) == 1
1255
- assert lambdify((x, y, z), Max(x, y, z))(1, 2, 3) == 3
1256
-
1257
-
1258
- def test_amin_amax_minimum_maximum():
1259
- if not numpy:
1260
- skip("numpy not installed")
1261
-
1262
- a234 = numpy.array([2, 3, 4])
1263
- a152 = numpy.array([1, 5, 2])
1264
-
1265
- a254 = numpy.array([2, 5, 4])
1266
- a132 = numpy.array([1, 3, 2])
1267
- # 2 args
1268
- assert numpy.all(lambdify((x, y), maximum(x, y))(a234, a152) == a254)
1269
- assert numpy.all(lambdify((x, y), minimum(x, y))(a234, a152) == a132)
1270
-
1271
- # 3 args
1272
- assert numpy.all(lambdify((x, y, z), maximum(x, y, z))(a234, a152, a234) == a254)
1273
- assert numpy.all(lambdify((x, y, z), minimum(x, y, z))(a234, a152, a234) == a132)
1274
-
1275
- # 1 arg
1276
- assert numpy.all(lambdify((x,), maximum(x))(a234) == a234)
1277
- assert numpy.all(lambdify((x,), minimum(x))(a234) == a234)
1278
-
1279
- # 4 args, mixed length
1280
- assert numpy.all(lambdify((x, y, z, w), maximum(x, y, z, w))(a234, a152, a234, 3) == [3, 5, 4])
1281
- assert numpy.all(lambdify((x, y, z, w), minimum(x, y, z, w))(a234, a152, a234, 2) == [1, 2, 2])
1282
-
1283
- # amin & amax
1284
- assert lambdify((x, y), [amin(x), amax(y)])(a234, a152) == [2, 5]
1285
- A = numpy.array([
1286
- [0, 4, 8],
1287
- [1, 5, 9],
1288
- [2, 6, 10],
1289
- ])
1290
- min_, max_ = lambdify((x,), [amin(x, axis=0), amax(x, axis=1)])(A)
1291
- assert numpy.all(min_ == numpy.amin(A, axis=0))
1292
- assert numpy.all(max_ == numpy.amax(A, axis=1))
1293
-
1294
- # see gh-25659
1295
- assert numpy.all(lambdify((x, y), Max(x, y))([1, 2, 3], [3, 2, 1]) == [3, 2, 3])
1296
- assert numpy.all(lambdify((x), Min(2, x))([1, 2, 3]) == [1, 2, 2])
1297
-
1298
-
1299
-
1300
- def test_Indexed():
1301
- # Issue #10934
1302
- if not numpy:
1303
- skip("numpy not installed")
1304
-
1305
- a = IndexedBase('a')
1306
- i, j = symbols('i j')
1307
- b = numpy.array([[1, 2], [3, 4]])
1308
- assert lambdify(a, Sum(a[x, y], (x, 0, 1), (y, 0, 1)))(b) == 10
1309
-
1310
- def test_Sum():
1311
- e = Sum(z, (y, 0, x), (x, 0, 10))
1312
- ref = 66*z
1313
- assert e.doit() == ref
1314
- assert lambdify([z], e)(7) == ref.subs(z, 7)
1315
-
1316
- def test_Idx():
1317
- # Issue 26888
1318
- a = IndexedBase('a')
1319
- i = Idx('i')
1320
- b = [1,2,3]
1321
- assert lambdify([a, i], a[i])(b, 2) == 3
1322
-
1323
-
1324
- def test_issue_12173():
1325
- #test for issue 12173
1326
- expr1 = lambdify((x, y), uppergamma(x, y),"mpmath")(1, 2)
1327
- expr2 = lambdify((x, y), lowergamma(x, y),"mpmath")(1, 2)
1328
- assert expr1 == uppergamma(1, 2).evalf()
1329
- assert expr2 == lowergamma(1, 2).evalf()
1330
-
1331
-
1332
- def test_issue_13642():
1333
- if not numpy:
1334
- skip("numpy not installed")
1335
- f = lambdify(x, sinc(x))
1336
- assert Abs(f(1) - sinc(1)).n() < 1e-15
1337
-
1338
-
1339
- def test_sinc_mpmath():
1340
- f = lambdify(x, sinc(x), "mpmath")
1341
- assert Abs(f(1) - sinc(1)).n() < 1e-15
1342
-
1343
-
1344
- def test_lambdify_dummy_arg():
1345
- d1 = Dummy()
1346
- f1 = lambdify(d1, d1 + 1, dummify=False)
1347
- assert f1(2) == 3
1348
- f1b = lambdify(d1, d1 + 1)
1349
- assert f1b(2) == 3
1350
- d2 = Dummy('x')
1351
- f2 = lambdify(d2, d2 + 1)
1352
- assert f2(2) == 3
1353
- f3 = lambdify([[d2]], d2 + 1)
1354
- assert f3([2]) == 3
1355
-
1356
-
1357
- def test_lambdify_mixed_symbol_dummy_args():
1358
- d = Dummy()
1359
- # Contrived example of name clash
1360
- dsym = symbols(str(d))
1361
- f = lambdify([d, dsym], d - dsym)
1362
- assert f(4, 1) == 3
1363
-
1364
-
1365
- def test_numpy_array_arg():
1366
- # Test for issue 14655 (numpy part)
1367
- if not numpy:
1368
- skip("numpy not installed")
1369
-
1370
- f = lambdify([[x, y]], x*x + y, 'numpy')
1371
-
1372
- assert f(numpy.array([2.0, 1.0])) == 5
1373
-
1374
-
1375
- def test_scipy_fns():
1376
- if not scipy:
1377
- skip("scipy not installed")
1378
-
1379
- single_arg_sympy_fns = [Ei, erf, erfc, factorial, gamma, loggamma, digamma, Si, Ci]
1380
- single_arg_scipy_fns = [scipy.special.expi, scipy.special.erf, scipy.special.erfc,
1381
- scipy.special.factorial, scipy.special.gamma, scipy.special.gammaln,
1382
- scipy.special.psi, scipy.special.sici, scipy.special.sici]
1383
- numpy.random.seed(0)
1384
- for (sympy_fn, scipy_fn) in zip(single_arg_sympy_fns, single_arg_scipy_fns):
1385
- f = lambdify(x, sympy_fn(x), modules="scipy")
1386
- for i in range(20):
1387
- tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
1388
- # SciPy thinks that factorial(z) is 0 when re(z) < 0 and
1389
- # does not support complex numbers.
1390
- # SymPy does not think so.
1391
- if sympy_fn == factorial:
1392
- tv = numpy.abs(tv)
1393
- # SciPy supports gammaln for real arguments only,
1394
- # and there is also a branch cut along the negative real axis
1395
- if sympy_fn == loggamma:
1396
- tv = numpy.abs(tv)
1397
- # SymPy's digamma evaluates as polygamma(0, z)
1398
- # which SciPy supports for real arguments only
1399
- if sympy_fn == digamma:
1400
- tv = numpy.real(tv)
1401
- sympy_result = sympy_fn(tv).evalf()
1402
- scipy_result = scipy_fn(tv)
1403
- # SciPy's sici returns a tuple with both Si and Ci present in it
1404
- # which needs to be unpacked
1405
- if sympy_fn == Si:
1406
- scipy_result = scipy_fn(tv)[0]
1407
- if sympy_fn == Ci:
1408
- scipy_result = scipy_fn(tv)[1]
1409
- assert abs(f(tv) - sympy_result) < 1e-13*(1 + abs(sympy_result))
1410
- assert abs(f(tv) - scipy_result) < 1e-13*(1 + abs(sympy_result))
1411
-
1412
- double_arg_sympy_fns = [RisingFactorial, besselj, bessely, besseli,
1413
- besselk, polygamma]
1414
- double_arg_scipy_fns = [scipy.special.poch, scipy.special.jv,
1415
- scipy.special.yv, scipy.special.iv, scipy.special.kv, scipy.special.polygamma]
1416
- for (sympy_fn, scipy_fn) in zip(double_arg_sympy_fns, double_arg_scipy_fns):
1417
- f = lambdify((x, y), sympy_fn(x, y), modules="scipy")
1418
- for i in range(20):
1419
- # SciPy supports only real orders of Bessel functions
1420
- tv1 = numpy.random.uniform(-10, 10)
1421
- tv2 = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
1422
- # SciPy requires a real valued 2nd argument for: poch, polygamma
1423
- if sympy_fn in (RisingFactorial, polygamma):
1424
- tv2 = numpy.real(tv2)
1425
- if sympy_fn == polygamma:
1426
- tv1 = abs(int(tv1)) # first argument to polygamma must be a non-negative integer.
1427
- sympy_result = sympy_fn(tv1, tv2).evalf()
1428
- assert abs(f(tv1, tv2) - sympy_result) < 1e-13*(1 + abs(sympy_result))
1429
- assert abs(f(tv1, tv2) - scipy_fn(tv1, tv2)) < 1e-13*(1 + abs(sympy_result))
1430
-
1431
-
1432
- def test_scipy_polys():
1433
- if not scipy:
1434
- skip("scipy not installed")
1435
- numpy.random.seed(0)
1436
-
1437
- params = symbols('n k a b')
1438
- # list polynomials with the number of parameters
1439
- polys = [
1440
- (chebyshevt, 1),
1441
- (chebyshevu, 1),
1442
- (legendre, 1),
1443
- (hermite, 1),
1444
- (laguerre, 1),
1445
- (gegenbauer, 2),
1446
- (assoc_legendre, 2),
1447
- (assoc_laguerre, 2),
1448
- (jacobi, 3)
1449
- ]
1450
-
1451
- msg = \
1452
- "The random test of the function {func} with the arguments " \
1453
- "{args} had failed because the SymPy result {sympy_result} " \
1454
- "and SciPy result {scipy_result} had failed to converge " \
1455
- "within the tolerance {tol} " \
1456
- "(Actual absolute difference : {diff})"
1457
-
1458
- for sympy_fn, num_params in polys:
1459
- args = params[:num_params] + (x,)
1460
- f = lambdify(args, sympy_fn(*args))
1461
- for _ in range(10):
1462
- tn = numpy.random.randint(3, 10)
1463
- tparams = tuple(numpy.random.uniform(0, 5, size=num_params-1))
1464
- tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
1465
- # SciPy supports hermite for real arguments only
1466
- if sympy_fn == hermite:
1467
- tv = numpy.real(tv)
1468
- # assoc_legendre needs x in (-1, 1) and integer param at most n
1469
- if sympy_fn == assoc_legendre:
1470
- tv = numpy.random.uniform(-1, 1)
1471
- tparams = tuple(numpy.random.randint(1, tn, size=1))
1472
-
1473
- vals = (tn,) + tparams + (tv,)
1474
- scipy_result = f(*vals)
1475
- sympy_result = sympy_fn(*vals).evalf()
1476
- atol = 1e-9*(1 + abs(sympy_result))
1477
- diff = abs(scipy_result - sympy_result)
1478
- try:
1479
- assert diff < atol
1480
- except TypeError:
1481
- raise AssertionError(
1482
- msg.format(
1483
- func=repr(sympy_fn),
1484
- args=repr(vals),
1485
- sympy_result=repr(sympy_result),
1486
- scipy_result=repr(scipy_result),
1487
- diff=diff,
1488
- tol=atol)
1489
- )
1490
-
1491
-
1492
- def test_lambdify_inspect():
1493
- f = lambdify(x, x**2)
1494
- # Test that inspect.getsource works but don't hard-code implementation
1495
- # details
1496
- assert 'x**2' in inspect.getsource(f)
1497
-
1498
-
1499
- def test_issue_14941():
1500
- x, y = Dummy(), Dummy()
1501
-
1502
- # test dict
1503
- f1 = lambdify([x, y], {x: 3, y: 3}, 'sympy')
1504
- assert f1(2, 3) == {2: 3, 3: 3}
1505
-
1506
- # test tuple
1507
- f2 = lambdify([x, y], (y, x), 'sympy')
1508
- assert f2(2, 3) == (3, 2)
1509
- f2b = lambdify([], (1,)) # gh-23224
1510
- assert f2b() == (1,)
1511
-
1512
- # test list
1513
- f3 = lambdify([x, y], [y, x], 'sympy')
1514
- assert f3(2, 3) == [3, 2]
1515
-
1516
-
1517
- def test_lambdify_Derivative_arg_issue_16468():
1518
- f = Function('f')(x)
1519
- fx = f.diff()
1520
- assert lambdify((f, fx), f + fx)(10, 5) == 15
1521
- assert eval(lambdastr((f, fx), f/fx))(10, 5) == 2
1522
- raises(Exception, lambda:
1523
- eval(lambdastr((f, fx), f/fx, dummify=False)))
1524
- assert eval(lambdastr((f, fx), f/fx, dummify=True))(10, 5) == 2
1525
- assert eval(lambdastr((fx, f), f/fx, dummify=True))(S(10), 5) == S.Half
1526
- assert lambdify(fx, 1 + fx)(41) == 42
1527
- assert eval(lambdastr(fx, 1 + fx, dummify=True))(41) == 42
1528
-
1529
-
1530
- def test_lambdify_Derivative_zeta():
1531
- # This is related to gh-11802 (and to lesser extent gh-26663)
1532
- expr1 = zeta(x).diff(x, evaluate=False)
1533
- f1 = lambdify(x, expr1, modules=['mpmath'])
1534
- ans1 = f1(2)
1535
- ref1 = (zeta(2+1e-8).evalf()-zeta(2).evalf())/1e-8
1536
- assert abs(ans1 - ref1)/abs(ref1) < 1e-7
1537
-
1538
- expr2 = zeta(x**2).diff(x)
1539
- f2 = lambdify(x, expr2, modules=['mpmath'])
1540
- ans2 = f2(2**0.5)
1541
- ref2 = 2*2**0.5*ref1
1542
- assert abs(ans2-ref2)/abs(ref2) < 1e-7
1543
-
1544
-
1545
- def test_lambdify_Derivative_custom_printer():
1546
- func1 = Function('func1')
1547
- func2 = Function('func2')
1548
-
1549
- class MyPrinter(NumPyPrinter):
1550
-
1551
- def _print_Derivative_func1(self, args, seq_orders):
1552
- arg, = args
1553
- order, = seq_orders
1554
- return '42'
1555
-
1556
- expr1 = func1(x).diff(x)
1557
- raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr1))
1558
- f1 = lambdify([x], expr1, printer=MyPrinter)
1559
- assert f1(7) == 42
1560
-
1561
- expr2 = func2(x).diff(x)
1562
- raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr2, printer=MyPrinter))
1563
-
1564
-
1565
- def test_lambdify_derivative_and_functions_as_arguments():
1566
- # see: https://github.com/sympy/sympy/issues/26663#issuecomment-2157179517
1567
- t, a, b = symbols('t, a, b')
1568
- f = Function('f')(t)
1569
- args = f.diff(t, 2), f.diff(t), f, a, b
1570
- expr1 = a*f.diff(t, 2) + b*f.diff(t) + a*b*f + a**2
1571
- num_args = 2.0, 3.0, 4.0, 5.0, 6.0
1572
- ref1 = 5*2 + 6*3 + 5*6*4 + 5**2
1573
-
1574
- expr2 = a*f.diff(t, 2) + b*f.diff(t) - a*b*f + b**2 - a**2
1575
- ref2 = 5*2 + 6*3 - 5*6*4 + 6**2 - 5**2
1576
-
1577
- for dummify, _cse in product([False, None, True], [False, True]):
1578
- func1 = lambdify(args, expr1, cse=_cse, dummify=dummify)
1579
- res1 = func1(*num_args)
1580
- assert abs(res1 - ref1) < 1e-12
1581
-
1582
- func12 = lambdify(args, [expr1, expr2], cse=_cse, dummify=dummify)
1583
- res12 = func12(*num_args)
1584
- assert len(res12) == 2
1585
- assert abs(res12[0] - ref1) < 1e-12
1586
- assert abs(res12[1] - ref2) < 1e-12
1587
-
1588
-
1589
- def test_imag_real():
1590
- f_re = lambdify([z], sympy.re(z))
1591
- val = 3+2j
1592
- assert f_re(val) == val.real
1593
-
1594
- f_im = lambdify([z], sympy.im(z)) # see #15400
1595
- assert f_im(val) == val.imag
1596
-
1597
-
1598
- def test_MatrixSymbol_issue_15578():
1599
- if not numpy:
1600
- skip("numpy not installed")
1601
- A = MatrixSymbol('A', 2, 2)
1602
- A0 = numpy.array([[1, 2], [3, 4]])
1603
- f = lambdify(A, A**(-1))
1604
- assert numpy.allclose(f(A0), numpy.array([[-2., 1.], [1.5, -0.5]]))
1605
- g = lambdify(A, A**3)
1606
- assert numpy.allclose(g(A0), numpy.array([[37, 54], [81, 118]]))
1607
-
1608
-
1609
- def test_issue_15654():
1610
- if not scipy:
1611
- skip("scipy not installed")
1612
- from sympy.abc import n, l, r, Z
1613
- from sympy.physics import hydrogen
1614
- nv, lv, rv, Zv = 1, 0, 3, 1
1615
- sympy_value = hydrogen.R_nl(nv, lv, rv, Zv).evalf()
1616
- f = lambdify((n, l, r, Z), hydrogen.R_nl(n, l, r, Z))
1617
- scipy_value = f(nv, lv, rv, Zv)
1618
- assert abs(sympy_value - scipy_value) < 1e-15
1619
-
1620
-
1621
- def test_issue_15827():
1622
- if not numpy:
1623
- skip("numpy not installed")
1624
- A = MatrixSymbol("A", 3, 3)
1625
- B = MatrixSymbol("B", 2, 3)
1626
- C = MatrixSymbol("C", 3, 4)
1627
- D = MatrixSymbol("D", 4, 5)
1628
- k=symbols("k")
1629
- f = lambdify(A, (2*k)*A)
1630
- g = lambdify(A, (2+k)*A)
1631
- h = lambdify(A, 2*A)
1632
- i = lambdify((B, C, D), 2*B*C*D)
1633
- assert numpy.array_equal(f(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
1634
- numpy.array([[2*k, 4*k, 6*k], [2*k, 4*k, 6*k], [2*k, 4*k, 6*k]], dtype=object))
1635
-
1636
- assert numpy.array_equal(g(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
1637
- numpy.array([[k + 2, 2*k + 4, 3*k + 6], [k + 2, 2*k + 4, 3*k + 6], \
1638
- [k + 2, 2*k + 4, 3*k + 6]], dtype=object))
1639
-
1640
- assert numpy.array_equal(h(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
1641
- numpy.array([[2, 4, 6], [2, 4, 6], [2, 4, 6]]))
1642
-
1643
- assert numpy.array_equal(i(numpy.array([[1, 2, 3], [1, 2, 3]]), numpy.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]), \
1644
- numpy.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])), numpy.array([[ 120, 240, 360, 480, 600], \
1645
- [ 120, 240, 360, 480, 600]]))
1646
-
1647
-
1648
- def test_issue_16930():
1649
- if not scipy:
1650
- skip("scipy not installed")
1651
-
1652
- x = symbols("x")
1653
- f = lambda x: S.GoldenRatio * x**2
1654
- f_ = lambdify(x, f(x), modules='scipy')
1655
- assert f_(1) == scipy.constants.golden_ratio
1656
-
1657
- def test_issue_17898():
1658
- if not scipy:
1659
- skip("scipy not installed")
1660
- x = symbols("x")
1661
- f_ = lambdify([x], sympy.LambertW(x,-1), modules='scipy')
1662
- assert f_(0.1) == mpmath.lambertw(0.1, -1)
1663
-
1664
- def test_issue_13167_21411():
1665
- if not numpy:
1666
- skip("numpy not installed")
1667
- f1 = lambdify(x, sympy.Heaviside(x))
1668
- f2 = lambdify(x, sympy.Heaviside(x, 1))
1669
- res1 = f1([-1, 0, 1])
1670
- res2 = f2([-1, 0, 1])
1671
- assert Abs(res1[0]).n() < 1e-15 # First functionality: only one argument passed
1672
- assert Abs(res1[1] - 1/2).n() < 1e-15
1673
- assert Abs(res1[2] - 1).n() < 1e-15
1674
- assert Abs(res2[0]).n() < 1e-15 # Second functionality: two arguments passed
1675
- assert Abs(res2[1] - 1).n() < 1e-15
1676
- assert Abs(res2[2] - 1).n() < 1e-15
1677
-
1678
- def test_single_e():
1679
- f = lambdify(x, E)
1680
- assert f(23) == exp(1.0)
1681
-
1682
- def test_issue_16536():
1683
- if not scipy:
1684
- skip("scipy not installed")
1685
-
1686
- a = symbols('a')
1687
- f1 = lowergamma(a, x)
1688
- F = lambdify((a, x), f1, modules='scipy')
1689
- assert abs(lowergamma(1, 3) - F(1, 3)) <= 1e-10
1690
-
1691
- f2 = uppergamma(a, x)
1692
- F = lambdify((a, x), f2, modules='scipy')
1693
- assert abs(uppergamma(1, 3) - F(1, 3)) <= 1e-10
1694
-
1695
-
1696
- def test_issue_22726():
1697
- if not numpy:
1698
- skip("numpy not installed")
1699
-
1700
- x1, x2 = symbols('x1 x2')
1701
- f = Max(S.Zero, Min(x1, x2))
1702
- g = derive_by_array(f, (x1, x2))
1703
- G = lambdify((x1, x2), g, modules='numpy')
1704
- point = {x1: 1, x2: 2}
1705
- assert (abs(g.subs(point) - G(*point.values())) <= 1e-10).all()
1706
-
1707
-
1708
- def test_issue_22739():
1709
- if not numpy:
1710
- skip("numpy not installed")
1711
-
1712
- x1, x2 = symbols('x1 x2')
1713
- f = Heaviside(Min(x1, x2))
1714
- F = lambdify((x1, x2), f, modules='numpy')
1715
- point = {x1: 1, x2: 2}
1716
- assert abs(f.subs(point) - F(*point.values())) <= 1e-10
1717
-
1718
-
1719
- def test_issue_22992():
1720
- if not numpy:
1721
- skip("numpy not installed")
1722
-
1723
- a, t = symbols('a t')
1724
- expr = a*(log(cot(t/2)) - cos(t))
1725
- F = lambdify([a, t], expr, 'numpy')
1726
-
1727
- point = {a: 10, t: 2}
1728
-
1729
- assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
1730
-
1731
- # Standard math
1732
- F = lambdify([a, t], expr)
1733
-
1734
- assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
1735
-
1736
-
1737
- def test_issue_19764():
1738
- if not numpy:
1739
- skip("numpy not installed")
1740
-
1741
- expr = Array([x, x**2])
1742
- f = lambdify(x, expr, 'numpy')
1743
-
1744
- assert f(1).__class__ == numpy.ndarray
1745
-
1746
- def test_issue_20070():
1747
- if not numba:
1748
- skip("numba not installed")
1749
-
1750
- f = lambdify(x, sin(x), 'numpy')
1751
- assert numba.jit(f, nopython=True)(1)==0.8414709848078965
1752
-
1753
-
1754
- def test_fresnel_integrals_scipy():
1755
- if not scipy:
1756
- skip("scipy not installed")
1757
-
1758
- f1 = fresnelc(x)
1759
- f2 = fresnels(x)
1760
- F1 = lambdify(x, f1, modules='scipy')
1761
- F2 = lambdify(x, f2, modules='scipy')
1762
-
1763
- assert abs(fresnelc(1.3) - F1(1.3)) <= 1e-10
1764
- assert abs(fresnels(1.3) - F2(1.3)) <= 1e-10
1765
-
1766
-
1767
- def test_beta_scipy():
1768
- if not scipy:
1769
- skip("scipy not installed")
1770
-
1771
- f = beta(x, y)
1772
- F = lambdify((x, y), f, modules='scipy')
1773
-
1774
- assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
1775
-
1776
-
1777
- def test_beta_math():
1778
- f = beta(x, y)
1779
- F = lambdify((x, y), f, modules='math')
1780
-
1781
- assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
1782
-
1783
-
1784
- def test_betainc_scipy():
1785
- if not scipy:
1786
- skip("scipy not installed")
1787
-
1788
- f = betainc(w, x, y, z)
1789
- F = lambdify((w, x, y, z), f, modules='scipy')
1790
-
1791
- assert abs(betainc(1.4, 3.1, 0.1, 0.5) - F(1.4, 3.1, 0.1, 0.5)) <= 1e-10
1792
-
1793
-
1794
- def test_betainc_regularized_scipy():
1795
- if not scipy:
1796
- skip("scipy not installed")
1797
-
1798
- f = betainc_regularized(w, x, y, z)
1799
- F = lambdify((w, x, y, z), f, modules='scipy')
1800
-
1801
- assert abs(betainc_regularized(0.2, 3.5, 0.1, 1) - F(0.2, 3.5, 0.1, 1)) <= 1e-10
1802
-
1803
-
1804
- def test_numpy_special_math():
1805
- if not numpy:
1806
- skip("numpy not installed")
1807
-
1808
- funcs = [expm1, log1p, exp2, log2, log10, hypot, logaddexp, logaddexp2]
1809
- for func in funcs:
1810
- if 2 in func.nargs:
1811
- expr = func(x, y)
1812
- args = (x, y)
1813
- num_args = (0.3, 0.4)
1814
- elif 1 in func.nargs:
1815
- expr = func(x)
1816
- args = (x,)
1817
- num_args = (0.3,)
1818
- else:
1819
- raise NotImplementedError("Need to handle other than unary & binary functions in test")
1820
- f = lambdify(args, expr)
1821
- result = f(*num_args)
1822
- reference = expr.subs(dict(zip(args, num_args))).evalf()
1823
- assert numpy.allclose(result, float(reference))
1824
-
1825
- lae2 = lambdify((x, y), logaddexp2(log2(x), log2(y)))
1826
- assert abs(2.0**lae2(1e-50, 2.5e-50) - 3.5e-50) < 1e-62 # from NumPy's docstring
1827
-
1828
-
1829
- def test_scipy_special_math():
1830
- if not scipy:
1831
- skip("scipy not installed")
1832
-
1833
- cm1 = lambdify((x,), cosm1(x), modules='scipy')
1834
- assert abs(cm1(1e-20) + 5e-41) < 1e-200
1835
-
1836
- have_scipy_1_10plus = tuple(map(int, scipy.version.version.split('.')[:2])) >= (1, 10)
1837
-
1838
- if have_scipy_1_10plus:
1839
- cm2 = lambdify((x, y), powm1(x, y), modules='scipy')
1840
- assert abs(cm2(1.2, 1e-9) - 1.82321557e-10) < 1e-17
1841
-
1842
-
1843
- def test_scipy_bernoulli():
1844
- if not scipy:
1845
- skip("scipy not installed")
1846
-
1847
- bern = lambdify((x,), bernoulli(x), modules='scipy')
1848
- assert bern(1) == 0.5
1849
-
1850
-
1851
- def test_scipy_harmonic():
1852
- if not scipy:
1853
- skip("scipy not installed")
1854
-
1855
- hn = lambdify((x,), harmonic(x), modules='scipy')
1856
- assert hn(2) == 1.5
1857
- hnm = lambdify((x, y), harmonic(x, y), modules='scipy')
1858
- assert hnm(2, 2) == 1.25
1859
-
1860
-
1861
- def test_cupy_array_arg():
1862
- if not cupy:
1863
- skip("CuPy not installed")
1864
-
1865
- f = lambdify([[x, y]], x*x + y, 'cupy')
1866
- result = f(cupy.array([2.0, 1.0]))
1867
- assert result == 5
1868
- assert "cupy" in str(type(result))
1869
-
1870
-
1871
- def test_cupy_array_arg_using_numpy():
1872
- # numpy functions can be run on cupy arrays
1873
- # unclear if we can "officially" support this,
1874
- # depends on numpy __array_function__ support
1875
- if not cupy:
1876
- skip("CuPy not installed")
1877
-
1878
- f = lambdify([[x, y]], x*x + y, 'numpy')
1879
- result = f(cupy.array([2.0, 1.0]))
1880
- assert result == 5
1881
- assert "cupy" in str(type(result))
1882
-
1883
- def test_cupy_dotproduct():
1884
- if not cupy:
1885
- skip("CuPy not installed")
1886
-
1887
- A = Matrix([x, y, z])
1888
- f1 = lambdify([x, y, z], DotProduct(A, A), modules='cupy')
1889
- f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
1890
- f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='cupy')
1891
- f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
1892
-
1893
- assert f1(1, 2, 3) == \
1894
- f2(1, 2, 3) == \
1895
- f3(1, 2, 3) == \
1896
- f4(1, 2, 3) == \
1897
- cupy.array([14])
1898
-
1899
-
1900
- def test_jax_array_arg():
1901
- if not jax:
1902
- skip("JAX not installed")
1903
-
1904
- f = lambdify([[x, y]], x*x + y, 'jax')
1905
- result = f(jax.numpy.array([2.0, 1.0]))
1906
- assert result == 5
1907
- assert "jax" in str(type(result))
1908
-
1909
-
1910
- def test_jax_array_arg_using_numpy():
1911
- if not jax:
1912
- skip("JAX not installed")
1913
-
1914
- f = lambdify([[x, y]], x*x + y, 'numpy')
1915
- result = f(jax.numpy.array([2.0, 1.0]))
1916
- assert result == 5
1917
- assert "jax" in str(type(result))
1918
-
1919
-
1920
- def test_jax_dotproduct():
1921
- if not jax:
1922
- skip("JAX not installed")
1923
-
1924
- A = Matrix([x, y, z])
1925
- f1 = lambdify([x, y, z], DotProduct(A, A), modules='jax')
1926
- f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
1927
- f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='jax')
1928
- f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
1929
-
1930
- assert f1(1, 2, 3) == \
1931
- f2(1, 2, 3) == \
1932
- f3(1, 2, 3) == \
1933
- f4(1, 2, 3) == \
1934
- jax.numpy.array([14])
1935
-
1936
-
1937
- def test_lambdify_cse():
1938
- def no_op_cse(exprs):
1939
- return (), exprs
1940
-
1941
- def dummy_cse(exprs):
1942
- from sympy.simplify.cse_main import cse
1943
- return cse(exprs, symbols=numbered_symbols(cls=Dummy))
1944
-
1945
- def minmem(exprs):
1946
- from sympy.simplify.cse_main import cse_release_variables, cse
1947
- return cse(exprs, postprocess=cse_release_variables)
1948
-
1949
- class Case:
1950
- def __init__(self, *, args, exprs, num_args, requires_numpy=False):
1951
- self.args = args
1952
- self.exprs = exprs
1953
- self.num_args = num_args
1954
- subs_dict = dict(zip(self.args, self.num_args))
1955
- self.ref = [e.subs(subs_dict).evalf() for e in exprs]
1956
- self.requires_numpy = requires_numpy
1957
-
1958
- def lambdify(self, *, cse):
1959
- return lambdify(self.args, self.exprs, cse=cse)
1960
-
1961
- def assertAllClose(self, result, *, abstol=1e-15, reltol=1e-15):
1962
- if self.requires_numpy:
1963
- assert all(numpy.allclose(result[i], numpy.asarray(r, dtype=float),
1964
- rtol=reltol, atol=abstol)
1965
- for i, r in enumerate(self.ref))
1966
- return
1967
-
1968
- for i, r in enumerate(self.ref):
1969
- abs_err = abs(result[i] - r)
1970
- if r == 0:
1971
- assert abs_err < abstol
1972
- else:
1973
- assert abs_err/abs(r) < reltol
1974
-
1975
- cases = [
1976
- Case(
1977
- args=(x, y, z),
1978
- exprs=[
1979
- x + y + z,
1980
- x + y - z,
1981
- 2*x + 2*y - z,
1982
- (x+y)**2 + (y+z)**2,
1983
- ],
1984
- num_args=(2., 3., 4.)
1985
- ),
1986
- Case(
1987
- args=(x, y, z),
1988
- exprs=[
1989
- x + sympy.Heaviside(x),
1990
- y + sympy.Heaviside(x),
1991
- z + sympy.Heaviside(x, 1),
1992
- z/sympy.Heaviside(x, 1)
1993
- ],
1994
- num_args=(0., 3., 4.)
1995
- ),
1996
- Case(
1997
- args=(x, y, z),
1998
- exprs=[
1999
- x + sinc(y),
2000
- y + sinc(y),
2001
- z - sinc(y)
2002
- ],
2003
- num_args=(0.1, 0.2, 0.3)
2004
- ),
2005
- Case(
2006
- args=(x, y, z),
2007
- exprs=[
2008
- Matrix([[x, x*y], [sin(z) + 4, x**z]]),
2009
- x*y+sin(z)-x**z,
2010
- Matrix([x*x, sin(z), x**z])
2011
- ],
2012
- num_args=(1.,2.,3.),
2013
- requires_numpy=True
2014
- ),
2015
- Case(
2016
- args=(x, y),
2017
- exprs=[(x + y - 1)**2, x, x + y,
2018
- (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)],
2019
- num_args=(1,2)
2020
- )
2021
- ]
2022
- for case in cases:
2023
- if not numpy and case.requires_numpy:
2024
- continue
2025
- for _cse in [False, True, minmem, no_op_cse, dummy_cse]:
2026
- f = case.lambdify(cse=_cse)
2027
- result = f(*case.num_args)
2028
- case.assertAllClose(result)
2029
-
2030
- def test_issue_25288():
2031
- syms = numbered_symbols(cls=Dummy)
2032
- ok = lambdify(x, [x**2, sin(x**2)], cse=lambda e: cse(e, symbols=syms))(2)
2033
- assert ok
2034
-
2035
-
2036
- def test_deprecated_set():
2037
- with warns_deprecated_sympy():
2038
- lambdify({x, y}, x + y)
2039
-
2040
- def test_issue_13881():
2041
- if not numpy:
2042
- skip("numpy not installed.")
2043
-
2044
- X = MatrixSymbol('X', 3, 1)
2045
-
2046
- f = lambdify(X, X.T*X, 'numpy')
2047
- assert f(numpy.array([1, 2, 3])) == 14
2048
- assert f(numpy.array([3, 2, 1])) == 14
2049
-
2050
- f = lambdify(X, X*X.T, 'numpy')
2051
- assert f(numpy.array([1, 2, 3])) == 14
2052
- assert f(numpy.array([3, 2, 1])) == 14
2053
-
2054
- f = lambdify(X, (X*X.T)*X, 'numpy')
2055
- arr1 = numpy.array([[1], [2], [3]])
2056
- arr2 = numpy.array([[14],[28],[42]])
2057
-
2058
- assert numpy.array_equal(f(arr1), arr2)
2059
-
2060
-
2061
- def test_23536_lambdify_cse_dummy():
2062
-
2063
- f = Function('x')(y)
2064
- g = Function('w')(y)
2065
- expr = z + (f**4 + g**5)*(f**3 + (g*f)**3)
2066
- expr = expr.expand()
2067
- eval_expr = lambdify(((f, g), z), expr, cse=True)
2068
- ans = eval_expr((1.0, 2.0), 3.0) # shouldn't raise NameError
2069
- assert ans == 300.0 # not a list and value is 300
2070
-
2071
-
2072
- class LambdifyDocstringTestCase:
2073
- SIGNATURE = None
2074
- EXPR = None
2075
- SRC = None
2076
-
2077
- def __init__(self, docstring_limit, expected_redacted):
2078
- self.docstring_limit = docstring_limit
2079
- self.expected_redacted = expected_redacted
2080
-
2081
- @property
2082
- def expected_expr(self):
2083
- expr_redacted_msg = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
2084
- return self.EXPR if not self.expected_redacted else expr_redacted_msg
2085
-
2086
- @property
2087
- def expected_src(self):
2088
- src_redacted_msg = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
2089
- return self.SRC if not self.expected_redacted else src_redacted_msg
2090
-
2091
- @property
2092
- def expected_docstring(self):
2093
- expected_docstring = (
2094
- f'Created with lambdify. Signature:\n\n'
2095
- f'func({self.SIGNATURE})\n\n'
2096
- f'Expression:\n\n'
2097
- f'{self.expected_expr}\n\n'
2098
- f'Source code:\n\n'
2099
- f'{self.expected_src}\n\n'
2100
- f'Imported modules:\n\n'
2101
- )
2102
- return expected_docstring
2103
-
2104
- def __len__(self):
2105
- return len(self.expected_docstring)
2106
-
2107
- def __repr__(self):
2108
- return (
2109
- f'{self.__class__.__name__}('
2110
- f'docstring_limit={self.docstring_limit}, '
2111
- f'expected_redacted={self.expected_redacted})'
2112
- )
2113
-
2114
-
2115
- def test_lambdify_docstring_size_limit_simple_symbol():
2116
-
2117
- class SimpleSymbolTestCase(LambdifyDocstringTestCase):
2118
- SIGNATURE = 'x'
2119
- EXPR = 'x'
2120
- SRC = (
2121
- 'def _lambdifygenerated(x):\n'
2122
- ' return x\n'
2123
- )
2124
-
2125
- x = symbols('x')
2126
-
2127
- test_cases = (
2128
- SimpleSymbolTestCase(docstring_limit=None, expected_redacted=False),
2129
- SimpleSymbolTestCase(docstring_limit=100, expected_redacted=False),
2130
- SimpleSymbolTestCase(docstring_limit=1, expected_redacted=False),
2131
- SimpleSymbolTestCase(docstring_limit=0, expected_redacted=True),
2132
- SimpleSymbolTestCase(docstring_limit=-1, expected_redacted=True),
2133
- )
2134
- for test_case in test_cases:
2135
- lambdified_expr = lambdify(
2136
- [x],
2137
- x,
2138
- 'sympy',
2139
- docstring_limit=test_case.docstring_limit,
2140
- )
2141
- assert lambdified_expr.__doc__ == test_case.expected_docstring
2142
-
2143
-
2144
- def test_lambdify_docstring_size_limit_nested_expr():
2145
-
2146
- class ExprListTestCase(LambdifyDocstringTestCase):
2147
- SIGNATURE = 'x, y, z'
2148
- EXPR = (
2149
- '[x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z '
2150
- '+ 3*x*z**2 +...'
2151
- )
2152
- SRC = (
2153
- 'def _lambdifygenerated(x, y, z):\n'
2154
- ' return [x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
2155
- '+ 6*x*y*z + 3*x*z**2 + y**3 + 3*y**2*z + 3*y*z**2 + z**3]\n'
2156
- )
2157
-
2158
- x, y, z = symbols('x, y, z')
2159
- expr = [x, [y], z, ((x + y + z)**3).expand()]
2160
-
2161
- test_cases = (
2162
- ExprListTestCase(docstring_limit=None, expected_redacted=False),
2163
- ExprListTestCase(docstring_limit=200, expected_redacted=False),
2164
- ExprListTestCase(docstring_limit=50, expected_redacted=True),
2165
- ExprListTestCase(docstring_limit=0, expected_redacted=True),
2166
- ExprListTestCase(docstring_limit=-1, expected_redacted=True),
2167
- )
2168
- for test_case in test_cases:
2169
- lambdified_expr = lambdify(
2170
- [x, y, z],
2171
- expr,
2172
- 'sympy',
2173
- docstring_limit=test_case.docstring_limit,
2174
- )
2175
- assert lambdified_expr.__doc__ == test_case.expected_docstring
2176
-
2177
-
2178
- def test_lambdify_docstring_size_limit_matrix():
2179
-
2180
- class MatrixTestCase(LambdifyDocstringTestCase):
2181
- SIGNATURE = 'x, y, z'
2182
- EXPR = (
2183
- 'Matrix([[0, x], [x + y + z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
2184
- '+ 6*x*y*z...'
2185
- )
2186
- SRC = (
2187
- 'def _lambdifygenerated(x, y, z):\n'
2188
- ' return ImmutableDenseMatrix([[0, x], [x + y + z, x**3 '
2189
- '+ 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z + 3*x*z**2 + y**3 '
2190
- '+ 3*y**2*z + 3*y*z**2 + z**3]])\n'
2191
- )
2192
-
2193
- x, y, z = symbols('x, y, z')
2194
- expr = Matrix([[S.Zero, x], [x + y + z, ((x + y + z)**3).expand()]])
2195
-
2196
- test_cases = (
2197
- MatrixTestCase(docstring_limit=None, expected_redacted=False),
2198
- MatrixTestCase(docstring_limit=200, expected_redacted=False),
2199
- MatrixTestCase(docstring_limit=50, expected_redacted=True),
2200
- MatrixTestCase(docstring_limit=0, expected_redacted=True),
2201
- MatrixTestCase(docstring_limit=-1, expected_redacted=True),
2202
- )
2203
- for test_case in test_cases:
2204
- lambdified_expr = lambdify(
2205
- [x, y, z],
2206
- expr,
2207
- 'sympy',
2208
- docstring_limit=test_case.docstring_limit,
2209
- )
2210
- assert lambdified_expr.__doc__ == test_case.expected_docstring
2211
-
2212
-
2213
- def test_lambdify_empty_tuple():
2214
- a = symbols("a")
2215
- expr = ((), (a,))
2216
- f = lambdify(a, expr)
2217
- result = f(1)
2218
- assert result == ((), (1,)), "Lambdify did not handle the empty tuple correctly."
2219
-
2220
-
2221
- def test_assoc_legendre_numerical_evaluation():
2222
-
2223
- tol = 1e-10
2224
-
2225
- sympy_result_integer = assoc_legendre(1, 1/2, 0.1).evalf()
2226
- sympy_result_complex = assoc_legendre(2, 1, 3).evalf()
2227
- mpmath_result_integer = -0.474572528387641
2228
- mpmath_result_complex = -25.45584412271571*I
2229
-
2230
- assert all_close(sympy_result_integer, mpmath_result_integer, tol)
2231
- assert all_close(sympy_result_complex, mpmath_result_complex, tol)
2232
-
2233
-
2234
- def test_Piecewise():
2235
-
2236
- modules = [math]
2237
- if numpy:
2238
- modules.append('numpy')
2239
-
2240
- for mod in modules:
2241
- # test isinf
2242
- f = lambdify(x, Piecewise((7.0, isinf(x)), (3.0, True)), mod)
2243
- assert f(+float('inf')) == +7.0
2244
- assert f(-float('inf')) == +7.0
2245
- assert f(42.) == 3.0
2246
-
2247
- f2 = lambdify(x, Piecewise((7.0*sign(x), isinf(x)), (3.0, True)), mod)
2248
- assert f2(+float('inf')) == +7.0
2249
- assert f2(-float('inf')) == -7.0
2250
- assert f2(42.) == 3.0
2251
-
2252
- # test isnan (gh-26784)
2253
- g = lambdify(x, Piecewise((7.0, isnan(x)), (3.0, True)), mod)
2254
- assert g(float('nan')) == 7.0
2255
- assert g(42.) == 3.0
2256
-
2257
-
2258
- def test_array_symbol():
2259
- if not numpy:
2260
- skip("numpy not installed.")
2261
- a = ArraySymbol('a', (3,))
2262
- f = lambdify((a), a)
2263
- assert numpy.all(f(numpy.array([1,2,3])) == numpy.array([1,2,3]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_matchpy_connector.py DELETED
@@ -1,164 +0,0 @@
1
- import pickle
2
-
3
- from sympy.core.relational import (Eq, Ne)
4
- from sympy.core.singleton import S
5
- from sympy.core.symbol import symbols
6
- from sympy.functions.elementary.miscellaneous import sqrt
7
- from sympy.functions.elementary.trigonometric import (cos, sin)
8
- from sympy.external import import_module
9
- from sympy.testing.pytest import skip
10
- from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer
11
-
12
- matchpy = import_module("matchpy")
13
-
14
- x, y, z = symbols("x y z")
15
-
16
-
17
- def _get_first_match(expr, pattern):
18
- from matchpy import ManyToOneMatcher, Pattern
19
-
20
- matcher = ManyToOneMatcher()
21
- matcher.add(Pattern(pattern))
22
- return next(iter(matcher.match(expr)))
23
-
24
-
25
- def test_matchpy_connector():
26
- if matchpy is None:
27
- skip("matchpy not installed")
28
-
29
- from multiset import Multiset
30
- from matchpy import Pattern, Substitution
31
-
32
- w_ = WildDot("w_")
33
- w__ = WildPlus("w__")
34
- w___ = WildStar("w___")
35
-
36
- expr = x + y
37
- pattern = x + w_
38
- p, subst = _get_first_match(expr, pattern)
39
- assert p == Pattern(pattern)
40
- assert subst == Substitution({'w_': y})
41
-
42
- expr = x + y + z
43
- pattern = x + w__
44
- p, subst = _get_first_match(expr, pattern)
45
- assert p == Pattern(pattern)
46
- assert subst == Substitution({'w__': Multiset([y, z])})
47
-
48
- expr = x + y + z
49
- pattern = x + y + z + w___
50
- p, subst = _get_first_match(expr, pattern)
51
- assert p == Pattern(pattern)
52
- assert subst == Substitution({'w___': Multiset()})
53
-
54
-
55
- def test_matchpy_optional():
56
- if matchpy is None:
57
- skip("matchpy not installed")
58
-
59
- from matchpy import Pattern, Substitution
60
- from matchpy import ManyToOneReplacer, ReplacementRule
61
-
62
- p = WildDot("p", optional=1)
63
- q = WildDot("q", optional=0)
64
-
65
- pattern = p*x + q
66
-
67
- expr1 = 2*x
68
- pa, subst = _get_first_match(expr1, pattern)
69
- assert pa == Pattern(pattern)
70
- assert subst == Substitution({'p': 2, 'q': 0})
71
-
72
- expr2 = x + 3
73
- pa, subst = _get_first_match(expr2, pattern)
74
- assert pa == Pattern(pattern)
75
- assert subst == Substitution({'p': 1, 'q': 3})
76
-
77
- expr3 = x
78
- pa, subst = _get_first_match(expr3, pattern)
79
- assert pa == Pattern(pattern)
80
- assert subst == Substitution({'p': 1, 'q': 0})
81
-
82
- expr4 = x*y + z
83
- pa, subst = _get_first_match(expr4, pattern)
84
- assert pa == Pattern(pattern)
85
- assert subst == Substitution({'p': y, 'q': z})
86
-
87
- replacer = ManyToOneReplacer()
88
- replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q)))
89
- assert replacer.replace(expr1) == sin(2)*cos(0)
90
- assert replacer.replace(expr2) == sin(1)*cos(3)
91
- assert replacer.replace(expr3) == sin(1)*cos(0)
92
- assert replacer.replace(expr4) == sin(y)*cos(z)
93
-
94
-
95
- def test_replacer():
96
- if matchpy is None:
97
- skip("matchpy not installed")
98
-
99
- for info in [True, False]:
100
- for lambdify in [True, False]:
101
- _perform_test_replacer(info, lambdify)
102
-
103
-
104
- def _perform_test_replacer(info, lambdify):
105
-
106
- x1_ = WildDot("x1_")
107
- x2_ = WildDot("x2_")
108
-
109
- a_ = WildDot("a_", optional=S.One)
110
- b_ = WildDot("b_", optional=S.One)
111
- c_ = WildDot("c_", optional=S.Zero)
112
-
113
- replacer = Replacer(common_constraints=[
114
- matchpy.CustomConstraint(lambda a_: not a_.has(x)),
115
- matchpy.CustomConstraint(lambda b_: not b_.has(x)),
116
- matchpy.CustomConstraint(lambda c_: not c_.has(x)),
117
- ], lambdify=lambdify, info=info)
118
-
119
- # Rewrite the equation into implicit form, unless it's already solved:
120
- replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)], info=1)
121
-
122
- # Simple equation solver for real numbers:
123
- replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_), info=2)
124
- disc = b_**2 - 4*a_*c_
125
- replacer.add(
126
- Eq(a_*x**2 + b_*x + c_, 0),
127
- Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)),
128
- conditions_nonfalse=[disc >= 0],
129
- info=3
130
- )
131
- replacer.add(
132
- Eq(a_*x**2 + c_, 0),
133
- Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)),
134
- conditions_nonfalse=[-c_*a_ > 0],
135
- info=4
136
- )
137
-
138
- g = lambda expr, infos: (expr, infos) if info else expr
139
-
140
- assert replacer.replace(Eq(3*x, y)) == g(Eq(x, y/3), [1, 2])
141
- assert replacer.replace(Eq(x**2 + 1, 0)) == g(Eq(x**2 + 1, 0), [])
142
- assert replacer.replace(Eq(x**2, 4)) == g((Eq(x, 2) | Eq(x, -2)), [1, 4])
143
- assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == g(Eq(x, -2*y), [3])
144
-
145
-
146
- def test_matchpy_object_pickle():
147
- if matchpy is None:
148
- return
149
-
150
- a1 = WildDot("a")
151
- a2 = pickle.loads(pickle.dumps(a1))
152
- assert a1 == a2
153
-
154
- a1 = WildDot("a", S(1))
155
- a2 = pickle.loads(pickle.dumps(a1))
156
- assert a1 == a2
157
-
158
- a1 = WildPlus("a", S(1))
159
- a2 = pickle.loads(pickle.dumps(a1))
160
- assert a1 == a2
161
-
162
- a1 = WildStar("a", S(1))
163
- a2 = pickle.loads(pickle.dumps(a1))
164
- assert a1 == a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_mathml.py DELETED
@@ -1,33 +0,0 @@
1
- import os
2
- from textwrap import dedent
3
- from sympy.external import import_module
4
- from sympy.testing.pytest import skip
5
- from sympy.utilities.mathml import apply_xsl
6
-
7
-
8
-
9
- lxml = import_module('lxml')
10
-
11
- path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_xxe.py"))
12
-
13
-
14
- def test_xxe():
15
- assert os.path.isfile(path)
16
- if not lxml:
17
- skip("lxml not installed.")
18
-
19
- mml = dedent(
20
- rf"""
21
- <!--?xml version="1.0" ?-->
22
- <!DOCTYPE replace [<!ENTITY ent SYSTEM "file://{path}"> ]>
23
- <userInfo>
24
- <firstName>John</firstName>
25
- <lastName>&ent;</lastName>
26
- </userInfo>
27
- """
28
- )
29
- xsl = 'mathml/data/simple_mmlctop.xsl'
30
-
31
- res = apply_xsl(mml, xsl)
32
- assert res == \
33
- '<?xml version="1.0"?>\n<userInfo>\n<firstName>John</firstName>\n<lastName/>\n</userInfo>\n'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_misc.py DELETED
@@ -1,148 +0,0 @@
1
- from textwrap import dedent
2
- import sys
3
- from subprocess import Popen, PIPE
4
- import os
5
-
6
- from sympy.core.singleton import S
7
- from sympy.testing.pytest import (raises, warns_deprecated_sympy,
8
- skip_under_pyodide)
9
- from sympy.utilities.misc import (translate, replace, ordinal, rawlines,
10
- strlines, as_int, find_executable)
11
-
12
-
13
- def test_translate():
14
- abc = 'abc'
15
- assert translate(abc, None, 'a') == 'bc'
16
- assert translate(abc, None, '') == 'abc'
17
- assert translate(abc, {'a': 'x'}, 'c') == 'xb'
18
- assert translate(abc, {'a': 'bc'}, 'c') == 'bcb'
19
- assert translate(abc, {'ab': 'x'}, 'c') == 'x'
20
- assert translate(abc, {'ab': ''}, 'c') == ''
21
- assert translate(abc, {'bc': 'x'}, 'c') == 'ab'
22
- assert translate(abc, {'abc': 'x', 'a': 'y'}) == 'x'
23
- u = chr(4096)
24
- assert translate(abc, 'a', 'x', u) == 'xbc'
25
- assert (u in translate(abc, 'a', u, u)) is True
26
-
27
-
28
- def test_replace():
29
- assert replace('abc', ('a', 'b')) == 'bbc'
30
- assert replace('abc', {'a': 'Aa'}) == 'Aabc'
31
- assert replace('abc', ('a', 'b'), ('c', 'C')) == 'bbC'
32
-
33
-
34
- def test_ordinal():
35
- assert ordinal(-1) == '-1st'
36
- assert ordinal(0) == '0th'
37
- assert ordinal(1) == '1st'
38
- assert ordinal(2) == '2nd'
39
- assert ordinal(3) == '3rd'
40
- assert all(ordinal(i).endswith('th') for i in range(4, 21))
41
- assert ordinal(100) == '100th'
42
- assert ordinal(101) == '101st'
43
- assert ordinal(102) == '102nd'
44
- assert ordinal(103) == '103rd'
45
- assert ordinal(104) == '104th'
46
- assert ordinal(200) == '200th'
47
- assert all(ordinal(i) == str(i) + 'th' for i in range(-220, -203))
48
-
49
-
50
- def test_rawlines():
51
- assert rawlines('a a\na') == "dedent('''\\\n a a\n a''')"
52
- assert rawlines('a a') == "'a a'"
53
- assert rawlines(strlines('\\le"ft')) == (
54
- '(\n'
55
- " '(\\n'\n"
56
- ' \'r\\\'\\\\le"ft\\\'\\n\'\n'
57
- " ')'\n"
58
- ')')
59
-
60
-
61
- def test_strlines():
62
- q = 'this quote (") is in the middle'
63
- # the following assert rhs was prepared with
64
- # print(rawlines(strlines(q, 10)))
65
- assert strlines(q, 10) == dedent('''\
66
- (
67
- 'this quo'
68
- 'te (") i'
69
- 's in the'
70
- ' middle'
71
- )''')
72
- assert q == (
73
- 'this quo'
74
- 'te (") i'
75
- 's in the'
76
- ' middle'
77
- )
78
- q = "this quote (') is in the middle"
79
- assert strlines(q, 20) == dedent('''\
80
- (
81
- "this quote (') is "
82
- "in the middle"
83
- )''')
84
- assert strlines('\\left') == (
85
- '(\n'
86
- "r'\\left'\n"
87
- ')')
88
- assert strlines('\\left', short=True) == r"r'\left'"
89
- assert strlines('\\le"ft') == (
90
- '(\n'
91
- 'r\'\\le"ft\'\n'
92
- ')')
93
- q = 'this\nother line'
94
- assert strlines(q) == rawlines(q)
95
-
96
-
97
- def test_translate_args():
98
- try:
99
- translate(None, None, None, 'not_none')
100
- except ValueError:
101
- pass # Exception raised successfully
102
- else:
103
- assert False
104
-
105
- assert translate('s', None, None, None) == 's'
106
-
107
- try:
108
- translate('s', 'a', 'bc')
109
- except ValueError:
110
- pass # Exception raised successfully
111
- else:
112
- assert False
113
-
114
-
115
- @skip_under_pyodide("Cannot create subprocess under pyodide.")
116
- def test_debug_output():
117
- env = os.environ.copy()
118
- env['SYMPY_DEBUG'] = 'True'
119
- cmd = 'from sympy import *; x = Symbol("x"); print(integrate((1-cos(x))/x, x))'
120
- cmdline = [sys.executable, '-c', cmd]
121
- proc = Popen(cmdline, env=env, stdout=PIPE, stderr=PIPE)
122
- out, err = proc.communicate()
123
- out = out.decode('ascii') # utf-8?
124
- err = err.decode('ascii')
125
- expected = 'substituted: -x*(1 - cos(x)), u: 1/x, u_var: _u'
126
- assert expected in err, err
127
-
128
-
129
- def test_as_int():
130
- raises(ValueError, lambda : as_int(True))
131
- raises(ValueError, lambda : as_int(1.1))
132
- raises(ValueError, lambda : as_int([]))
133
- raises(ValueError, lambda : as_int(S.NaN))
134
- raises(ValueError, lambda : as_int(S.Infinity))
135
- raises(ValueError, lambda : as_int(S.NegativeInfinity))
136
- raises(ValueError, lambda : as_int(S.ComplexInfinity))
137
- # for the following, limited precision makes int(arg) == arg
138
- # but the int value is not necessarily what a user might have
139
- # expected; Q.prime is more nuanced in its response for
140
- # expressions which might be complex representations of an
141
- # integer. This is not -- by design -- as_ints role.
142
- raises(ValueError, lambda : as_int(1e23))
143
- raises(ValueError, lambda : as_int(S('1.'+'0'*20+'1')))
144
- assert as_int(True, strict=False) == 1
145
-
146
- def test_deprecated_find_executable():
147
- with warns_deprecated_sympy():
148
- find_executable('python')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_pickling.py DELETED
@@ -1,723 +0,0 @@
1
- import inspect
2
- import copy
3
- import pickle
4
-
5
- from sympy.physics.units import meter
6
-
7
- from sympy.testing.pytest import XFAIL, raises, ignore_warnings
8
-
9
- from sympy.core.basic import Atom, Basic
10
- from sympy.core.singleton import SingletonRegistry
11
- from sympy.core.symbol import Str, Dummy, Symbol, Wild
12
- from sympy.core.numbers import (E, I, pi, oo, zoo, nan, Integer,
13
- Rational, Float, AlgebraicNumber)
14
- from sympy.core.relational import (Equality, GreaterThan, LessThan, Relational,
15
- StrictGreaterThan, StrictLessThan, Unequality)
16
- from sympy.core.add import Add
17
- from sympy.core.mul import Mul
18
- from sympy.core.power import Pow
19
- from sympy.core.function import Derivative, Function, FunctionClass, Lambda, \
20
- WildFunction
21
- from sympy.sets.sets import Interval
22
- from sympy.core.multidimensional import vectorize
23
-
24
- from sympy.external.gmpy import gmpy as _gmpy
25
- from sympy.utilities.exceptions import SymPyDeprecationWarning
26
-
27
- from sympy.core.singleton import S
28
- from sympy.core.symbol import symbols
29
-
30
- from sympy.external import import_module
31
- cloudpickle = import_module('cloudpickle')
32
-
33
-
34
- not_equal_attrs = {
35
- '_assumptions', # This is a local cache that isn't automatically filled on creation
36
- '_mhash', # Cached after __hash__ is called but set to None after creation
37
- }
38
-
39
-
40
- deprecated_attrs = {
41
- 'is_EmptySet', # Deprecated from SymPy 1.5. This can be removed when is_EmptySet is removed.
42
- 'expr_free_symbols', # Deprecated from SymPy 1.9. This can be removed when exr_free_symbols is removed.
43
- }
44
-
45
- dont_check_attrs = {
46
- '_sage_', # Fails because Sage is not installed
47
- }
48
-
49
-
50
- def check(a, exclude=[], check_attr=True, deprecated=()):
51
- """ Check that pickling and copying round-trips.
52
- """
53
- # Pickling with protocols 0 and 1 is disabled for Basic instances:
54
- if isinstance(a, Basic):
55
- for protocol in [0, 1]:
56
- raises(NotImplementedError, lambda: pickle.dumps(a, protocol))
57
-
58
- protocols = [2, copy.copy, copy.deepcopy, 3, 4]
59
- if cloudpickle:
60
- protocols.extend([cloudpickle])
61
-
62
- for protocol in protocols:
63
- if protocol in exclude:
64
- continue
65
-
66
- if callable(protocol):
67
- if isinstance(a, type):
68
- # Classes can't be copied, but that's okay.
69
- continue
70
- b = protocol(a)
71
- elif inspect.ismodule(protocol):
72
- b = protocol.loads(protocol.dumps(a))
73
- else:
74
- b = pickle.loads(pickle.dumps(a, protocol))
75
-
76
- d1 = dir(a)
77
- d2 = dir(b)
78
- assert set(d1) == set(d2)
79
-
80
- if not check_attr:
81
- continue
82
-
83
- def c(a, b, d):
84
- for i in d:
85
- if i in dont_check_attrs:
86
- continue
87
- elif i in not_equal_attrs:
88
- if hasattr(a, i):
89
- assert hasattr(b, i), i
90
- elif i in deprecated_attrs or i in deprecated:
91
- with ignore_warnings(SymPyDeprecationWarning):
92
- assert getattr(a, i) == getattr(b, i), i
93
- elif not hasattr(a, i):
94
- continue
95
- else:
96
- attr = getattr(a, i)
97
- if not hasattr(attr, "__call__"):
98
- assert hasattr(b, i), i
99
- assert getattr(b, i) == attr, "%s != %s, protocol: %s" % (getattr(b, i), attr, protocol)
100
-
101
- c(a, b, d1)
102
- c(b, a, d2)
103
-
104
-
105
-
106
- #================== core =========================
107
-
108
-
109
- def test_core_basic():
110
- for c in (Atom, Atom(), Basic, Basic(), SingletonRegistry, S):
111
- check(c)
112
-
113
- def test_core_Str():
114
- check(Str('x'))
115
-
116
- def test_core_symbol():
117
- # make the Symbol a unique name that doesn't class with any other
118
- # testing variable in this file since after this test the symbol
119
- # having the same name will be cached as noncommutative
120
- for c in (Dummy, Dummy("x", commutative=False), Symbol,
121
- Symbol("_issue_3130", commutative=False), Wild, Wild("x")):
122
- check(c)
123
-
124
-
125
- def test_core_numbers():
126
- for c in (Integer(2), Rational(2, 3), Float("1.2")):
127
- check(c)
128
- for c in (AlgebraicNumber, AlgebraicNumber(sqrt(3))):
129
- check(c, check_attr=False)
130
-
131
-
132
- def test_core_float_copy():
133
- # See gh-7457
134
- y = Symbol("x") + 1.0
135
- check(y) # does not raise TypeError ("argument is not an mpz")
136
-
137
-
138
- def test_core_relational():
139
- x = Symbol("x")
140
- y = Symbol("y")
141
- for c in (Equality, Equality(x, y), GreaterThan, GreaterThan(x, y),
142
- LessThan, LessThan(x, y), Relational, Relational(x, y),
143
- StrictGreaterThan, StrictGreaterThan(x, y), StrictLessThan,
144
- StrictLessThan(x, y), Unequality, Unequality(x, y)):
145
- check(c)
146
-
147
-
148
- def test_core_add():
149
- x = Symbol("x")
150
- for c in (Add, Add(x, 4)):
151
- check(c)
152
-
153
-
154
- def test_core_mul():
155
- x = Symbol("x")
156
- for c in (Mul, Mul(x, 4)):
157
- check(c)
158
-
159
-
160
- def test_core_power():
161
- x = Symbol("x")
162
- for c in (Pow, Pow(x, 4)):
163
- check(c)
164
-
165
-
166
- def test_core_function():
167
- x = Symbol("x")
168
- for f in (Derivative, Derivative(x), Function, FunctionClass, Lambda,
169
- WildFunction):
170
- check(f)
171
-
172
-
173
- def test_core_undefinedfunctions():
174
- f = Function("f")
175
- check(f)
176
-
177
-
178
- def test_core_appliedundef():
179
- x = Symbol("_long_unique_name_1")
180
- f = Function("_long_unique_name_2")
181
- check(f(x))
182
-
183
-
184
- def test_core_interval():
185
- for c in (Interval, Interval(0, 2)):
186
- check(c)
187
-
188
-
189
- def test_core_multidimensional():
190
- for c in (vectorize, vectorize(0)):
191
- check(c)
192
-
193
-
194
- def test_Singletons():
195
- protocols = [0, 1, 2, 3, 4]
196
- copiers = [copy.copy, copy.deepcopy]
197
- copiers += [lambda x: pickle.loads(pickle.dumps(x, proto))
198
- for proto in protocols]
199
- if cloudpickle:
200
- copiers += [lambda x: cloudpickle.loads(cloudpickle.dumps(x))]
201
-
202
- for obj in (Integer(-1), Integer(0), Integer(1), Rational(1, 2), pi, E, I,
203
- oo, -oo, zoo, nan, S.GoldenRatio, S.TribonacciConstant,
204
- S.EulerGamma, S.Catalan, S.EmptySet, S.IdentityFunction):
205
- for func in copiers:
206
- assert func(obj) is obj
207
-
208
- #================== combinatorics ===================
209
- from sympy.combinatorics.free_groups import FreeGroup
210
-
211
- def test_free_group():
212
- check(FreeGroup("x, y, z"), check_attr=False)
213
-
214
- #================== functions ===================
215
- from sympy.functions import (Piecewise, lowergamma, acosh, chebyshevu,
216
- chebyshevt, ln, chebyshevt_root, legendre, Heaviside, bernoulli, coth,
217
- tanh, assoc_legendre, sign, arg, asin, DiracDelta, re, rf, Abs,
218
- uppergamma, binomial, sinh, cos, cot, acos, acot, gamma, bell,
219
- hermite, harmonic, LambertW, zeta, log, factorial, asinh, acoth, cosh,
220
- dirichlet_eta, Eijk, loggamma, erf, ceiling, im, fibonacci,
221
- tribonacci, conjugate, tan, chebyshevu_root, floor, atanh, sqrt, sin,
222
- atan, ff, lucas, atan2, polygamma, exp)
223
-
224
-
225
- def test_functions():
226
- one_var = (acosh, ln, Heaviside, factorial, bernoulli, coth, tanh,
227
- sign, arg, asin, DiracDelta, re, Abs, sinh, cos, cot, acos, acot,
228
- gamma, bell, harmonic, LambertW, zeta, log, factorial, asinh,
229
- acoth, cosh, dirichlet_eta, loggamma, erf, ceiling, im, fibonacci,
230
- tribonacci, conjugate, tan, floor, atanh, sin, atan, lucas, exp)
231
- two_var = (rf, ff, lowergamma, chebyshevu, chebyshevt, binomial,
232
- atan2, polygamma, hermite, legendre, uppergamma)
233
- x, y, z = symbols("x,y,z")
234
- others = (chebyshevt_root, chebyshevu_root, Eijk(x, y, z),
235
- Piecewise( (0, x < -1), (x**2, x <= 1), (x**3, True)),
236
- assoc_legendre)
237
- for cls in one_var:
238
- check(cls)
239
- c = cls(x)
240
- check(c)
241
- for cls in two_var:
242
- check(cls)
243
- c = cls(x, y)
244
- check(c)
245
- for cls in others:
246
- check(cls)
247
-
248
- #================== geometry ====================
249
- from sympy.geometry.entity import GeometryEntity
250
- from sympy.geometry.point import Point
251
- from sympy.geometry.ellipse import Circle, Ellipse
252
- from sympy.geometry.line import Line, LinearEntity, Ray, Segment
253
- from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle
254
-
255
-
256
- def test_geometry():
257
- p1 = Point(1, 2)
258
- p2 = Point(2, 3)
259
- p3 = Point(0, 0)
260
- p4 = Point(0, 1)
261
- for c in (
262
- GeometryEntity, GeometryEntity(), Point, p1, Circle, Circle(p1, 2),
263
- Ellipse, Ellipse(p1, 3, 4), Line, Line(p1, p2), LinearEntity,
264
- LinearEntity(p1, p2), Ray, Ray(p1, p2), Segment, Segment(p1, p2),
265
- Polygon, Polygon(p1, p2, p3, p4), RegularPolygon,
266
- RegularPolygon(p1, 4, 5), Triangle, Triangle(p1, p2, p3)):
267
- check(c, check_attr=False)
268
-
269
- #================== integrals ====================
270
- from sympy.integrals.integrals import Integral
271
-
272
-
273
- def test_integrals():
274
- x = Symbol("x")
275
- for c in (Integral, Integral(x)):
276
- check(c)
277
-
278
- #==================== logic =====================
279
- from sympy.core.logic import Logic
280
-
281
-
282
- def test_logic():
283
- for c in (Logic, Logic(1)):
284
- check(c)
285
-
286
- #================== matrices ====================
287
- from sympy.matrices import Matrix, SparseMatrix
288
-
289
-
290
- def test_matrices():
291
- for c in (Matrix, Matrix([1, 2, 3]), SparseMatrix, SparseMatrix([[1, 2], [3, 4]])):
292
- check(c, deprecated=['_smat', '_mat'])
293
-
294
- #================== ntheory =====================
295
- from sympy.ntheory.generate import Sieve
296
-
297
-
298
- def test_ntheory():
299
- for c in (Sieve, Sieve()):
300
- check(c)
301
-
302
- #================== physics =====================
303
- from sympy.physics.paulialgebra import Pauli
304
- from sympy.physics.units import Unit
305
-
306
-
307
- def test_physics():
308
- for c in (Unit, meter, Pauli, Pauli(1)):
309
- check(c)
310
-
311
- #================== plotting ====================
312
- # XXX: These tests are not complete, so XFAIL them
313
-
314
-
315
- @XFAIL
316
- def test_plotting():
317
- from sympy.plotting.pygletplot.color_scheme import ColorGradient, ColorScheme
318
- from sympy.plotting.pygletplot.managed_window import ManagedWindow
319
- from sympy.plotting.plot import Plot, ScreenShot
320
- from sympy.plotting.pygletplot.plot_axes import PlotAxes, PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
321
- from sympy.plotting.pygletplot.plot_camera import PlotCamera
322
- from sympy.plotting.pygletplot.plot_controller import PlotController
323
- from sympy.plotting.pygletplot.plot_curve import PlotCurve
324
- from sympy.plotting.pygletplot.plot_interval import PlotInterval
325
- from sympy.plotting.pygletplot.plot_mode import PlotMode
326
- from sympy.plotting.pygletplot.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \
327
- ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
328
- from sympy.plotting.pygletplot.plot_object import PlotObject
329
- from sympy.plotting.pygletplot.plot_surface import PlotSurface
330
- from sympy.plotting.pygletplot.plot_window import PlotWindow
331
- for c in (
332
- ColorGradient, ColorGradient(0.2, 0.4), ColorScheme, ManagedWindow,
333
- ManagedWindow, Plot, ScreenShot, PlotAxes, PlotAxesBase,
334
- PlotAxesFrame, PlotAxesOrdinate, PlotCamera, PlotController,
335
- PlotCurve, PlotInterval, PlotMode, Cartesian2D, Cartesian3D,
336
- Cylindrical, ParametricCurve2D, ParametricCurve3D,
337
- ParametricSurface, Polar, Spherical, PlotObject, PlotSurface,
338
- PlotWindow):
339
- check(c)
340
-
341
-
342
- @XFAIL
343
- def test_plotting2():
344
- #from sympy.plotting.color_scheme import ColorGradient
345
- from sympy.plotting.pygletplot.color_scheme import ColorScheme
346
- #from sympy.plotting.managed_window import ManagedWindow
347
- from sympy.plotting.plot import Plot
348
- #from sympy.plotting.plot import ScreenShot
349
- from sympy.plotting.pygletplot.plot_axes import PlotAxes
350
- #from sympy.plotting.plot_axes import PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
351
- #from sympy.plotting.plot_camera import PlotCamera
352
- #from sympy.plotting.plot_controller import PlotController
353
- #from sympy.plotting.plot_curve import PlotCurve
354
- #from sympy.plotting.plot_interval import PlotInterval
355
- #from sympy.plotting.plot_mode import PlotMode
356
- #from sympy.plotting.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \
357
- # ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
358
- #from sympy.plotting.plot_object import PlotObject
359
- #from sympy.plotting.plot_surface import PlotSurface
360
- # from sympy.plotting.plot_window import PlotWindow
361
- check(ColorScheme("rainbow"))
362
- check(Plot(1, visible=False))
363
- check(PlotAxes())
364
-
365
- #================== polys =======================
366
- from sympy.polys.domains.integerring import ZZ
367
- from sympy.polys.domains.rationalfield import QQ
368
- from sympy.polys.orderings import lex
369
- from sympy.polys.polytools import Poly
370
-
371
- def test_pickling_polys_polytools():
372
- from sympy.polys.polytools import PurePoly
373
- # from sympy.polys.polytools import GroebnerBasis
374
- x = Symbol('x')
375
-
376
- for c in (Poly, Poly(x, x)):
377
- check(c)
378
-
379
- for c in (PurePoly, PurePoly(x)):
380
- check(c)
381
-
382
- # TODO: fix pickling of Options class (see GroebnerBasis._options)
383
- # for c in (GroebnerBasis, GroebnerBasis([x**2 - 1], x, order=lex)):
384
- # check(c)
385
-
386
- def test_pickling_polys_polyclasses():
387
- from sympy.polys.polyclasses import DMP, DMF, ANP
388
-
389
- for c in (DMP, DMP([[ZZ(1)], [ZZ(2)], [ZZ(3)]], ZZ)):
390
- check(c, deprecated=['rep'])
391
- for c in (DMF, DMF(([ZZ(1), ZZ(2)], [ZZ(1), ZZ(3)]), ZZ)):
392
- check(c)
393
- for c in (ANP, ANP([QQ(1), QQ(2)], [QQ(1), QQ(2), QQ(3)], QQ)):
394
- check(c)
395
-
396
- @XFAIL
397
- def test_pickling_polys_rings():
398
- # NOTE: can't use protocols < 2 because we have to execute __new__ to
399
- # make sure caching of rings works properly.
400
-
401
- from sympy.polys.rings import PolyRing
402
-
403
- ring = PolyRing("x,y,z", ZZ, lex)
404
-
405
- for c in (PolyRing, ring):
406
- check(c, exclude=[0, 1])
407
-
408
- for c in (ring.dtype, ring.one):
409
- check(c, exclude=[0, 1], check_attr=False) # TODO: Py3k
410
-
411
- def test_pickling_polys_fields():
412
- pass
413
- # NOTE: can't use protocols < 2 because we have to execute __new__ to
414
- # make sure caching of fields works properly.
415
-
416
- # from sympy.polys.fields import FracField
417
-
418
- # field = FracField("x,y,z", ZZ, lex)
419
-
420
- # TODO: AssertionError: assert id(obj) not in self.memo
421
- # for c in (FracField, field):
422
- # check(c, exclude=[0, 1])
423
-
424
- # TODO: AssertionError: assert id(obj) not in self.memo
425
- # for c in (field.dtype, field.one):
426
- # check(c, exclude=[0, 1])
427
-
428
- def test_pickling_polys_elements():
429
- from sympy.polys.domains.pythonrational import PythonRational
430
- #from sympy.polys.domains.pythonfinitefield import PythonFiniteField
431
- #from sympy.polys.domains.mpelements import MPContext
432
-
433
- for c in (PythonRational, PythonRational(1, 7)):
434
- check(c)
435
-
436
- #gf = PythonFiniteField(17)
437
-
438
- # TODO: fix pickling of ModularInteger
439
- # for c in (gf.dtype, gf(5)):
440
- # check(c)
441
-
442
- #mp = MPContext()
443
-
444
- # TODO: fix pickling of RealElement
445
- # for c in (mp.mpf, mp.mpf(1.0)):
446
- # check(c)
447
-
448
- # TODO: fix pickling of ComplexElement
449
- # for c in (mp.mpc, mp.mpc(1.0, -1.5)):
450
- # check(c)
451
-
452
- def test_pickling_polys_domains():
453
- # from sympy.polys.domains.pythonfinitefield import PythonFiniteField
454
- from sympy.polys.domains.pythonintegerring import PythonIntegerRing
455
- from sympy.polys.domains.pythonrationalfield import PythonRationalField
456
-
457
- # TODO: fix pickling of ModularInteger
458
- # for c in (PythonFiniteField, PythonFiniteField(17)):
459
- # check(c)
460
-
461
- for c in (PythonIntegerRing, PythonIntegerRing()):
462
- check(c, check_attr=False)
463
-
464
- for c in (PythonRationalField, PythonRationalField()):
465
- check(c, check_attr=False)
466
-
467
- if _gmpy is not None:
468
- # from sympy.polys.domains.gmpyfinitefield import GMPYFiniteField
469
- from sympy.polys.domains.gmpyintegerring import GMPYIntegerRing
470
- from sympy.polys.domains.gmpyrationalfield import GMPYRationalField
471
-
472
- # TODO: fix pickling of ModularInteger
473
- # for c in (GMPYFiniteField, GMPYFiniteField(17)):
474
- # check(c)
475
-
476
- for c in (GMPYIntegerRing, GMPYIntegerRing()):
477
- check(c, check_attr=False)
478
-
479
- for c in (GMPYRationalField, GMPYRationalField()):
480
- check(c, check_attr=False)
481
-
482
- #from sympy.polys.domains.realfield import RealField
483
- #from sympy.polys.domains.complexfield import ComplexField
484
- from sympy.polys.domains.algebraicfield import AlgebraicField
485
- #from sympy.polys.domains.polynomialring import PolynomialRing
486
- #from sympy.polys.domains.fractionfield import FractionField
487
- from sympy.polys.domains.expressiondomain import ExpressionDomain
488
-
489
- # TODO: fix pickling of RealElement
490
- # for c in (RealField, RealField(100)):
491
- # check(c)
492
-
493
- # TODO: fix pickling of ComplexElement
494
- # for c in (ComplexField, ComplexField(100)):
495
- # check(c)
496
-
497
- for c in (AlgebraicField, AlgebraicField(QQ, sqrt(3))):
498
- check(c, check_attr=False)
499
-
500
- # TODO: AssertionError
501
- # for c in (PolynomialRing, PolynomialRing(ZZ, "x,y,z")):
502
- # check(c)
503
-
504
- # TODO: AttributeError: 'PolyElement' object has no attribute 'ring'
505
- # for c in (FractionField, FractionField(ZZ, "x,y,z")):
506
- # check(c)
507
-
508
- for c in (ExpressionDomain, ExpressionDomain()):
509
- check(c, check_attr=False)
510
-
511
-
512
- def test_pickling_polys_orderings():
513
- from sympy.polys.orderings import (LexOrder, GradedLexOrder,
514
- ReversedGradedLexOrder, InverseOrder)
515
- # from sympy.polys.orderings import ProductOrder
516
-
517
- for c in (LexOrder, LexOrder()):
518
- check(c)
519
-
520
- for c in (GradedLexOrder, GradedLexOrder()):
521
- check(c)
522
-
523
- for c in (ReversedGradedLexOrder, ReversedGradedLexOrder()):
524
- check(c)
525
-
526
- # TODO: Argh, Python is so naive. No lambdas nor inner function support in
527
- # pickling module. Maybe someone could figure out what to do with this.
528
- #
529
- # for c in (ProductOrder, ProductOrder((LexOrder(), lambda m: m[:2]),
530
- # (GradedLexOrder(), lambda m: m[2:]))):
531
- # check(c)
532
-
533
- for c in (InverseOrder, InverseOrder(LexOrder())):
534
- check(c)
535
-
536
- def test_pickling_polys_monomials():
537
- from sympy.polys.monomials import MonomialOps, Monomial
538
- x, y, z = symbols("x,y,z")
539
-
540
- for c in (MonomialOps, MonomialOps(3)):
541
- check(c)
542
-
543
- for c in (Monomial, Monomial((1, 2, 3), (x, y, z))):
544
- check(c)
545
-
546
- def test_pickling_polys_errors():
547
- from sympy.polys.polyerrors import (HeuristicGCDFailed,
548
- HomomorphismFailed, IsomorphismFailed, ExtraneousFactors,
549
- EvaluationFailed, RefinementFailed, CoercionFailed, NotInvertible,
550
- NotReversible, NotAlgebraic, DomainError, PolynomialError,
551
- UnificationFailed, GeneratorsError, GeneratorsNeeded,
552
- UnivariatePolynomialError, MultivariatePolynomialError, OptionError,
553
- FlagError)
554
- # from sympy.polys.polyerrors import (ExactQuotientFailed,
555
- # OperationNotSupported, ComputationFailed, PolificationFailed)
556
-
557
- # x = Symbol('x')
558
-
559
- # TODO: TypeError: __init__() takes at least 3 arguments (1 given)
560
- # for c in (ExactQuotientFailed, ExactQuotientFailed(x, 3*x, ZZ)):
561
- # check(c)
562
-
563
- # TODO: TypeError: can't pickle instancemethod objects
564
- # for c in (OperationNotSupported, OperationNotSupported(Poly(x), Poly.gcd)):
565
- # check(c)
566
-
567
- for c in (HeuristicGCDFailed, HeuristicGCDFailed()):
568
- check(c)
569
-
570
- for c in (HomomorphismFailed, HomomorphismFailed()):
571
- check(c)
572
-
573
- for c in (IsomorphismFailed, IsomorphismFailed()):
574
- check(c)
575
-
576
- for c in (ExtraneousFactors, ExtraneousFactors()):
577
- check(c)
578
-
579
- for c in (EvaluationFailed, EvaluationFailed()):
580
- check(c)
581
-
582
- for c in (RefinementFailed, RefinementFailed()):
583
- check(c)
584
-
585
- for c in (CoercionFailed, CoercionFailed()):
586
- check(c)
587
-
588
- for c in (NotInvertible, NotInvertible()):
589
- check(c)
590
-
591
- for c in (NotReversible, NotReversible()):
592
- check(c)
593
-
594
- for c in (NotAlgebraic, NotAlgebraic()):
595
- check(c)
596
-
597
- for c in (DomainError, DomainError()):
598
- check(c)
599
-
600
- for c in (PolynomialError, PolynomialError()):
601
- check(c)
602
-
603
- for c in (UnificationFailed, UnificationFailed()):
604
- check(c)
605
-
606
- for c in (GeneratorsError, GeneratorsError()):
607
- check(c)
608
-
609
- for c in (GeneratorsNeeded, GeneratorsNeeded()):
610
- check(c)
611
-
612
- # TODO: PicklingError: Can't pickle <function <lambda> at 0x38578c0>: it's not found as __main__.<lambda>
613
- # for c in (ComputationFailed, ComputationFailed(lambda t: t, 3, None)):
614
- # check(c)
615
-
616
- for c in (UnivariatePolynomialError, UnivariatePolynomialError()):
617
- check(c)
618
-
619
- for c in (MultivariatePolynomialError, MultivariatePolynomialError()):
620
- check(c)
621
-
622
- # TODO: TypeError: __init__() takes at least 3 arguments (1 given)
623
- # for c in (PolificationFailed, PolificationFailed({}, x, x, False)):
624
- # check(c)
625
-
626
- for c in (OptionError, OptionError()):
627
- check(c)
628
-
629
- for c in (FlagError, FlagError()):
630
- check(c)
631
-
632
- #def test_pickling_polys_options():
633
- #from sympy.polys.polyoptions import Options
634
-
635
- # TODO: fix pickling of `symbols' flag
636
- # for c in (Options, Options((), dict(domain='ZZ', polys=False))):
637
- # check(c)
638
-
639
- # TODO: def test_pickling_polys_rootisolation():
640
- # RealInterval
641
- # ComplexInterval
642
-
643
- def test_pickling_polys_rootoftools():
644
- from sympy.polys.rootoftools import CRootOf, RootSum
645
-
646
- x = Symbol('x')
647
- f = x**3 + x + 3
648
-
649
- for c in (CRootOf, CRootOf(f, 0)):
650
- check(c)
651
-
652
- for c in (RootSum, RootSum(f, exp)):
653
- check(c)
654
-
655
- #================== printing ====================
656
- from sympy.printing.latex import LatexPrinter
657
- from sympy.printing.mathml import MathMLContentPrinter, MathMLPresentationPrinter
658
- from sympy.printing.pretty.pretty import PrettyPrinter
659
- from sympy.printing.pretty.stringpict import prettyForm, stringPict
660
- from sympy.printing.printer import Printer
661
- from sympy.printing.python import PythonPrinter
662
-
663
-
664
- def test_printing():
665
- for c in (LatexPrinter, LatexPrinter(), MathMLContentPrinter,
666
- MathMLPresentationPrinter, PrettyPrinter, prettyForm, stringPict,
667
- stringPict("a"), Printer, Printer(), PythonPrinter,
668
- PythonPrinter()):
669
- check(c)
670
-
671
-
672
- @XFAIL
673
- def test_printing1():
674
- check(MathMLContentPrinter())
675
-
676
-
677
- @XFAIL
678
- def test_printing2():
679
- check(MathMLPresentationPrinter())
680
-
681
-
682
- @XFAIL
683
- def test_printing3():
684
- check(PrettyPrinter())
685
-
686
- #================== series ======================
687
- from sympy.series.limits import Limit
688
- from sympy.series.order import Order
689
-
690
-
691
- def test_series():
692
- e = Symbol("e")
693
- x = Symbol("x")
694
- for c in (Limit, Limit(e, x, 1), Order, Order(e)):
695
- check(c)
696
-
697
- #================== concrete ==================
698
- from sympy.concrete.products import Product
699
- from sympy.concrete.summations import Sum
700
-
701
-
702
- def test_concrete():
703
- x = Symbol("x")
704
- for c in (Product, Product(x, (x, 2, 4)), Sum, Sum(x, (x, 2, 4))):
705
- check(c)
706
-
707
- def test_deprecation_warning():
708
- w = SymPyDeprecationWarning("message", deprecated_since_version='1.0', active_deprecations_target="active-deprecations")
709
- check(w)
710
-
711
- def test_issue_18438():
712
- assert pickle.loads(pickle.dumps(S.Half)) == S.Half
713
-
714
-
715
- #================= old pickles =================
716
- def test_unpickle_from_older_versions():
717
- data = (
718
- b'\x80\x04\x95^\x00\x00\x00\x00\x00\x00\x00\x8c\x10sympy.core.power'
719
- b'\x94\x8c\x03Pow\x94\x93\x94\x8c\x12sympy.core.numbers\x94\x8c'
720
- b'\x07Integer\x94\x93\x94K\x02\x85\x94R\x94}\x94bh\x03\x8c\x04Half'
721
- b'\x94\x93\x94)R\x94}\x94b\x86\x94R\x94}\x94b.'
722
- )
723
- assert pickle.loads(data) == sqrt(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_source.py DELETED
@@ -1,11 +0,0 @@
1
- from sympy.utilities.source import get_mod_func, get_class
2
-
3
-
4
- def test_get_mod_func():
5
- assert get_mod_func(
6
- 'sympy.core.basic.Basic') == ('sympy.core.basic', 'Basic')
7
-
8
-
9
- def test_get_class():
10
- _basic = get_class('sympy.core.basic.Basic')
11
- assert _basic.__name__ == 'Basic'