chore: remove stray .venv files (1400-1482)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .venv/lib/python3.13/site-packages/sympy/unify/tests/__init__.py +0 -0
- .venv/lib/python3.13/site-packages/sympy/unify/tests/test_rewrite.py +0 -74
- .venv/lib/python3.13/site-packages/sympy/unify/tests/test_sympy.py +0 -162
- .venv/lib/python3.13/site-packages/sympy/unify/tests/test_unify.py +0 -88
- .venv/lib/python3.13/site-packages/sympy/utilities/__init__.py +0 -30
- .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/__init__.py +0 -22
- .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/availability.py +0 -77
- .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/compilation.py +0 -657
- .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/runners.py +0 -301
- .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/__init__.py +0 -0
- .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/test_compilation.py +0 -104
- .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/util.py +0 -312
- .venv/lib/python3.13/site-packages/sympy/utilities/autowrap.py +0 -1178
- .venv/lib/python3.13/site-packages/sympy/utilities/codegen.py +0 -2237
- .venv/lib/python3.13/site-packages/sympy/utilities/decorator.py +0 -339
- .venv/lib/python3.13/site-packages/sympy/utilities/enumerative.py +0 -1155
- .venv/lib/python3.13/site-packages/sympy/utilities/exceptions.py +0 -271
- .venv/lib/python3.13/site-packages/sympy/utilities/iterables.py +0 -3179
- .venv/lib/python3.13/site-packages/sympy/utilities/lambdify.py +0 -1592
- .venv/lib/python3.13/site-packages/sympy/utilities/magic.py +0 -12
- .venv/lib/python3.13/site-packages/sympy/utilities/matchpy_connector.py +0 -340
- .venv/lib/python3.13/site-packages/sympy/utilities/mathml/__init__.py +0 -122
- .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/__init__.py +0 -0
- .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmlctop.xsl +0 -0
- .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmltex.xsl +0 -0
- .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl +0 -0
- .venv/lib/python3.13/site-packages/sympy/utilities/memoization.py +0 -76
- .venv/lib/python3.13/site-packages/sympy/utilities/misc.py +0 -564
- .venv/lib/python3.13/site-packages/sympy/utilities/pkgdata.py +0 -33
- .venv/lib/python3.13/site-packages/sympy/utilities/pytest.py +0 -12
- .venv/lib/python3.13/site-packages/sympy/utilities/randtest.py +0 -12
- .venv/lib/python3.13/site-packages/sympy/utilities/runtests.py +0 -13
- .venv/lib/python3.13/site-packages/sympy/utilities/source.py +0 -40
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/__init__.py +0 -0
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_autowrap.py +0 -467
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen.py +0 -1632
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_julia.py +0 -620
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_octave.py +0 -589
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_rust.py +0 -401
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_decorator.py +0 -129
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_deprecated.py +0 -13
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_enumerative.py +0 -179
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_exceptions.py +0 -12
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_iterables.py +0 -945
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_lambdify.py +0 -2263
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_matchpy_connector.py +0 -164
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_mathml.py +0 -33
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_misc.py +0 -148
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_pickling.py +0 -723
- .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_source.py +0 -11
.venv/lib/python3.13/site-packages/sympy/unify/tests/__init__.py
DELETED
|
File without changes
|
.venv/lib/python3.13/site-packages/sympy/unify/tests/test_rewrite.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
from sympy.unify.rewrite import rewriterule
|
| 2 |
-
from sympy.core.basic import Basic
|
| 3 |
-
from sympy.core.singleton import S
|
| 4 |
-
from sympy.core.symbol import Symbol
|
| 5 |
-
from sympy.functions.elementary.trigonometric import sin
|
| 6 |
-
from sympy.abc import x, y
|
| 7 |
-
from sympy.strategies.rl import rebuild
|
| 8 |
-
from sympy.assumptions import Q
|
| 9 |
-
|
| 10 |
-
p, q = Symbol('p'), Symbol('q')
|
| 11 |
-
|
| 12 |
-
def test_simple():
|
| 13 |
-
rl = rewriterule(Basic(p, S(1)), Basic(p, S(2)), variables=(p,))
|
| 14 |
-
assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
|
| 15 |
-
|
| 16 |
-
p1 = p**2
|
| 17 |
-
p2 = p**3
|
| 18 |
-
rl = rewriterule(p1, p2, variables=(p,))
|
| 19 |
-
|
| 20 |
-
expr = x**2
|
| 21 |
-
assert list(rl(expr)) == [x**3]
|
| 22 |
-
|
| 23 |
-
def test_simple_variables():
|
| 24 |
-
rl = rewriterule(Basic(x, S(1)), Basic(x, S(2)), variables=(x,))
|
| 25 |
-
assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
|
| 26 |
-
|
| 27 |
-
rl = rewriterule(x**2, x**3, variables=(x,))
|
| 28 |
-
assert list(rl(y**2)) == [y**3]
|
| 29 |
-
|
| 30 |
-
def test_moderate():
|
| 31 |
-
p1 = p**2 + q**3
|
| 32 |
-
p2 = (p*q)**4
|
| 33 |
-
rl = rewriterule(p1, p2, (p, q))
|
| 34 |
-
|
| 35 |
-
expr = x**2 + y**3
|
| 36 |
-
assert list(rl(expr)) == [(x*y)**4]
|
| 37 |
-
|
| 38 |
-
def test_sincos():
|
| 39 |
-
p1 = sin(p)**2 + sin(p)**2
|
| 40 |
-
p2 = 1
|
| 41 |
-
rl = rewriterule(p1, p2, (p, q))
|
| 42 |
-
|
| 43 |
-
assert list(rl(sin(x)**2 + sin(x)**2)) == [1]
|
| 44 |
-
assert list(rl(sin(y)**2 + sin(y)**2)) == [1]
|
| 45 |
-
|
| 46 |
-
def test_Exprs_ok():
|
| 47 |
-
rl = rewriterule(p+q, q+p, (p, q))
|
| 48 |
-
next(rl(x+y)).is_commutative
|
| 49 |
-
str(next(rl(x+y)))
|
| 50 |
-
|
| 51 |
-
def test_condition_simple():
|
| 52 |
-
rl = rewriterule(x, x+1, [x], lambda x: x < 10)
|
| 53 |
-
assert not list(rl(S(15)))
|
| 54 |
-
assert rebuild(next(rl(S(5)))) == 6
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def test_condition_multiple():
|
| 58 |
-
rl = rewriterule(x + y, x**y, [x,y], lambda x, y: x.is_integer)
|
| 59 |
-
|
| 60 |
-
a = Symbol('a')
|
| 61 |
-
b = Symbol('b', integer=True)
|
| 62 |
-
expr = a + b
|
| 63 |
-
assert list(rl(expr)) == [b**a]
|
| 64 |
-
|
| 65 |
-
c = Symbol('c', integer=True)
|
| 66 |
-
d = Symbol('d', integer=True)
|
| 67 |
-
assert set(rl(c + d)) == {c**d, d**c}
|
| 68 |
-
|
| 69 |
-
def test_assumptions():
|
| 70 |
-
rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x))
|
| 71 |
-
|
| 72 |
-
a, b = map(Symbol, 'ab')
|
| 73 |
-
expr = a + b
|
| 74 |
-
assert list(rl(expr, Q.integer(b))) == [b**a]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/unify/tests/test_sympy.py
DELETED
|
@@ -1,162 +0,0 @@
|
|
| 1 |
-
from sympy.core.add import Add
|
| 2 |
-
from sympy.core.basic import Basic
|
| 3 |
-
from sympy.core.containers import Tuple
|
| 4 |
-
from sympy.core.singleton import S
|
| 5 |
-
from sympy.core.symbol import (Symbol, symbols)
|
| 6 |
-
from sympy.logic.boolalg import And
|
| 7 |
-
from sympy.core.symbol import Str
|
| 8 |
-
from sympy.unify.core import Compound, Variable
|
| 9 |
-
from sympy.unify.usympy import (deconstruct, construct, unify, is_associative,
|
| 10 |
-
is_commutative)
|
| 11 |
-
from sympy.abc import x, y, z, n
|
| 12 |
-
|
| 13 |
-
def test_deconstruct():
|
| 14 |
-
expr = Basic(S(1), S(2), S(3))
|
| 15 |
-
expected = Compound(Basic, (1, 2, 3))
|
| 16 |
-
assert deconstruct(expr) == expected
|
| 17 |
-
|
| 18 |
-
assert deconstruct(1) == 1
|
| 19 |
-
assert deconstruct(x) == x
|
| 20 |
-
assert deconstruct(x, variables=(x,)) == Variable(x)
|
| 21 |
-
assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x))
|
| 22 |
-
assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \
|
| 23 |
-
Compound(Add, (1, Variable(x)))
|
| 24 |
-
|
| 25 |
-
def test_construct():
|
| 26 |
-
expr = Compound(Basic, (S(1), S(2), S(3)))
|
| 27 |
-
expected = Basic(S(1), S(2), S(3))
|
| 28 |
-
assert construct(expr) == expected
|
| 29 |
-
|
| 30 |
-
def test_nested():
|
| 31 |
-
expr = Basic(S(1), Basic(S(2)), S(3))
|
| 32 |
-
cmpd = Compound(Basic, (S(1), Compound(Basic, Tuple(2)), S(3)))
|
| 33 |
-
assert deconstruct(expr) == cmpd
|
| 34 |
-
assert construct(cmpd) == expr
|
| 35 |
-
|
| 36 |
-
def test_unify():
|
| 37 |
-
expr = Basic(S(1), S(2), S(3))
|
| 38 |
-
a, b, c = map(Symbol, 'abc')
|
| 39 |
-
pattern = Basic(a, b, c)
|
| 40 |
-
assert list(unify(expr, pattern, {}, (a, b, c))) == [{a: 1, b: 2, c: 3}]
|
| 41 |
-
assert list(unify(expr, pattern, variables=(a, b, c))) == \
|
| 42 |
-
[{a: 1, b: 2, c: 3}]
|
| 43 |
-
|
| 44 |
-
def test_unify_variables():
|
| 45 |
-
assert list(unify(Basic(S(1), S(2)), Basic(S(1), x), {}, variables=(x,))) == [{x: 2}]
|
| 46 |
-
|
| 47 |
-
def test_s_input():
|
| 48 |
-
expr = Basic(S(1), S(2))
|
| 49 |
-
a, b = map(Symbol, 'ab')
|
| 50 |
-
pattern = Basic(a, b)
|
| 51 |
-
assert list(unify(expr, pattern, {}, (a, b))) == [{a: 1, b: 2}]
|
| 52 |
-
assert list(unify(expr, pattern, {a: 5}, (a, b))) == []
|
| 53 |
-
|
| 54 |
-
def iterdicteq(a, b):
|
| 55 |
-
a = tuple(a)
|
| 56 |
-
b = tuple(b)
|
| 57 |
-
return len(a) == len(b) and all(x in b for x in a)
|
| 58 |
-
|
| 59 |
-
def test_unify_commutative():
|
| 60 |
-
expr = Add(1, 2, 3, evaluate=False)
|
| 61 |
-
a, b, c = map(Symbol, 'abc')
|
| 62 |
-
pattern = Add(a, b, c, evaluate=False)
|
| 63 |
-
|
| 64 |
-
result = tuple(unify(expr, pattern, {}, (a, b, c)))
|
| 65 |
-
expected = ({a: 1, b: 2, c: 3},
|
| 66 |
-
{a: 1, b: 3, c: 2},
|
| 67 |
-
{a: 2, b: 1, c: 3},
|
| 68 |
-
{a: 2, b: 3, c: 1},
|
| 69 |
-
{a: 3, b: 1, c: 2},
|
| 70 |
-
{a: 3, b: 2, c: 1})
|
| 71 |
-
|
| 72 |
-
assert iterdicteq(result, expected)
|
| 73 |
-
|
| 74 |
-
def test_unify_iter():
|
| 75 |
-
expr = Add(1, 2, 3, evaluate=False)
|
| 76 |
-
a, b, c = map(Symbol, 'abc')
|
| 77 |
-
pattern = Add(a, c, evaluate=False)
|
| 78 |
-
assert is_associative(deconstruct(pattern))
|
| 79 |
-
assert is_commutative(deconstruct(pattern))
|
| 80 |
-
|
| 81 |
-
result = list(unify(expr, pattern, {}, (a, c)))
|
| 82 |
-
expected = [{a: 1, c: Add(2, 3, evaluate=False)},
|
| 83 |
-
{a: 1, c: Add(3, 2, evaluate=False)},
|
| 84 |
-
{a: 2, c: Add(1, 3, evaluate=False)},
|
| 85 |
-
{a: 2, c: Add(3, 1, evaluate=False)},
|
| 86 |
-
{a: 3, c: Add(1, 2, evaluate=False)},
|
| 87 |
-
{a: 3, c: Add(2, 1, evaluate=False)},
|
| 88 |
-
{a: Add(1, 2, evaluate=False), c: 3},
|
| 89 |
-
{a: Add(2, 1, evaluate=False), c: 3},
|
| 90 |
-
{a: Add(1, 3, evaluate=False), c: 2},
|
| 91 |
-
{a: Add(3, 1, evaluate=False), c: 2},
|
| 92 |
-
{a: Add(2, 3, evaluate=False), c: 1},
|
| 93 |
-
{a: Add(3, 2, evaluate=False), c: 1}]
|
| 94 |
-
|
| 95 |
-
assert iterdicteq(result, expected)
|
| 96 |
-
|
| 97 |
-
def test_hard_match():
|
| 98 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 99 |
-
expr = sin(x) + cos(x)**2
|
| 100 |
-
p, q = map(Symbol, 'pq')
|
| 101 |
-
pattern = sin(p) + cos(p)**2
|
| 102 |
-
assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
|
| 103 |
-
|
| 104 |
-
def test_matrix():
|
| 105 |
-
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
| 106 |
-
X = MatrixSymbol('X', n, n)
|
| 107 |
-
Y = MatrixSymbol('Y', 2, 2)
|
| 108 |
-
Z = MatrixSymbol('Z', 2, 3)
|
| 109 |
-
assert list(unify(X, Y, {}, variables=[n, Str('X')])) == [{Str('X'): Str('Y'), n: 2}]
|
| 110 |
-
assert list(unify(X, Z, {}, variables=[n, Str('X')])) == []
|
| 111 |
-
|
| 112 |
-
def test_non_frankenAdds():
|
| 113 |
-
# the is_commutative property used to fail because of Basic.__new__
|
| 114 |
-
# This caused is_commutative and str calls to fail
|
| 115 |
-
expr = x+y*2
|
| 116 |
-
rebuilt = construct(deconstruct(expr))
|
| 117 |
-
# Ensure that we can run these commands without causing an error
|
| 118 |
-
str(rebuilt)
|
| 119 |
-
rebuilt.is_commutative
|
| 120 |
-
|
| 121 |
-
def test_FiniteSet_commutivity():
|
| 122 |
-
from sympy.sets.sets import FiniteSet
|
| 123 |
-
a, b, c, x, y = symbols('a,b,c,x,y')
|
| 124 |
-
s = FiniteSet(a, b, c)
|
| 125 |
-
t = FiniteSet(x, y)
|
| 126 |
-
variables = (x, y)
|
| 127 |
-
assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, t, variables=variables))
|
| 128 |
-
|
| 129 |
-
def test_FiniteSet_complex():
|
| 130 |
-
from sympy.sets.sets import FiniteSet
|
| 131 |
-
a, b, c, x, y, z = symbols('a,b,c,x,y,z')
|
| 132 |
-
expr = FiniteSet(Basic(S(1), x), y, Basic(x, z))
|
| 133 |
-
pattern = FiniteSet(a, Basic(x, b))
|
| 134 |
-
variables = a, b
|
| 135 |
-
expected = ({b: 1, a: FiniteSet(y, Basic(x, z))},
|
| 136 |
-
{b: z, a: FiniteSet(y, Basic(S(1), x))})
|
| 137 |
-
assert iterdicteq(unify(expr, pattern, variables=variables), expected)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def test_and():
|
| 141 |
-
variables = x, y
|
| 142 |
-
expected = ({x: z > 0, y: n < 3},)
|
| 143 |
-
assert iterdicteq(unify((z>0) & (n<3), And(x, y), variables=variables),
|
| 144 |
-
expected)
|
| 145 |
-
|
| 146 |
-
def test_Union():
|
| 147 |
-
from sympy.sets.sets import Interval
|
| 148 |
-
assert list(unify(Interval(0, 1) + Interval(10, 11),
|
| 149 |
-
Interval(0, 1) + Interval(12, 13),
|
| 150 |
-
variables=(Interval(12, 13),)))
|
| 151 |
-
|
| 152 |
-
def test_is_commutative():
|
| 153 |
-
assert is_commutative(deconstruct(x+y))
|
| 154 |
-
assert is_commutative(deconstruct(x*y))
|
| 155 |
-
assert not is_commutative(deconstruct(x**y))
|
| 156 |
-
|
| 157 |
-
def test_commutative_in_commutative():
|
| 158 |
-
from sympy.abc import a,b,c,d
|
| 159 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 160 |
-
eq = sin(3)*sin(4)*sin(5) + 4*cos(3)*cos(4)
|
| 161 |
-
pat = a*cos(b)*cos(c) + d*sin(b)*sin(c)
|
| 162 |
-
assert next(unify(eq, pat, variables=(a,b,c,d)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/unify/tests/test_unify.py
DELETED
|
@@ -1,88 +0,0 @@
|
|
| 1 |
-
from sympy.unify.core import Compound, Variable, CondVariable, allcombinations
|
| 2 |
-
from sympy.unify import core
|
| 3 |
-
|
| 4 |
-
a,b,c = 'a', 'b', 'c'
|
| 5 |
-
w,x,y,z = map(Variable, 'wxyz')
|
| 6 |
-
|
| 7 |
-
C = Compound
|
| 8 |
-
|
| 9 |
-
def is_associative(x):
|
| 10 |
-
return isinstance(x, Compound) and (x.op in ('Add', 'Mul', 'CAdd', 'CMul'))
|
| 11 |
-
def is_commutative(x):
|
| 12 |
-
return isinstance(x, Compound) and (x.op in ('CAdd', 'CMul'))
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def unify(a, b, s={}):
|
| 16 |
-
return core.unify(a, b, s=s, is_associative=is_associative,
|
| 17 |
-
is_commutative=is_commutative)
|
| 18 |
-
|
| 19 |
-
def test_basic():
|
| 20 |
-
assert list(unify(a, x, {})) == [{x: a}]
|
| 21 |
-
assert list(unify(a, x, {x: 10})) == []
|
| 22 |
-
assert list(unify(1, x, {})) == [{x: 1}]
|
| 23 |
-
assert list(unify(a, a, {})) == [{}]
|
| 24 |
-
assert list(unify((w, x), (y, z), {})) == [{w: y, x: z}]
|
| 25 |
-
assert list(unify(x, (a, b), {})) == [{x: (a, b)}]
|
| 26 |
-
|
| 27 |
-
assert list(unify((a, b), (x, x), {})) == []
|
| 28 |
-
assert list(unify((y, z), (x, x), {}))!= []
|
| 29 |
-
assert list(unify((a, (b, c)), (a, (x, y)), {})) == [{x: b, y: c}]
|
| 30 |
-
|
| 31 |
-
def test_ops():
|
| 32 |
-
assert list(unify(C('Add', (a,b,c)), C('Add', (a,x,y)), {})) == \
|
| 33 |
-
[{x:b, y:c}]
|
| 34 |
-
assert list(unify(C('Add', (C('Mul', (1,2)), b,c)), C('Add', (x,y,c)), {})) == \
|
| 35 |
-
[{x: C('Mul', (1,2)), y:b}]
|
| 36 |
-
|
| 37 |
-
def test_associative():
|
| 38 |
-
c1 = C('Add', (1,2,3))
|
| 39 |
-
c2 = C('Add', (x,y))
|
| 40 |
-
assert tuple(unify(c1, c2, {})) == ({x: 1, y: C('Add', (2, 3))},
|
| 41 |
-
{x: C('Add', (1, 2)), y: 3})
|
| 42 |
-
|
| 43 |
-
def test_commutative():
|
| 44 |
-
c1 = C('CAdd', (1,2,3))
|
| 45 |
-
c2 = C('CAdd', (x,y))
|
| 46 |
-
result = list(unify(c1, c2, {}))
|
| 47 |
-
assert {x: 1, y: C('CAdd', (2, 3))} in result
|
| 48 |
-
assert ({x: 2, y: C('CAdd', (1, 3))} in result or
|
| 49 |
-
{x: 2, y: C('CAdd', (3, 1))} in result)
|
| 50 |
-
|
| 51 |
-
def _test_combinations_assoc():
|
| 52 |
-
assert set(allcombinations((1,2,3), (a,b), True)) == \
|
| 53 |
-
{(((1, 2), (3,)), (a, b)), (((1,), (2, 3)), (a, b))}
|
| 54 |
-
|
| 55 |
-
def _test_combinations_comm():
|
| 56 |
-
assert set(allcombinations((1,2,3), (a,b), None)) == \
|
| 57 |
-
{(((1,), (2, 3)), ('a', 'b')), (((2,), (3, 1)), ('a', 'b')),
|
| 58 |
-
(((3,), (1, 2)), ('a', 'b')), (((1, 2), (3,)), ('a', 'b')),
|
| 59 |
-
(((2, 3), (1,)), ('a', 'b')), (((3, 1), (2,)), ('a', 'b'))}
|
| 60 |
-
|
| 61 |
-
def test_allcombinations():
|
| 62 |
-
assert set(allcombinations((1,2), (1,2), 'commutative')) ==\
|
| 63 |
-
{(((1,),(2,)), ((1,),(2,))), (((1,),(2,)), ((2,),(1,)))}
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def test_commutativity():
|
| 67 |
-
c1 = Compound('CAdd', (a, b))
|
| 68 |
-
c2 = Compound('CAdd', (x, y))
|
| 69 |
-
assert is_commutative(c1) and is_commutative(c2)
|
| 70 |
-
assert len(list(unify(c1, c2, {}))) == 2
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def test_CondVariable():
|
| 74 |
-
expr = C('CAdd', (1, 2))
|
| 75 |
-
x = Variable('x')
|
| 76 |
-
y = CondVariable('y', lambda a: a % 2 == 0)
|
| 77 |
-
z = CondVariable('z', lambda a: a > 3)
|
| 78 |
-
pattern = C('CAdd', (x, y))
|
| 79 |
-
assert list(unify(expr, pattern, {})) == \
|
| 80 |
-
[{x: 1, y: 2}]
|
| 81 |
-
|
| 82 |
-
z = CondVariable('z', lambda a: a > 3)
|
| 83 |
-
pattern = C('CAdd', (z, y))
|
| 84 |
-
|
| 85 |
-
assert list(unify(expr, pattern, {})) == []
|
| 86 |
-
|
| 87 |
-
def test_defaultdict():
|
| 88 |
-
assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/__init__.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
"""This module contains some general purpose utilities that are used across
|
| 2 |
-
SymPy.
|
| 3 |
-
"""
|
| 4 |
-
from .iterables import (flatten, group, take, subsets,
|
| 5 |
-
variations, numbered_symbols, cartes, capture, dict_merge,
|
| 6 |
-
prefixes, postfixes, sift, topological_sort, unflatten,
|
| 7 |
-
has_dups, has_variety, reshape, rotations)
|
| 8 |
-
|
| 9 |
-
from .misc import filldedent
|
| 10 |
-
|
| 11 |
-
from .lambdify import lambdify
|
| 12 |
-
|
| 13 |
-
from .decorator import threaded, xthreaded, public, memoize_property
|
| 14 |
-
|
| 15 |
-
from .timeutils import timed
|
| 16 |
-
|
| 17 |
-
__all__ = [
|
| 18 |
-
'flatten', 'group', 'take', 'subsets', 'variations', 'numbered_symbols',
|
| 19 |
-
'cartes', 'capture', 'dict_merge', 'prefixes', 'postfixes', 'sift',
|
| 20 |
-
'topological_sort', 'unflatten', 'has_dups', 'has_variety', 'reshape',
|
| 21 |
-
'rotations',
|
| 22 |
-
|
| 23 |
-
'filldedent',
|
| 24 |
-
|
| 25 |
-
'lambdify',
|
| 26 |
-
|
| 27 |
-
'threaded', 'xthreaded', 'public', 'memoize_property',
|
| 28 |
-
|
| 29 |
-
'timed',
|
| 30 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/__init__.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
""" This sub-module is private, i.e. external code should not depend on it.
|
| 2 |
-
|
| 3 |
-
These functions are used by tests run as part of continuous integration.
|
| 4 |
-
Once the implementation is mature (it should support the major
|
| 5 |
-
platforms: Windows, OS X & Linux) it may become official API which
|
| 6 |
-
may be relied upon by downstream libraries. Until then API may break
|
| 7 |
-
without prior notice.
|
| 8 |
-
|
| 9 |
-
TODO:
|
| 10 |
-
- (optionally) clean up after tempfile.mkdtemp()
|
| 11 |
-
- cross-platform testing
|
| 12 |
-
- caching of compiler choice and intermediate files
|
| 13 |
-
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
from .compilation import compile_link_import_strings, compile_run_strings
|
| 17 |
-
from .availability import has_fortran, has_c, has_cxx
|
| 18 |
-
|
| 19 |
-
__all__ = [
|
| 20 |
-
'compile_link_import_strings', 'compile_run_strings',
|
| 21 |
-
'has_fortran', 'has_c', 'has_cxx',
|
| 22 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/availability.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from .compilation import compile_run_strings
|
| 3 |
-
from .util import CompilerNotFoundError
|
| 4 |
-
|
| 5 |
-
def has_fortran():
|
| 6 |
-
if not hasattr(has_fortran, 'result'):
|
| 7 |
-
try:
|
| 8 |
-
(stdout, stderr), info = compile_run_strings(
|
| 9 |
-
[('main.f90', (
|
| 10 |
-
'program foo\n'
|
| 11 |
-
'print *, "hello world"\n'
|
| 12 |
-
'end program'
|
| 13 |
-
))], clean=True
|
| 14 |
-
)
|
| 15 |
-
except CompilerNotFoundError:
|
| 16 |
-
has_fortran.result = False
|
| 17 |
-
if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
|
| 18 |
-
raise
|
| 19 |
-
else:
|
| 20 |
-
if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
|
| 21 |
-
if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
|
| 22 |
-
raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
|
| 23 |
-
has_fortran.result = False
|
| 24 |
-
else:
|
| 25 |
-
has_fortran.result = True
|
| 26 |
-
return has_fortran.result
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def has_c():
|
| 30 |
-
if not hasattr(has_c, 'result'):
|
| 31 |
-
try:
|
| 32 |
-
(stdout, stderr), info = compile_run_strings(
|
| 33 |
-
[('main.c', (
|
| 34 |
-
'#include <stdio.h>\n'
|
| 35 |
-
'int main(){\n'
|
| 36 |
-
'printf("hello world\\n");\n'
|
| 37 |
-
'return 0;\n'
|
| 38 |
-
'}'
|
| 39 |
-
))], clean=True
|
| 40 |
-
)
|
| 41 |
-
except CompilerNotFoundError:
|
| 42 |
-
has_c.result = False
|
| 43 |
-
if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
|
| 44 |
-
raise
|
| 45 |
-
else:
|
| 46 |
-
if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
|
| 47 |
-
if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
|
| 48 |
-
raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
|
| 49 |
-
has_c.result = False
|
| 50 |
-
else:
|
| 51 |
-
has_c.result = True
|
| 52 |
-
return has_c.result
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def has_cxx():
|
| 56 |
-
if not hasattr(has_cxx, 'result'):
|
| 57 |
-
try:
|
| 58 |
-
(stdout, stderr), info = compile_run_strings(
|
| 59 |
-
[('main.cxx', (
|
| 60 |
-
'#include <iostream>\n'
|
| 61 |
-
'int main(){\n'
|
| 62 |
-
'std::cout << "hello world" << std::endl;\n'
|
| 63 |
-
'}'
|
| 64 |
-
))], clean=True
|
| 65 |
-
)
|
| 66 |
-
except CompilerNotFoundError:
|
| 67 |
-
has_cxx.result = False
|
| 68 |
-
if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
|
| 69 |
-
raise
|
| 70 |
-
else:
|
| 71 |
-
if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
|
| 72 |
-
if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
|
| 73 |
-
raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
|
| 74 |
-
has_cxx.result = False
|
| 75 |
-
else:
|
| 76 |
-
has_cxx.result = True
|
| 77 |
-
return has_cxx.result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/compilation.py
DELETED
|
@@ -1,657 +0,0 @@
|
|
| 1 |
-
import glob
|
| 2 |
-
import os
|
| 3 |
-
import shutil
|
| 4 |
-
import subprocess
|
| 5 |
-
import sys
|
| 6 |
-
import tempfile
|
| 7 |
-
import warnings
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from sysconfig import get_config_var, get_config_vars, get_path
|
| 10 |
-
|
| 11 |
-
from .runners import (
|
| 12 |
-
CCompilerRunner,
|
| 13 |
-
CppCompilerRunner,
|
| 14 |
-
FortranCompilerRunner
|
| 15 |
-
)
|
| 16 |
-
from .util import (
|
| 17 |
-
get_abspath, make_dirs, copy, Glob, ArbitraryDepthGlob,
|
| 18 |
-
glob_at_depth, import_module_from_file, pyx_is_cplus,
|
| 19 |
-
sha256_of_string, sha256_of_file, CompileError
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
if os.name == 'posix':
|
| 23 |
-
objext = '.o'
|
| 24 |
-
elif os.name == 'nt':
|
| 25 |
-
objext = '.obj'
|
| 26 |
-
else:
|
| 27 |
-
warnings.warn("Unknown os.name: {}".format(os.name))
|
| 28 |
-
objext = '.o'
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def compile_sources(files, Runner=None, destdir=None, cwd=None, keep_dir_struct=False,
|
| 32 |
-
per_file_kwargs=None, **kwargs):
|
| 33 |
-
""" Compile source code files to object files.
|
| 34 |
-
|
| 35 |
-
Parameters
|
| 36 |
-
==========
|
| 37 |
-
|
| 38 |
-
files : iterable of str
|
| 39 |
-
Paths to source files, if ``cwd`` is given, the paths are taken as relative.
|
| 40 |
-
Runner: CompilerRunner subclass (optional)
|
| 41 |
-
Could be e.g. ``FortranCompilerRunner``. Will be inferred from filename
|
| 42 |
-
extensions if missing.
|
| 43 |
-
destdir: str
|
| 44 |
-
Output directory, if cwd is given, the path is taken as relative.
|
| 45 |
-
cwd: str
|
| 46 |
-
Working directory. Specify to have compiler run in other directory.
|
| 47 |
-
also used as root of relative paths.
|
| 48 |
-
keep_dir_struct: bool
|
| 49 |
-
Reproduce directory structure in `destdir`. default: ``False``
|
| 50 |
-
per_file_kwargs: dict
|
| 51 |
-
Dict mapping instances in ``files`` to keyword arguments.
|
| 52 |
-
\\*\\*kwargs: dict
|
| 53 |
-
Default keyword arguments to pass to ``Runner``.
|
| 54 |
-
|
| 55 |
-
Returns
|
| 56 |
-
=======
|
| 57 |
-
List of strings (paths of object files).
|
| 58 |
-
"""
|
| 59 |
-
_per_file_kwargs = {}
|
| 60 |
-
|
| 61 |
-
if per_file_kwargs is not None:
|
| 62 |
-
for k, v in per_file_kwargs.items():
|
| 63 |
-
if isinstance(k, Glob):
|
| 64 |
-
for path in glob.glob(k.pathname):
|
| 65 |
-
_per_file_kwargs[path] = v
|
| 66 |
-
elif isinstance(k, ArbitraryDepthGlob):
|
| 67 |
-
for path in glob_at_depth(k.filename, cwd):
|
| 68 |
-
_per_file_kwargs[path] = v
|
| 69 |
-
else:
|
| 70 |
-
_per_file_kwargs[k] = v
|
| 71 |
-
|
| 72 |
-
# Set up destination directory
|
| 73 |
-
destdir = destdir or '.'
|
| 74 |
-
if not os.path.isdir(destdir):
|
| 75 |
-
if os.path.exists(destdir):
|
| 76 |
-
raise OSError("{} is not a directory".format(destdir))
|
| 77 |
-
else:
|
| 78 |
-
make_dirs(destdir)
|
| 79 |
-
if cwd is None:
|
| 80 |
-
cwd = '.'
|
| 81 |
-
for f in files:
|
| 82 |
-
copy(f, destdir, only_update=True, dest_is_dir=True)
|
| 83 |
-
|
| 84 |
-
# Compile files and return list of paths to the objects
|
| 85 |
-
dstpaths = []
|
| 86 |
-
for f in files:
|
| 87 |
-
if keep_dir_struct:
|
| 88 |
-
name, ext = os.path.splitext(f)
|
| 89 |
-
else:
|
| 90 |
-
name, ext = os.path.splitext(os.path.basename(f))
|
| 91 |
-
file_kwargs = kwargs.copy()
|
| 92 |
-
file_kwargs.update(_per_file_kwargs.get(f, {}))
|
| 93 |
-
dstpaths.append(src2obj(f, Runner, cwd=cwd, **file_kwargs))
|
| 94 |
-
return dstpaths
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def get_mixed_fort_c_linker(vendor=None, cplus=False, cwd=None):
|
| 98 |
-
vendor = vendor or os.environ.get('SYMPY_COMPILER_VENDOR', 'gnu')
|
| 99 |
-
|
| 100 |
-
if vendor.lower() == 'intel':
|
| 101 |
-
if cplus:
|
| 102 |
-
return (FortranCompilerRunner,
|
| 103 |
-
{'flags': ['-nofor_main', '-cxxlib']}, vendor)
|
| 104 |
-
else:
|
| 105 |
-
return (FortranCompilerRunner,
|
| 106 |
-
{'flags': ['-nofor_main']}, vendor)
|
| 107 |
-
elif vendor.lower() == 'gnu' or 'llvm':
|
| 108 |
-
if cplus:
|
| 109 |
-
return (CppCompilerRunner,
|
| 110 |
-
{'lib_options': ['fortran']}, vendor)
|
| 111 |
-
else:
|
| 112 |
-
return (FortranCompilerRunner,
|
| 113 |
-
{}, vendor)
|
| 114 |
-
else:
|
| 115 |
-
raise ValueError("No vendor found.")
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def link(obj_files, out_file=None, shared=False, Runner=None,
|
| 119 |
-
cwd=None, cplus=False, fort=False, extra_objs=None, **kwargs):
|
| 120 |
-
""" Link object files.
|
| 121 |
-
|
| 122 |
-
Parameters
|
| 123 |
-
==========
|
| 124 |
-
|
| 125 |
-
obj_files: iterable of str
|
| 126 |
-
Paths to object files.
|
| 127 |
-
out_file: str (optional)
|
| 128 |
-
Path to executable/shared library, if ``None`` it will be
|
| 129 |
-
deduced from the last item in obj_files.
|
| 130 |
-
shared: bool
|
| 131 |
-
Generate a shared library?
|
| 132 |
-
Runner: CompilerRunner subclass (optional)
|
| 133 |
-
If not given the ``cplus`` and ``fort`` flags will be inspected
|
| 134 |
-
(fallback is the C compiler).
|
| 135 |
-
cwd: str
|
| 136 |
-
Path to the root of relative paths and working directory for compiler.
|
| 137 |
-
cplus: bool
|
| 138 |
-
C++ objects? default: ``False``.
|
| 139 |
-
fort: bool
|
| 140 |
-
Fortran objects? default: ``False``.
|
| 141 |
-
extra_objs: list
|
| 142 |
-
List of paths to extra object files / static libraries.
|
| 143 |
-
\\*\\*kwargs: dict
|
| 144 |
-
Keyword arguments passed to ``Runner``.
|
| 145 |
-
|
| 146 |
-
Returns
|
| 147 |
-
=======
|
| 148 |
-
|
| 149 |
-
The absolute path to the generated shared object / executable.
|
| 150 |
-
|
| 151 |
-
"""
|
| 152 |
-
if out_file is None:
|
| 153 |
-
out_file, ext = os.path.splitext(os.path.basename(obj_files[-1]))
|
| 154 |
-
if shared:
|
| 155 |
-
out_file += get_config_var('EXT_SUFFIX')
|
| 156 |
-
|
| 157 |
-
if not Runner:
|
| 158 |
-
if fort:
|
| 159 |
-
Runner, extra_kwargs, vendor = \
|
| 160 |
-
get_mixed_fort_c_linker(
|
| 161 |
-
vendor=kwargs.get('vendor', None),
|
| 162 |
-
cplus=cplus,
|
| 163 |
-
cwd=cwd,
|
| 164 |
-
)
|
| 165 |
-
for k, v in extra_kwargs.items():
|
| 166 |
-
if k in kwargs:
|
| 167 |
-
kwargs[k].expand(v)
|
| 168 |
-
else:
|
| 169 |
-
kwargs[k] = v
|
| 170 |
-
else:
|
| 171 |
-
if cplus:
|
| 172 |
-
Runner = CppCompilerRunner
|
| 173 |
-
else:
|
| 174 |
-
Runner = CCompilerRunner
|
| 175 |
-
|
| 176 |
-
flags = kwargs.pop('flags', [])
|
| 177 |
-
if shared:
|
| 178 |
-
if '-shared' not in flags:
|
| 179 |
-
flags.append('-shared')
|
| 180 |
-
run_linker = kwargs.pop('run_linker', True)
|
| 181 |
-
if not run_linker:
|
| 182 |
-
raise ValueError("run_linker was set to False (nonsensical).")
|
| 183 |
-
|
| 184 |
-
out_file = get_abspath(out_file, cwd=cwd)
|
| 185 |
-
runner = Runner(obj_files+(extra_objs or []), out_file, flags, cwd=cwd, **kwargs)
|
| 186 |
-
runner.run()
|
| 187 |
-
return out_file
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def link_py_so(obj_files, so_file=None, cwd=None, libraries=None,
|
| 191 |
-
cplus=False, fort=False, extra_objs=None, **kwargs):
|
| 192 |
-
""" Link Python extension module (shared object) for importing
|
| 193 |
-
|
| 194 |
-
Parameters
|
| 195 |
-
==========
|
| 196 |
-
|
| 197 |
-
obj_files: iterable of str
|
| 198 |
-
Paths to object files to be linked.
|
| 199 |
-
so_file: str
|
| 200 |
-
Name (path) of shared object file to create. If not specified it will
|
| 201 |
-
have the basname of the last object file in `obj_files` but with the
|
| 202 |
-
extension '.so' (Unix).
|
| 203 |
-
cwd: path string
|
| 204 |
-
Root of relative paths and working directory of linker.
|
| 205 |
-
libraries: iterable of strings
|
| 206 |
-
Libraries to link against, e.g. ['m'].
|
| 207 |
-
cplus: bool
|
| 208 |
-
Any C++ objects? default: ``False``.
|
| 209 |
-
fort: bool
|
| 210 |
-
Any Fortran objects? default: ``False``.
|
| 211 |
-
extra_objs: list
|
| 212 |
-
List of paths of extra object files / static libraries to link against.
|
| 213 |
-
kwargs**: dict
|
| 214 |
-
Keyword arguments passed to ``link(...)``.
|
| 215 |
-
|
| 216 |
-
Returns
|
| 217 |
-
=======
|
| 218 |
-
|
| 219 |
-
Absolute path to the generate shared object.
|
| 220 |
-
"""
|
| 221 |
-
libraries = libraries or []
|
| 222 |
-
|
| 223 |
-
include_dirs = kwargs.pop('include_dirs', [])
|
| 224 |
-
library_dirs = kwargs.pop('library_dirs', [])
|
| 225 |
-
|
| 226 |
-
# Add Python include and library directories
|
| 227 |
-
# PY_LDFLAGS does not available on all python implementations
|
| 228 |
-
# e.g. when with pypy, so it's LDFLAGS we need to use
|
| 229 |
-
if sys.platform == "win32":
|
| 230 |
-
warnings.warn("Windows not yet supported.")
|
| 231 |
-
elif sys.platform == 'darwin':
|
| 232 |
-
cfgDict = get_config_vars()
|
| 233 |
-
kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']]
|
| 234 |
-
library_dirs += [cfgDict['LIBDIR']]
|
| 235 |
-
|
| 236 |
-
# In macOS, linker needs to compile frameworks
|
| 237 |
-
# e.g. "-framework CoreFoundation"
|
| 238 |
-
is_framework = False
|
| 239 |
-
for opt in cfgDict['LIBS'].split():
|
| 240 |
-
if is_framework:
|
| 241 |
-
kwargs['linkline'] = kwargs.get('linkline', []) + ['-framework', opt]
|
| 242 |
-
is_framework = False
|
| 243 |
-
elif opt.startswith('-l'):
|
| 244 |
-
libraries.append(opt[2:])
|
| 245 |
-
elif opt.startswith('-framework'):
|
| 246 |
-
is_framework = True
|
| 247 |
-
# The python library is not included in LIBS
|
| 248 |
-
libfile = cfgDict['LIBRARY']
|
| 249 |
-
libname = ".".join(libfile.split('.')[:-1])[3:]
|
| 250 |
-
libraries.append(libname)
|
| 251 |
-
|
| 252 |
-
elif sys.platform[:3] == 'aix':
|
| 253 |
-
# Don't use the default code below
|
| 254 |
-
pass
|
| 255 |
-
else:
|
| 256 |
-
if get_config_var('Py_ENABLE_SHARED'):
|
| 257 |
-
cfgDict = get_config_vars()
|
| 258 |
-
kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']]
|
| 259 |
-
library_dirs += [cfgDict['LIBDIR']]
|
| 260 |
-
for opt in cfgDict['BLDLIBRARY'].split():
|
| 261 |
-
if opt.startswith('-l'):
|
| 262 |
-
libraries += [opt[2:]]
|
| 263 |
-
else:
|
| 264 |
-
pass
|
| 265 |
-
|
| 266 |
-
flags = kwargs.pop('flags', [])
|
| 267 |
-
needed_flags = ('-pthread',)
|
| 268 |
-
for flag in needed_flags:
|
| 269 |
-
if flag not in flags:
|
| 270 |
-
flags.append(flag)
|
| 271 |
-
|
| 272 |
-
return link(obj_files, shared=True, flags=flags, cwd=cwd, cplus=cplus, fort=fort,
|
| 273 |
-
include_dirs=include_dirs, libraries=libraries,
|
| 274 |
-
library_dirs=library_dirs, extra_objs=extra_objs, **kwargs)
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
def simple_cythonize(src, destdir=None, cwd=None, **cy_kwargs):
|
| 278 |
-
""" Generates a C file from a Cython source file.
|
| 279 |
-
|
| 280 |
-
Parameters
|
| 281 |
-
==========
|
| 282 |
-
|
| 283 |
-
src: str
|
| 284 |
-
Path to Cython source.
|
| 285 |
-
destdir: str (optional)
|
| 286 |
-
Path to output directory (default: '.').
|
| 287 |
-
cwd: path string (optional)
|
| 288 |
-
Root of relative paths (default: '.').
|
| 289 |
-
**cy_kwargs:
|
| 290 |
-
Second argument passed to cy_compile. Generates a .cpp file if ``cplus=True`` in ``cy_kwargs``,
|
| 291 |
-
else a .c file.
|
| 292 |
-
"""
|
| 293 |
-
from Cython.Compiler.Main import (
|
| 294 |
-
default_options, CompilationOptions
|
| 295 |
-
)
|
| 296 |
-
from Cython.Compiler.Main import compile as cy_compile
|
| 297 |
-
|
| 298 |
-
assert src.lower().endswith('.pyx') or src.lower().endswith('.py')
|
| 299 |
-
cwd = cwd or '.'
|
| 300 |
-
destdir = destdir or '.'
|
| 301 |
-
|
| 302 |
-
ext = '.cpp' if cy_kwargs.get('cplus', False) else '.c'
|
| 303 |
-
c_name = os.path.splitext(os.path.basename(src))[0] + ext
|
| 304 |
-
|
| 305 |
-
dstfile = os.path.join(destdir, c_name)
|
| 306 |
-
|
| 307 |
-
if cwd:
|
| 308 |
-
ori_dir = os.getcwd()
|
| 309 |
-
else:
|
| 310 |
-
ori_dir = '.'
|
| 311 |
-
os.chdir(cwd)
|
| 312 |
-
try:
|
| 313 |
-
cy_options = CompilationOptions(default_options)
|
| 314 |
-
cy_options.__dict__.update(cy_kwargs)
|
| 315 |
-
# Set language_level if not set by cy_kwargs
|
| 316 |
-
# as not setting it is deprecated
|
| 317 |
-
if 'language_level' not in cy_kwargs:
|
| 318 |
-
cy_options.__dict__['language_level'] = 3
|
| 319 |
-
cy_result = cy_compile([src], cy_options)
|
| 320 |
-
if cy_result.num_errors > 0:
|
| 321 |
-
raise ValueError("Cython compilation failed.")
|
| 322 |
-
|
| 323 |
-
# Move generated C file to destination
|
| 324 |
-
# In macOS, the generated C file is in the same directory as the source
|
| 325 |
-
# but the /var is a symlink to /private/var, so we need to use realpath
|
| 326 |
-
if os.path.realpath(os.path.dirname(src)) != os.path.realpath(destdir):
|
| 327 |
-
if os.path.exists(dstfile):
|
| 328 |
-
os.unlink(dstfile)
|
| 329 |
-
shutil.move(os.path.join(os.path.dirname(src), c_name), destdir)
|
| 330 |
-
finally:
|
| 331 |
-
os.chdir(ori_dir)
|
| 332 |
-
return dstfile
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
extension_mapping = {
|
| 336 |
-
'.c': (CCompilerRunner, None),
|
| 337 |
-
'.cpp': (CppCompilerRunner, None),
|
| 338 |
-
'.cxx': (CppCompilerRunner, None),
|
| 339 |
-
'.f': (FortranCompilerRunner, None),
|
| 340 |
-
'.for': (FortranCompilerRunner, None),
|
| 341 |
-
'.ftn': (FortranCompilerRunner, None),
|
| 342 |
-
'.f90': (FortranCompilerRunner, None), # ifort only knows about .f90
|
| 343 |
-
'.f95': (FortranCompilerRunner, 'f95'),
|
| 344 |
-
'.f03': (FortranCompilerRunner, 'f2003'),
|
| 345 |
-
'.f08': (FortranCompilerRunner, 'f2008'),
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
def src2obj(srcpath, Runner=None, objpath=None, cwd=None, inc_py=False, **kwargs):
|
| 350 |
-
""" Compiles a source code file to an object file.
|
| 351 |
-
|
| 352 |
-
Files ending with '.pyx' assumed to be cython files and
|
| 353 |
-
are dispatched to pyx2obj.
|
| 354 |
-
|
| 355 |
-
Parameters
|
| 356 |
-
==========
|
| 357 |
-
|
| 358 |
-
srcpath: str
|
| 359 |
-
Path to source file.
|
| 360 |
-
Runner: CompilerRunner subclass (optional)
|
| 361 |
-
If ``None``: deduced from extension of srcpath.
|
| 362 |
-
objpath : str (optional)
|
| 363 |
-
Path to generated object. If ``None``: deduced from ``srcpath``.
|
| 364 |
-
cwd: str (optional)
|
| 365 |
-
Working directory and root of relative paths. If ``None``: current dir.
|
| 366 |
-
inc_py: bool
|
| 367 |
-
Add Python include path to kwarg "include_dirs". Default: False
|
| 368 |
-
\\*\\*kwargs: dict
|
| 369 |
-
keyword arguments passed to Runner or pyx2obj
|
| 370 |
-
|
| 371 |
-
"""
|
| 372 |
-
name, ext = os.path.splitext(os.path.basename(srcpath))
|
| 373 |
-
if objpath is None:
|
| 374 |
-
if os.path.isabs(srcpath):
|
| 375 |
-
objpath = '.'
|
| 376 |
-
else:
|
| 377 |
-
objpath = os.path.dirname(srcpath)
|
| 378 |
-
objpath = objpath or '.' # avoid objpath == ''
|
| 379 |
-
|
| 380 |
-
if os.path.isdir(objpath):
|
| 381 |
-
objpath = os.path.join(objpath, name + objext)
|
| 382 |
-
|
| 383 |
-
include_dirs = kwargs.pop('include_dirs', [])
|
| 384 |
-
if inc_py:
|
| 385 |
-
py_inc_dir = get_path('include')
|
| 386 |
-
if py_inc_dir not in include_dirs:
|
| 387 |
-
include_dirs.append(py_inc_dir)
|
| 388 |
-
|
| 389 |
-
if ext.lower() == '.pyx':
|
| 390 |
-
return pyx2obj(srcpath, objpath=objpath, include_dirs=include_dirs, cwd=cwd,
|
| 391 |
-
**kwargs)
|
| 392 |
-
|
| 393 |
-
if Runner is None:
|
| 394 |
-
Runner, std = extension_mapping[ext.lower()]
|
| 395 |
-
if 'std' not in kwargs:
|
| 396 |
-
kwargs['std'] = std
|
| 397 |
-
|
| 398 |
-
flags = kwargs.pop('flags', [])
|
| 399 |
-
needed_flags = ('-fPIC',)
|
| 400 |
-
for flag in needed_flags:
|
| 401 |
-
if flag not in flags:
|
| 402 |
-
flags.append(flag)
|
| 403 |
-
|
| 404 |
-
# src2obj implies not running the linker...
|
| 405 |
-
run_linker = kwargs.pop('run_linker', False)
|
| 406 |
-
if run_linker:
|
| 407 |
-
raise CompileError("src2obj called with run_linker=True")
|
| 408 |
-
|
| 409 |
-
runner = Runner([srcpath], objpath, include_dirs=include_dirs,
|
| 410 |
-
run_linker=run_linker, cwd=cwd, flags=flags, **kwargs)
|
| 411 |
-
runner.run()
|
| 412 |
-
return objpath
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
def pyx2obj(pyxpath, objpath=None, destdir=None, cwd=None,
|
| 416 |
-
include_dirs=None, cy_kwargs=None, cplus=None, **kwargs):
|
| 417 |
-
"""
|
| 418 |
-
Convenience function
|
| 419 |
-
|
| 420 |
-
If cwd is specified, pyxpath and dst are taken to be relative
|
| 421 |
-
If only_update is set to `True` the modification time is checked
|
| 422 |
-
and compilation is only run if the source is newer than the
|
| 423 |
-
destination
|
| 424 |
-
|
| 425 |
-
Parameters
|
| 426 |
-
==========
|
| 427 |
-
|
| 428 |
-
pyxpath: str
|
| 429 |
-
Path to Cython source file.
|
| 430 |
-
objpath: str (optional)
|
| 431 |
-
Path to object file to generate.
|
| 432 |
-
destdir: str (optional)
|
| 433 |
-
Directory to put generated C file. When ``None``: directory of ``objpath``.
|
| 434 |
-
cwd: str (optional)
|
| 435 |
-
Working directory and root of relative paths.
|
| 436 |
-
include_dirs: iterable of path strings (optional)
|
| 437 |
-
Passed onto src2obj and via cy_kwargs['include_path']
|
| 438 |
-
to simple_cythonize.
|
| 439 |
-
cy_kwargs: dict (optional)
|
| 440 |
-
Keyword arguments passed onto `simple_cythonize`
|
| 441 |
-
cplus: bool (optional)
|
| 442 |
-
Indicate whether C++ is used. default: auto-detect using ``.util.pyx_is_cplus``.
|
| 443 |
-
compile_kwargs: dict
|
| 444 |
-
keyword arguments passed onto src2obj
|
| 445 |
-
|
| 446 |
-
Returns
|
| 447 |
-
=======
|
| 448 |
-
|
| 449 |
-
Absolute path of generated object file.
|
| 450 |
-
|
| 451 |
-
"""
|
| 452 |
-
assert pyxpath.endswith('.pyx')
|
| 453 |
-
cwd = cwd or '.'
|
| 454 |
-
objpath = objpath or '.'
|
| 455 |
-
destdir = destdir or os.path.dirname(objpath)
|
| 456 |
-
|
| 457 |
-
abs_objpath = get_abspath(objpath, cwd=cwd)
|
| 458 |
-
|
| 459 |
-
if os.path.isdir(abs_objpath):
|
| 460 |
-
pyx_fname = os.path.basename(pyxpath)
|
| 461 |
-
name, ext = os.path.splitext(pyx_fname)
|
| 462 |
-
objpath = os.path.join(objpath, name + objext)
|
| 463 |
-
|
| 464 |
-
cy_kwargs = cy_kwargs or {}
|
| 465 |
-
cy_kwargs['output_dir'] = cwd
|
| 466 |
-
if cplus is None:
|
| 467 |
-
cplus = pyx_is_cplus(pyxpath)
|
| 468 |
-
cy_kwargs['cplus'] = cplus
|
| 469 |
-
|
| 470 |
-
interm_c_file = simple_cythonize(pyxpath, destdir=destdir, cwd=cwd, **cy_kwargs)
|
| 471 |
-
|
| 472 |
-
include_dirs = include_dirs or []
|
| 473 |
-
flags = kwargs.pop('flags', [])
|
| 474 |
-
needed_flags = ('-fwrapv', '-pthread', '-fPIC')
|
| 475 |
-
for flag in needed_flags:
|
| 476 |
-
if flag not in flags:
|
| 477 |
-
flags.append(flag)
|
| 478 |
-
|
| 479 |
-
options = kwargs.pop('options', [])
|
| 480 |
-
|
| 481 |
-
if kwargs.pop('strict_aliasing', False):
|
| 482 |
-
raise CompileError("Cython requires strict aliasing to be disabled.")
|
| 483 |
-
|
| 484 |
-
# Let's be explicit about standard
|
| 485 |
-
if cplus:
|
| 486 |
-
std = kwargs.pop('std', 'c++98')
|
| 487 |
-
else:
|
| 488 |
-
std = kwargs.pop('std', 'c99')
|
| 489 |
-
|
| 490 |
-
return src2obj(interm_c_file, objpath=objpath, cwd=cwd,
|
| 491 |
-
include_dirs=include_dirs, flags=flags, std=std,
|
| 492 |
-
options=options, inc_py=True, strict_aliasing=False,
|
| 493 |
-
**kwargs)
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
def _any_X(srcs, cls):
|
| 497 |
-
for src in srcs:
|
| 498 |
-
name, ext = os.path.splitext(src)
|
| 499 |
-
key = ext.lower()
|
| 500 |
-
if key in extension_mapping:
|
| 501 |
-
if extension_mapping[key][0] == cls:
|
| 502 |
-
return True
|
| 503 |
-
return False
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
def any_fortran_src(srcs):
|
| 507 |
-
return _any_X(srcs, FortranCompilerRunner)
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
def any_cplus_src(srcs):
|
| 511 |
-
return _any_X(srcs, CppCompilerRunner)
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
def compile_link_import_py_ext(sources, extname=None, build_dir='.', compile_kwargs=None,
|
| 515 |
-
link_kwargs=None, extra_objs=None):
|
| 516 |
-
""" Compiles sources to a shared object (Python extension) and imports it
|
| 517 |
-
|
| 518 |
-
Sources in ``sources`` which is imported. If shared object is newer than the sources, they
|
| 519 |
-
are not recompiled but instead it is imported.
|
| 520 |
-
|
| 521 |
-
Parameters
|
| 522 |
-
==========
|
| 523 |
-
|
| 524 |
-
sources : list of strings
|
| 525 |
-
List of paths to sources.
|
| 526 |
-
extname : string
|
| 527 |
-
Name of extension (default: ``None``).
|
| 528 |
-
If ``None``: taken from the last file in ``sources`` without extension.
|
| 529 |
-
build_dir: str
|
| 530 |
-
Path to directory in which objects files etc. are generated.
|
| 531 |
-
compile_kwargs: dict
|
| 532 |
-
keyword arguments passed to ``compile_sources``
|
| 533 |
-
link_kwargs: dict
|
| 534 |
-
keyword arguments passed to ``link_py_so``
|
| 535 |
-
extra_objs: list
|
| 536 |
-
List of paths to (prebuilt) object files / static libraries to link against.
|
| 537 |
-
|
| 538 |
-
Returns
|
| 539 |
-
=======
|
| 540 |
-
|
| 541 |
-
The imported module from of the Python extension.
|
| 542 |
-
"""
|
| 543 |
-
if extname is None:
|
| 544 |
-
extname = os.path.splitext(os.path.basename(sources[-1]))[0]
|
| 545 |
-
|
| 546 |
-
compile_kwargs = compile_kwargs or {}
|
| 547 |
-
link_kwargs = link_kwargs or {}
|
| 548 |
-
|
| 549 |
-
try:
|
| 550 |
-
mod = import_module_from_file(os.path.join(build_dir, extname), sources)
|
| 551 |
-
except ImportError:
|
| 552 |
-
objs = compile_sources(list(map(get_abspath, sources)), destdir=build_dir,
|
| 553 |
-
cwd=build_dir, **compile_kwargs)
|
| 554 |
-
so = link_py_so(objs, cwd=build_dir, fort=any_fortran_src(sources),
|
| 555 |
-
cplus=any_cplus_src(sources), extra_objs=extra_objs, **link_kwargs)
|
| 556 |
-
mod = import_module_from_file(so)
|
| 557 |
-
return mod
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
def _write_sources_to_build_dir(sources, build_dir):
|
| 561 |
-
build_dir = build_dir or tempfile.mkdtemp()
|
| 562 |
-
if not os.path.isdir(build_dir):
|
| 563 |
-
raise OSError("Non-existent directory: ", build_dir)
|
| 564 |
-
|
| 565 |
-
source_files = []
|
| 566 |
-
for name, src in sources:
|
| 567 |
-
dest = os.path.join(build_dir, name)
|
| 568 |
-
differs = True
|
| 569 |
-
sha256_in_mem = sha256_of_string(src.encode('utf-8')).hexdigest()
|
| 570 |
-
if os.path.exists(dest):
|
| 571 |
-
if os.path.exists(dest + '.sha256'):
|
| 572 |
-
sha256_on_disk = Path(dest + '.sha256').read_text()
|
| 573 |
-
else:
|
| 574 |
-
sha256_on_disk = sha256_of_file(dest).hexdigest()
|
| 575 |
-
|
| 576 |
-
differs = sha256_on_disk != sha256_in_mem
|
| 577 |
-
if differs:
|
| 578 |
-
with open(dest, 'wt') as fh:
|
| 579 |
-
fh.write(src)
|
| 580 |
-
with open(dest + '.sha256', 'wt') as fh:
|
| 581 |
-
fh.write(sha256_in_mem)
|
| 582 |
-
source_files.append(dest)
|
| 583 |
-
return source_files, build_dir
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
def compile_link_import_strings(sources, build_dir=None, **kwargs):
|
| 587 |
-
""" Compiles, links and imports extension module from source.
|
| 588 |
-
|
| 589 |
-
Parameters
|
| 590 |
-
==========
|
| 591 |
-
|
| 592 |
-
sources : iterable of name/source pair tuples
|
| 593 |
-
build_dir : string (default: None)
|
| 594 |
-
Path. ``None`` implies use a temporary directory.
|
| 595 |
-
**kwargs:
|
| 596 |
-
Keyword arguments passed onto `compile_link_import_py_ext`.
|
| 597 |
-
|
| 598 |
-
Returns
|
| 599 |
-
=======
|
| 600 |
-
|
| 601 |
-
mod : module
|
| 602 |
-
The compiled and imported extension module.
|
| 603 |
-
info : dict
|
| 604 |
-
Containing ``build_dir`` as 'build_dir'.
|
| 605 |
-
|
| 606 |
-
"""
|
| 607 |
-
source_files, build_dir = _write_sources_to_build_dir(sources, build_dir)
|
| 608 |
-
mod = compile_link_import_py_ext(source_files, build_dir=build_dir, **kwargs)
|
| 609 |
-
info = {"build_dir": build_dir}
|
| 610 |
-
return mod, info
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
def compile_run_strings(sources, build_dir=None, clean=False, compile_kwargs=None, link_kwargs=None):
|
| 614 |
-
""" Compiles, links and runs a program built from sources.
|
| 615 |
-
|
| 616 |
-
Parameters
|
| 617 |
-
==========
|
| 618 |
-
|
| 619 |
-
sources : iterable of name/source pair tuples
|
| 620 |
-
build_dir : string (default: None)
|
| 621 |
-
Path. ``None`` implies use a temporary directory.
|
| 622 |
-
clean : bool
|
| 623 |
-
Whether to remove build_dir after use. This will only have an
|
| 624 |
-
effect if ``build_dir`` is ``None`` (which creates a temporary directory).
|
| 625 |
-
Passing ``clean == True`` and ``build_dir != None`` raises a ``ValueError``.
|
| 626 |
-
This will also set ``build_dir`` in returned info dictionary to ``None``.
|
| 627 |
-
compile_kwargs: dict
|
| 628 |
-
Keyword arguments passed onto ``compile_sources``
|
| 629 |
-
link_kwargs: dict
|
| 630 |
-
Keyword arguments passed onto ``link``
|
| 631 |
-
|
| 632 |
-
Returns
|
| 633 |
-
=======
|
| 634 |
-
|
| 635 |
-
(stdout, stderr): pair of strings
|
| 636 |
-
info: dict
|
| 637 |
-
Containing exit status as 'exit_status' and ``build_dir`` as 'build_dir'
|
| 638 |
-
|
| 639 |
-
"""
|
| 640 |
-
if clean and build_dir is not None:
|
| 641 |
-
raise ValueError("Automatic removal of build_dir is only available for temporary directory.")
|
| 642 |
-
try:
|
| 643 |
-
source_files, build_dir = _write_sources_to_build_dir(sources, build_dir)
|
| 644 |
-
objs = compile_sources(list(map(get_abspath, source_files)), destdir=build_dir,
|
| 645 |
-
cwd=build_dir, **(compile_kwargs or {}))
|
| 646 |
-
prog = link(objs, cwd=build_dir,
|
| 647 |
-
fort=any_fortran_src(source_files),
|
| 648 |
-
cplus=any_cplus_src(source_files), **(link_kwargs or {}))
|
| 649 |
-
p = subprocess.Popen([prog], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 650 |
-
exit_status = p.wait()
|
| 651 |
-
stdout, stderr = [txt.decode('utf-8') for txt in p.communicate()]
|
| 652 |
-
finally:
|
| 653 |
-
if clean and os.path.isdir(build_dir):
|
| 654 |
-
shutil.rmtree(build_dir)
|
| 655 |
-
build_dir = None
|
| 656 |
-
info = {"exit_status": exit_status, "build_dir": build_dir}
|
| 657 |
-
return (stdout, stderr), info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/runners.py
DELETED
|
@@ -1,301 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
from typing import Callable, Optional
|
| 3 |
-
|
| 4 |
-
from collections import OrderedDict
|
| 5 |
-
import os
|
| 6 |
-
import re
|
| 7 |
-
import subprocess
|
| 8 |
-
import warnings
|
| 9 |
-
|
| 10 |
-
from .util import (
|
| 11 |
-
find_binary_of_command, unique_list, CompileError
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class CompilerRunner:
|
| 16 |
-
""" CompilerRunner base class.
|
| 17 |
-
|
| 18 |
-
Parameters
|
| 19 |
-
==========
|
| 20 |
-
|
| 21 |
-
sources : list of str
|
| 22 |
-
Paths to sources.
|
| 23 |
-
out : str
|
| 24 |
-
flags : iterable of str
|
| 25 |
-
Compiler flags.
|
| 26 |
-
run_linker : bool
|
| 27 |
-
compiler_name_exe : (str, str) tuple
|
| 28 |
-
Tuple of compiler name & command to call.
|
| 29 |
-
cwd : str
|
| 30 |
-
Path of root of relative paths.
|
| 31 |
-
include_dirs : list of str
|
| 32 |
-
Include directories.
|
| 33 |
-
libraries : list of str
|
| 34 |
-
Libraries to link against.
|
| 35 |
-
library_dirs : list of str
|
| 36 |
-
Paths to search for shared libraries.
|
| 37 |
-
std : str
|
| 38 |
-
Standard string, e.g. ``'c++11'``, ``'c99'``, ``'f2003'``.
|
| 39 |
-
define: iterable of strings
|
| 40 |
-
macros to define
|
| 41 |
-
undef : iterable of strings
|
| 42 |
-
macros to undefine
|
| 43 |
-
preferred_vendor : string
|
| 44 |
-
name of preferred vendor e.g. 'gnu' or 'intel'
|
| 45 |
-
|
| 46 |
-
Methods
|
| 47 |
-
=======
|
| 48 |
-
|
| 49 |
-
run():
|
| 50 |
-
Invoke compilation as a subprocess.
|
| 51 |
-
|
| 52 |
-
"""
|
| 53 |
-
|
| 54 |
-
environ_key_compiler: str # e.g. 'CC', 'CXX', ...
|
| 55 |
-
environ_key_flags: str # e.g. 'CFLAGS', 'CXXFLAGS', ...
|
| 56 |
-
environ_key_ldflags: str = "LDFLAGS" # typically 'LDFLAGS'
|
| 57 |
-
|
| 58 |
-
# Subclass to vendor/binary dict
|
| 59 |
-
compiler_dict: dict[str, str]
|
| 60 |
-
|
| 61 |
-
# Standards should be a tuple of supported standards
|
| 62 |
-
# (first one will be the default)
|
| 63 |
-
standards: tuple[None | str, ...]
|
| 64 |
-
|
| 65 |
-
# Subclass to dict of binary/formater-callback
|
| 66 |
-
std_formater: dict[str, Callable[[Optional[str]], str]]
|
| 67 |
-
|
| 68 |
-
# subclass to be e.g. {'gcc': 'gnu', ...}
|
| 69 |
-
compiler_name_vendor_mapping: dict[str, str]
|
| 70 |
-
|
| 71 |
-
def __init__(self, sources, out, flags=None, run_linker=True, compiler=None, cwd='.',
|
| 72 |
-
include_dirs=None, libraries=None, library_dirs=None, std=None, define=None,
|
| 73 |
-
undef=None, strict_aliasing=None, preferred_vendor=None, linkline=None, **kwargs):
|
| 74 |
-
if isinstance(sources, str):
|
| 75 |
-
raise ValueError("Expected argument sources to be a list of strings.")
|
| 76 |
-
self.sources = list(sources)
|
| 77 |
-
self.out = out
|
| 78 |
-
self.flags = flags or []
|
| 79 |
-
if os.environ.get(self.environ_key_flags):
|
| 80 |
-
self.flags += os.environ[self.environ_key_flags].split()
|
| 81 |
-
self.cwd = cwd
|
| 82 |
-
if compiler:
|
| 83 |
-
self.compiler_name, self.compiler_binary = compiler
|
| 84 |
-
elif os.environ.get(self.environ_key_compiler):
|
| 85 |
-
self.compiler_binary = os.environ[self.environ_key_compiler]
|
| 86 |
-
for k, v in self.compiler_dict.items():
|
| 87 |
-
if k in self.compiler_binary:
|
| 88 |
-
self.compiler_vendor = k
|
| 89 |
-
self.compiler_name = v
|
| 90 |
-
break
|
| 91 |
-
else:
|
| 92 |
-
self.compiler_vendor, self.compiler_name = list(self.compiler_dict.items())[0]
|
| 93 |
-
warnings.warn("failed to determine what kind of compiler %s is, assuming %s" %
|
| 94 |
-
(self.compiler_binary, self.compiler_name))
|
| 95 |
-
else:
|
| 96 |
-
# Find a compiler
|
| 97 |
-
if preferred_vendor is None:
|
| 98 |
-
preferred_vendor = os.environ.get('SYMPY_COMPILER_VENDOR', None)
|
| 99 |
-
self.compiler_name, self.compiler_binary, self.compiler_vendor = self.find_compiler(preferred_vendor)
|
| 100 |
-
if self.compiler_binary is None:
|
| 101 |
-
raise ValueError("No compiler found (searched: {})".format(', '.join(self.compiler_dict.values())))
|
| 102 |
-
self.define = define or []
|
| 103 |
-
self.undef = undef or []
|
| 104 |
-
self.include_dirs = include_dirs or []
|
| 105 |
-
self.libraries = libraries or []
|
| 106 |
-
self.library_dirs = library_dirs or []
|
| 107 |
-
self.std = std or self.standards[0]
|
| 108 |
-
self.run_linker = run_linker
|
| 109 |
-
if self.run_linker:
|
| 110 |
-
# both gnu and intel compilers use '-c' for disabling linker
|
| 111 |
-
self.flags = list(filter(lambda x: x != '-c', self.flags))
|
| 112 |
-
else:
|
| 113 |
-
if '-c' not in self.flags:
|
| 114 |
-
self.flags.append('-c')
|
| 115 |
-
|
| 116 |
-
if self.std:
|
| 117 |
-
self.flags.append(self.std_formater[
|
| 118 |
-
self.compiler_name](self.std))
|
| 119 |
-
|
| 120 |
-
self.linkline = (linkline or []) + [lf for lf in map(
|
| 121 |
-
str.strip, os.environ.get(self.environ_key_ldflags, "").split()
|
| 122 |
-
) if lf != ""]
|
| 123 |
-
|
| 124 |
-
if strict_aliasing is not None:
|
| 125 |
-
nsa_re = re.compile("no-strict-aliasing$")
|
| 126 |
-
sa_re = re.compile("strict-aliasing$")
|
| 127 |
-
if strict_aliasing is True:
|
| 128 |
-
if any(map(nsa_re.match, flags)):
|
| 129 |
-
raise CompileError("Strict aliasing cannot be both enforced and disabled")
|
| 130 |
-
elif any(map(sa_re.match, flags)):
|
| 131 |
-
pass # already enforced
|
| 132 |
-
else:
|
| 133 |
-
flags.append('-fstrict-aliasing')
|
| 134 |
-
elif strict_aliasing is False:
|
| 135 |
-
if any(map(nsa_re.match, flags)):
|
| 136 |
-
pass # already disabled
|
| 137 |
-
else:
|
| 138 |
-
if any(map(sa_re.match, flags)):
|
| 139 |
-
raise CompileError("Strict aliasing cannot be both enforced and disabled")
|
| 140 |
-
else:
|
| 141 |
-
flags.append('-fno-strict-aliasing')
|
| 142 |
-
else:
|
| 143 |
-
msg = "Expected argument strict_aliasing to be True/False, got {}"
|
| 144 |
-
raise ValueError(msg.format(strict_aliasing))
|
| 145 |
-
|
| 146 |
-
@classmethod
|
| 147 |
-
def find_compiler(cls, preferred_vendor=None):
|
| 148 |
-
""" Identify a suitable C/fortran/other compiler. """
|
| 149 |
-
candidates = list(cls.compiler_dict.keys())
|
| 150 |
-
if preferred_vendor:
|
| 151 |
-
if preferred_vendor in candidates:
|
| 152 |
-
candidates = [preferred_vendor]+candidates
|
| 153 |
-
else:
|
| 154 |
-
raise ValueError("Unknown vendor {}".format(preferred_vendor))
|
| 155 |
-
name, path = find_binary_of_command([cls.compiler_dict[x] for x in candidates])
|
| 156 |
-
return name, path, cls.compiler_name_vendor_mapping[name]
|
| 157 |
-
|
| 158 |
-
def cmd(self):
|
| 159 |
-
""" List of arguments (str) to be passed to e.g. ``subprocess.Popen``. """
|
| 160 |
-
cmd = (
|
| 161 |
-
[self.compiler_binary] +
|
| 162 |
-
self.flags +
|
| 163 |
-
['-U'+x for x in self.undef] +
|
| 164 |
-
['-D'+x for x in self.define] +
|
| 165 |
-
['-I'+x for x in self.include_dirs] +
|
| 166 |
-
self.sources
|
| 167 |
-
)
|
| 168 |
-
if self.run_linker:
|
| 169 |
-
cmd += (['-L'+x for x in self.library_dirs] +
|
| 170 |
-
['-l'+x for x in self.libraries] +
|
| 171 |
-
self.linkline)
|
| 172 |
-
counted = []
|
| 173 |
-
for envvar in re.findall(r'\$\{(\w+)\}', ' '.join(cmd)):
|
| 174 |
-
if os.getenv(envvar) is None:
|
| 175 |
-
if envvar not in counted:
|
| 176 |
-
counted.append(envvar)
|
| 177 |
-
msg = "Environment variable '{}' undefined.".format(envvar)
|
| 178 |
-
raise CompileError(msg)
|
| 179 |
-
return cmd
|
| 180 |
-
|
| 181 |
-
def run(self):
|
| 182 |
-
self.flags = unique_list(self.flags)
|
| 183 |
-
|
| 184 |
-
# Append output flag and name to tail of flags
|
| 185 |
-
self.flags.extend(['-o', self.out])
|
| 186 |
-
env = os.environ.copy()
|
| 187 |
-
env['PWD'] = self.cwd
|
| 188 |
-
|
| 189 |
-
# NOTE: intel compilers seems to need shell=True
|
| 190 |
-
p = subprocess.Popen(' '.join(self.cmd()),
|
| 191 |
-
shell=True,
|
| 192 |
-
cwd=self.cwd,
|
| 193 |
-
stdin=subprocess.PIPE,
|
| 194 |
-
stdout=subprocess.PIPE,
|
| 195 |
-
stderr=subprocess.STDOUT,
|
| 196 |
-
env=env)
|
| 197 |
-
comm = p.communicate()
|
| 198 |
-
try:
|
| 199 |
-
self.cmd_outerr = comm[0].decode('utf-8')
|
| 200 |
-
except UnicodeDecodeError:
|
| 201 |
-
self.cmd_outerr = comm[0].decode('iso-8859-1') # win32
|
| 202 |
-
self.cmd_returncode = p.returncode
|
| 203 |
-
|
| 204 |
-
# Error handling
|
| 205 |
-
if self.cmd_returncode != 0:
|
| 206 |
-
msg = "Error executing '{}' in {} (exited status {}):\n {}\n".format(
|
| 207 |
-
' '.join(self.cmd()), self.cwd, str(self.cmd_returncode), self.cmd_outerr
|
| 208 |
-
)
|
| 209 |
-
raise CompileError(msg)
|
| 210 |
-
|
| 211 |
-
return self.cmd_outerr, self.cmd_returncode
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
class CCompilerRunner(CompilerRunner):
|
| 215 |
-
|
| 216 |
-
environ_key_compiler = 'CC'
|
| 217 |
-
environ_key_flags = 'CFLAGS'
|
| 218 |
-
|
| 219 |
-
compiler_dict = OrderedDict([
|
| 220 |
-
('gnu', 'gcc'),
|
| 221 |
-
('intel', 'icc'),
|
| 222 |
-
('llvm', 'clang'),
|
| 223 |
-
])
|
| 224 |
-
|
| 225 |
-
standards = ('c89', 'c90', 'c99', 'c11') # First is default
|
| 226 |
-
|
| 227 |
-
std_formater = {
|
| 228 |
-
'gcc': '-std={}'.format,
|
| 229 |
-
'icc': '-std={}'.format,
|
| 230 |
-
'clang': '-std={}'.format,
|
| 231 |
-
}
|
| 232 |
-
|
| 233 |
-
compiler_name_vendor_mapping = {
|
| 234 |
-
'gcc': 'gnu',
|
| 235 |
-
'icc': 'intel',
|
| 236 |
-
'clang': 'llvm'
|
| 237 |
-
}
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def _mk_flag_filter(cmplr_name): # helper for class initialization
|
| 241 |
-
not_welcome = {'g++': ("Wimplicit-interface",)} # "Wstrict-prototypes",)}
|
| 242 |
-
if cmplr_name in not_welcome:
|
| 243 |
-
def fltr(x):
|
| 244 |
-
for nw in not_welcome[cmplr_name]:
|
| 245 |
-
if nw in x:
|
| 246 |
-
return False
|
| 247 |
-
return True
|
| 248 |
-
else:
|
| 249 |
-
def fltr(x):
|
| 250 |
-
return True
|
| 251 |
-
return fltr
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
class CppCompilerRunner(CompilerRunner):
|
| 255 |
-
|
| 256 |
-
environ_key_compiler = 'CXX'
|
| 257 |
-
environ_key_flags = 'CXXFLAGS'
|
| 258 |
-
|
| 259 |
-
compiler_dict = OrderedDict([
|
| 260 |
-
('gnu', 'g++'),
|
| 261 |
-
('intel', 'icpc'),
|
| 262 |
-
('llvm', 'clang++'),
|
| 263 |
-
])
|
| 264 |
-
|
| 265 |
-
# First is the default, c++0x == c++11
|
| 266 |
-
standards = ('c++98', 'c++0x')
|
| 267 |
-
|
| 268 |
-
std_formater = {
|
| 269 |
-
'g++': '-std={}'.format,
|
| 270 |
-
'icpc': '-std={}'.format,
|
| 271 |
-
'clang++': '-std={}'.format,
|
| 272 |
-
}
|
| 273 |
-
|
| 274 |
-
compiler_name_vendor_mapping = {
|
| 275 |
-
'g++': 'gnu',
|
| 276 |
-
'icpc': 'intel',
|
| 277 |
-
'clang++': 'llvm'
|
| 278 |
-
}
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
class FortranCompilerRunner(CompilerRunner):
|
| 282 |
-
|
| 283 |
-
environ_key_compiler = 'FC'
|
| 284 |
-
environ_key_flags = 'FFLAGS'
|
| 285 |
-
|
| 286 |
-
standards = (None, 'f77', 'f95', 'f2003', 'f2008')
|
| 287 |
-
|
| 288 |
-
std_formater = {
|
| 289 |
-
'gfortran': lambda x: '-std=gnu' if x is None else '-std=legacy' if x == 'f77' else '-std={}'.format(x),
|
| 290 |
-
'ifort': lambda x: '-stand f08' if x is None else '-stand f{}'.format(x[-2:]), # f2008 => f08
|
| 291 |
-
}
|
| 292 |
-
|
| 293 |
-
compiler_dict = OrderedDict([
|
| 294 |
-
('gnu', 'gfortran'),
|
| 295 |
-
('intel', 'ifort'),
|
| 296 |
-
])
|
| 297 |
-
|
| 298 |
-
compiler_name_vendor_mapping = {
|
| 299 |
-
'gfortran': 'gnu',
|
| 300 |
-
'ifort': 'intel',
|
| 301 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/__init__.py
DELETED
|
File without changes
|
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/test_compilation.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
| 1 |
-
import shutil
|
| 2 |
-
import os
|
| 3 |
-
import subprocess
|
| 4 |
-
import tempfile
|
| 5 |
-
from sympy.external import import_module
|
| 6 |
-
from sympy.testing.pytest import skip, skip_under_pyodide
|
| 7 |
-
|
| 8 |
-
from sympy.utilities._compilation.compilation import compile_link_import_py_ext, compile_link_import_strings, compile_sources, get_abspath
|
| 9 |
-
|
| 10 |
-
numpy = import_module('numpy')
|
| 11 |
-
cython = import_module('cython')
|
| 12 |
-
|
| 13 |
-
_sources1 = [
|
| 14 |
-
('sigmoid.c', r"""
|
| 15 |
-
#include <math.h>
|
| 16 |
-
|
| 17 |
-
void sigmoid(int n, const double * const restrict in,
|
| 18 |
-
double * const restrict out, double lim){
|
| 19 |
-
for (int i=0; i<n; ++i){
|
| 20 |
-
const double x = in[i];
|
| 21 |
-
out[i] = x*pow(pow(x/lim, 8)+1, -1./8.);
|
| 22 |
-
}
|
| 23 |
-
}
|
| 24 |
-
"""),
|
| 25 |
-
('_sigmoid.pyx', r"""
|
| 26 |
-
import numpy as np
|
| 27 |
-
cimport numpy as cnp
|
| 28 |
-
|
| 29 |
-
cdef extern void c_sigmoid "sigmoid" (int, const double * const,
|
| 30 |
-
double * const, double)
|
| 31 |
-
|
| 32 |
-
def sigmoid(double [:] inp, double lim=350.0):
|
| 33 |
-
cdef cnp.ndarray[cnp.float64_t, ndim=1] out = np.empty(
|
| 34 |
-
inp.size, dtype=np.float64)
|
| 35 |
-
c_sigmoid(inp.size, &inp[0], &out[0], lim)
|
| 36 |
-
return out
|
| 37 |
-
""")
|
| 38 |
-
]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def npy(data, lim=350.0):
|
| 42 |
-
return data/((data/lim)**8+1)**(1/8.)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def test_compile_link_import_strings():
|
| 46 |
-
if not numpy:
|
| 47 |
-
skip("numpy not installed.")
|
| 48 |
-
if not cython:
|
| 49 |
-
skip("cython not installed.")
|
| 50 |
-
|
| 51 |
-
from sympy.utilities._compilation import has_c
|
| 52 |
-
if not has_c():
|
| 53 |
-
skip("No C compiler found.")
|
| 54 |
-
|
| 55 |
-
compile_kw = {"std": 'c99', "include_dirs": [numpy.get_include()]}
|
| 56 |
-
info = None
|
| 57 |
-
try:
|
| 58 |
-
mod, info = compile_link_import_strings(_sources1, compile_kwargs=compile_kw)
|
| 59 |
-
data = numpy.random.random(1024*1024*8) # 64 MB of RAM needed..
|
| 60 |
-
res_mod = mod.sigmoid(data)
|
| 61 |
-
res_npy = npy(data)
|
| 62 |
-
assert numpy.allclose(res_mod, res_npy)
|
| 63 |
-
finally:
|
| 64 |
-
if info and info['build_dir']:
|
| 65 |
-
shutil.rmtree(info['build_dir'])
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
@skip_under_pyodide("Emscripten does not support subprocesses")
|
| 69 |
-
def test_compile_sources():
|
| 70 |
-
tmpdir = tempfile.mkdtemp()
|
| 71 |
-
|
| 72 |
-
from sympy.utilities._compilation import has_c
|
| 73 |
-
if not has_c():
|
| 74 |
-
skip("No C compiler found.")
|
| 75 |
-
|
| 76 |
-
build_dir = str(tmpdir)
|
| 77 |
-
_handle, file_path = tempfile.mkstemp('.c', dir=build_dir)
|
| 78 |
-
with open(file_path, 'wt') as ofh:
|
| 79 |
-
ofh.write("""
|
| 80 |
-
int foo(int bar) {
|
| 81 |
-
return 2*bar;
|
| 82 |
-
}
|
| 83 |
-
""")
|
| 84 |
-
obj, = compile_sources([file_path], cwd=build_dir)
|
| 85 |
-
obj_path = get_abspath(obj, cwd=build_dir)
|
| 86 |
-
assert os.path.exists(obj_path)
|
| 87 |
-
try:
|
| 88 |
-
_ = subprocess.check_output(["nm", "--help"])
|
| 89 |
-
except subprocess.CalledProcessError:
|
| 90 |
-
pass # we cannot test contents of object file
|
| 91 |
-
else:
|
| 92 |
-
nm_out = subprocess.check_output(["nm", obj_path])
|
| 93 |
-
assert 'foo' in nm_out.decode('utf-8')
|
| 94 |
-
|
| 95 |
-
if not cython:
|
| 96 |
-
return # the final (optional) part of the test below requires Cython.
|
| 97 |
-
|
| 98 |
-
_handle, pyx_path = tempfile.mkstemp('.pyx', dir=build_dir)
|
| 99 |
-
with open(pyx_path, 'wt') as ofh:
|
| 100 |
-
ofh.write(("cdef extern int foo(int)\n"
|
| 101 |
-
"def _foo(arg):\n"
|
| 102 |
-
" return foo(arg)"))
|
| 103 |
-
mod = compile_link_import_py_ext([pyx_path], extra_objs=[obj_path], build_dir=build_dir)
|
| 104 |
-
assert mod._foo(21) == 42
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/util.py
DELETED
|
@@ -1,312 +0,0 @@
|
|
| 1 |
-
from collections import namedtuple
|
| 2 |
-
from hashlib import sha256
|
| 3 |
-
import os
|
| 4 |
-
import shutil
|
| 5 |
-
import sys
|
| 6 |
-
import fnmatch
|
| 7 |
-
|
| 8 |
-
from sympy.testing.pytest import XFAIL
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def may_xfail(func):
|
| 12 |
-
if sys.platform.lower() == 'darwin' or os.name == 'nt':
|
| 13 |
-
# sympy.utilities._compilation needs more testing on Windows and macOS
|
| 14 |
-
# once those two platforms are reliably supported this xfail decorator
|
| 15 |
-
# may be removed.
|
| 16 |
-
return XFAIL(func)
|
| 17 |
-
else:
|
| 18 |
-
return func
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class CompilerNotFoundError(FileNotFoundError):
|
| 22 |
-
pass
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class CompileError (Exception):
|
| 26 |
-
"""Failure to compile one or more C/C++ source files."""
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def get_abspath(path, cwd='.'):
|
| 30 |
-
""" Returns the absolute path.
|
| 31 |
-
|
| 32 |
-
Parameters
|
| 33 |
-
==========
|
| 34 |
-
|
| 35 |
-
path : str
|
| 36 |
-
(relative) path.
|
| 37 |
-
cwd : str
|
| 38 |
-
Path to root of relative path.
|
| 39 |
-
"""
|
| 40 |
-
if os.path.isabs(path):
|
| 41 |
-
return path
|
| 42 |
-
else:
|
| 43 |
-
if not os.path.isabs(cwd):
|
| 44 |
-
cwd = os.path.abspath(cwd)
|
| 45 |
-
return os.path.abspath(
|
| 46 |
-
os.path.join(cwd, path)
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def make_dirs(path):
|
| 51 |
-
""" Create directories (equivalent of ``mkdir -p``). """
|
| 52 |
-
if path[-1] == '/':
|
| 53 |
-
parent = os.path.dirname(path[:-1])
|
| 54 |
-
else:
|
| 55 |
-
parent = os.path.dirname(path)
|
| 56 |
-
|
| 57 |
-
if len(parent) > 0:
|
| 58 |
-
if not os.path.exists(parent):
|
| 59 |
-
make_dirs(parent)
|
| 60 |
-
|
| 61 |
-
if not os.path.exists(path):
|
| 62 |
-
os.mkdir(path, 0o777)
|
| 63 |
-
else:
|
| 64 |
-
assert os.path.isdir(path)
|
| 65 |
-
|
| 66 |
-
def missing_or_other_newer(path, other_path, cwd=None):
|
| 67 |
-
"""
|
| 68 |
-
Investigate if path is non-existent or older than provided reference
|
| 69 |
-
path.
|
| 70 |
-
|
| 71 |
-
Parameters
|
| 72 |
-
==========
|
| 73 |
-
path: string
|
| 74 |
-
path to path which might be missing or too old
|
| 75 |
-
other_path: string
|
| 76 |
-
reference path
|
| 77 |
-
cwd: string
|
| 78 |
-
working directory (root of relative paths)
|
| 79 |
-
|
| 80 |
-
Returns
|
| 81 |
-
=======
|
| 82 |
-
True if path is older or missing.
|
| 83 |
-
"""
|
| 84 |
-
cwd = cwd or '.'
|
| 85 |
-
path = get_abspath(path, cwd=cwd)
|
| 86 |
-
other_path = get_abspath(other_path, cwd=cwd)
|
| 87 |
-
if not os.path.exists(path):
|
| 88 |
-
return True
|
| 89 |
-
if os.path.getmtime(other_path) - 1e-6 >= os.path.getmtime(path):
|
| 90 |
-
# 1e-6 is needed because http://stackoverflow.com/questions/17086426/
|
| 91 |
-
return True
|
| 92 |
-
return False
|
| 93 |
-
|
| 94 |
-
def copy(src, dst, only_update=False, copystat=True, cwd=None,
|
| 95 |
-
dest_is_dir=False, create_dest_dirs=False):
|
| 96 |
-
""" Variation of ``shutil.copy`` with extra options.
|
| 97 |
-
|
| 98 |
-
Parameters
|
| 99 |
-
==========
|
| 100 |
-
|
| 101 |
-
src : str
|
| 102 |
-
Path to source file.
|
| 103 |
-
dst : str
|
| 104 |
-
Path to destination.
|
| 105 |
-
only_update : bool
|
| 106 |
-
Only copy if source is newer than destination
|
| 107 |
-
(returns None if it was newer), default: ``False``.
|
| 108 |
-
copystat : bool
|
| 109 |
-
See ``shutil.copystat``. default: ``True``.
|
| 110 |
-
cwd : str
|
| 111 |
-
Path to working directory (root of relative paths).
|
| 112 |
-
dest_is_dir : bool
|
| 113 |
-
Ensures that dst is treated as a directory. default: ``False``
|
| 114 |
-
create_dest_dirs : bool
|
| 115 |
-
Creates directories if needed.
|
| 116 |
-
|
| 117 |
-
Returns
|
| 118 |
-
=======
|
| 119 |
-
|
| 120 |
-
Path to the copied file.
|
| 121 |
-
|
| 122 |
-
"""
|
| 123 |
-
if cwd: # Handle working directory
|
| 124 |
-
if not os.path.isabs(src):
|
| 125 |
-
src = os.path.join(cwd, src)
|
| 126 |
-
if not os.path.isabs(dst):
|
| 127 |
-
dst = os.path.join(cwd, dst)
|
| 128 |
-
|
| 129 |
-
if not os.path.exists(src): # Make sure source file exists
|
| 130 |
-
raise FileNotFoundError("Source: `{}` does not exist".format(src))
|
| 131 |
-
|
| 132 |
-
# We accept both (re)naming destination file _or_
|
| 133 |
-
# passing a (possible non-existent) destination directory
|
| 134 |
-
if dest_is_dir:
|
| 135 |
-
if not dst[-1] == '/':
|
| 136 |
-
dst = dst+'/'
|
| 137 |
-
else:
|
| 138 |
-
if os.path.exists(dst) and os.path.isdir(dst):
|
| 139 |
-
dest_is_dir = True
|
| 140 |
-
|
| 141 |
-
if dest_is_dir:
|
| 142 |
-
dest_dir = dst
|
| 143 |
-
dest_fname = os.path.basename(src)
|
| 144 |
-
dst = os.path.join(dest_dir, dest_fname)
|
| 145 |
-
else:
|
| 146 |
-
dest_dir = os.path.dirname(dst)
|
| 147 |
-
|
| 148 |
-
if not os.path.exists(dest_dir):
|
| 149 |
-
if create_dest_dirs:
|
| 150 |
-
make_dirs(dest_dir)
|
| 151 |
-
else:
|
| 152 |
-
raise FileNotFoundError("You must create directory first.")
|
| 153 |
-
|
| 154 |
-
if only_update:
|
| 155 |
-
if not missing_or_other_newer(dst, src):
|
| 156 |
-
return
|
| 157 |
-
|
| 158 |
-
if os.path.islink(dst):
|
| 159 |
-
dst = os.path.abspath(os.path.realpath(dst), cwd=cwd)
|
| 160 |
-
|
| 161 |
-
shutil.copy(src, dst)
|
| 162 |
-
if copystat:
|
| 163 |
-
shutil.copystat(src, dst)
|
| 164 |
-
|
| 165 |
-
return dst
|
| 166 |
-
|
| 167 |
-
Glob = namedtuple('Glob', 'pathname')
|
| 168 |
-
ArbitraryDepthGlob = namedtuple('ArbitraryDepthGlob', 'filename')
|
| 169 |
-
|
| 170 |
-
def glob_at_depth(filename_glob, cwd=None):
|
| 171 |
-
if cwd is not None:
|
| 172 |
-
cwd = '.'
|
| 173 |
-
globbed = []
|
| 174 |
-
for root, dirs, filenames in os.walk(cwd):
|
| 175 |
-
for fn in filenames:
|
| 176 |
-
# This is not tested:
|
| 177 |
-
if fnmatch.fnmatch(fn, filename_glob):
|
| 178 |
-
globbed.append(os.path.join(root, fn))
|
| 179 |
-
return globbed
|
| 180 |
-
|
| 181 |
-
def sha256_of_file(path, nblocks=128):
|
| 182 |
-
""" Computes the SHA256 hash of a file.
|
| 183 |
-
|
| 184 |
-
Parameters
|
| 185 |
-
==========
|
| 186 |
-
|
| 187 |
-
path : string
|
| 188 |
-
Path to file to compute hash of.
|
| 189 |
-
nblocks : int
|
| 190 |
-
Number of blocks to read per iteration.
|
| 191 |
-
|
| 192 |
-
Returns
|
| 193 |
-
=======
|
| 194 |
-
|
| 195 |
-
hashlib sha256 hash object. Use ``.digest()`` or ``.hexdigest()``
|
| 196 |
-
on returned object to get binary or hex encoded string.
|
| 197 |
-
"""
|
| 198 |
-
sh = sha256()
|
| 199 |
-
with open(path, 'rb') as f:
|
| 200 |
-
for chunk in iter(lambda: f.read(nblocks*sh.block_size), b''):
|
| 201 |
-
sh.update(chunk)
|
| 202 |
-
return sh
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def sha256_of_string(string):
|
| 206 |
-
""" Computes the SHA256 hash of a string. """
|
| 207 |
-
sh = sha256()
|
| 208 |
-
sh.update(string)
|
| 209 |
-
return sh
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def pyx_is_cplus(path):
|
| 213 |
-
"""
|
| 214 |
-
Inspect a Cython source file (.pyx) and look for comment line like:
|
| 215 |
-
|
| 216 |
-
# distutils: language = c++
|
| 217 |
-
|
| 218 |
-
Returns True if such a file is present in the file, else False.
|
| 219 |
-
"""
|
| 220 |
-
with open(path) as fh:
|
| 221 |
-
for line in fh:
|
| 222 |
-
if line.startswith('#') and '=' in line:
|
| 223 |
-
splitted = line.split('=')
|
| 224 |
-
if len(splitted) != 2:
|
| 225 |
-
continue
|
| 226 |
-
lhs, rhs = splitted
|
| 227 |
-
if lhs.strip().split()[-1].lower() == 'language' and \
|
| 228 |
-
rhs.strip().split()[0].lower() == 'c++':
|
| 229 |
-
return True
|
| 230 |
-
return False
|
| 231 |
-
|
| 232 |
-
def import_module_from_file(filename, only_if_newer_than=None):
|
| 233 |
-
""" Imports Python extension (from shared object file)
|
| 234 |
-
|
| 235 |
-
Provide a list of paths in `only_if_newer_than` to check
|
| 236 |
-
timestamps of dependencies. import_ raises an ImportError
|
| 237 |
-
if any is newer.
|
| 238 |
-
|
| 239 |
-
Word of warning: The OS may cache shared objects which makes
|
| 240 |
-
reimporting same path of an shared object file very problematic.
|
| 241 |
-
|
| 242 |
-
It will not detect the new time stamp, nor new checksum, but will
|
| 243 |
-
instead silently use old module. Use unique names for this reason.
|
| 244 |
-
|
| 245 |
-
Parameters
|
| 246 |
-
==========
|
| 247 |
-
|
| 248 |
-
filename : str
|
| 249 |
-
Path to shared object.
|
| 250 |
-
only_if_newer_than : iterable of strings
|
| 251 |
-
Paths to dependencies of the shared object.
|
| 252 |
-
|
| 253 |
-
Raises
|
| 254 |
-
======
|
| 255 |
-
|
| 256 |
-
``ImportError`` if any of the files specified in ``only_if_newer_than`` are newer
|
| 257 |
-
than the file given by filename.
|
| 258 |
-
"""
|
| 259 |
-
path, name = os.path.split(filename)
|
| 260 |
-
name, ext = os.path.splitext(name)
|
| 261 |
-
name = name.split('.')[0]
|
| 262 |
-
if sys.version_info[0] == 2:
|
| 263 |
-
from imp import find_module, load_module
|
| 264 |
-
fobj, filename, data = find_module(name, [path])
|
| 265 |
-
if only_if_newer_than:
|
| 266 |
-
for dep in only_if_newer_than:
|
| 267 |
-
if os.path.getmtime(filename) < os.path.getmtime(dep):
|
| 268 |
-
raise ImportError("{} is newer than {}".format(dep, filename))
|
| 269 |
-
mod = load_module(name, fobj, filename, data)
|
| 270 |
-
else:
|
| 271 |
-
import importlib.util
|
| 272 |
-
spec = importlib.util.spec_from_file_location(name, filename)
|
| 273 |
-
if spec is None:
|
| 274 |
-
raise ImportError("Failed to import: '%s'" % filename)
|
| 275 |
-
mod = importlib.util.module_from_spec(spec)
|
| 276 |
-
spec.loader.exec_module(mod)
|
| 277 |
-
return mod
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
def find_binary_of_command(candidates):
|
| 281 |
-
""" Finds binary first matching name among candidates.
|
| 282 |
-
|
| 283 |
-
Calls ``which`` from shutils for provided candidates and returns
|
| 284 |
-
first hit.
|
| 285 |
-
|
| 286 |
-
Parameters
|
| 287 |
-
==========
|
| 288 |
-
|
| 289 |
-
candidates : iterable of str
|
| 290 |
-
Names of candidate commands
|
| 291 |
-
|
| 292 |
-
Raises
|
| 293 |
-
======
|
| 294 |
-
|
| 295 |
-
CompilerNotFoundError if no candidates match.
|
| 296 |
-
"""
|
| 297 |
-
from shutil import which
|
| 298 |
-
for c in candidates:
|
| 299 |
-
binary_path = which(c)
|
| 300 |
-
if c and binary_path:
|
| 301 |
-
return c, binary_path
|
| 302 |
-
|
| 303 |
-
raise CompilerNotFoundError('No binary located for candidates: {}'.format(candidates))
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def unique_list(l):
|
| 307 |
-
""" Uniquify a list (skip duplicate items). """
|
| 308 |
-
result = []
|
| 309 |
-
for x in l:
|
| 310 |
-
if x not in result:
|
| 311 |
-
result.append(x)
|
| 312 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/autowrap.py
DELETED
|
@@ -1,1178 +0,0 @@
|
|
| 1 |
-
"""Module for compiling codegen output, and wrap the binary for use in
|
| 2 |
-
python.
|
| 3 |
-
|
| 4 |
-
.. note:: To use the autowrap module it must first be imported
|
| 5 |
-
|
| 6 |
-
>>> from sympy.utilities.autowrap import autowrap
|
| 7 |
-
|
| 8 |
-
This module provides a common interface for different external backends, such
|
| 9 |
-
as f2py, fwrap, Cython, SWIG(?) etc. (Currently only f2py and Cython are
|
| 10 |
-
implemented) The goal is to provide access to compiled binaries of acceptable
|
| 11 |
-
performance with a one-button user interface, e.g.,
|
| 12 |
-
|
| 13 |
-
>>> from sympy.abc import x,y
|
| 14 |
-
>>> expr = (x - y)**25
|
| 15 |
-
>>> flat = expr.expand()
|
| 16 |
-
>>> binary_callable = autowrap(flat)
|
| 17 |
-
>>> binary_callable(2, 3)
|
| 18 |
-
-1.0
|
| 19 |
-
|
| 20 |
-
Although a SymPy user might primarily be interested in working with
|
| 21 |
-
mathematical expressions and not in the details of wrapping tools
|
| 22 |
-
needed to evaluate such expressions efficiently in numerical form,
|
| 23 |
-
the user cannot do so without some understanding of the
|
| 24 |
-
limits in the target language. For example, the expanded expression
|
| 25 |
-
contains large coefficients which result in loss of precision when
|
| 26 |
-
computing the expression:
|
| 27 |
-
|
| 28 |
-
>>> binary_callable(3, 2)
|
| 29 |
-
0.0
|
| 30 |
-
>>> binary_callable(4, 5), binary_callable(5, 4)
|
| 31 |
-
(-22925376.0, 25165824.0)
|
| 32 |
-
|
| 33 |
-
Wrapping the unexpanded expression gives the expected behavior:
|
| 34 |
-
|
| 35 |
-
>>> e = autowrap(expr)
|
| 36 |
-
>>> e(4, 5), e(5, 4)
|
| 37 |
-
(-1.0, 1.0)
|
| 38 |
-
|
| 39 |
-
The callable returned from autowrap() is a binary Python function, not a
|
| 40 |
-
SymPy object. If it is desired to use the compiled function in symbolic
|
| 41 |
-
expressions, it is better to use binary_function() which returns a SymPy
|
| 42 |
-
Function object. The binary callable is attached as the _imp_ attribute and
|
| 43 |
-
invoked when a numerical evaluation is requested with evalf(), or with
|
| 44 |
-
lambdify().
|
| 45 |
-
|
| 46 |
-
>>> from sympy.utilities.autowrap import binary_function
|
| 47 |
-
>>> f = binary_function('f', expr)
|
| 48 |
-
>>> 2*f(x, y) + y
|
| 49 |
-
y + 2*f(x, y)
|
| 50 |
-
>>> (2*f(x, y) + y).evalf(2, subs={x: 1, y:2})
|
| 51 |
-
0.e-110
|
| 52 |
-
|
| 53 |
-
When is this useful?
|
| 54 |
-
|
| 55 |
-
1) For computations on large arrays, Python iterations may be too slow,
|
| 56 |
-
and depending on the mathematical expression, it may be difficult to
|
| 57 |
-
exploit the advanced index operations provided by NumPy.
|
| 58 |
-
|
| 59 |
-
2) For *really* long expressions that will be called repeatedly, the
|
| 60 |
-
compiled binary should be significantly faster than SymPy's .evalf()
|
| 61 |
-
|
| 62 |
-
3) If you are generating code with the codegen utility in order to use
|
| 63 |
-
it in another project, the automatic Python wrappers let you test the
|
| 64 |
-
binaries immediately from within SymPy.
|
| 65 |
-
|
| 66 |
-
4) To create customized ufuncs for use with numpy arrays.
|
| 67 |
-
See *ufuncify*.
|
| 68 |
-
|
| 69 |
-
When is this module NOT the best approach?
|
| 70 |
-
|
| 71 |
-
1) If you are really concerned about speed or memory optimizations,
|
| 72 |
-
you will probably get better results by working directly with the
|
| 73 |
-
wrapper tools and the low level code. However, the files generated
|
| 74 |
-
by this utility may provide a useful starting point and reference
|
| 75 |
-
code. Temporary files will be left intact if you supply the keyword
|
| 76 |
-
tempdir="path/to/files/".
|
| 77 |
-
|
| 78 |
-
2) If the array computation can be handled easily by numpy, and you
|
| 79 |
-
do not need the binaries for another project.
|
| 80 |
-
|
| 81 |
-
"""
|
| 82 |
-
|
| 83 |
-
import sys
|
| 84 |
-
import os
|
| 85 |
-
import shutil
|
| 86 |
-
import tempfile
|
| 87 |
-
from pathlib import Path
|
| 88 |
-
from subprocess import STDOUT, CalledProcessError, check_output
|
| 89 |
-
from string import Template
|
| 90 |
-
from warnings import warn
|
| 91 |
-
|
| 92 |
-
from sympy.core.cache import cacheit
|
| 93 |
-
from sympy.core.function import Lambda
|
| 94 |
-
from sympy.core.relational import Eq
|
| 95 |
-
from sympy.core.symbol import Dummy, Symbol
|
| 96 |
-
from sympy.tensor.indexed import Idx, IndexedBase
|
| 97 |
-
from sympy.utilities.codegen import (make_routine, get_code_generator,
|
| 98 |
-
OutputArgument, InOutArgument,
|
| 99 |
-
InputArgument, CodeGenArgumentListError,
|
| 100 |
-
Result, ResultBase, C99CodeGen)
|
| 101 |
-
from sympy.utilities.iterables import iterable
|
| 102 |
-
from sympy.utilities.lambdify import implemented_function
|
| 103 |
-
from sympy.utilities.decorator import doctest_depends_on
|
| 104 |
-
|
| 105 |
-
_doctest_depends_on = {'exe': ('f2py', 'gfortran', 'gcc'),
|
| 106 |
-
'modules': ('numpy',)}
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
class CodeWrapError(Exception):
|
| 110 |
-
pass
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
class CodeWrapper:
|
| 114 |
-
"""Base Class for code wrappers"""
|
| 115 |
-
_filename = "wrapped_code"
|
| 116 |
-
_module_basename = "wrapper_module"
|
| 117 |
-
_module_counter = 0
|
| 118 |
-
|
| 119 |
-
@property
|
| 120 |
-
def filename(self):
|
| 121 |
-
return "%s_%s" % (self._filename, CodeWrapper._module_counter)
|
| 122 |
-
|
| 123 |
-
@property
|
| 124 |
-
def module_name(self):
|
| 125 |
-
return "%s_%s" % (self._module_basename, CodeWrapper._module_counter)
|
| 126 |
-
|
| 127 |
-
def __init__(self, generator, filepath=None, flags=[], verbose=False):
|
| 128 |
-
"""
|
| 129 |
-
generator -- the code generator to use
|
| 130 |
-
"""
|
| 131 |
-
self.generator = generator
|
| 132 |
-
self.filepath = filepath
|
| 133 |
-
self.flags = flags
|
| 134 |
-
self.quiet = not verbose
|
| 135 |
-
|
| 136 |
-
@property
|
| 137 |
-
def include_header(self):
|
| 138 |
-
return bool(self.filepath)
|
| 139 |
-
|
| 140 |
-
@property
|
| 141 |
-
def include_empty(self):
|
| 142 |
-
return bool(self.filepath)
|
| 143 |
-
|
| 144 |
-
def _generate_code(self, main_routine, routines):
|
| 145 |
-
routines.append(main_routine)
|
| 146 |
-
self.generator.write(
|
| 147 |
-
routines, self.filename, True, self.include_header,
|
| 148 |
-
self.include_empty)
|
| 149 |
-
|
| 150 |
-
def wrap_code(self, routine, helpers=None):
|
| 151 |
-
helpers = helpers or []
|
| 152 |
-
if self.filepath:
|
| 153 |
-
workdir = os.path.abspath(self.filepath)
|
| 154 |
-
else:
|
| 155 |
-
workdir = tempfile.mkdtemp("_sympy_compile")
|
| 156 |
-
if not os.access(workdir, os.F_OK):
|
| 157 |
-
os.mkdir(workdir)
|
| 158 |
-
oldwork = os.getcwd()
|
| 159 |
-
os.chdir(workdir)
|
| 160 |
-
try:
|
| 161 |
-
sys.path.append(workdir)
|
| 162 |
-
self._generate_code(routine, helpers)
|
| 163 |
-
self._prepare_files(routine)
|
| 164 |
-
self._process_files(routine)
|
| 165 |
-
mod = __import__(self.module_name)
|
| 166 |
-
finally:
|
| 167 |
-
sys.path.remove(workdir)
|
| 168 |
-
CodeWrapper._module_counter += 1
|
| 169 |
-
os.chdir(oldwork)
|
| 170 |
-
if not self.filepath:
|
| 171 |
-
try:
|
| 172 |
-
shutil.rmtree(workdir)
|
| 173 |
-
except OSError:
|
| 174 |
-
# Could be some issues on Windows
|
| 175 |
-
pass
|
| 176 |
-
|
| 177 |
-
return self._get_wrapped_function(mod, routine.name)
|
| 178 |
-
|
| 179 |
-
def _process_files(self, routine):
|
| 180 |
-
command = self.command
|
| 181 |
-
command.extend(self.flags)
|
| 182 |
-
try:
|
| 183 |
-
retoutput = check_output(command, stderr=STDOUT)
|
| 184 |
-
except CalledProcessError as e:
|
| 185 |
-
raise CodeWrapError(
|
| 186 |
-
"Error while executing command: %s. Command output is:\n%s" % (
|
| 187 |
-
" ".join(command), e.output.decode('utf-8')))
|
| 188 |
-
if not self.quiet:
|
| 189 |
-
print(retoutput)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
class DummyWrapper(CodeWrapper):
|
| 193 |
-
"""Class used for testing independent of backends """
|
| 194 |
-
|
| 195 |
-
template = """# dummy module for testing of SymPy
|
| 196 |
-
def %(name)s():
|
| 197 |
-
return "%(expr)s"
|
| 198 |
-
%(name)s.args = "%(args)s"
|
| 199 |
-
%(name)s.returns = "%(retvals)s"
|
| 200 |
-
"""
|
| 201 |
-
|
| 202 |
-
def _prepare_files(self, routine):
|
| 203 |
-
return
|
| 204 |
-
|
| 205 |
-
def _generate_code(self, routine, helpers):
|
| 206 |
-
with open('%s.py' % self.module_name, 'w') as f:
|
| 207 |
-
printed = ", ".join(
|
| 208 |
-
[str(res.expr) for res in routine.result_variables])
|
| 209 |
-
# convert OutputArguments to return value like f2py
|
| 210 |
-
args = filter(lambda x: not isinstance(
|
| 211 |
-
x, OutputArgument), routine.arguments)
|
| 212 |
-
retvals = []
|
| 213 |
-
for val in routine.result_variables:
|
| 214 |
-
if isinstance(val, Result):
|
| 215 |
-
retvals.append('nameless')
|
| 216 |
-
else:
|
| 217 |
-
retvals.append(val.result_var)
|
| 218 |
-
|
| 219 |
-
print(DummyWrapper.template % {
|
| 220 |
-
'name': routine.name,
|
| 221 |
-
'expr': printed,
|
| 222 |
-
'args': ", ".join([str(a.name) for a in args]),
|
| 223 |
-
'retvals': ", ".join([str(val) for val in retvals])
|
| 224 |
-
}, end="", file=f)
|
| 225 |
-
|
| 226 |
-
def _process_files(self, routine):
|
| 227 |
-
return
|
| 228 |
-
|
| 229 |
-
@classmethod
|
| 230 |
-
def _get_wrapped_function(cls, mod, name):
|
| 231 |
-
return getattr(mod, name)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
class CythonCodeWrapper(CodeWrapper):
|
| 235 |
-
"""Wrapper that uses Cython"""
|
| 236 |
-
|
| 237 |
-
setup_template = """\
|
| 238 |
-
from setuptools import setup
|
| 239 |
-
from setuptools import Extension
|
| 240 |
-
from Cython.Build import cythonize
|
| 241 |
-
cy_opts = {cythonize_options}
|
| 242 |
-
{np_import}
|
| 243 |
-
ext_mods = [Extension(
|
| 244 |
-
{ext_args},
|
| 245 |
-
include_dirs={include_dirs},
|
| 246 |
-
library_dirs={library_dirs},
|
| 247 |
-
libraries={libraries},
|
| 248 |
-
extra_compile_args={extra_compile_args},
|
| 249 |
-
extra_link_args={extra_link_args}
|
| 250 |
-
)]
|
| 251 |
-
setup(ext_modules=cythonize(ext_mods, **cy_opts))
|
| 252 |
-
"""
|
| 253 |
-
|
| 254 |
-
_cythonize_options = {'compiler_directives':{'language_level' : "3"}}
|
| 255 |
-
|
| 256 |
-
pyx_imports = (
|
| 257 |
-
"import numpy as np\n"
|
| 258 |
-
"cimport numpy as np\n\n")
|
| 259 |
-
|
| 260 |
-
pyx_header = (
|
| 261 |
-
"cdef extern from '{header_file}.h':\n"
|
| 262 |
-
" {prototype}\n\n")
|
| 263 |
-
|
| 264 |
-
pyx_func = (
|
| 265 |
-
"def {name}_c({arg_string}):\n"
|
| 266 |
-
"\n"
|
| 267 |
-
"{declarations}"
|
| 268 |
-
"{body}")
|
| 269 |
-
|
| 270 |
-
std_compile_flag = '-std=c99'
|
| 271 |
-
|
| 272 |
-
def __init__(self, *args, **kwargs):
|
| 273 |
-
"""Instantiates a Cython code wrapper.
|
| 274 |
-
|
| 275 |
-
The following optional parameters get passed to ``setuptools.Extension``
|
| 276 |
-
for building the Python extension module. Read its documentation to
|
| 277 |
-
learn more.
|
| 278 |
-
|
| 279 |
-
Parameters
|
| 280 |
-
==========
|
| 281 |
-
include_dirs : [list of strings]
|
| 282 |
-
A list of directories to search for C/C++ header files (in Unix
|
| 283 |
-
form for portability).
|
| 284 |
-
library_dirs : [list of strings]
|
| 285 |
-
A list of directories to search for C/C++ libraries at link time.
|
| 286 |
-
libraries : [list of strings]
|
| 287 |
-
A list of library names (not filenames or paths) to link against.
|
| 288 |
-
extra_compile_args : [list of strings]
|
| 289 |
-
Any extra platform- and compiler-specific information to use when
|
| 290 |
-
compiling the source files in 'sources'. For platforms and
|
| 291 |
-
compilers where "command line" makes sense, this is typically a
|
| 292 |
-
list of command-line arguments, but for other platforms it could be
|
| 293 |
-
anything. Note that the attribute ``std_compile_flag`` will be
|
| 294 |
-
appended to this list.
|
| 295 |
-
extra_link_args : [list of strings]
|
| 296 |
-
Any extra platform- and compiler-specific information to use when
|
| 297 |
-
linking object files together to create the extension (or to create
|
| 298 |
-
a new static Python interpreter). Similar interpretation as for
|
| 299 |
-
'extra_compile_args'.
|
| 300 |
-
cythonize_options : [dictionary]
|
| 301 |
-
Keyword arguments passed on to cythonize.
|
| 302 |
-
|
| 303 |
-
"""
|
| 304 |
-
|
| 305 |
-
self._include_dirs = kwargs.pop('include_dirs', [])
|
| 306 |
-
self._library_dirs = kwargs.pop('library_dirs', [])
|
| 307 |
-
self._libraries = kwargs.pop('libraries', [])
|
| 308 |
-
self._extra_compile_args = kwargs.pop('extra_compile_args', [])
|
| 309 |
-
self._extra_compile_args.append(self.std_compile_flag)
|
| 310 |
-
self._extra_link_args = kwargs.pop('extra_link_args', [])
|
| 311 |
-
self._cythonize_options = kwargs.pop('cythonize_options', self._cythonize_options)
|
| 312 |
-
|
| 313 |
-
self._need_numpy = False
|
| 314 |
-
|
| 315 |
-
super().__init__(*args, **kwargs)
|
| 316 |
-
|
| 317 |
-
@property
|
| 318 |
-
def command(self):
|
| 319 |
-
command = [sys.executable, "setup.py", "build_ext", "--inplace"]
|
| 320 |
-
return command
|
| 321 |
-
|
| 322 |
-
def _prepare_files(self, routine, build_dir=os.curdir):
|
| 323 |
-
# NOTE : build_dir is used for testing purposes.
|
| 324 |
-
pyxfilename = self.module_name + '.pyx'
|
| 325 |
-
codefilename = "%s.%s" % (self.filename, self.generator.code_extension)
|
| 326 |
-
|
| 327 |
-
# pyx
|
| 328 |
-
with open(os.path.join(build_dir, pyxfilename), 'w') as f:
|
| 329 |
-
self.dump_pyx([routine], f, self.filename)
|
| 330 |
-
|
| 331 |
-
# setup.py
|
| 332 |
-
ext_args = [repr(self.module_name), repr([pyxfilename, codefilename])]
|
| 333 |
-
if self._need_numpy:
|
| 334 |
-
np_import = 'import numpy as np\n'
|
| 335 |
-
self._include_dirs.append('np.get_include()')
|
| 336 |
-
else:
|
| 337 |
-
np_import = ''
|
| 338 |
-
|
| 339 |
-
includes = str(self._include_dirs).replace("'np.get_include()'",
|
| 340 |
-
'np.get_include()')
|
| 341 |
-
code = self.setup_template.format(
|
| 342 |
-
ext_args=", ".join(ext_args),
|
| 343 |
-
np_import=np_import,
|
| 344 |
-
include_dirs=includes,
|
| 345 |
-
library_dirs=self._library_dirs,
|
| 346 |
-
libraries=self._libraries,
|
| 347 |
-
extra_compile_args=self._extra_compile_args,
|
| 348 |
-
extra_link_args=self._extra_link_args,
|
| 349 |
-
cythonize_options=self._cythonize_options)
|
| 350 |
-
Path(os.path.join(build_dir, 'setup.py')).write_text(code)
|
| 351 |
-
|
| 352 |
-
@classmethod
|
| 353 |
-
def _get_wrapped_function(cls, mod, name):
|
| 354 |
-
return getattr(mod, name + '_c')
|
| 355 |
-
|
| 356 |
-
def dump_pyx(self, routines, f, prefix):
|
| 357 |
-
"""Write a Cython file with Python wrappers
|
| 358 |
-
|
| 359 |
-
This file contains all the definitions of the routines in c code and
|
| 360 |
-
refers to the header file.
|
| 361 |
-
|
| 362 |
-
Arguments
|
| 363 |
-
---------
|
| 364 |
-
routines
|
| 365 |
-
List of Routine instances
|
| 366 |
-
f
|
| 367 |
-
File-like object to write the file to
|
| 368 |
-
prefix
|
| 369 |
-
The filename prefix, used to refer to the proper header file.
|
| 370 |
-
Only the basename of the prefix is used.
|
| 371 |
-
"""
|
| 372 |
-
headers = []
|
| 373 |
-
functions = []
|
| 374 |
-
for routine in routines:
|
| 375 |
-
prototype = self.generator.get_prototype(routine)
|
| 376 |
-
|
| 377 |
-
# C Function Header Import
|
| 378 |
-
headers.append(self.pyx_header.format(header_file=prefix,
|
| 379 |
-
prototype=prototype))
|
| 380 |
-
|
| 381 |
-
# Partition the C function arguments into categories
|
| 382 |
-
py_rets, py_args, py_loc, py_inf = self._partition_args(routine.arguments)
|
| 383 |
-
|
| 384 |
-
# Function prototype
|
| 385 |
-
name = routine.name
|
| 386 |
-
arg_string = ", ".join(self._prototype_arg(arg) for arg in py_args)
|
| 387 |
-
|
| 388 |
-
# Local Declarations
|
| 389 |
-
local_decs = []
|
| 390 |
-
for arg, val in py_inf.items():
|
| 391 |
-
proto = self._prototype_arg(arg)
|
| 392 |
-
mat, ind = [self._string_var(v) for v in val]
|
| 393 |
-
local_decs.append(" cdef {} = {}.shape[{}]".format(proto, mat, ind))
|
| 394 |
-
local_decs.extend([" cdef {}".format(self._declare_arg(a)) for a in py_loc])
|
| 395 |
-
declarations = "\n".join(local_decs)
|
| 396 |
-
if declarations:
|
| 397 |
-
declarations = declarations + "\n"
|
| 398 |
-
|
| 399 |
-
# Function Body
|
| 400 |
-
args_c = ", ".join([self._call_arg(a) for a in routine.arguments])
|
| 401 |
-
rets = ", ".join([self._string_var(r.name) for r in py_rets])
|
| 402 |
-
if routine.results:
|
| 403 |
-
body = ' return %s(%s)' % (routine.name, args_c)
|
| 404 |
-
if rets:
|
| 405 |
-
body = body + ', ' + rets
|
| 406 |
-
else:
|
| 407 |
-
body = ' %s(%s)\n' % (routine.name, args_c)
|
| 408 |
-
body = body + ' return ' + rets
|
| 409 |
-
|
| 410 |
-
functions.append(self.pyx_func.format(name=name, arg_string=arg_string,
|
| 411 |
-
declarations=declarations, body=body))
|
| 412 |
-
|
| 413 |
-
# Write text to file
|
| 414 |
-
if self._need_numpy:
|
| 415 |
-
# Only import numpy if required
|
| 416 |
-
f.write(self.pyx_imports)
|
| 417 |
-
f.write('\n'.join(headers))
|
| 418 |
-
f.write('\n'.join(functions))
|
| 419 |
-
|
| 420 |
-
def _partition_args(self, args):
|
| 421 |
-
"""Group function arguments into categories."""
|
| 422 |
-
py_args = []
|
| 423 |
-
py_returns = []
|
| 424 |
-
py_locals = []
|
| 425 |
-
py_inferred = {}
|
| 426 |
-
for arg in args:
|
| 427 |
-
if isinstance(arg, OutputArgument):
|
| 428 |
-
py_returns.append(arg)
|
| 429 |
-
py_locals.append(arg)
|
| 430 |
-
elif isinstance(arg, InOutArgument):
|
| 431 |
-
py_returns.append(arg)
|
| 432 |
-
py_args.append(arg)
|
| 433 |
-
else:
|
| 434 |
-
py_args.append(arg)
|
| 435 |
-
# Find arguments that are array dimensions. These can be inferred
|
| 436 |
-
# locally in the Cython code.
|
| 437 |
-
if isinstance(arg, (InputArgument, InOutArgument)) and arg.dimensions:
|
| 438 |
-
dims = [d[1] + 1 for d in arg.dimensions]
|
| 439 |
-
sym_dims = [(i, d) for (i, d) in enumerate(dims) if
|
| 440 |
-
isinstance(d, Symbol)]
|
| 441 |
-
for (i, d) in sym_dims:
|
| 442 |
-
py_inferred[d] = (arg.name, i)
|
| 443 |
-
for arg in args:
|
| 444 |
-
if arg.name in py_inferred:
|
| 445 |
-
py_inferred[arg] = py_inferred.pop(arg.name)
|
| 446 |
-
# Filter inferred arguments from py_args
|
| 447 |
-
py_args = [a for a in py_args if a not in py_inferred]
|
| 448 |
-
return py_returns, py_args, py_locals, py_inferred
|
| 449 |
-
|
| 450 |
-
def _prototype_arg(self, arg):
|
| 451 |
-
mat_dec = "np.ndarray[{mtype}, ndim={ndim}] {name}"
|
| 452 |
-
np_types = {'double': 'np.double_t',
|
| 453 |
-
'int': 'np.int_t'}
|
| 454 |
-
t = arg.get_datatype('c')
|
| 455 |
-
if arg.dimensions:
|
| 456 |
-
self._need_numpy = True
|
| 457 |
-
ndim = len(arg.dimensions)
|
| 458 |
-
mtype = np_types[t]
|
| 459 |
-
return mat_dec.format(mtype=mtype, ndim=ndim, name=self._string_var(arg.name))
|
| 460 |
-
else:
|
| 461 |
-
return "%s %s" % (t, self._string_var(arg.name))
|
| 462 |
-
|
| 463 |
-
def _declare_arg(self, arg):
|
| 464 |
-
proto = self._prototype_arg(arg)
|
| 465 |
-
if arg.dimensions:
|
| 466 |
-
shape = '(' + ','.join(self._string_var(i[1] + 1) for i in arg.dimensions) + ')'
|
| 467 |
-
return proto + " = np.empty({shape})".format(shape=shape)
|
| 468 |
-
else:
|
| 469 |
-
return proto + " = 0"
|
| 470 |
-
|
| 471 |
-
def _call_arg(self, arg):
|
| 472 |
-
if arg.dimensions:
|
| 473 |
-
t = arg.get_datatype('c')
|
| 474 |
-
return "<{}*> {}.data".format(t, self._string_var(arg.name))
|
| 475 |
-
elif isinstance(arg, ResultBase):
|
| 476 |
-
return "&{}".format(self._string_var(arg.name))
|
| 477 |
-
else:
|
| 478 |
-
return self._string_var(arg.name)
|
| 479 |
-
|
| 480 |
-
def _string_var(self, var):
|
| 481 |
-
printer = self.generator.printer.doprint
|
| 482 |
-
return printer(var)
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
class F2PyCodeWrapper(CodeWrapper):
|
| 486 |
-
"""Wrapper that uses f2py"""
|
| 487 |
-
|
| 488 |
-
def __init__(self, *args, **kwargs):
|
| 489 |
-
|
| 490 |
-
ext_keys = ['include_dirs', 'library_dirs', 'libraries',
|
| 491 |
-
'extra_compile_args', 'extra_link_args']
|
| 492 |
-
msg = ('The compilation option kwarg {} is not supported with the f2py '
|
| 493 |
-
'backend.')
|
| 494 |
-
|
| 495 |
-
for k in ext_keys:
|
| 496 |
-
if k in kwargs.keys():
|
| 497 |
-
warn(msg.format(k))
|
| 498 |
-
kwargs.pop(k, None)
|
| 499 |
-
|
| 500 |
-
super().__init__(*args, **kwargs)
|
| 501 |
-
|
| 502 |
-
@property
|
| 503 |
-
def command(self):
|
| 504 |
-
filename = self.filename + '.' + self.generator.code_extension
|
| 505 |
-
args = ['-c', '-m', self.module_name, filename]
|
| 506 |
-
command = [sys.executable, "-c", "import numpy.f2py as f2py2e;f2py2e.main()"]+args
|
| 507 |
-
return command
|
| 508 |
-
|
| 509 |
-
def _prepare_files(self, routine):
|
| 510 |
-
pass
|
| 511 |
-
|
| 512 |
-
@classmethod
|
| 513 |
-
def _get_wrapped_function(cls, mod, name):
|
| 514 |
-
return getattr(mod, name)
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
# Here we define a lookup of backends -> tuples of languages. For now, each
|
| 518 |
-
# tuple is of length 1, but if a backend supports more than one language,
|
| 519 |
-
# the most preferable language is listed first.
|
| 520 |
-
_lang_lookup = {'CYTHON': ('C99', 'C89', 'C'),
|
| 521 |
-
'F2PY': ('F95',),
|
| 522 |
-
'NUMPY': ('C99', 'C89', 'C'),
|
| 523 |
-
'DUMMY': ('F95',)} # Dummy here just for testing
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
def _infer_language(backend):
|
| 527 |
-
"""For a given backend, return the top choice of language"""
|
| 528 |
-
langs = _lang_lookup.get(backend.upper(), False)
|
| 529 |
-
if not langs:
|
| 530 |
-
raise ValueError("Unrecognized backend: " + backend)
|
| 531 |
-
return langs[0]
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
def _validate_backend_language(backend, language):
|
| 535 |
-
"""Throws error if backend and language are incompatible"""
|
| 536 |
-
langs = _lang_lookup.get(backend.upper(), False)
|
| 537 |
-
if not langs:
|
| 538 |
-
raise ValueError("Unrecognized backend: " + backend)
|
| 539 |
-
if language.upper() not in langs:
|
| 540 |
-
raise ValueError(("Backend {} and language {} are "
|
| 541 |
-
"incompatible").format(backend, language))
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
@cacheit
|
| 545 |
-
@doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',))
|
| 546 |
-
def autowrap(expr, language=None, backend='f2py', tempdir=None, args=None,
|
| 547 |
-
flags=None, verbose=False, helpers=None, code_gen=None, **kwargs):
|
| 548 |
-
"""Generates Python callable binaries based on the math expression.
|
| 549 |
-
|
| 550 |
-
Parameters
|
| 551 |
-
==========
|
| 552 |
-
|
| 553 |
-
expr
|
| 554 |
-
The SymPy expression that should be wrapped as a binary routine.
|
| 555 |
-
language : string, optional
|
| 556 |
-
If supplied, (options: 'C' or 'F95'), specifies the language of the
|
| 557 |
-
generated code. If ``None`` [default], the language is inferred based
|
| 558 |
-
upon the specified backend.
|
| 559 |
-
backend : string, optional
|
| 560 |
-
Backend used to wrap the generated code. Either 'f2py' [default],
|
| 561 |
-
or 'cython'.
|
| 562 |
-
tempdir : string, optional
|
| 563 |
-
Path to directory for temporary files. If this argument is supplied,
|
| 564 |
-
the generated code and the wrapper input files are left intact in the
|
| 565 |
-
specified path.
|
| 566 |
-
args : iterable, optional
|
| 567 |
-
An ordered iterable of symbols. Specifies the argument sequence for the
|
| 568 |
-
function.
|
| 569 |
-
flags : iterable, optional
|
| 570 |
-
Additional option flags that will be passed to the backend.
|
| 571 |
-
verbose : bool, optional
|
| 572 |
-
If True, autowrap will not mute the command line backends. This can be
|
| 573 |
-
helpful for debugging.
|
| 574 |
-
helpers : 3-tuple or iterable of 3-tuples, optional
|
| 575 |
-
Used to define auxiliary functions needed for the main expression.
|
| 576 |
-
Each tuple should be of the form (name, expr, args) where:
|
| 577 |
-
|
| 578 |
-
- name : str, the function name
|
| 579 |
-
- expr : sympy expression, the function
|
| 580 |
-
- args : iterable, the function arguments (can be any iterable of symbols)
|
| 581 |
-
|
| 582 |
-
code_gen : CodeGen instance
|
| 583 |
-
An instance of a CodeGen subclass. Overrides ``language``.
|
| 584 |
-
include_dirs : [string]
|
| 585 |
-
A list of directories to search for C/C++ header files (in Unix form
|
| 586 |
-
for portability).
|
| 587 |
-
library_dirs : [string]
|
| 588 |
-
A list of directories to search for C/C++ libraries at link time.
|
| 589 |
-
libraries : [string]
|
| 590 |
-
A list of library names (not filenames or paths) to link against.
|
| 591 |
-
extra_compile_args : [string]
|
| 592 |
-
Any extra platform- and compiler-specific information to use when
|
| 593 |
-
compiling the source files in 'sources'. For platforms and compilers
|
| 594 |
-
where "command line" makes sense, this is typically a list of
|
| 595 |
-
command-line arguments, but for other platforms it could be anything.
|
| 596 |
-
extra_link_args : [string]
|
| 597 |
-
Any extra platform- and compiler-specific information to use when
|
| 598 |
-
linking object files together to create the extension (or to create a
|
| 599 |
-
new static Python interpreter). Similar interpretation as for
|
| 600 |
-
'extra_compile_args'.
|
| 601 |
-
|
| 602 |
-
Examples
|
| 603 |
-
========
|
| 604 |
-
|
| 605 |
-
Basic usage:
|
| 606 |
-
|
| 607 |
-
>>> from sympy.abc import x, y, z
|
| 608 |
-
>>> from sympy.utilities.autowrap import autowrap
|
| 609 |
-
>>> expr = ((x - y + z)**(13)).expand()
|
| 610 |
-
>>> binary_func = autowrap(expr)
|
| 611 |
-
>>> binary_func(1, 4, 2)
|
| 612 |
-
-1.0
|
| 613 |
-
|
| 614 |
-
Using helper functions:
|
| 615 |
-
|
| 616 |
-
>>> from sympy.abc import x, t
|
| 617 |
-
>>> from sympy import Function
|
| 618 |
-
>>> helper_func = Function('helper_func') # Define symbolic function
|
| 619 |
-
>>> expr = 3*x + helper_func(t) # Main expression using helper function
|
| 620 |
-
>>> # Define helper_func(x) = 4*x using f2py backend
|
| 621 |
-
>>> binary_func = autowrap(expr, args=[x, t],
|
| 622 |
-
... helpers=('helper_func', 4*x, [x]))
|
| 623 |
-
>>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20
|
| 624 |
-
26.0
|
| 625 |
-
>>> # Same example using cython backend
|
| 626 |
-
>>> binary_func = autowrap(expr, args=[x, t], backend='cython',
|
| 627 |
-
... helpers=[('helper_func', 4*x, [x])])
|
| 628 |
-
>>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20
|
| 629 |
-
26.0
|
| 630 |
-
|
| 631 |
-
Type handling example:
|
| 632 |
-
|
| 633 |
-
>>> import numpy as np
|
| 634 |
-
>>> expr = x + y
|
| 635 |
-
>>> f_cython = autowrap(expr, backend='cython')
|
| 636 |
-
>>> f_cython(1, 2) # doctest: +ELLIPSIS
|
| 637 |
-
Traceback (most recent call last):
|
| 638 |
-
...
|
| 639 |
-
TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int)
|
| 640 |
-
>>> f_cython(np.array([1.0]), np.array([2.0]))
|
| 641 |
-
array([ 3.])
|
| 642 |
-
|
| 643 |
-
"""
|
| 644 |
-
if language:
|
| 645 |
-
if not isinstance(language, type):
|
| 646 |
-
_validate_backend_language(backend, language)
|
| 647 |
-
else:
|
| 648 |
-
language = _infer_language(backend)
|
| 649 |
-
|
| 650 |
-
# two cases 1) helpers is an iterable of 3-tuples and 2) helpers is a
|
| 651 |
-
# 3-tuple
|
| 652 |
-
if iterable(helpers) and len(helpers) != 0 and iterable(helpers[0]):
|
| 653 |
-
helpers = helpers if helpers else ()
|
| 654 |
-
else:
|
| 655 |
-
helpers = [helpers] if helpers else ()
|
| 656 |
-
args = list(args) if iterable(args, exclude=set) else args
|
| 657 |
-
|
| 658 |
-
if code_gen is None:
|
| 659 |
-
code_gen = get_code_generator(language, "autowrap")
|
| 660 |
-
|
| 661 |
-
CodeWrapperClass = {
|
| 662 |
-
'F2PY': F2PyCodeWrapper,
|
| 663 |
-
'CYTHON': CythonCodeWrapper,
|
| 664 |
-
'DUMMY': DummyWrapper
|
| 665 |
-
}[backend.upper()]
|
| 666 |
-
code_wrapper = CodeWrapperClass(code_gen, tempdir, flags if flags else (),
|
| 667 |
-
verbose, **kwargs)
|
| 668 |
-
|
| 669 |
-
helps = []
|
| 670 |
-
for name_h, expr_h, args_h in helpers:
|
| 671 |
-
helps.append(code_gen.routine(name_h, expr_h, args_h))
|
| 672 |
-
|
| 673 |
-
for name_h, expr_h, args_h in helpers:
|
| 674 |
-
if expr.has(expr_h):
|
| 675 |
-
name_h = binary_function(name_h, expr_h, backend='dummy')
|
| 676 |
-
expr = expr.subs(expr_h, name_h(*args_h))
|
| 677 |
-
try:
|
| 678 |
-
routine = code_gen.routine('autofunc', expr, args)
|
| 679 |
-
except CodeGenArgumentListError as e:
|
| 680 |
-
# if all missing arguments are for pure output, we simply attach them
|
| 681 |
-
# at the end and try again, because the wrappers will silently convert
|
| 682 |
-
# them to return values anyway.
|
| 683 |
-
new_args = []
|
| 684 |
-
for missing in e.missing_args:
|
| 685 |
-
if not isinstance(missing, OutputArgument):
|
| 686 |
-
raise
|
| 687 |
-
new_args.append(missing.name)
|
| 688 |
-
routine = code_gen.routine('autofunc', expr, args + new_args)
|
| 689 |
-
|
| 690 |
-
return code_wrapper.wrap_code(routine, helpers=helps)
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
@doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',))
|
| 694 |
-
def binary_function(symfunc, expr, **kwargs):
|
| 695 |
-
"""Returns a SymPy function with expr as binary implementation
|
| 696 |
-
|
| 697 |
-
This is a convenience function that automates the steps needed to
|
| 698 |
-
autowrap the SymPy expression and attaching it to a Function object
|
| 699 |
-
with implemented_function().
|
| 700 |
-
|
| 701 |
-
Parameters
|
| 702 |
-
==========
|
| 703 |
-
|
| 704 |
-
symfunc : SymPy Function
|
| 705 |
-
The function to bind the callable to.
|
| 706 |
-
expr : SymPy Expression
|
| 707 |
-
The expression used to generate the function.
|
| 708 |
-
kwargs : dict
|
| 709 |
-
Any kwargs accepted by autowrap.
|
| 710 |
-
|
| 711 |
-
Examples
|
| 712 |
-
========
|
| 713 |
-
|
| 714 |
-
>>> from sympy.abc import x, y
|
| 715 |
-
>>> from sympy.utilities.autowrap import binary_function
|
| 716 |
-
>>> expr = ((x - y)**(25)).expand()
|
| 717 |
-
>>> f = binary_function('f', expr)
|
| 718 |
-
>>> type(f)
|
| 719 |
-
<class 'sympy.core.function.UndefinedFunction'>
|
| 720 |
-
>>> 2*f(x, y)
|
| 721 |
-
2*f(x, y)
|
| 722 |
-
>>> f(x, y).evalf(2, subs={x: 1, y: 2})
|
| 723 |
-
-1.0
|
| 724 |
-
|
| 725 |
-
"""
|
| 726 |
-
binary = autowrap(expr, **kwargs)
|
| 727 |
-
return implemented_function(symfunc, binary)
|
| 728 |
-
|
| 729 |
-
#################################################################
|
| 730 |
-
# UFUNCIFY #
|
| 731 |
-
#################################################################
|
| 732 |
-
|
| 733 |
-
_ufunc_top = Template("""\
|
| 734 |
-
#include "Python.h"
|
| 735 |
-
#include "math.h"
|
| 736 |
-
#include "numpy/ndarraytypes.h"
|
| 737 |
-
#include "numpy/ufuncobject.h"
|
| 738 |
-
#include "numpy/halffloat.h"
|
| 739 |
-
#include ${include_file}
|
| 740 |
-
|
| 741 |
-
static PyMethodDef ${module}Methods[] = {
|
| 742 |
-
{NULL, NULL, 0, NULL}
|
| 743 |
-
};""")
|
| 744 |
-
|
| 745 |
-
_ufunc_outcalls = Template("*((double *)out${outnum}) = ${funcname}(${call_args});")
|
| 746 |
-
|
| 747 |
-
_ufunc_body = Template("""\
|
| 748 |
-
#ifdef NPY_1_19_API_VERSION
|
| 749 |
-
static void ${funcname}_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
|
| 750 |
-
#else
|
| 751 |
-
static void ${funcname}_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
|
| 752 |
-
#endif
|
| 753 |
-
{
|
| 754 |
-
npy_intp i;
|
| 755 |
-
npy_intp n = dimensions[0];
|
| 756 |
-
${declare_args}
|
| 757 |
-
${declare_steps}
|
| 758 |
-
for (i = 0; i < n; i++) {
|
| 759 |
-
${outcalls}
|
| 760 |
-
${step_increments}
|
| 761 |
-
}
|
| 762 |
-
}
|
| 763 |
-
PyUFuncGenericFunction ${funcname}_funcs[1] = {&${funcname}_ufunc};
|
| 764 |
-
static char ${funcname}_types[${n_types}] = ${types}
|
| 765 |
-
static void *${funcname}_data[1] = {NULL};""")
|
| 766 |
-
|
| 767 |
-
_ufunc_bottom = Template("""\
|
| 768 |
-
#if PY_VERSION_HEX >= 0x03000000
|
| 769 |
-
static struct PyModuleDef moduledef = {
|
| 770 |
-
PyModuleDef_HEAD_INIT,
|
| 771 |
-
"${module}",
|
| 772 |
-
NULL,
|
| 773 |
-
-1,
|
| 774 |
-
${module}Methods,
|
| 775 |
-
NULL,
|
| 776 |
-
NULL,
|
| 777 |
-
NULL,
|
| 778 |
-
NULL
|
| 779 |
-
};
|
| 780 |
-
|
| 781 |
-
PyMODINIT_FUNC PyInit_${module}(void)
|
| 782 |
-
{
|
| 783 |
-
PyObject *m, *d;
|
| 784 |
-
${function_creation}
|
| 785 |
-
m = PyModule_Create(&moduledef);
|
| 786 |
-
if (!m) {
|
| 787 |
-
return NULL;
|
| 788 |
-
}
|
| 789 |
-
import_array();
|
| 790 |
-
import_umath();
|
| 791 |
-
d = PyModule_GetDict(m);
|
| 792 |
-
${ufunc_init}
|
| 793 |
-
return m;
|
| 794 |
-
}
|
| 795 |
-
#else
|
| 796 |
-
PyMODINIT_FUNC init${module}(void)
|
| 797 |
-
{
|
| 798 |
-
PyObject *m, *d;
|
| 799 |
-
${function_creation}
|
| 800 |
-
m = Py_InitModule("${module}", ${module}Methods);
|
| 801 |
-
if (m == NULL) {
|
| 802 |
-
return;
|
| 803 |
-
}
|
| 804 |
-
import_array();
|
| 805 |
-
import_umath();
|
| 806 |
-
d = PyModule_GetDict(m);
|
| 807 |
-
${ufunc_init}
|
| 808 |
-
}
|
| 809 |
-
#endif\
|
| 810 |
-
""")
|
| 811 |
-
|
| 812 |
-
_ufunc_init_form = Template("""\
|
| 813 |
-
ufunc${ind} = PyUFunc_FromFuncAndData(${funcname}_funcs, ${funcname}_data, ${funcname}_types, 1, ${n_in}, ${n_out},
|
| 814 |
-
PyUFunc_None, "${module}", ${docstring}, 0);
|
| 815 |
-
PyDict_SetItemString(d, "${funcname}", ufunc${ind});
|
| 816 |
-
Py_DECREF(ufunc${ind});""")
|
| 817 |
-
|
| 818 |
-
_ufunc_setup = Template("""\
|
| 819 |
-
from setuptools.extension import Extension
|
| 820 |
-
from setuptools import setup
|
| 821 |
-
|
| 822 |
-
from numpy import get_include
|
| 823 |
-
|
| 824 |
-
if __name__ == "__main__":
|
| 825 |
-
setup(ext_modules=[
|
| 826 |
-
Extension('${module}',
|
| 827 |
-
sources=['${module}.c', '${filename}.c'],
|
| 828 |
-
include_dirs=[get_include()])])
|
| 829 |
-
""")
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
class UfuncifyCodeWrapper(CodeWrapper):
|
| 833 |
-
"""Wrapper for Ufuncify"""
|
| 834 |
-
|
| 835 |
-
def __init__(self, *args, **kwargs):
|
| 836 |
-
|
| 837 |
-
ext_keys = ['include_dirs', 'library_dirs', 'libraries',
|
| 838 |
-
'extra_compile_args', 'extra_link_args']
|
| 839 |
-
msg = ('The compilation option kwarg {} is not supported with the numpy'
|
| 840 |
-
' backend.')
|
| 841 |
-
|
| 842 |
-
for k in ext_keys:
|
| 843 |
-
if k in kwargs.keys():
|
| 844 |
-
warn(msg.format(k))
|
| 845 |
-
kwargs.pop(k, None)
|
| 846 |
-
|
| 847 |
-
super().__init__(*args, **kwargs)
|
| 848 |
-
|
| 849 |
-
@property
|
| 850 |
-
def command(self):
|
| 851 |
-
command = [sys.executable, "setup.py", "build_ext", "--inplace"]
|
| 852 |
-
return command
|
| 853 |
-
|
| 854 |
-
def wrap_code(self, routines, helpers=None):
|
| 855 |
-
# This routine overrides CodeWrapper because we can't assume funcname == routines[0].name
|
| 856 |
-
# Therefore we have to break the CodeWrapper private API.
|
| 857 |
-
# There isn't an obvious way to extend multi-expr support to
|
| 858 |
-
# the other autowrap backends, so we limit this change to ufuncify.
|
| 859 |
-
helpers = helpers if helpers is not None else []
|
| 860 |
-
# We just need a consistent name
|
| 861 |
-
funcname = 'wrapped_' + str(id(routines) + id(helpers))
|
| 862 |
-
|
| 863 |
-
workdir = self.filepath or tempfile.mkdtemp("_sympy_compile")
|
| 864 |
-
if not os.access(workdir, os.F_OK):
|
| 865 |
-
os.mkdir(workdir)
|
| 866 |
-
oldwork = os.getcwd()
|
| 867 |
-
os.chdir(workdir)
|
| 868 |
-
try:
|
| 869 |
-
sys.path.append(workdir)
|
| 870 |
-
self._generate_code(routines, helpers)
|
| 871 |
-
self._prepare_files(routines, funcname)
|
| 872 |
-
self._process_files(routines)
|
| 873 |
-
mod = __import__(self.module_name)
|
| 874 |
-
finally:
|
| 875 |
-
sys.path.remove(workdir)
|
| 876 |
-
CodeWrapper._module_counter += 1
|
| 877 |
-
os.chdir(oldwork)
|
| 878 |
-
if not self.filepath:
|
| 879 |
-
try:
|
| 880 |
-
shutil.rmtree(workdir)
|
| 881 |
-
except OSError:
|
| 882 |
-
# Could be some issues on Windows
|
| 883 |
-
pass
|
| 884 |
-
|
| 885 |
-
return self._get_wrapped_function(mod, funcname)
|
| 886 |
-
|
| 887 |
-
def _generate_code(self, main_routines, helper_routines):
|
| 888 |
-
all_routines = main_routines + helper_routines
|
| 889 |
-
self.generator.write(
|
| 890 |
-
all_routines, self.filename, True, self.include_header,
|
| 891 |
-
self.include_empty)
|
| 892 |
-
|
| 893 |
-
def _prepare_files(self, routines, funcname):
|
| 894 |
-
|
| 895 |
-
# C
|
| 896 |
-
codefilename = self.module_name + '.c'
|
| 897 |
-
with open(codefilename, 'w') as f:
|
| 898 |
-
self.dump_c(routines, f, self.filename, funcname=funcname)
|
| 899 |
-
|
| 900 |
-
# setup.py
|
| 901 |
-
with open('setup.py', 'w') as f:
|
| 902 |
-
self.dump_setup(f)
|
| 903 |
-
|
| 904 |
-
@classmethod
|
| 905 |
-
def _get_wrapped_function(cls, mod, name):
|
| 906 |
-
return getattr(mod, name)
|
| 907 |
-
|
| 908 |
-
def dump_setup(self, f):
|
| 909 |
-
setup = _ufunc_setup.substitute(module=self.module_name,
|
| 910 |
-
filename=self.filename)
|
| 911 |
-
f.write(setup)
|
| 912 |
-
|
| 913 |
-
def dump_c(self, routines, f, prefix, funcname=None):
|
| 914 |
-
"""Write a C file with Python wrappers
|
| 915 |
-
|
| 916 |
-
This file contains all the definitions of the routines in c code.
|
| 917 |
-
|
| 918 |
-
Arguments
|
| 919 |
-
---------
|
| 920 |
-
routines
|
| 921 |
-
List of Routine instances
|
| 922 |
-
f
|
| 923 |
-
File-like object to write the file to
|
| 924 |
-
prefix
|
| 925 |
-
The filename prefix, used to name the imported module.
|
| 926 |
-
funcname
|
| 927 |
-
Name of the main function to be returned.
|
| 928 |
-
"""
|
| 929 |
-
if funcname is None:
|
| 930 |
-
if len(routines) == 1:
|
| 931 |
-
funcname = routines[0].name
|
| 932 |
-
else:
|
| 933 |
-
msg = 'funcname must be specified for multiple output routines'
|
| 934 |
-
raise ValueError(msg)
|
| 935 |
-
functions = []
|
| 936 |
-
function_creation = []
|
| 937 |
-
ufunc_init = []
|
| 938 |
-
module = self.module_name
|
| 939 |
-
include_file = "\"{}.h\"".format(prefix)
|
| 940 |
-
top = _ufunc_top.substitute(include_file=include_file, module=module)
|
| 941 |
-
|
| 942 |
-
name = funcname
|
| 943 |
-
|
| 944 |
-
# Partition the C function arguments into categories
|
| 945 |
-
# Here we assume all routines accept the same arguments
|
| 946 |
-
r_index = 0
|
| 947 |
-
py_in, _ = self._partition_args(routines[0].arguments)
|
| 948 |
-
n_in = len(py_in)
|
| 949 |
-
n_out = len(routines)
|
| 950 |
-
|
| 951 |
-
# Declare Args
|
| 952 |
-
form = "char *{0}{1} = args[{2}];"
|
| 953 |
-
arg_decs = [form.format('in', i, i) for i in range(n_in)]
|
| 954 |
-
arg_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)])
|
| 955 |
-
declare_args = '\n '.join(arg_decs)
|
| 956 |
-
|
| 957 |
-
# Declare Steps
|
| 958 |
-
form = "npy_intp {0}{1}_step = steps[{2}];"
|
| 959 |
-
step_decs = [form.format('in', i, i) for i in range(n_in)]
|
| 960 |
-
step_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)])
|
| 961 |
-
declare_steps = '\n '.join(step_decs)
|
| 962 |
-
|
| 963 |
-
# Call Args
|
| 964 |
-
form = "*(double *)in{0}"
|
| 965 |
-
call_args = ', '.join([form.format(a) for a in range(n_in)])
|
| 966 |
-
|
| 967 |
-
# Step Increments
|
| 968 |
-
form = "{0}{1} += {0}{1}_step;"
|
| 969 |
-
step_incs = [form.format('in', i) for i in range(n_in)]
|
| 970 |
-
step_incs.extend([form.format('out', i, i) for i in range(n_out)])
|
| 971 |
-
step_increments = '\n '.join(step_incs)
|
| 972 |
-
|
| 973 |
-
# Types
|
| 974 |
-
n_types = n_in + n_out
|
| 975 |
-
types = "{" + ', '.join(["NPY_DOUBLE"]*n_types) + "};"
|
| 976 |
-
|
| 977 |
-
# Docstring
|
| 978 |
-
docstring = '"Created in SymPy with Ufuncify"'
|
| 979 |
-
|
| 980 |
-
# Function Creation
|
| 981 |
-
function_creation.append("PyObject *ufunc{};".format(r_index))
|
| 982 |
-
|
| 983 |
-
# Ufunc initialization
|
| 984 |
-
init_form = _ufunc_init_form.substitute(module=module,
|
| 985 |
-
funcname=name,
|
| 986 |
-
docstring=docstring,
|
| 987 |
-
n_in=n_in, n_out=n_out,
|
| 988 |
-
ind=r_index)
|
| 989 |
-
ufunc_init.append(init_form)
|
| 990 |
-
|
| 991 |
-
outcalls = [_ufunc_outcalls.substitute(
|
| 992 |
-
outnum=i, call_args=call_args, funcname=routines[i].name) for i in
|
| 993 |
-
range(n_out)]
|
| 994 |
-
|
| 995 |
-
body = _ufunc_body.substitute(module=module, funcname=name,
|
| 996 |
-
declare_args=declare_args,
|
| 997 |
-
declare_steps=declare_steps,
|
| 998 |
-
call_args=call_args,
|
| 999 |
-
step_increments=step_increments,
|
| 1000 |
-
n_types=n_types, types=types,
|
| 1001 |
-
outcalls='\n '.join(outcalls))
|
| 1002 |
-
functions.append(body)
|
| 1003 |
-
|
| 1004 |
-
body = '\n\n'.join(functions)
|
| 1005 |
-
ufunc_init = '\n '.join(ufunc_init)
|
| 1006 |
-
function_creation = '\n '.join(function_creation)
|
| 1007 |
-
bottom = _ufunc_bottom.substitute(module=module,
|
| 1008 |
-
ufunc_init=ufunc_init,
|
| 1009 |
-
function_creation=function_creation)
|
| 1010 |
-
text = [top, body, bottom]
|
| 1011 |
-
f.write('\n\n'.join(text))
|
| 1012 |
-
|
| 1013 |
-
def _partition_args(self, args):
|
| 1014 |
-
"""Group function arguments into categories."""
|
| 1015 |
-
py_in = []
|
| 1016 |
-
py_out = []
|
| 1017 |
-
for arg in args:
|
| 1018 |
-
if isinstance(arg, OutputArgument):
|
| 1019 |
-
py_out.append(arg)
|
| 1020 |
-
elif isinstance(arg, InOutArgument):
|
| 1021 |
-
raise ValueError("Ufuncify doesn't support InOutArguments")
|
| 1022 |
-
else:
|
| 1023 |
-
py_in.append(arg)
|
| 1024 |
-
return py_in, py_out
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
@cacheit
|
| 1028 |
-
@doctest_depends_on(exe=('f2py', 'gfortran', 'gcc'), modules=('numpy',))
|
| 1029 |
-
def ufuncify(args, expr, language=None, backend='numpy', tempdir=None,
|
| 1030 |
-
flags=None, verbose=False, helpers=None, **kwargs):
|
| 1031 |
-
"""Generates a binary function that supports broadcasting on numpy arrays.
|
| 1032 |
-
|
| 1033 |
-
Parameters
|
| 1034 |
-
==========
|
| 1035 |
-
|
| 1036 |
-
args : iterable
|
| 1037 |
-
Either a Symbol or an iterable of symbols. Specifies the argument
|
| 1038 |
-
sequence for the function.
|
| 1039 |
-
expr
|
| 1040 |
-
A SymPy expression that defines the element wise operation.
|
| 1041 |
-
language : string, optional
|
| 1042 |
-
If supplied, (options: 'C' or 'F95'), specifies the language of the
|
| 1043 |
-
generated code. If ``None`` [default], the language is inferred based
|
| 1044 |
-
upon the specified backend.
|
| 1045 |
-
backend : string, optional
|
| 1046 |
-
Backend used to wrap the generated code. Either 'numpy' [default],
|
| 1047 |
-
'cython', or 'f2py'.
|
| 1048 |
-
tempdir : string, optional
|
| 1049 |
-
Path to directory for temporary files. If this argument is supplied,
|
| 1050 |
-
the generated code and the wrapper input files are left intact in
|
| 1051 |
-
the specified path.
|
| 1052 |
-
flags : iterable, optional
|
| 1053 |
-
Additional option flags that will be passed to the backend.
|
| 1054 |
-
verbose : bool, optional
|
| 1055 |
-
If True, autowrap will not mute the command line backends. This can
|
| 1056 |
-
be helpful for debugging.
|
| 1057 |
-
helpers : 3-tuple or iterable of 3-tuples, optional
|
| 1058 |
-
Used to define auxiliary functions needed for the main expression.
|
| 1059 |
-
Each tuple should be of the form (name, expr, args) where:
|
| 1060 |
-
|
| 1061 |
-
- name : str, the function name
|
| 1062 |
-
- expr : sympy expression, the function
|
| 1063 |
-
- args : iterable, the function arguments (can be any iterable of symbols)
|
| 1064 |
-
|
| 1065 |
-
kwargs : dict
|
| 1066 |
-
These kwargs will be passed to autowrap if the `f2py` or `cython`
|
| 1067 |
-
backend is used and ignored if the `numpy` backend is used.
|
| 1068 |
-
|
| 1069 |
-
Notes
|
| 1070 |
-
=====
|
| 1071 |
-
|
| 1072 |
-
The default backend ('numpy') will create actual instances of
|
| 1073 |
-
``numpy.ufunc``. These support ndimensional broadcasting, and implicit type
|
| 1074 |
-
conversion. Use of the other backends will result in a "ufunc-like"
|
| 1075 |
-
function, which requires equal length 1-dimensional arrays for all
|
| 1076 |
-
arguments, and will not perform any type conversions.
|
| 1077 |
-
|
| 1078 |
-
References
|
| 1079 |
-
==========
|
| 1080 |
-
|
| 1081 |
-
.. [1] https://numpy.org/doc/stable/reference/ufuncs.html
|
| 1082 |
-
|
| 1083 |
-
Examples
|
| 1084 |
-
========
|
| 1085 |
-
|
| 1086 |
-
Basic usage:
|
| 1087 |
-
|
| 1088 |
-
>>> from sympy.utilities.autowrap import ufuncify
|
| 1089 |
-
>>> from sympy.abc import x, y
|
| 1090 |
-
>>> import numpy as np
|
| 1091 |
-
>>> f = ufuncify((x, y), y + x**2)
|
| 1092 |
-
>>> type(f)
|
| 1093 |
-
<class 'numpy.ufunc'>
|
| 1094 |
-
>>> f([1, 2, 3], 2)
|
| 1095 |
-
array([ 3., 6., 11.])
|
| 1096 |
-
>>> f(np.arange(5), 3)
|
| 1097 |
-
array([ 3., 4., 7., 12., 19.])
|
| 1098 |
-
|
| 1099 |
-
Using helper functions:
|
| 1100 |
-
|
| 1101 |
-
>>> from sympy import Function
|
| 1102 |
-
>>> helper_func = Function('helper_func') # Define symbolic function
|
| 1103 |
-
>>> expr = x**2 + y*helper_func(x) # Main expression using helper function
|
| 1104 |
-
>>> # Define helper_func(x) = x**3
|
| 1105 |
-
>>> f = ufuncify((x, y), expr, helpers=[('helper_func', x**3, [x])])
|
| 1106 |
-
>>> f([1, 2], [3, 4])
|
| 1107 |
-
array([ 4., 36.])
|
| 1108 |
-
|
| 1109 |
-
Type handling with different backends:
|
| 1110 |
-
|
| 1111 |
-
For the 'f2py' and 'cython' backends, inputs are required to be equal length
|
| 1112 |
-
1-dimensional arrays. The 'f2py' backend will perform type conversion, but
|
| 1113 |
-
the Cython backend will error if the inputs are not of the expected type.
|
| 1114 |
-
|
| 1115 |
-
>>> f_fortran = ufuncify((x, y), y + x**2, backend='f2py')
|
| 1116 |
-
>>> f_fortran(1, 2)
|
| 1117 |
-
array([ 3.])
|
| 1118 |
-
>>> f_fortran(np.array([1, 2, 3]), np.array([1.0, 2.0, 3.0]))
|
| 1119 |
-
array([ 2., 6., 12.])
|
| 1120 |
-
>>> f_cython = ufuncify((x, y), y + x**2, backend='Cython')
|
| 1121 |
-
>>> f_cython(1, 2) # doctest: +ELLIPSIS
|
| 1122 |
-
Traceback (most recent call last):
|
| 1123 |
-
...
|
| 1124 |
-
TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int)
|
| 1125 |
-
>>> f_cython(np.array([1.0]), np.array([2.0]))
|
| 1126 |
-
array([ 3.])
|
| 1127 |
-
|
| 1128 |
-
"""
|
| 1129 |
-
|
| 1130 |
-
if isinstance(args, Symbol):
|
| 1131 |
-
args = (args,)
|
| 1132 |
-
else:
|
| 1133 |
-
args = tuple(args)
|
| 1134 |
-
|
| 1135 |
-
if language:
|
| 1136 |
-
_validate_backend_language(backend, language)
|
| 1137 |
-
else:
|
| 1138 |
-
language = _infer_language(backend)
|
| 1139 |
-
|
| 1140 |
-
helpers = helpers if helpers else ()
|
| 1141 |
-
flags = flags if flags else ()
|
| 1142 |
-
|
| 1143 |
-
if backend.upper() == 'NUMPY':
|
| 1144 |
-
# maxargs is set by numpy compile-time constant NPY_MAXARGS
|
| 1145 |
-
# If a future version of numpy modifies or removes this restriction
|
| 1146 |
-
# this variable should be changed or removed
|
| 1147 |
-
maxargs = 32
|
| 1148 |
-
helps = []
|
| 1149 |
-
for name, expr, args in helpers:
|
| 1150 |
-
helps.append(make_routine(name, expr, args))
|
| 1151 |
-
code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"), tempdir,
|
| 1152 |
-
flags, verbose)
|
| 1153 |
-
if not isinstance(expr, (list, tuple)):
|
| 1154 |
-
expr = [expr]
|
| 1155 |
-
if len(expr) == 0:
|
| 1156 |
-
raise ValueError('Expression iterable has zero length')
|
| 1157 |
-
if len(expr) + len(args) > maxargs:
|
| 1158 |
-
msg = ('Cannot create ufunc with more than {0} total arguments: '
|
| 1159 |
-
'got {1} in, {2} out')
|
| 1160 |
-
raise ValueError(msg.format(maxargs, len(args), len(expr)))
|
| 1161 |
-
routines = [make_routine('autofunc{}'.format(idx), exprx, args) for
|
| 1162 |
-
idx, exprx in enumerate(expr)]
|
| 1163 |
-
return code_wrapper.wrap_code(routines, helpers=helps)
|
| 1164 |
-
else:
|
| 1165 |
-
# Dummies are used for all added expressions to prevent name clashes
|
| 1166 |
-
# within the original expression.
|
| 1167 |
-
y = IndexedBase(Dummy('y'))
|
| 1168 |
-
m = Dummy('m', integer=True)
|
| 1169 |
-
i = Idx(Dummy('i', integer=True), m)
|
| 1170 |
-
f_dummy = Dummy('f')
|
| 1171 |
-
f = implemented_function('%s_%d' % (f_dummy.name, f_dummy.dummy_index), Lambda(args, expr))
|
| 1172 |
-
# For each of the args create an indexed version.
|
| 1173 |
-
indexed_args = [IndexedBase(Dummy(str(a))) for a in args]
|
| 1174 |
-
# Order the arguments (out, args, dim)
|
| 1175 |
-
args = [y] + indexed_args + [m]
|
| 1176 |
-
args_with_indices = [a[i] for a in indexed_args]
|
| 1177 |
-
return autowrap(Eq(y[i], f(*args_with_indices)), language, backend,
|
| 1178 |
-
tempdir, args, flags, verbose, helpers, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/codegen.py
DELETED
|
@@ -1,2237 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
module for generating C, C++, Fortran77, Fortran90, Julia, Rust
|
| 3 |
-
and Octave/Matlab routines that evaluate SymPy expressions.
|
| 4 |
-
This module is work in progress.
|
| 5 |
-
Only the milestones with a '+' character in the list below have been completed.
|
| 6 |
-
|
| 7 |
-
--- How is sympy.utilities.codegen different from sympy.printing.ccode? ---
|
| 8 |
-
|
| 9 |
-
We considered the idea to extend the printing routines for SymPy functions in
|
| 10 |
-
such a way that it prints complete compilable code, but this leads to a few
|
| 11 |
-
unsurmountable issues that can only be tackled with dedicated code generator:
|
| 12 |
-
|
| 13 |
-
- For C, one needs both a code and a header file, while the printing routines
|
| 14 |
-
generate just one string. This code generator can be extended to support
|
| 15 |
-
.pyf files for f2py.
|
| 16 |
-
|
| 17 |
-
- SymPy functions are not concerned with programming-technical issues, such
|
| 18 |
-
as input, output and input-output arguments. Other examples are contiguous
|
| 19 |
-
or non-contiguous arrays, including headers of other libraries such as gsl
|
| 20 |
-
or others.
|
| 21 |
-
|
| 22 |
-
- It is highly interesting to evaluate several SymPy functions in one C
|
| 23 |
-
routine, eventually sharing common intermediate results with the help
|
| 24 |
-
of the cse routine. This is more than just printing.
|
| 25 |
-
|
| 26 |
-
- From the programming perspective, expressions with constants should be
|
| 27 |
-
evaluated in the code generator as much as possible. This is different
|
| 28 |
-
for printing.
|
| 29 |
-
|
| 30 |
-
--- Basic assumptions ---
|
| 31 |
-
|
| 32 |
-
* A generic Routine data structure describes the routine that must be
|
| 33 |
-
translated into C/Fortran/... code. This data structure covers all
|
| 34 |
-
features present in one or more of the supported languages.
|
| 35 |
-
|
| 36 |
-
* Descendants from the CodeGen class transform multiple Routine instances
|
| 37 |
-
into compilable code. Each derived class translates into a specific
|
| 38 |
-
language.
|
| 39 |
-
|
| 40 |
-
* In many cases, one wants a simple workflow. The friendly functions in the
|
| 41 |
-
last part are a simple api on top of the Routine/CodeGen stuff. They are
|
| 42 |
-
easier to use, but are less powerful.
|
| 43 |
-
|
| 44 |
-
--- Milestones ---
|
| 45 |
-
|
| 46 |
-
+ First working version with scalar input arguments, generating C code,
|
| 47 |
-
tests
|
| 48 |
-
+ Friendly functions that are easier to use than the rigorous
|
| 49 |
-
Routine/CodeGen workflow.
|
| 50 |
-
+ Integer and Real numbers as input and output
|
| 51 |
-
+ Output arguments
|
| 52 |
-
+ InputOutput arguments
|
| 53 |
-
+ Sort input/output arguments properly
|
| 54 |
-
+ Contiguous array arguments (numpy matrices)
|
| 55 |
-
+ Also generate .pyf code for f2py (in autowrap module)
|
| 56 |
-
+ Isolate constants and evaluate them beforehand in double precision
|
| 57 |
-
+ Fortran 90
|
| 58 |
-
+ Octave/Matlab
|
| 59 |
-
|
| 60 |
-
- Common Subexpression Elimination
|
| 61 |
-
- User defined comments in the generated code
|
| 62 |
-
- Optional extra include lines for libraries/objects that can eval special
|
| 63 |
-
functions
|
| 64 |
-
- Test other C compilers and libraries: gcc, tcc, libtcc, gcc+gsl, ...
|
| 65 |
-
- Contiguous array arguments (SymPy matrices)
|
| 66 |
-
- Non-contiguous array arguments (SymPy matrices)
|
| 67 |
-
- ccode must raise an error when it encounters something that cannot be
|
| 68 |
-
translated into c. ccode(integrate(sin(x)/x, x)) does not make sense.
|
| 69 |
-
- Complex numbers as input and output
|
| 70 |
-
- A default complex datatype
|
| 71 |
-
- Include extra information in the header: date, user, hostname, sha1
|
| 72 |
-
hash, ...
|
| 73 |
-
- Fortran 77
|
| 74 |
-
- C++
|
| 75 |
-
- Python
|
| 76 |
-
- Julia
|
| 77 |
-
- Rust
|
| 78 |
-
- ...
|
| 79 |
-
|
| 80 |
-
"""
|
| 81 |
-
|
| 82 |
-
import os
|
| 83 |
-
import textwrap
|
| 84 |
-
from io import StringIO
|
| 85 |
-
|
| 86 |
-
from sympy import __version__ as sympy_version
|
| 87 |
-
from sympy.core import Symbol, S, Tuple, Equality, Function, Basic
|
| 88 |
-
from sympy.printing.c import c_code_printers
|
| 89 |
-
from sympy.printing.codeprinter import AssignmentError
|
| 90 |
-
from sympy.printing.fortran import FCodePrinter
|
| 91 |
-
from sympy.printing.julia import JuliaCodePrinter
|
| 92 |
-
from sympy.printing.octave import OctaveCodePrinter
|
| 93 |
-
from sympy.printing.rust import RustCodePrinter
|
| 94 |
-
from sympy.tensor import Idx, Indexed, IndexedBase
|
| 95 |
-
from sympy.matrices import (MatrixSymbol, ImmutableMatrix, MatrixBase,
|
| 96 |
-
MatrixExpr, MatrixSlice)
|
| 97 |
-
from sympy.utilities.iterables import is_sequence
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
__all__ = [
|
| 101 |
-
# description of routines
|
| 102 |
-
"Routine", "DataType", "default_datatypes", "get_default_datatype",
|
| 103 |
-
"Argument", "InputArgument", "OutputArgument", "Result",
|
| 104 |
-
# routines -> code
|
| 105 |
-
"CodeGen", "CCodeGen", "FCodeGen", "JuliaCodeGen", "OctaveCodeGen",
|
| 106 |
-
"RustCodeGen",
|
| 107 |
-
# friendly functions
|
| 108 |
-
"codegen", "make_routine",
|
| 109 |
-
]
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
#
|
| 113 |
-
# Description of routines
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
class Routine:
|
| 118 |
-
"""Generic description of evaluation routine for set of expressions.
|
| 119 |
-
|
| 120 |
-
A CodeGen class can translate instances of this class into code in a
|
| 121 |
-
particular language. The routine specification covers all the features
|
| 122 |
-
present in these languages. The CodeGen part must raise an exception
|
| 123 |
-
when certain features are not present in the target language. For
|
| 124 |
-
example, multiple return values are possible in Python, but not in C or
|
| 125 |
-
Fortran. Another example: Fortran and Python support complex numbers,
|
| 126 |
-
while C does not.
|
| 127 |
-
|
| 128 |
-
"""
|
| 129 |
-
|
| 130 |
-
def __init__(self, name, arguments, results, local_vars, global_vars):
|
| 131 |
-
"""Initialize a Routine instance.
|
| 132 |
-
|
| 133 |
-
Parameters
|
| 134 |
-
==========
|
| 135 |
-
|
| 136 |
-
name : string
|
| 137 |
-
Name of the routine.
|
| 138 |
-
|
| 139 |
-
arguments : list of Arguments
|
| 140 |
-
These are things that appear in arguments of a routine, often
|
| 141 |
-
appearing on the right-hand side of a function call. These are
|
| 142 |
-
commonly InputArguments but in some languages, they can also be
|
| 143 |
-
OutputArguments or InOutArguments (e.g., pass-by-reference in C
|
| 144 |
-
code).
|
| 145 |
-
|
| 146 |
-
results : list of Results
|
| 147 |
-
These are the return values of the routine, often appearing on
|
| 148 |
-
the left-hand side of a function call. The difference between
|
| 149 |
-
Results and OutputArguments and when you should use each is
|
| 150 |
-
language-specific.
|
| 151 |
-
|
| 152 |
-
local_vars : list of Results
|
| 153 |
-
These are variables that will be defined at the beginning of the
|
| 154 |
-
function.
|
| 155 |
-
|
| 156 |
-
global_vars : list of Symbols
|
| 157 |
-
Variables which will not be passed into the function.
|
| 158 |
-
|
| 159 |
-
"""
|
| 160 |
-
|
| 161 |
-
# extract all input symbols and all symbols appearing in an expression
|
| 162 |
-
input_symbols = set()
|
| 163 |
-
symbols = set()
|
| 164 |
-
for arg in arguments:
|
| 165 |
-
if isinstance(arg, OutputArgument):
|
| 166 |
-
symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed))
|
| 167 |
-
elif isinstance(arg, InputArgument):
|
| 168 |
-
input_symbols.add(arg.name)
|
| 169 |
-
elif isinstance(arg, InOutArgument):
|
| 170 |
-
input_symbols.add(arg.name)
|
| 171 |
-
symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed))
|
| 172 |
-
else:
|
| 173 |
-
raise ValueError("Unknown Routine argument: %s" % arg)
|
| 174 |
-
|
| 175 |
-
for r in results:
|
| 176 |
-
if not isinstance(r, Result):
|
| 177 |
-
raise ValueError("Unknown Routine result: %s" % r)
|
| 178 |
-
symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed))
|
| 179 |
-
|
| 180 |
-
local_symbols = set()
|
| 181 |
-
for r in local_vars:
|
| 182 |
-
if isinstance(r, Result):
|
| 183 |
-
symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed))
|
| 184 |
-
local_symbols.add(r.name)
|
| 185 |
-
else:
|
| 186 |
-
local_symbols.add(r)
|
| 187 |
-
|
| 188 |
-
symbols = {s.label if isinstance(s, Idx) else s for s in symbols}
|
| 189 |
-
|
| 190 |
-
# Check that all symbols in the expressions are covered by
|
| 191 |
-
# InputArguments/InOutArguments---subset because user could
|
| 192 |
-
# specify additional (unused) InputArguments or local_vars.
|
| 193 |
-
notcovered = symbols.difference(
|
| 194 |
-
input_symbols.union(local_symbols).union(global_vars))
|
| 195 |
-
if notcovered != set():
|
| 196 |
-
raise ValueError("Symbols needed for output are not in input " +
|
| 197 |
-
", ".join([str(x) for x in notcovered]))
|
| 198 |
-
|
| 199 |
-
self.name = name
|
| 200 |
-
self.arguments = arguments
|
| 201 |
-
self.results = results
|
| 202 |
-
self.local_vars = local_vars
|
| 203 |
-
self.global_vars = global_vars
|
| 204 |
-
|
| 205 |
-
def __str__(self):
|
| 206 |
-
return self.__class__.__name__ + "({name!r}, {arguments}, {results}, {local_vars}, {global_vars})".format(**self.__dict__)
|
| 207 |
-
|
| 208 |
-
__repr__ = __str__
|
| 209 |
-
|
| 210 |
-
@property
|
| 211 |
-
def variables(self):
|
| 212 |
-
"""Returns a set of all variables possibly used in the routine.
|
| 213 |
-
|
| 214 |
-
For routines with unnamed return values, the dummies that may or
|
| 215 |
-
may not be used will be included in the set.
|
| 216 |
-
|
| 217 |
-
"""
|
| 218 |
-
v = set(self.local_vars)
|
| 219 |
-
v.update(arg.name for arg in self.arguments)
|
| 220 |
-
v.update(res.result_var for res in self.results)
|
| 221 |
-
return v
|
| 222 |
-
|
| 223 |
-
@property
|
| 224 |
-
def result_variables(self):
|
| 225 |
-
"""Returns a list of OutputArgument, InOutArgument and Result.
|
| 226 |
-
|
| 227 |
-
If return values are present, they are at the end of the list.
|
| 228 |
-
"""
|
| 229 |
-
args = [arg for arg in self.arguments if isinstance(
|
| 230 |
-
arg, (OutputArgument, InOutArgument))]
|
| 231 |
-
args.extend(self.results)
|
| 232 |
-
return args
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
class DataType:
|
| 236 |
-
"""Holds strings for a certain datatype in different languages."""
|
| 237 |
-
def __init__(self, cname, fname, pyname, jlname, octname, rsname):
|
| 238 |
-
self.cname = cname
|
| 239 |
-
self.fname = fname
|
| 240 |
-
self.pyname = pyname
|
| 241 |
-
self.jlname = jlname
|
| 242 |
-
self.octname = octname
|
| 243 |
-
self.rsname = rsname
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
default_datatypes = {
|
| 247 |
-
"int": DataType("int", "INTEGER*4", "int", "", "", "i32"),
|
| 248 |
-
"float": DataType("double", "REAL*8", "float", "", "", "f64"),
|
| 249 |
-
"complex": DataType("double", "COMPLEX*16", "complex", "", "", "float") #FIXME:
|
| 250 |
-
# complex is only supported in fortran, python, julia, and octave.
|
| 251 |
-
# So to not break c or rust code generation, we stick with double or
|
| 252 |
-
# float, respectively (but actually should raise an exception for
|
| 253 |
-
# explicitly complex variables (x.is_complex==True))
|
| 254 |
-
}
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
COMPLEX_ALLOWED = False
|
| 258 |
-
def get_default_datatype(expr, complex_allowed=None):
|
| 259 |
-
"""Derives an appropriate datatype based on the expression."""
|
| 260 |
-
if complex_allowed is None:
|
| 261 |
-
complex_allowed = COMPLEX_ALLOWED
|
| 262 |
-
if complex_allowed:
|
| 263 |
-
final_dtype = "complex"
|
| 264 |
-
else:
|
| 265 |
-
final_dtype = "float"
|
| 266 |
-
if expr.is_integer:
|
| 267 |
-
return default_datatypes["int"]
|
| 268 |
-
elif expr.is_real:
|
| 269 |
-
return default_datatypes["float"]
|
| 270 |
-
elif isinstance(expr, MatrixBase):
|
| 271 |
-
#check all entries
|
| 272 |
-
dt = "int"
|
| 273 |
-
for element in expr:
|
| 274 |
-
if dt == "int" and not element.is_integer:
|
| 275 |
-
dt = "float"
|
| 276 |
-
if dt == "float" and not element.is_real:
|
| 277 |
-
return default_datatypes[final_dtype]
|
| 278 |
-
return default_datatypes[dt]
|
| 279 |
-
else:
|
| 280 |
-
return default_datatypes[final_dtype]
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
class Variable:
|
| 284 |
-
"""Represents a typed variable."""
|
| 285 |
-
|
| 286 |
-
def __init__(self, name, datatype=None, dimensions=None, precision=None):
|
| 287 |
-
"""Return a new variable.
|
| 288 |
-
|
| 289 |
-
Parameters
|
| 290 |
-
==========
|
| 291 |
-
|
| 292 |
-
name : Symbol or MatrixSymbol
|
| 293 |
-
|
| 294 |
-
datatype : optional
|
| 295 |
-
When not given, the data type will be guessed based on the
|
| 296 |
-
assumptions on the symbol argument.
|
| 297 |
-
|
| 298 |
-
dimensions : sequence containing tuples, optional
|
| 299 |
-
If present, the argument is interpreted as an array, where this
|
| 300 |
-
sequence of tuples specifies (lower, upper) bounds for each
|
| 301 |
-
index of the array.
|
| 302 |
-
|
| 303 |
-
precision : int, optional
|
| 304 |
-
Controls the precision of floating point constants.
|
| 305 |
-
|
| 306 |
-
"""
|
| 307 |
-
if not isinstance(name, (Symbol, MatrixSymbol)):
|
| 308 |
-
raise TypeError("The first argument must be a SymPy symbol.")
|
| 309 |
-
if datatype is None:
|
| 310 |
-
datatype = get_default_datatype(name)
|
| 311 |
-
elif not isinstance(datatype, DataType):
|
| 312 |
-
raise TypeError("The (optional) `datatype' argument must be an "
|
| 313 |
-
"instance of the DataType class.")
|
| 314 |
-
if dimensions and not isinstance(dimensions, (tuple, list)):
|
| 315 |
-
raise TypeError(
|
| 316 |
-
"The dimensions argument must be a sequence of tuples")
|
| 317 |
-
|
| 318 |
-
self._name = name
|
| 319 |
-
self._datatype = {
|
| 320 |
-
'C': datatype.cname,
|
| 321 |
-
'FORTRAN': datatype.fname,
|
| 322 |
-
'JULIA': datatype.jlname,
|
| 323 |
-
'OCTAVE': datatype.octname,
|
| 324 |
-
'PYTHON': datatype.pyname,
|
| 325 |
-
'RUST': datatype.rsname,
|
| 326 |
-
}
|
| 327 |
-
self.dimensions = dimensions
|
| 328 |
-
self.precision = precision
|
| 329 |
-
|
| 330 |
-
def __str__(self):
|
| 331 |
-
return "%s(%r)" % (self.__class__.__name__, self.name)
|
| 332 |
-
|
| 333 |
-
__repr__ = __str__
|
| 334 |
-
|
| 335 |
-
@property
|
| 336 |
-
def name(self):
|
| 337 |
-
return self._name
|
| 338 |
-
|
| 339 |
-
def get_datatype(self, language):
|
| 340 |
-
"""Returns the datatype string for the requested language.
|
| 341 |
-
|
| 342 |
-
Examples
|
| 343 |
-
========
|
| 344 |
-
|
| 345 |
-
>>> from sympy import Symbol
|
| 346 |
-
>>> from sympy.utilities.codegen import Variable
|
| 347 |
-
>>> x = Variable(Symbol('x'))
|
| 348 |
-
>>> x.get_datatype('c')
|
| 349 |
-
'double'
|
| 350 |
-
>>> x.get_datatype('fortran')
|
| 351 |
-
'REAL*8'
|
| 352 |
-
|
| 353 |
-
"""
|
| 354 |
-
try:
|
| 355 |
-
return self._datatype[language.upper()]
|
| 356 |
-
except KeyError:
|
| 357 |
-
raise CodeGenError("Has datatypes for languages: %s" %
|
| 358 |
-
", ".join(self._datatype))
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
class Argument(Variable):
|
| 362 |
-
"""An abstract Argument data structure: a name and a data type.
|
| 363 |
-
|
| 364 |
-
This structure is refined in the descendants below.
|
| 365 |
-
|
| 366 |
-
"""
|
| 367 |
-
pass
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
class InputArgument(Argument):
|
| 371 |
-
pass
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
class ResultBase:
|
| 375 |
-
"""Base class for all "outgoing" information from a routine.
|
| 376 |
-
|
| 377 |
-
Objects of this class stores a SymPy expression, and a SymPy object
|
| 378 |
-
representing a result variable that will be used in the generated code
|
| 379 |
-
only if necessary.
|
| 380 |
-
|
| 381 |
-
"""
|
| 382 |
-
def __init__(self, expr, result_var):
|
| 383 |
-
self.expr = expr
|
| 384 |
-
self.result_var = result_var
|
| 385 |
-
|
| 386 |
-
def __str__(self):
|
| 387 |
-
return "%s(%r, %r)" % (self.__class__.__name__, self.expr,
|
| 388 |
-
self.result_var)
|
| 389 |
-
|
| 390 |
-
__repr__ = __str__
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
class OutputArgument(Argument, ResultBase):
|
| 394 |
-
"""OutputArgument are always initialized in the routine."""
|
| 395 |
-
|
| 396 |
-
def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None):
|
| 397 |
-
"""Return a new variable.
|
| 398 |
-
|
| 399 |
-
Parameters
|
| 400 |
-
==========
|
| 401 |
-
|
| 402 |
-
name : Symbol, MatrixSymbol
|
| 403 |
-
The name of this variable. When used for code generation, this
|
| 404 |
-
might appear, for example, in the prototype of function in the
|
| 405 |
-
argument list.
|
| 406 |
-
|
| 407 |
-
result_var : Symbol, Indexed
|
| 408 |
-
Something that can be used to assign a value to this variable.
|
| 409 |
-
Typically the same as `name` but for Indexed this should be e.g.,
|
| 410 |
-
"y[i]" whereas `name` should be the Symbol "y".
|
| 411 |
-
|
| 412 |
-
expr : object
|
| 413 |
-
The expression that should be output, typically a SymPy
|
| 414 |
-
expression.
|
| 415 |
-
|
| 416 |
-
datatype : optional
|
| 417 |
-
When not given, the data type will be guessed based on the
|
| 418 |
-
assumptions on the symbol argument.
|
| 419 |
-
|
| 420 |
-
dimensions : sequence containing tuples, optional
|
| 421 |
-
If present, the argument is interpreted as an array, where this
|
| 422 |
-
sequence of tuples specifies (lower, upper) bounds for each
|
| 423 |
-
index of the array.
|
| 424 |
-
|
| 425 |
-
precision : int, optional
|
| 426 |
-
Controls the precision of floating point constants.
|
| 427 |
-
|
| 428 |
-
"""
|
| 429 |
-
|
| 430 |
-
Argument.__init__(self, name, datatype, dimensions, precision)
|
| 431 |
-
ResultBase.__init__(self, expr, result_var)
|
| 432 |
-
|
| 433 |
-
def __str__(self):
|
| 434 |
-
return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.result_var, self.expr)
|
| 435 |
-
|
| 436 |
-
__repr__ = __str__
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
class InOutArgument(Argument, ResultBase):
|
| 440 |
-
"""InOutArgument are never initialized in the routine."""
|
| 441 |
-
|
| 442 |
-
def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None):
|
| 443 |
-
if not datatype:
|
| 444 |
-
datatype = get_default_datatype(expr)
|
| 445 |
-
Argument.__init__(self, name, datatype, dimensions, precision)
|
| 446 |
-
ResultBase.__init__(self, expr, result_var)
|
| 447 |
-
__init__.__doc__ = OutputArgument.__init__.__doc__
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
def __str__(self):
|
| 451 |
-
return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.expr,
|
| 452 |
-
self.result_var)
|
| 453 |
-
|
| 454 |
-
__repr__ = __str__
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
class Result(Variable, ResultBase):
|
| 458 |
-
"""An expression for a return value.
|
| 459 |
-
|
| 460 |
-
The name result is used to avoid conflicts with the reserved word
|
| 461 |
-
"return" in the Python language. It is also shorter than ReturnValue.
|
| 462 |
-
|
| 463 |
-
These may or may not need a name in the destination (e.g., "return(x*y)"
|
| 464 |
-
might return a value without ever naming it).
|
| 465 |
-
|
| 466 |
-
"""
|
| 467 |
-
|
| 468 |
-
def __init__(self, expr, name=None, result_var=None, datatype=None,
|
| 469 |
-
dimensions=None, precision=None):
|
| 470 |
-
"""Initialize a return value.
|
| 471 |
-
|
| 472 |
-
Parameters
|
| 473 |
-
==========
|
| 474 |
-
|
| 475 |
-
expr : SymPy expression
|
| 476 |
-
|
| 477 |
-
name : Symbol, MatrixSymbol, optional
|
| 478 |
-
The name of this return variable. When used for code generation,
|
| 479 |
-
this might appear, for example, in the prototype of function in a
|
| 480 |
-
list of return values. A dummy name is generated if omitted.
|
| 481 |
-
|
| 482 |
-
result_var : Symbol, Indexed, optional
|
| 483 |
-
Something that can be used to assign a value to this variable.
|
| 484 |
-
Typically the same as `name` but for Indexed this should be e.g.,
|
| 485 |
-
"y[i]" whereas `name` should be the Symbol "y". Defaults to
|
| 486 |
-
`name` if omitted.
|
| 487 |
-
|
| 488 |
-
datatype : optional
|
| 489 |
-
When not given, the data type will be guessed based on the
|
| 490 |
-
assumptions on the expr argument.
|
| 491 |
-
|
| 492 |
-
dimensions : sequence containing tuples, optional
|
| 493 |
-
If present, this variable is interpreted as an array,
|
| 494 |
-
where this sequence of tuples specifies (lower, upper)
|
| 495 |
-
bounds for each index of the array.
|
| 496 |
-
|
| 497 |
-
precision : int, optional
|
| 498 |
-
Controls the precision of floating point constants.
|
| 499 |
-
|
| 500 |
-
"""
|
| 501 |
-
# Basic because it is the base class for all types of expressions
|
| 502 |
-
if not isinstance(expr, (Basic, MatrixBase)):
|
| 503 |
-
raise TypeError("The first argument must be a SymPy expression.")
|
| 504 |
-
|
| 505 |
-
if name is None:
|
| 506 |
-
name = 'result_%d' % abs(hash(expr))
|
| 507 |
-
|
| 508 |
-
if datatype is None:
|
| 509 |
-
#try to infer data type from the expression
|
| 510 |
-
datatype = get_default_datatype(expr)
|
| 511 |
-
|
| 512 |
-
if isinstance(name, str):
|
| 513 |
-
if isinstance(expr, (MatrixBase, MatrixExpr)):
|
| 514 |
-
name = MatrixSymbol(name, *expr.shape)
|
| 515 |
-
else:
|
| 516 |
-
name = Symbol(name)
|
| 517 |
-
|
| 518 |
-
if result_var is None:
|
| 519 |
-
result_var = name
|
| 520 |
-
|
| 521 |
-
Variable.__init__(self, name, datatype=datatype,
|
| 522 |
-
dimensions=dimensions, precision=precision)
|
| 523 |
-
ResultBase.__init__(self, expr, result_var)
|
| 524 |
-
|
| 525 |
-
def __str__(self):
|
| 526 |
-
return "%s(%r, %r, %r)" % (self.__class__.__name__, self.expr, self.name,
|
| 527 |
-
self.result_var)
|
| 528 |
-
|
| 529 |
-
__repr__ = __str__
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
#
|
| 533 |
-
# Transformation of routine objects into code
|
| 534 |
-
#
|
| 535 |
-
|
| 536 |
-
class CodeGen:
|
| 537 |
-
"""Abstract class for the code generators."""
|
| 538 |
-
|
| 539 |
-
printer = None # will be set to an instance of a CodePrinter subclass
|
| 540 |
-
|
| 541 |
-
def _indent_code(self, codelines):
|
| 542 |
-
return self.printer.indent_code(codelines)
|
| 543 |
-
|
| 544 |
-
def _printer_method_with_settings(self, method, settings=None, *args, **kwargs):
|
| 545 |
-
settings = settings or {}
|
| 546 |
-
ori = {k: self.printer._settings[k] for k in settings}
|
| 547 |
-
for k, v in settings.items():
|
| 548 |
-
self.printer._settings[k] = v
|
| 549 |
-
result = getattr(self.printer, method)(*args, **kwargs)
|
| 550 |
-
for k, v in ori.items():
|
| 551 |
-
self.printer._settings[k] = v
|
| 552 |
-
return result
|
| 553 |
-
|
| 554 |
-
def _get_symbol(self, s):
|
| 555 |
-
"""Returns the symbol as fcode prints it."""
|
| 556 |
-
if self.printer._settings['human']:
|
| 557 |
-
expr_str = self.printer.doprint(s)
|
| 558 |
-
else:
|
| 559 |
-
constants, not_supported, expr_str = self.printer.doprint(s)
|
| 560 |
-
if constants or not_supported:
|
| 561 |
-
raise ValueError("Failed to print %s" % str(s))
|
| 562 |
-
return expr_str.strip()
|
| 563 |
-
|
| 564 |
-
def __init__(self, project="project", cse=False):
|
| 565 |
-
"""Initialize a code generator.
|
| 566 |
-
|
| 567 |
-
Derived classes will offer more options that affect the generated
|
| 568 |
-
code.
|
| 569 |
-
|
| 570 |
-
"""
|
| 571 |
-
self.project = project
|
| 572 |
-
self.cse = cse
|
| 573 |
-
|
| 574 |
-
def routine(self, name, expr, argument_sequence=None, global_vars=None):
|
| 575 |
-
"""Creates an Routine object that is appropriate for this language.
|
| 576 |
-
|
| 577 |
-
This implementation is appropriate for at least C/Fortran. Subclasses
|
| 578 |
-
can override this if necessary.
|
| 579 |
-
|
| 580 |
-
Here, we assume at most one return value (the l-value) which must be
|
| 581 |
-
scalar. Additional outputs are OutputArguments (e.g., pointers on
|
| 582 |
-
right-hand-side or pass-by-reference). Matrices are always returned
|
| 583 |
-
via OutputArguments. If ``argument_sequence`` is None, arguments will
|
| 584 |
-
be ordered alphabetically, but with all InputArguments first, and then
|
| 585 |
-
OutputArgument and InOutArguments.
|
| 586 |
-
|
| 587 |
-
"""
|
| 588 |
-
|
| 589 |
-
if self.cse:
|
| 590 |
-
from sympy.simplify.cse_main import cse
|
| 591 |
-
|
| 592 |
-
if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
|
| 593 |
-
if not expr:
|
| 594 |
-
raise ValueError("No expression given")
|
| 595 |
-
for e in expr:
|
| 596 |
-
if not e.is_Equality:
|
| 597 |
-
raise CodeGenError("Lists of expressions must all be Equalities. {} is not.".format(e))
|
| 598 |
-
|
| 599 |
-
# create a list of right hand sides and simplify them
|
| 600 |
-
rhs = [e.rhs for e in expr]
|
| 601 |
-
common, simplified = cse(rhs)
|
| 602 |
-
|
| 603 |
-
# pack the simplified expressions back up with their left hand sides
|
| 604 |
-
expr = [Equality(e.lhs, rhs) for e, rhs in zip(expr, simplified)]
|
| 605 |
-
else:
|
| 606 |
-
if isinstance(expr, Equality):
|
| 607 |
-
common, simplified = cse(expr.rhs) #, ignore=in_out_args)
|
| 608 |
-
expr = Equality(expr.lhs, simplified[0])
|
| 609 |
-
else:
|
| 610 |
-
common, simplified = cse(expr)
|
| 611 |
-
expr = simplified
|
| 612 |
-
|
| 613 |
-
local_vars = [Result(b,a) for a,b in common]
|
| 614 |
-
local_symbols = {a for a,_ in common}
|
| 615 |
-
local_expressions = Tuple(*[b for _,b in common])
|
| 616 |
-
else:
|
| 617 |
-
local_expressions = Tuple()
|
| 618 |
-
|
| 619 |
-
if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
|
| 620 |
-
if not expr:
|
| 621 |
-
raise ValueError("No expression given")
|
| 622 |
-
expressions = Tuple(*expr)
|
| 623 |
-
else:
|
| 624 |
-
expressions = Tuple(expr)
|
| 625 |
-
|
| 626 |
-
if self.cse:
|
| 627 |
-
if {i.label for i in expressions.atoms(Idx)} != set():
|
| 628 |
-
raise CodeGenError("CSE and Indexed expressions do not play well together yet")
|
| 629 |
-
else:
|
| 630 |
-
# local variables for indexed expressions
|
| 631 |
-
local_vars = {i.label for i in expressions.atoms(Idx)}
|
| 632 |
-
local_symbols = local_vars
|
| 633 |
-
|
| 634 |
-
# global variables
|
| 635 |
-
global_vars = set() if global_vars is None else set(global_vars)
|
| 636 |
-
|
| 637 |
-
# symbols that should be arguments
|
| 638 |
-
symbols = (expressions.free_symbols | local_expressions.free_symbols) - local_symbols - global_vars
|
| 639 |
-
new_symbols = set()
|
| 640 |
-
new_symbols.update(symbols)
|
| 641 |
-
|
| 642 |
-
for symbol in symbols:
|
| 643 |
-
if isinstance(symbol, Idx):
|
| 644 |
-
new_symbols.remove(symbol)
|
| 645 |
-
new_symbols.update(symbol.args[1].free_symbols)
|
| 646 |
-
if isinstance(symbol, Indexed):
|
| 647 |
-
new_symbols.remove(symbol)
|
| 648 |
-
symbols = new_symbols
|
| 649 |
-
|
| 650 |
-
# Decide whether to use output argument or return value
|
| 651 |
-
return_val = []
|
| 652 |
-
output_args = []
|
| 653 |
-
for expr in expressions:
|
| 654 |
-
if isinstance(expr, Equality):
|
| 655 |
-
out_arg = expr.lhs
|
| 656 |
-
expr = expr.rhs
|
| 657 |
-
if isinstance(out_arg, Indexed):
|
| 658 |
-
dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape])
|
| 659 |
-
symbol = out_arg.base.label
|
| 660 |
-
elif isinstance(out_arg, Symbol):
|
| 661 |
-
dims = []
|
| 662 |
-
symbol = out_arg
|
| 663 |
-
elif isinstance(out_arg, MatrixSymbol):
|
| 664 |
-
dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape])
|
| 665 |
-
symbol = out_arg
|
| 666 |
-
else:
|
| 667 |
-
raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
|
| 668 |
-
"can define output arguments.")
|
| 669 |
-
|
| 670 |
-
if expr.has(symbol):
|
| 671 |
-
output_args.append(
|
| 672 |
-
InOutArgument(symbol, out_arg, expr, dimensions=dims))
|
| 673 |
-
else:
|
| 674 |
-
output_args.append(
|
| 675 |
-
OutputArgument(symbol, out_arg, expr, dimensions=dims))
|
| 676 |
-
|
| 677 |
-
# remove duplicate arguments when they are not local variables
|
| 678 |
-
if symbol not in local_vars:
|
| 679 |
-
# avoid duplicate arguments
|
| 680 |
-
symbols.remove(symbol)
|
| 681 |
-
elif isinstance(expr, (ImmutableMatrix, MatrixSlice)):
|
| 682 |
-
# Create a "dummy" MatrixSymbol to use as the Output arg
|
| 683 |
-
out_arg = MatrixSymbol('out_%s' % abs(hash(expr)), *expr.shape)
|
| 684 |
-
dims = tuple([(S.Zero, dim - 1) for dim in out_arg.shape])
|
| 685 |
-
output_args.append(
|
| 686 |
-
OutputArgument(out_arg, out_arg, expr, dimensions=dims))
|
| 687 |
-
else:
|
| 688 |
-
return_val.append(Result(expr))
|
| 689 |
-
|
| 690 |
-
arg_list = []
|
| 691 |
-
|
| 692 |
-
# setup input argument list
|
| 693 |
-
|
| 694 |
-
# helper to get dimensions for data for array-like args
|
| 695 |
-
def dimensions(s):
|
| 696 |
-
return [(S.Zero, dim - 1) for dim in s.shape]
|
| 697 |
-
|
| 698 |
-
array_symbols = {}
|
| 699 |
-
for array in expressions.atoms(Indexed) | local_expressions.atoms(Indexed):
|
| 700 |
-
array_symbols[array.base.label] = array
|
| 701 |
-
for array in expressions.atoms(MatrixSymbol) | local_expressions.atoms(MatrixSymbol):
|
| 702 |
-
array_symbols[array] = array
|
| 703 |
-
|
| 704 |
-
for symbol in sorted(symbols, key=str):
|
| 705 |
-
if symbol in array_symbols:
|
| 706 |
-
array = array_symbols[symbol]
|
| 707 |
-
metadata = {'dimensions': dimensions(array)}
|
| 708 |
-
else:
|
| 709 |
-
metadata = {}
|
| 710 |
-
|
| 711 |
-
arg_list.append(InputArgument(symbol, **metadata))
|
| 712 |
-
|
| 713 |
-
output_args.sort(key=lambda x: str(x.name))
|
| 714 |
-
arg_list.extend(output_args)
|
| 715 |
-
|
| 716 |
-
if argument_sequence is not None:
|
| 717 |
-
# if the user has supplied IndexedBase instances, we'll accept that
|
| 718 |
-
new_sequence = []
|
| 719 |
-
for arg in argument_sequence:
|
| 720 |
-
if isinstance(arg, IndexedBase):
|
| 721 |
-
new_sequence.append(arg.label)
|
| 722 |
-
else:
|
| 723 |
-
new_sequence.append(arg)
|
| 724 |
-
argument_sequence = new_sequence
|
| 725 |
-
|
| 726 |
-
missing = [x for x in arg_list if x.name not in argument_sequence]
|
| 727 |
-
if missing:
|
| 728 |
-
msg = "Argument list didn't specify: {0} "
|
| 729 |
-
msg = msg.format(", ".join([str(m.name) for m in missing]))
|
| 730 |
-
raise CodeGenArgumentListError(msg, missing)
|
| 731 |
-
|
| 732 |
-
# create redundant arguments to produce the requested sequence
|
| 733 |
-
name_arg_dict = {x.name: x for x in arg_list}
|
| 734 |
-
new_args = []
|
| 735 |
-
for symbol in argument_sequence:
|
| 736 |
-
try:
|
| 737 |
-
new_args.append(name_arg_dict[symbol])
|
| 738 |
-
except KeyError:
|
| 739 |
-
if isinstance(symbol, (IndexedBase, MatrixSymbol)):
|
| 740 |
-
metadata = {'dimensions': dimensions(symbol)}
|
| 741 |
-
else:
|
| 742 |
-
metadata = {}
|
| 743 |
-
new_args.append(InputArgument(symbol, **metadata))
|
| 744 |
-
arg_list = new_args
|
| 745 |
-
|
| 746 |
-
return Routine(name, arg_list, return_val, local_vars, global_vars)
|
| 747 |
-
|
| 748 |
-
def write(self, routines, prefix, to_files=False, header=True, empty=True):
|
| 749 |
-
"""Writes all the source code files for the given routines.
|
| 750 |
-
|
| 751 |
-
The generated source is returned as a list of (filename, contents)
|
| 752 |
-
tuples, or is written to files (see below). Each filename consists
|
| 753 |
-
of the given prefix, appended with an appropriate extension.
|
| 754 |
-
|
| 755 |
-
Parameters
|
| 756 |
-
==========
|
| 757 |
-
|
| 758 |
-
routines : list
|
| 759 |
-
A list of Routine instances to be written
|
| 760 |
-
|
| 761 |
-
prefix : string
|
| 762 |
-
The prefix for the output files
|
| 763 |
-
|
| 764 |
-
to_files : bool, optional
|
| 765 |
-
When True, the output is written to files. Otherwise, a list
|
| 766 |
-
of (filename, contents) tuples is returned. [default: False]
|
| 767 |
-
|
| 768 |
-
header : bool, optional
|
| 769 |
-
When True, a header comment is included on top of each source
|
| 770 |
-
file. [default: True]
|
| 771 |
-
|
| 772 |
-
empty : bool, optional
|
| 773 |
-
When True, empty lines are included to structure the source
|
| 774 |
-
files. [default: True]
|
| 775 |
-
|
| 776 |
-
"""
|
| 777 |
-
if to_files:
|
| 778 |
-
for dump_fn in self.dump_fns:
|
| 779 |
-
filename = "%s.%s" % (prefix, dump_fn.extension)
|
| 780 |
-
with open(filename, "w") as f:
|
| 781 |
-
dump_fn(self, routines, f, prefix, header, empty)
|
| 782 |
-
else:
|
| 783 |
-
result = []
|
| 784 |
-
for dump_fn in self.dump_fns:
|
| 785 |
-
filename = "%s.%s" % (prefix, dump_fn.extension)
|
| 786 |
-
contents = StringIO()
|
| 787 |
-
dump_fn(self, routines, contents, prefix, header, empty)
|
| 788 |
-
result.append((filename, contents.getvalue()))
|
| 789 |
-
return result
|
| 790 |
-
|
| 791 |
-
def dump_code(self, routines, f, prefix, header=True, empty=True):
|
| 792 |
-
"""Write the code by calling language specific methods.
|
| 793 |
-
|
| 794 |
-
The generated file contains all the definitions of the routines in
|
| 795 |
-
low-level code and refers to the header file if appropriate.
|
| 796 |
-
|
| 797 |
-
Parameters
|
| 798 |
-
==========
|
| 799 |
-
|
| 800 |
-
routines : list
|
| 801 |
-
A list of Routine instances.
|
| 802 |
-
|
| 803 |
-
f : file-like
|
| 804 |
-
Where to write the file.
|
| 805 |
-
|
| 806 |
-
prefix : string
|
| 807 |
-
The filename prefix, used to refer to the proper header file.
|
| 808 |
-
Only the basename of the prefix is used.
|
| 809 |
-
|
| 810 |
-
header : bool, optional
|
| 811 |
-
When True, a header comment is included on top of each source
|
| 812 |
-
file. [default : True]
|
| 813 |
-
|
| 814 |
-
empty : bool, optional
|
| 815 |
-
When True, empty lines are included to structure the source
|
| 816 |
-
files. [default : True]
|
| 817 |
-
|
| 818 |
-
"""
|
| 819 |
-
|
| 820 |
-
code_lines = self._preprocessor_statements(prefix)
|
| 821 |
-
|
| 822 |
-
for routine in routines:
|
| 823 |
-
if empty:
|
| 824 |
-
code_lines.append("\n")
|
| 825 |
-
code_lines.extend(self._get_routine_opening(routine))
|
| 826 |
-
code_lines.extend(self._declare_arguments(routine))
|
| 827 |
-
code_lines.extend(self._declare_globals(routine))
|
| 828 |
-
code_lines.extend(self._declare_locals(routine))
|
| 829 |
-
if empty:
|
| 830 |
-
code_lines.append("\n")
|
| 831 |
-
code_lines.extend(self._call_printer(routine))
|
| 832 |
-
if empty:
|
| 833 |
-
code_lines.append("\n")
|
| 834 |
-
code_lines.extend(self._get_routine_ending(routine))
|
| 835 |
-
|
| 836 |
-
code_lines = self._indent_code(''.join(code_lines))
|
| 837 |
-
|
| 838 |
-
if header:
|
| 839 |
-
code_lines = ''.join(self._get_header() + [code_lines])
|
| 840 |
-
|
| 841 |
-
if code_lines:
|
| 842 |
-
f.write(code_lines)
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
class CodeGenError(Exception):
|
| 846 |
-
pass
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
class CodeGenArgumentListError(Exception):
|
| 850 |
-
@property
|
| 851 |
-
def missing_args(self):
|
| 852 |
-
return self.args[1]
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
header_comment = """Code generated with SymPy %(version)s
|
| 856 |
-
|
| 857 |
-
See http://www.sympy.org/ for more information.
|
| 858 |
-
|
| 859 |
-
This file is part of '%(project)s'
|
| 860 |
-
"""
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
class CCodeGen(CodeGen):
|
| 864 |
-
"""Generator for C code.
|
| 865 |
-
|
| 866 |
-
The .write() method inherited from CodeGen will output a code file and
|
| 867 |
-
an interface file, <prefix>.c and <prefix>.h respectively.
|
| 868 |
-
|
| 869 |
-
"""
|
| 870 |
-
|
| 871 |
-
code_extension = "c"
|
| 872 |
-
interface_extension = "h"
|
| 873 |
-
standard = 'c99'
|
| 874 |
-
|
| 875 |
-
def __init__(self, project="project", printer=None,
|
| 876 |
-
preprocessor_statements=None, cse=False):
|
| 877 |
-
super().__init__(project=project, cse=cse)
|
| 878 |
-
self.printer = printer or c_code_printers[self.standard.lower()]()
|
| 879 |
-
|
| 880 |
-
self.preprocessor_statements = preprocessor_statements
|
| 881 |
-
if preprocessor_statements is None:
|
| 882 |
-
self.preprocessor_statements = ['#include <math.h>']
|
| 883 |
-
|
| 884 |
-
def _get_header(self):
|
| 885 |
-
"""Writes a common header for the generated files."""
|
| 886 |
-
code_lines = []
|
| 887 |
-
code_lines.append("/" + "*"*78 + '\n')
|
| 888 |
-
tmp = header_comment % {"version": sympy_version,
|
| 889 |
-
"project": self.project}
|
| 890 |
-
for line in tmp.splitlines():
|
| 891 |
-
code_lines.append(" *%s*\n" % line.center(76))
|
| 892 |
-
code_lines.append(" " + "*"*78 + "/\n")
|
| 893 |
-
return code_lines
|
| 894 |
-
|
| 895 |
-
def get_prototype(self, routine):
|
| 896 |
-
"""Returns a string for the function prototype of the routine.
|
| 897 |
-
|
| 898 |
-
If the routine has multiple result objects, an CodeGenError is
|
| 899 |
-
raised.
|
| 900 |
-
|
| 901 |
-
See: https://en.wikipedia.org/wiki/Function_prototype
|
| 902 |
-
|
| 903 |
-
"""
|
| 904 |
-
if len(routine.results) > 1:
|
| 905 |
-
raise CodeGenError("C only supports a single or no return value.")
|
| 906 |
-
elif len(routine.results) == 1:
|
| 907 |
-
ctype = routine.results[0].get_datatype('C')
|
| 908 |
-
else:
|
| 909 |
-
ctype = "void"
|
| 910 |
-
|
| 911 |
-
type_args = []
|
| 912 |
-
for arg in routine.arguments:
|
| 913 |
-
name = self.printer.doprint(arg.name)
|
| 914 |
-
if arg.dimensions or isinstance(arg, ResultBase):
|
| 915 |
-
type_args.append((arg.get_datatype('C'), "*%s" % name))
|
| 916 |
-
else:
|
| 917 |
-
type_args.append((arg.get_datatype('C'), name))
|
| 918 |
-
arguments = ", ".join([ "%s %s" % t for t in type_args])
|
| 919 |
-
return "%s %s(%s)" % (ctype, routine.name, arguments)
|
| 920 |
-
|
| 921 |
-
def _preprocessor_statements(self, prefix):
|
| 922 |
-
code_lines = []
|
| 923 |
-
code_lines.append('#include "{}.h"'.format(os.path.basename(prefix)))
|
| 924 |
-
code_lines.extend(self.preprocessor_statements)
|
| 925 |
-
code_lines = ['{}\n'.format(l) for l in code_lines]
|
| 926 |
-
return code_lines
|
| 927 |
-
|
| 928 |
-
def _get_routine_opening(self, routine):
|
| 929 |
-
prototype = self.get_prototype(routine)
|
| 930 |
-
return ["%s {\n" % prototype]
|
| 931 |
-
|
| 932 |
-
def _declare_arguments(self, routine):
|
| 933 |
-
# arguments are declared in prototype
|
| 934 |
-
return []
|
| 935 |
-
|
| 936 |
-
def _declare_globals(self, routine):
|
| 937 |
-
# global variables are not explicitly declared within C functions
|
| 938 |
-
return []
|
| 939 |
-
|
| 940 |
-
def _declare_locals(self, routine):
|
| 941 |
-
|
| 942 |
-
# Compose a list of symbols to be dereferenced in the function
|
| 943 |
-
# body. These are the arguments that were passed by a reference
|
| 944 |
-
# pointer, excluding arrays.
|
| 945 |
-
dereference = []
|
| 946 |
-
for arg in routine.arguments:
|
| 947 |
-
if isinstance(arg, ResultBase) and not arg.dimensions:
|
| 948 |
-
dereference.append(arg.name)
|
| 949 |
-
|
| 950 |
-
code_lines = []
|
| 951 |
-
for result in routine.local_vars:
|
| 952 |
-
|
| 953 |
-
# local variables that are simple symbols such as those used as indices into
|
| 954 |
-
# for loops are defined declared elsewhere.
|
| 955 |
-
if not isinstance(result, Result):
|
| 956 |
-
continue
|
| 957 |
-
|
| 958 |
-
if result.name != result.result_var:
|
| 959 |
-
raise CodeGen("Result variable and name should match: {}".format(result))
|
| 960 |
-
assign_to = result.name
|
| 961 |
-
t = result.get_datatype('c')
|
| 962 |
-
if isinstance(result.expr, (MatrixBase, MatrixExpr)):
|
| 963 |
-
dims = result.expr.shape
|
| 964 |
-
code_lines.append("{} {}[{}];\n".format(t, str(assign_to), dims[0]*dims[1]))
|
| 965 |
-
prefix = ""
|
| 966 |
-
else:
|
| 967 |
-
prefix = "const {} ".format(t)
|
| 968 |
-
|
| 969 |
-
constants, not_c, c_expr = self._printer_method_with_settings(
|
| 970 |
-
'doprint', {"human": False, "dereference": dereference, "strict": False},
|
| 971 |
-
result.expr, assign_to=assign_to)
|
| 972 |
-
|
| 973 |
-
for name, value in sorted(constants, key=str):
|
| 974 |
-
code_lines.append("double const %s = %s;\n" % (name, value))
|
| 975 |
-
|
| 976 |
-
code_lines.append("{}{}\n".format(prefix, c_expr))
|
| 977 |
-
|
| 978 |
-
return code_lines
|
| 979 |
-
|
| 980 |
-
def _call_printer(self, routine):
|
| 981 |
-
code_lines = []
|
| 982 |
-
|
| 983 |
-
# Compose a list of symbols to be dereferenced in the function
|
| 984 |
-
# body. These are the arguments that were passed by a reference
|
| 985 |
-
# pointer, excluding arrays.
|
| 986 |
-
dereference = []
|
| 987 |
-
for arg in routine.arguments:
|
| 988 |
-
if isinstance(arg, ResultBase) and not arg.dimensions:
|
| 989 |
-
dereference.append(arg.name)
|
| 990 |
-
|
| 991 |
-
return_val = None
|
| 992 |
-
for result in routine.result_variables:
|
| 993 |
-
if isinstance(result, Result):
|
| 994 |
-
assign_to = routine.name + "_result"
|
| 995 |
-
t = result.get_datatype('c')
|
| 996 |
-
code_lines.append("{} {};\n".format(t, str(assign_to)))
|
| 997 |
-
return_val = assign_to
|
| 998 |
-
else:
|
| 999 |
-
assign_to = result.result_var
|
| 1000 |
-
|
| 1001 |
-
try:
|
| 1002 |
-
constants, not_c, c_expr = self._printer_method_with_settings(
|
| 1003 |
-
'doprint', {"human": False, "dereference": dereference, "strict": False},
|
| 1004 |
-
result.expr, assign_to=assign_to)
|
| 1005 |
-
except AssignmentError:
|
| 1006 |
-
assign_to = result.result_var
|
| 1007 |
-
code_lines.append(
|
| 1008 |
-
"%s %s;\n" % (result.get_datatype('c'), str(assign_to)))
|
| 1009 |
-
constants, not_c, c_expr = self._printer_method_with_settings(
|
| 1010 |
-
'doprint', {"human": False, "dereference": dereference, "strict": False},
|
| 1011 |
-
result.expr, assign_to=assign_to)
|
| 1012 |
-
|
| 1013 |
-
for name, value in sorted(constants, key=str):
|
| 1014 |
-
code_lines.append("double const %s = %s;\n" % (name, value))
|
| 1015 |
-
code_lines.append("%s\n" % c_expr)
|
| 1016 |
-
|
| 1017 |
-
if return_val:
|
| 1018 |
-
code_lines.append(" return %s;\n" % return_val)
|
| 1019 |
-
return code_lines
|
| 1020 |
-
|
| 1021 |
-
def _get_routine_ending(self, routine):
|
| 1022 |
-
return ["}\n"]
|
| 1023 |
-
|
| 1024 |
-
def dump_c(self, routines, f, prefix, header=True, empty=True):
|
| 1025 |
-
self.dump_code(routines, f, prefix, header, empty)
|
| 1026 |
-
dump_c.extension = code_extension # type: ignore
|
| 1027 |
-
dump_c.__doc__ = CodeGen.dump_code.__doc__
|
| 1028 |
-
|
| 1029 |
-
def dump_h(self, routines, f, prefix, header=True, empty=True):
|
| 1030 |
-
"""Writes the C header file.
|
| 1031 |
-
|
| 1032 |
-
This file contains all the function declarations.
|
| 1033 |
-
|
| 1034 |
-
Parameters
|
| 1035 |
-
==========
|
| 1036 |
-
|
| 1037 |
-
routines : list
|
| 1038 |
-
A list of Routine instances.
|
| 1039 |
-
|
| 1040 |
-
f : file-like
|
| 1041 |
-
Where to write the file.
|
| 1042 |
-
|
| 1043 |
-
prefix : string
|
| 1044 |
-
The filename prefix, used to construct the include guards.
|
| 1045 |
-
Only the basename of the prefix is used.
|
| 1046 |
-
|
| 1047 |
-
header : bool, optional
|
| 1048 |
-
When True, a header comment is included on top of each source
|
| 1049 |
-
file. [default : True]
|
| 1050 |
-
|
| 1051 |
-
empty : bool, optional
|
| 1052 |
-
When True, empty lines are included to structure the source
|
| 1053 |
-
files. [default : True]
|
| 1054 |
-
|
| 1055 |
-
"""
|
| 1056 |
-
if header:
|
| 1057 |
-
print(''.join(self._get_header()), file=f)
|
| 1058 |
-
guard_name = "%s__%s__H" % (self.project.replace(
|
| 1059 |
-
" ", "_").upper(), prefix.replace("/", "_").upper())
|
| 1060 |
-
# include guards
|
| 1061 |
-
if empty:
|
| 1062 |
-
print(file=f)
|
| 1063 |
-
print("#ifndef %s" % guard_name, file=f)
|
| 1064 |
-
print("#define %s" % guard_name, file=f)
|
| 1065 |
-
if empty:
|
| 1066 |
-
print(file=f)
|
| 1067 |
-
# declaration of the function prototypes
|
| 1068 |
-
for routine in routines:
|
| 1069 |
-
prototype = self.get_prototype(routine)
|
| 1070 |
-
print("%s;" % prototype, file=f)
|
| 1071 |
-
# end if include guards
|
| 1072 |
-
if empty:
|
| 1073 |
-
print(file=f)
|
| 1074 |
-
print("#endif", file=f)
|
| 1075 |
-
if empty:
|
| 1076 |
-
print(file=f)
|
| 1077 |
-
dump_h.extension = interface_extension # type: ignore
|
| 1078 |
-
|
| 1079 |
-
# This list of dump functions is used by CodeGen.write to know which dump
|
| 1080 |
-
# functions it has to call.
|
| 1081 |
-
dump_fns = [dump_c, dump_h]
|
| 1082 |
-
|
| 1083 |
-
class C89CodeGen(CCodeGen):
|
| 1084 |
-
standard = 'C89'
|
| 1085 |
-
|
| 1086 |
-
class C99CodeGen(CCodeGen):
|
| 1087 |
-
standard = 'C99'
|
| 1088 |
-
|
| 1089 |
-
class FCodeGen(CodeGen):
|
| 1090 |
-
"""Generator for Fortran 95 code
|
| 1091 |
-
|
| 1092 |
-
The .write() method inherited from CodeGen will output a code file and
|
| 1093 |
-
an interface file, <prefix>.f90 and <prefix>.h respectively.
|
| 1094 |
-
|
| 1095 |
-
"""
|
| 1096 |
-
|
| 1097 |
-
code_extension = "f90"
|
| 1098 |
-
interface_extension = "h"
|
| 1099 |
-
|
| 1100 |
-
def __init__(self, project='project', printer=None):
|
| 1101 |
-
super().__init__(project)
|
| 1102 |
-
self.printer = printer or FCodePrinter()
|
| 1103 |
-
|
| 1104 |
-
def _get_header(self):
|
| 1105 |
-
"""Writes a common header for the generated files."""
|
| 1106 |
-
code_lines = []
|
| 1107 |
-
code_lines.append("!" + "*"*78 + '\n')
|
| 1108 |
-
tmp = header_comment % {"version": sympy_version,
|
| 1109 |
-
"project": self.project}
|
| 1110 |
-
for line in tmp.splitlines():
|
| 1111 |
-
code_lines.append("!*%s*\n" % line.center(76))
|
| 1112 |
-
code_lines.append("!" + "*"*78 + '\n')
|
| 1113 |
-
return code_lines
|
| 1114 |
-
|
| 1115 |
-
def _preprocessor_statements(self, prefix):
|
| 1116 |
-
return []
|
| 1117 |
-
|
| 1118 |
-
def _get_routine_opening(self, routine):
|
| 1119 |
-
"""Returns the opening statements of the fortran routine."""
|
| 1120 |
-
code_list = []
|
| 1121 |
-
if len(routine.results) > 1:
|
| 1122 |
-
raise CodeGenError(
|
| 1123 |
-
"Fortran only supports a single or no return value.")
|
| 1124 |
-
elif len(routine.results) == 1:
|
| 1125 |
-
result = routine.results[0]
|
| 1126 |
-
code_list.append(result.get_datatype('fortran'))
|
| 1127 |
-
code_list.append("function")
|
| 1128 |
-
else:
|
| 1129 |
-
code_list.append("subroutine")
|
| 1130 |
-
|
| 1131 |
-
args = ", ".join("%s" % self._get_symbol(arg.name)
|
| 1132 |
-
for arg in routine.arguments)
|
| 1133 |
-
|
| 1134 |
-
call_sig = "{}({})\n".format(routine.name, args)
|
| 1135 |
-
# Fortran 95 requires all lines be less than 132 characters, so wrap
|
| 1136 |
-
# this line before appending.
|
| 1137 |
-
call_sig = ' &\n'.join(textwrap.wrap(call_sig,
|
| 1138 |
-
width=60,
|
| 1139 |
-
break_long_words=False)) + '\n'
|
| 1140 |
-
code_list.append(call_sig)
|
| 1141 |
-
code_list = [' '.join(code_list)]
|
| 1142 |
-
code_list.append('implicit none\n')
|
| 1143 |
-
return code_list
|
| 1144 |
-
|
| 1145 |
-
def _declare_arguments(self, routine):
|
| 1146 |
-
# argument type declarations
|
| 1147 |
-
code_list = []
|
| 1148 |
-
array_list = []
|
| 1149 |
-
scalar_list = []
|
| 1150 |
-
for arg in routine.arguments:
|
| 1151 |
-
|
| 1152 |
-
if isinstance(arg, InputArgument):
|
| 1153 |
-
typeinfo = "%s, intent(in)" % arg.get_datatype('fortran')
|
| 1154 |
-
elif isinstance(arg, InOutArgument):
|
| 1155 |
-
typeinfo = "%s, intent(inout)" % arg.get_datatype('fortran')
|
| 1156 |
-
elif isinstance(arg, OutputArgument):
|
| 1157 |
-
typeinfo = "%s, intent(out)" % arg.get_datatype('fortran')
|
| 1158 |
-
else:
|
| 1159 |
-
raise CodeGenError("Unknown Argument type: %s" % type(arg))
|
| 1160 |
-
|
| 1161 |
-
fprint = self._get_symbol
|
| 1162 |
-
|
| 1163 |
-
if arg.dimensions:
|
| 1164 |
-
# fortran arrays start at 1
|
| 1165 |
-
dimstr = ", ".join(["%s:%s" % (
|
| 1166 |
-
fprint(dim[0] + 1), fprint(dim[1] + 1))
|
| 1167 |
-
for dim in arg.dimensions])
|
| 1168 |
-
typeinfo += ", dimension(%s)" % dimstr
|
| 1169 |
-
array_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name)))
|
| 1170 |
-
else:
|
| 1171 |
-
scalar_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name)))
|
| 1172 |
-
|
| 1173 |
-
# scalars first, because they can be used in array declarations
|
| 1174 |
-
code_list.extend(scalar_list)
|
| 1175 |
-
code_list.extend(array_list)
|
| 1176 |
-
|
| 1177 |
-
return code_list
|
| 1178 |
-
|
| 1179 |
-
def _declare_globals(self, routine):
|
| 1180 |
-
# Global variables not explicitly declared within Fortran 90 functions.
|
| 1181 |
-
# Note: a future F77 mode may need to generate "common" blocks.
|
| 1182 |
-
return []
|
| 1183 |
-
|
| 1184 |
-
def _declare_locals(self, routine):
|
| 1185 |
-
code_list = []
|
| 1186 |
-
for var in sorted(routine.local_vars, key=str):
|
| 1187 |
-
typeinfo = get_default_datatype(var)
|
| 1188 |
-
code_list.append("%s :: %s\n" % (
|
| 1189 |
-
typeinfo.fname, self._get_symbol(var)))
|
| 1190 |
-
return code_list
|
| 1191 |
-
|
| 1192 |
-
def _get_routine_ending(self, routine):
|
| 1193 |
-
"""Returns the closing statements of the fortran routine."""
|
| 1194 |
-
if len(routine.results) == 1:
|
| 1195 |
-
return ["end function\n"]
|
| 1196 |
-
else:
|
| 1197 |
-
return ["end subroutine\n"]
|
| 1198 |
-
|
| 1199 |
-
def get_interface(self, routine):
|
| 1200 |
-
"""Returns a string for the function interface.
|
| 1201 |
-
|
| 1202 |
-
The routine should have a single result object, which can be None.
|
| 1203 |
-
If the routine has multiple result objects, a CodeGenError is
|
| 1204 |
-
raised.
|
| 1205 |
-
|
| 1206 |
-
See: https://en.wikipedia.org/wiki/Function_prototype
|
| 1207 |
-
|
| 1208 |
-
"""
|
| 1209 |
-
prototype = [ "interface\n" ]
|
| 1210 |
-
prototype.extend(self._get_routine_opening(routine))
|
| 1211 |
-
prototype.extend(self._declare_arguments(routine))
|
| 1212 |
-
prototype.extend(self._get_routine_ending(routine))
|
| 1213 |
-
prototype.append("end interface\n")
|
| 1214 |
-
|
| 1215 |
-
return "".join(prototype)
|
| 1216 |
-
|
| 1217 |
-
def _call_printer(self, routine):
|
| 1218 |
-
declarations = []
|
| 1219 |
-
code_lines = []
|
| 1220 |
-
for result in routine.result_variables:
|
| 1221 |
-
if isinstance(result, Result):
|
| 1222 |
-
assign_to = routine.name
|
| 1223 |
-
elif isinstance(result, (OutputArgument, InOutArgument)):
|
| 1224 |
-
assign_to = result.result_var
|
| 1225 |
-
|
| 1226 |
-
constants, not_fortran, f_expr = self._printer_method_with_settings(
|
| 1227 |
-
'doprint', {"human": False, "source_format": 'free', "standard": 95, "strict": False},
|
| 1228 |
-
result.expr, assign_to=assign_to)
|
| 1229 |
-
|
| 1230 |
-
for obj, v in sorted(constants, key=str):
|
| 1231 |
-
t = get_default_datatype(obj)
|
| 1232 |
-
declarations.append(
|
| 1233 |
-
"%s, parameter :: %s = %s\n" % (t.fname, obj, v))
|
| 1234 |
-
for obj in sorted(not_fortran, key=str):
|
| 1235 |
-
t = get_default_datatype(obj)
|
| 1236 |
-
if isinstance(obj, Function):
|
| 1237 |
-
name = obj.func
|
| 1238 |
-
else:
|
| 1239 |
-
name = obj
|
| 1240 |
-
declarations.append("%s :: %s\n" % (t.fname, name))
|
| 1241 |
-
|
| 1242 |
-
code_lines.append("%s\n" % f_expr)
|
| 1243 |
-
return declarations + code_lines
|
| 1244 |
-
|
| 1245 |
-
def _indent_code(self, codelines):
|
| 1246 |
-
return self._printer_method_with_settings(
|
| 1247 |
-
'indent_code', {"human": False, "source_format": 'free', "strict": False}, codelines)
|
| 1248 |
-
|
| 1249 |
-
def dump_f95(self, routines, f, prefix, header=True, empty=True):
|
| 1250 |
-
# check that symbols are unique with ignorecase
|
| 1251 |
-
for r in routines:
|
| 1252 |
-
lowercase = {str(x).lower() for x in r.variables}
|
| 1253 |
-
orig_case = {str(x) for x in r.variables}
|
| 1254 |
-
if len(lowercase) < len(orig_case):
|
| 1255 |
-
raise CodeGenError("Fortran ignores case. Got symbols: %s" %
|
| 1256 |
-
(", ".join([str(var) for var in r.variables])))
|
| 1257 |
-
self.dump_code(routines, f, prefix, header, empty)
|
| 1258 |
-
dump_f95.extension = code_extension # type: ignore
|
| 1259 |
-
dump_f95.__doc__ = CodeGen.dump_code.__doc__
|
| 1260 |
-
|
| 1261 |
-
def dump_h(self, routines, f, prefix, header=True, empty=True):
|
| 1262 |
-
"""Writes the interface to a header file.
|
| 1263 |
-
|
| 1264 |
-
This file contains all the function declarations.
|
| 1265 |
-
|
| 1266 |
-
Parameters
|
| 1267 |
-
==========
|
| 1268 |
-
|
| 1269 |
-
routines : list
|
| 1270 |
-
A list of Routine instances.
|
| 1271 |
-
|
| 1272 |
-
f : file-like
|
| 1273 |
-
Where to write the file.
|
| 1274 |
-
|
| 1275 |
-
prefix : string
|
| 1276 |
-
The filename prefix.
|
| 1277 |
-
|
| 1278 |
-
header : bool, optional
|
| 1279 |
-
When True, a header comment is included on top of each source
|
| 1280 |
-
file. [default : True]
|
| 1281 |
-
|
| 1282 |
-
empty : bool, optional
|
| 1283 |
-
When True, empty lines are included to structure the source
|
| 1284 |
-
files. [default : True]
|
| 1285 |
-
|
| 1286 |
-
"""
|
| 1287 |
-
if header:
|
| 1288 |
-
print(''.join(self._get_header()), file=f)
|
| 1289 |
-
if empty:
|
| 1290 |
-
print(file=f)
|
| 1291 |
-
# declaration of the function prototypes
|
| 1292 |
-
for routine in routines:
|
| 1293 |
-
prototype = self.get_interface(routine)
|
| 1294 |
-
f.write(prototype)
|
| 1295 |
-
if empty:
|
| 1296 |
-
print(file=f)
|
| 1297 |
-
dump_h.extension = interface_extension # type: ignore
|
| 1298 |
-
|
| 1299 |
-
# This list of dump functions is used by CodeGen.write to know which dump
|
| 1300 |
-
# functions it has to call.
|
| 1301 |
-
dump_fns = [dump_f95, dump_h]
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
class JuliaCodeGen(CodeGen):
|
| 1305 |
-
"""Generator for Julia code.
|
| 1306 |
-
|
| 1307 |
-
The .write() method inherited from CodeGen will output a code file
|
| 1308 |
-
<prefix>.jl.
|
| 1309 |
-
|
| 1310 |
-
"""
|
| 1311 |
-
|
| 1312 |
-
code_extension = "jl"
|
| 1313 |
-
|
| 1314 |
-
def __init__(self, project='project', printer=None):
|
| 1315 |
-
super().__init__(project)
|
| 1316 |
-
self.printer = printer or JuliaCodePrinter()
|
| 1317 |
-
|
| 1318 |
-
def routine(self, name, expr, argument_sequence, global_vars):
|
| 1319 |
-
"""Specialized Routine creation for Julia."""
|
| 1320 |
-
|
| 1321 |
-
if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
|
| 1322 |
-
if not expr:
|
| 1323 |
-
raise ValueError("No expression given")
|
| 1324 |
-
expressions = Tuple(*expr)
|
| 1325 |
-
else:
|
| 1326 |
-
expressions = Tuple(expr)
|
| 1327 |
-
|
| 1328 |
-
# local variables
|
| 1329 |
-
local_vars = {i.label for i in expressions.atoms(Idx)}
|
| 1330 |
-
|
| 1331 |
-
# global variables
|
| 1332 |
-
global_vars = set() if global_vars is None else set(global_vars)
|
| 1333 |
-
|
| 1334 |
-
# symbols that should be arguments
|
| 1335 |
-
old_symbols = expressions.free_symbols - local_vars - global_vars
|
| 1336 |
-
symbols = set()
|
| 1337 |
-
for s in old_symbols:
|
| 1338 |
-
if isinstance(s, Idx):
|
| 1339 |
-
symbols.update(s.args[1].free_symbols)
|
| 1340 |
-
elif not isinstance(s, Indexed):
|
| 1341 |
-
symbols.add(s)
|
| 1342 |
-
|
| 1343 |
-
# Julia supports multiple return values
|
| 1344 |
-
return_vals = []
|
| 1345 |
-
output_args = []
|
| 1346 |
-
for (i, expr) in enumerate(expressions):
|
| 1347 |
-
if isinstance(expr, Equality):
|
| 1348 |
-
out_arg = expr.lhs
|
| 1349 |
-
expr = expr.rhs
|
| 1350 |
-
symbol = out_arg
|
| 1351 |
-
if isinstance(out_arg, Indexed):
|
| 1352 |
-
dims = tuple([ (S.One, dim) for dim in out_arg.shape])
|
| 1353 |
-
symbol = out_arg.base.label
|
| 1354 |
-
output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims))
|
| 1355 |
-
if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
|
| 1356 |
-
raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
|
| 1357 |
-
"can define output arguments.")
|
| 1358 |
-
|
| 1359 |
-
return_vals.append(Result(expr, name=symbol, result_var=out_arg))
|
| 1360 |
-
if not expr.has(symbol):
|
| 1361 |
-
# this is a pure output: remove from the symbols list, so
|
| 1362 |
-
# it doesn't become an input.
|
| 1363 |
-
symbols.remove(symbol)
|
| 1364 |
-
|
| 1365 |
-
else:
|
| 1366 |
-
# we have no name for this output
|
| 1367 |
-
return_vals.append(Result(expr, name='out%d' % (i+1)))
|
| 1368 |
-
|
| 1369 |
-
# setup input argument list
|
| 1370 |
-
output_args.sort(key=lambda x: str(x.name))
|
| 1371 |
-
arg_list = list(output_args)
|
| 1372 |
-
array_symbols = {}
|
| 1373 |
-
for array in expressions.atoms(Indexed):
|
| 1374 |
-
array_symbols[array.base.label] = array
|
| 1375 |
-
for array in expressions.atoms(MatrixSymbol):
|
| 1376 |
-
array_symbols[array] = array
|
| 1377 |
-
|
| 1378 |
-
for symbol in sorted(symbols, key=str):
|
| 1379 |
-
arg_list.append(InputArgument(symbol))
|
| 1380 |
-
|
| 1381 |
-
if argument_sequence is not None:
|
| 1382 |
-
# if the user has supplied IndexedBase instances, we'll accept that
|
| 1383 |
-
new_sequence = []
|
| 1384 |
-
for arg in argument_sequence:
|
| 1385 |
-
if isinstance(arg, IndexedBase):
|
| 1386 |
-
new_sequence.append(arg.label)
|
| 1387 |
-
else:
|
| 1388 |
-
new_sequence.append(arg)
|
| 1389 |
-
argument_sequence = new_sequence
|
| 1390 |
-
|
| 1391 |
-
missing = [x for x in arg_list if x.name not in argument_sequence]
|
| 1392 |
-
if missing:
|
| 1393 |
-
msg = "Argument list didn't specify: {0} "
|
| 1394 |
-
msg = msg.format(", ".join([str(m.name) for m in missing]))
|
| 1395 |
-
raise CodeGenArgumentListError(msg, missing)
|
| 1396 |
-
|
| 1397 |
-
# create redundant arguments to produce the requested sequence
|
| 1398 |
-
name_arg_dict = {x.name: x for x in arg_list}
|
| 1399 |
-
new_args = []
|
| 1400 |
-
for symbol in argument_sequence:
|
| 1401 |
-
try:
|
| 1402 |
-
new_args.append(name_arg_dict[symbol])
|
| 1403 |
-
except KeyError:
|
| 1404 |
-
new_args.append(InputArgument(symbol))
|
| 1405 |
-
arg_list = new_args
|
| 1406 |
-
|
| 1407 |
-
return Routine(name, arg_list, return_vals, local_vars, global_vars)
|
| 1408 |
-
|
| 1409 |
-
def _get_header(self):
|
| 1410 |
-
"""Writes a common header for the generated files."""
|
| 1411 |
-
code_lines = []
|
| 1412 |
-
tmp = header_comment % {"version": sympy_version,
|
| 1413 |
-
"project": self.project}
|
| 1414 |
-
for line in tmp.splitlines():
|
| 1415 |
-
if line == '':
|
| 1416 |
-
code_lines.append("#\n")
|
| 1417 |
-
else:
|
| 1418 |
-
code_lines.append("# %s\n" % line)
|
| 1419 |
-
return code_lines
|
| 1420 |
-
|
| 1421 |
-
def _preprocessor_statements(self, prefix):
|
| 1422 |
-
return []
|
| 1423 |
-
|
| 1424 |
-
def _get_routine_opening(self, routine):
|
| 1425 |
-
"""Returns the opening statements of the routine."""
|
| 1426 |
-
code_list = []
|
| 1427 |
-
code_list.append("function ")
|
| 1428 |
-
|
| 1429 |
-
# Inputs
|
| 1430 |
-
args = []
|
| 1431 |
-
for arg in routine.arguments:
|
| 1432 |
-
if isinstance(arg, OutputArgument):
|
| 1433 |
-
raise CodeGenError("Julia: invalid argument of type %s" %
|
| 1434 |
-
str(type(arg)))
|
| 1435 |
-
if isinstance(arg, (InputArgument, InOutArgument)):
|
| 1436 |
-
args.append("%s" % self._get_symbol(arg.name))
|
| 1437 |
-
args = ", ".join(args)
|
| 1438 |
-
code_list.append("%s(%s)\n" % (routine.name, args))
|
| 1439 |
-
code_list = [ "".join(code_list) ]
|
| 1440 |
-
|
| 1441 |
-
return code_list
|
| 1442 |
-
|
| 1443 |
-
def _declare_arguments(self, routine):
|
| 1444 |
-
return []
|
| 1445 |
-
|
| 1446 |
-
def _declare_globals(self, routine):
|
| 1447 |
-
return []
|
| 1448 |
-
|
| 1449 |
-
def _declare_locals(self, routine):
|
| 1450 |
-
return []
|
| 1451 |
-
|
| 1452 |
-
def _get_routine_ending(self, routine):
|
| 1453 |
-
outs = []
|
| 1454 |
-
for result in routine.results:
|
| 1455 |
-
if isinstance(result, Result):
|
| 1456 |
-
# Note: name not result_var; want `y` not `y[i]` for Indexed
|
| 1457 |
-
s = self._get_symbol(result.name)
|
| 1458 |
-
else:
|
| 1459 |
-
raise CodeGenError("unexpected object in Routine results")
|
| 1460 |
-
outs.append(s)
|
| 1461 |
-
return ["return " + ", ".join(outs) + "\nend\n"]
|
| 1462 |
-
|
| 1463 |
-
def _call_printer(self, routine):
|
| 1464 |
-
declarations = []
|
| 1465 |
-
code_lines = []
|
| 1466 |
-
for result in routine.results:
|
| 1467 |
-
if isinstance(result, Result):
|
| 1468 |
-
assign_to = result.result_var
|
| 1469 |
-
else:
|
| 1470 |
-
raise CodeGenError("unexpected object in Routine results")
|
| 1471 |
-
|
| 1472 |
-
constants, not_supported, jl_expr = self._printer_method_with_settings(
|
| 1473 |
-
'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to)
|
| 1474 |
-
|
| 1475 |
-
for obj, v in sorted(constants, key=str):
|
| 1476 |
-
declarations.append(
|
| 1477 |
-
"%s = %s\n" % (obj, v))
|
| 1478 |
-
for obj in sorted(not_supported, key=str):
|
| 1479 |
-
if isinstance(obj, Function):
|
| 1480 |
-
name = obj.func
|
| 1481 |
-
else:
|
| 1482 |
-
name = obj
|
| 1483 |
-
declarations.append(
|
| 1484 |
-
"# unsupported: %s\n" % (name))
|
| 1485 |
-
code_lines.append("%s\n" % (jl_expr))
|
| 1486 |
-
return declarations + code_lines
|
| 1487 |
-
|
| 1488 |
-
def _indent_code(self, codelines):
|
| 1489 |
-
# Note that indenting seems to happen twice, first
|
| 1490 |
-
# statement-by-statement by JuliaPrinter then again here.
|
| 1491 |
-
p = JuliaCodePrinter({'human': False, "strict": False})
|
| 1492 |
-
return p.indent_code(codelines)
|
| 1493 |
-
|
| 1494 |
-
def dump_jl(self, routines, f, prefix, header=True, empty=True):
|
| 1495 |
-
self.dump_code(routines, f, prefix, header, empty)
|
| 1496 |
-
|
| 1497 |
-
dump_jl.extension = code_extension # type: ignore
|
| 1498 |
-
dump_jl.__doc__ = CodeGen.dump_code.__doc__
|
| 1499 |
-
|
| 1500 |
-
# This list of dump functions is used by CodeGen.write to know which dump
|
| 1501 |
-
# functions it has to call.
|
| 1502 |
-
dump_fns = [dump_jl]
|
| 1503 |
-
|
| 1504 |
-
|
| 1505 |
-
class OctaveCodeGen(CodeGen):
|
| 1506 |
-
"""Generator for Octave code.
|
| 1507 |
-
|
| 1508 |
-
The .write() method inherited from CodeGen will output a code file
|
| 1509 |
-
<prefix>.m.
|
| 1510 |
-
|
| 1511 |
-
Octave .m files usually contain one function. That function name should
|
| 1512 |
-
match the filename (``prefix``). If you pass multiple ``name_expr`` pairs,
|
| 1513 |
-
the latter ones are presumed to be private functions accessed by the
|
| 1514 |
-
primary function.
|
| 1515 |
-
|
| 1516 |
-
You should only pass inputs to ``argument_sequence``: outputs are ordered
|
| 1517 |
-
according to their order in ``name_expr``.
|
| 1518 |
-
|
| 1519 |
-
"""
|
| 1520 |
-
|
| 1521 |
-
code_extension = "m"
|
| 1522 |
-
|
| 1523 |
-
def __init__(self, project='project', printer=None):
|
| 1524 |
-
super().__init__(project)
|
| 1525 |
-
self.printer = printer or OctaveCodePrinter()
|
| 1526 |
-
|
| 1527 |
-
def routine(self, name, expr, argument_sequence, global_vars):
|
| 1528 |
-
"""Specialized Routine creation for Octave."""
|
| 1529 |
-
|
| 1530 |
-
# FIXME: this is probably general enough for other high-level
|
| 1531 |
-
# languages, perhaps its the C/Fortran one that is specialized!
|
| 1532 |
-
|
| 1533 |
-
if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
|
| 1534 |
-
if not expr:
|
| 1535 |
-
raise ValueError("No expression given")
|
| 1536 |
-
expressions = Tuple(*expr)
|
| 1537 |
-
else:
|
| 1538 |
-
expressions = Tuple(expr)
|
| 1539 |
-
|
| 1540 |
-
# local variables
|
| 1541 |
-
local_vars = {i.label for i in expressions.atoms(Idx)}
|
| 1542 |
-
|
| 1543 |
-
# global variables
|
| 1544 |
-
global_vars = set() if global_vars is None else set(global_vars)
|
| 1545 |
-
|
| 1546 |
-
# symbols that should be arguments
|
| 1547 |
-
old_symbols = expressions.free_symbols - local_vars - global_vars
|
| 1548 |
-
symbols = set()
|
| 1549 |
-
for s in old_symbols:
|
| 1550 |
-
if isinstance(s, Idx):
|
| 1551 |
-
symbols.update(s.args[1].free_symbols)
|
| 1552 |
-
elif not isinstance(s, Indexed):
|
| 1553 |
-
symbols.add(s)
|
| 1554 |
-
|
| 1555 |
-
# Octave supports multiple return values
|
| 1556 |
-
return_vals = []
|
| 1557 |
-
for (i, expr) in enumerate(expressions):
|
| 1558 |
-
if isinstance(expr, Equality):
|
| 1559 |
-
out_arg = expr.lhs
|
| 1560 |
-
expr = expr.rhs
|
| 1561 |
-
symbol = out_arg
|
| 1562 |
-
if isinstance(out_arg, Indexed):
|
| 1563 |
-
symbol = out_arg.base.label
|
| 1564 |
-
if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
|
| 1565 |
-
raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
|
| 1566 |
-
"can define output arguments.")
|
| 1567 |
-
|
| 1568 |
-
return_vals.append(Result(expr, name=symbol, result_var=out_arg))
|
| 1569 |
-
if not expr.has(symbol):
|
| 1570 |
-
# this is a pure output: remove from the symbols list, so
|
| 1571 |
-
# it doesn't become an input.
|
| 1572 |
-
symbols.remove(symbol)
|
| 1573 |
-
|
| 1574 |
-
else:
|
| 1575 |
-
# we have no name for this output
|
| 1576 |
-
return_vals.append(Result(expr, name='out%d' % (i+1)))
|
| 1577 |
-
|
| 1578 |
-
# setup input argument list
|
| 1579 |
-
arg_list = []
|
| 1580 |
-
array_symbols = {}
|
| 1581 |
-
for array in expressions.atoms(Indexed):
|
| 1582 |
-
array_symbols[array.base.label] = array
|
| 1583 |
-
for array in expressions.atoms(MatrixSymbol):
|
| 1584 |
-
array_symbols[array] = array
|
| 1585 |
-
|
| 1586 |
-
for symbol in sorted(symbols, key=str):
|
| 1587 |
-
arg_list.append(InputArgument(symbol))
|
| 1588 |
-
|
| 1589 |
-
if argument_sequence is not None:
|
| 1590 |
-
# if the user has supplied IndexedBase instances, we'll accept that
|
| 1591 |
-
new_sequence = []
|
| 1592 |
-
for arg in argument_sequence:
|
| 1593 |
-
if isinstance(arg, IndexedBase):
|
| 1594 |
-
new_sequence.append(arg.label)
|
| 1595 |
-
else:
|
| 1596 |
-
new_sequence.append(arg)
|
| 1597 |
-
argument_sequence = new_sequence
|
| 1598 |
-
|
| 1599 |
-
missing = [x for x in arg_list if x.name not in argument_sequence]
|
| 1600 |
-
if missing:
|
| 1601 |
-
msg = "Argument list didn't specify: {0} "
|
| 1602 |
-
msg = msg.format(", ".join([str(m.name) for m in missing]))
|
| 1603 |
-
raise CodeGenArgumentListError(msg, missing)
|
| 1604 |
-
|
| 1605 |
-
# create redundant arguments to produce the requested sequence
|
| 1606 |
-
name_arg_dict = {x.name: x for x in arg_list}
|
| 1607 |
-
new_args = []
|
| 1608 |
-
for symbol in argument_sequence:
|
| 1609 |
-
try:
|
| 1610 |
-
new_args.append(name_arg_dict[symbol])
|
| 1611 |
-
except KeyError:
|
| 1612 |
-
new_args.append(InputArgument(symbol))
|
| 1613 |
-
arg_list = new_args
|
| 1614 |
-
|
| 1615 |
-
return Routine(name, arg_list, return_vals, local_vars, global_vars)
|
| 1616 |
-
|
| 1617 |
-
def _get_header(self):
|
| 1618 |
-
"""Writes a common header for the generated files."""
|
| 1619 |
-
code_lines = []
|
| 1620 |
-
tmp = header_comment % {"version": sympy_version,
|
| 1621 |
-
"project": self.project}
|
| 1622 |
-
for line in tmp.splitlines():
|
| 1623 |
-
if line == '':
|
| 1624 |
-
code_lines.append("%\n")
|
| 1625 |
-
else:
|
| 1626 |
-
code_lines.append("%% %s\n" % line)
|
| 1627 |
-
return code_lines
|
| 1628 |
-
|
| 1629 |
-
def _preprocessor_statements(self, prefix):
|
| 1630 |
-
return []
|
| 1631 |
-
|
| 1632 |
-
def _get_routine_opening(self, routine):
|
| 1633 |
-
"""Returns the opening statements of the routine."""
|
| 1634 |
-
code_list = []
|
| 1635 |
-
code_list.append("function ")
|
| 1636 |
-
|
| 1637 |
-
# Outputs
|
| 1638 |
-
outs = []
|
| 1639 |
-
for result in routine.results:
|
| 1640 |
-
if isinstance(result, Result):
|
| 1641 |
-
# Note: name not result_var; want `y` not `y(i)` for Indexed
|
| 1642 |
-
s = self._get_symbol(result.name)
|
| 1643 |
-
else:
|
| 1644 |
-
raise CodeGenError("unexpected object in Routine results")
|
| 1645 |
-
outs.append(s)
|
| 1646 |
-
if len(outs) > 1:
|
| 1647 |
-
code_list.append("[" + (", ".join(outs)) + "]")
|
| 1648 |
-
else:
|
| 1649 |
-
code_list.append("".join(outs))
|
| 1650 |
-
code_list.append(" = ")
|
| 1651 |
-
|
| 1652 |
-
# Inputs
|
| 1653 |
-
args = []
|
| 1654 |
-
for arg in routine.arguments:
|
| 1655 |
-
if isinstance(arg, (OutputArgument, InOutArgument)):
|
| 1656 |
-
raise CodeGenError("Octave: invalid argument of type %s" %
|
| 1657 |
-
str(type(arg)))
|
| 1658 |
-
if isinstance(arg, InputArgument):
|
| 1659 |
-
args.append("%s" % self._get_symbol(arg.name))
|
| 1660 |
-
args = ", ".join(args)
|
| 1661 |
-
code_list.append("%s(%s)\n" % (routine.name, args))
|
| 1662 |
-
code_list = [ "".join(code_list) ]
|
| 1663 |
-
|
| 1664 |
-
return code_list
|
| 1665 |
-
|
| 1666 |
-
def _declare_arguments(self, routine):
|
| 1667 |
-
return []
|
| 1668 |
-
|
| 1669 |
-
def _declare_globals(self, routine):
|
| 1670 |
-
if not routine.global_vars:
|
| 1671 |
-
return []
|
| 1672 |
-
s = " ".join(sorted([self._get_symbol(g) for g in routine.global_vars]))
|
| 1673 |
-
return ["global " + s + "\n"]
|
| 1674 |
-
|
| 1675 |
-
def _declare_locals(self, routine):
|
| 1676 |
-
return []
|
| 1677 |
-
|
| 1678 |
-
def _get_routine_ending(self, routine):
|
| 1679 |
-
return ["end\n"]
|
| 1680 |
-
|
| 1681 |
-
def _call_printer(self, routine):
|
| 1682 |
-
declarations = []
|
| 1683 |
-
code_lines = []
|
| 1684 |
-
for result in routine.results:
|
| 1685 |
-
if isinstance(result, Result):
|
| 1686 |
-
assign_to = result.result_var
|
| 1687 |
-
else:
|
| 1688 |
-
raise CodeGenError("unexpected object in Routine results")
|
| 1689 |
-
|
| 1690 |
-
constants, not_supported, oct_expr = self._printer_method_with_settings(
|
| 1691 |
-
'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to)
|
| 1692 |
-
|
| 1693 |
-
for obj, v in sorted(constants, key=str):
|
| 1694 |
-
declarations.append(
|
| 1695 |
-
" %s = %s; %% constant\n" % (obj, v))
|
| 1696 |
-
for obj in sorted(not_supported, key=str):
|
| 1697 |
-
if isinstance(obj, Function):
|
| 1698 |
-
name = obj.func
|
| 1699 |
-
else:
|
| 1700 |
-
name = obj
|
| 1701 |
-
declarations.append(
|
| 1702 |
-
" %% unsupported: %s\n" % (name))
|
| 1703 |
-
code_lines.append("%s\n" % (oct_expr))
|
| 1704 |
-
return declarations + code_lines
|
| 1705 |
-
|
| 1706 |
-
def _indent_code(self, codelines):
|
| 1707 |
-
return self._printer_method_with_settings(
|
| 1708 |
-
'indent_code', {"human": False, "strict": False}, codelines)
|
| 1709 |
-
|
| 1710 |
-
def dump_m(self, routines, f, prefix, header=True, empty=True, inline=True):
|
| 1711 |
-
# Note used to call self.dump_code() but we need more control for header
|
| 1712 |
-
|
| 1713 |
-
code_lines = self._preprocessor_statements(prefix)
|
| 1714 |
-
|
| 1715 |
-
for i, routine in enumerate(routines):
|
| 1716 |
-
if i > 0:
|
| 1717 |
-
if empty:
|
| 1718 |
-
code_lines.append("\n")
|
| 1719 |
-
code_lines.extend(self._get_routine_opening(routine))
|
| 1720 |
-
if i == 0:
|
| 1721 |
-
if routine.name != prefix:
|
| 1722 |
-
raise ValueError('Octave function name should match prefix')
|
| 1723 |
-
if header:
|
| 1724 |
-
code_lines.append("%" + prefix.upper() +
|
| 1725 |
-
" Autogenerated by SymPy\n")
|
| 1726 |
-
code_lines.append(''.join(self._get_header()))
|
| 1727 |
-
code_lines.extend(self._declare_arguments(routine))
|
| 1728 |
-
code_lines.extend(self._declare_globals(routine))
|
| 1729 |
-
code_lines.extend(self._declare_locals(routine))
|
| 1730 |
-
if empty:
|
| 1731 |
-
code_lines.append("\n")
|
| 1732 |
-
code_lines.extend(self._call_printer(routine))
|
| 1733 |
-
if empty:
|
| 1734 |
-
code_lines.append("\n")
|
| 1735 |
-
code_lines.extend(self._get_routine_ending(routine))
|
| 1736 |
-
|
| 1737 |
-
code_lines = self._indent_code(''.join(code_lines))
|
| 1738 |
-
|
| 1739 |
-
if code_lines:
|
| 1740 |
-
f.write(code_lines)
|
| 1741 |
-
|
| 1742 |
-
dump_m.extension = code_extension # type: ignore
|
| 1743 |
-
dump_m.__doc__ = CodeGen.dump_code.__doc__
|
| 1744 |
-
|
| 1745 |
-
# This list of dump functions is used by CodeGen.write to know which dump
|
| 1746 |
-
# functions it has to call.
|
| 1747 |
-
dump_fns = [dump_m]
|
| 1748 |
-
|
| 1749 |
-
class RustCodeGen(CodeGen):
|
| 1750 |
-
"""Generator for Rust code.
|
| 1751 |
-
|
| 1752 |
-
The .write() method inherited from CodeGen will output a code file
|
| 1753 |
-
<prefix>.rs
|
| 1754 |
-
|
| 1755 |
-
"""
|
| 1756 |
-
|
| 1757 |
-
code_extension = "rs"
|
| 1758 |
-
|
| 1759 |
-
def __init__(self, project="project", printer=None):
|
| 1760 |
-
super().__init__(project=project)
|
| 1761 |
-
self.printer = printer or RustCodePrinter()
|
| 1762 |
-
|
| 1763 |
-
def routine(self, name, expr, argument_sequence, global_vars):
|
| 1764 |
-
"""Specialized Routine creation for Rust."""
|
| 1765 |
-
|
| 1766 |
-
if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
|
| 1767 |
-
if not expr:
|
| 1768 |
-
raise ValueError("No expression given")
|
| 1769 |
-
expressions = Tuple(*expr)
|
| 1770 |
-
else:
|
| 1771 |
-
expressions = Tuple(expr)
|
| 1772 |
-
|
| 1773 |
-
# local variables
|
| 1774 |
-
local_vars = {i.label for i in expressions.atoms(Idx)}
|
| 1775 |
-
|
| 1776 |
-
# global variables
|
| 1777 |
-
global_vars = set() if global_vars is None else set(global_vars)
|
| 1778 |
-
|
| 1779 |
-
# symbols that should be arguments
|
| 1780 |
-
symbols = expressions.free_symbols - local_vars - global_vars - expressions.atoms(Indexed)
|
| 1781 |
-
|
| 1782 |
-
# Rust supports multiple return values
|
| 1783 |
-
return_vals = []
|
| 1784 |
-
output_args = []
|
| 1785 |
-
for (i, expr) in enumerate(expressions):
|
| 1786 |
-
if isinstance(expr, Equality):
|
| 1787 |
-
out_arg = expr.lhs
|
| 1788 |
-
expr = expr.rhs
|
| 1789 |
-
symbol = out_arg
|
| 1790 |
-
if isinstance(out_arg, Indexed):
|
| 1791 |
-
dims = tuple([ (S.One, dim) for dim in out_arg.shape])
|
| 1792 |
-
symbol = out_arg.base.label
|
| 1793 |
-
output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims))
|
| 1794 |
-
if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
|
| 1795 |
-
raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
|
| 1796 |
-
"can define output arguments.")
|
| 1797 |
-
|
| 1798 |
-
return_vals.append(Result(expr, name=symbol, result_var=out_arg))
|
| 1799 |
-
if not expr.has(symbol):
|
| 1800 |
-
# this is a pure output: remove from the symbols list, so
|
| 1801 |
-
# it doesn't become an input.
|
| 1802 |
-
symbols.remove(symbol)
|
| 1803 |
-
|
| 1804 |
-
else:
|
| 1805 |
-
# we have no name for this output
|
| 1806 |
-
return_vals.append(Result(expr, name='out%d' % (i+1)))
|
| 1807 |
-
|
| 1808 |
-
# setup input argument list
|
| 1809 |
-
output_args.sort(key=lambda x: str(x.name))
|
| 1810 |
-
arg_list = list(output_args)
|
| 1811 |
-
array_symbols = {}
|
| 1812 |
-
for array in expressions.atoms(Indexed):
|
| 1813 |
-
array_symbols[array.base.label] = array
|
| 1814 |
-
for array in expressions.atoms(MatrixSymbol):
|
| 1815 |
-
array_symbols[array] = array
|
| 1816 |
-
|
| 1817 |
-
for symbol in sorted(symbols, key=str):
|
| 1818 |
-
arg_list.append(InputArgument(symbol))
|
| 1819 |
-
|
| 1820 |
-
if argument_sequence is not None:
|
| 1821 |
-
# if the user has supplied IndexedBase instances, we'll accept that
|
| 1822 |
-
new_sequence = []
|
| 1823 |
-
for arg in argument_sequence:
|
| 1824 |
-
if isinstance(arg, IndexedBase):
|
| 1825 |
-
new_sequence.append(arg.label)
|
| 1826 |
-
else:
|
| 1827 |
-
new_sequence.append(arg)
|
| 1828 |
-
argument_sequence = new_sequence
|
| 1829 |
-
|
| 1830 |
-
missing = [x for x in arg_list if x.name not in argument_sequence]
|
| 1831 |
-
if missing:
|
| 1832 |
-
msg = "Argument list didn't specify: {0} "
|
| 1833 |
-
msg = msg.format(", ".join([str(m.name) for m in missing]))
|
| 1834 |
-
raise CodeGenArgumentListError(msg, missing)
|
| 1835 |
-
|
| 1836 |
-
# create redundant arguments to produce the requested sequence
|
| 1837 |
-
name_arg_dict = {x.name: x for x in arg_list}
|
| 1838 |
-
new_args = []
|
| 1839 |
-
for symbol in argument_sequence:
|
| 1840 |
-
try:
|
| 1841 |
-
new_args.append(name_arg_dict[symbol])
|
| 1842 |
-
except KeyError:
|
| 1843 |
-
new_args.append(InputArgument(symbol))
|
| 1844 |
-
arg_list = new_args
|
| 1845 |
-
|
| 1846 |
-
return Routine(name, arg_list, return_vals, local_vars, global_vars)
|
| 1847 |
-
|
| 1848 |
-
|
| 1849 |
-
def _get_header(self):
|
| 1850 |
-
"""Writes a common header for the generated files."""
|
| 1851 |
-
code_lines = []
|
| 1852 |
-
code_lines.append("/*\n")
|
| 1853 |
-
tmp = header_comment % {"version": sympy_version,
|
| 1854 |
-
"project": self.project}
|
| 1855 |
-
for line in tmp.splitlines():
|
| 1856 |
-
code_lines.append((" *%s" % line.center(76)).rstrip() + "\n")
|
| 1857 |
-
code_lines.append(" */\n")
|
| 1858 |
-
return code_lines
|
| 1859 |
-
|
| 1860 |
-
def get_prototype(self, routine):
|
| 1861 |
-
"""Returns a string for the function prototype of the routine.
|
| 1862 |
-
|
| 1863 |
-
If the routine has multiple result objects, an CodeGenError is
|
| 1864 |
-
raised.
|
| 1865 |
-
|
| 1866 |
-
See: https://en.wikipedia.org/wiki/Function_prototype
|
| 1867 |
-
|
| 1868 |
-
"""
|
| 1869 |
-
results = [i.get_datatype('Rust') for i in routine.results]
|
| 1870 |
-
|
| 1871 |
-
if len(results) == 1:
|
| 1872 |
-
rstype = " -> " + results[0]
|
| 1873 |
-
elif len(routine.results) > 1:
|
| 1874 |
-
rstype = " -> (" + ", ".join(results) + ")"
|
| 1875 |
-
else:
|
| 1876 |
-
rstype = ""
|
| 1877 |
-
|
| 1878 |
-
type_args = []
|
| 1879 |
-
for arg in routine.arguments:
|
| 1880 |
-
name = self.printer.doprint(arg.name)
|
| 1881 |
-
if arg.dimensions or isinstance(arg, ResultBase):
|
| 1882 |
-
type_args.append(("*%s" % name, arg.get_datatype('Rust')))
|
| 1883 |
-
else:
|
| 1884 |
-
type_args.append((name, arg.get_datatype('Rust')))
|
| 1885 |
-
arguments = ", ".join([ "%s: %s" % t for t in type_args])
|
| 1886 |
-
return "fn %s(%s)%s" % (routine.name, arguments, rstype)
|
| 1887 |
-
|
| 1888 |
-
def _preprocessor_statements(self, prefix):
|
| 1889 |
-
code_lines = []
|
| 1890 |
-
# code_lines.append("use std::f64::consts::*;\n")
|
| 1891 |
-
return code_lines
|
| 1892 |
-
|
| 1893 |
-
def _get_routine_opening(self, routine):
|
| 1894 |
-
prototype = self.get_prototype(routine)
|
| 1895 |
-
return ["%s {\n" % prototype]
|
| 1896 |
-
|
| 1897 |
-
def _declare_arguments(self, routine):
|
| 1898 |
-
# arguments are declared in prototype
|
| 1899 |
-
return []
|
| 1900 |
-
|
| 1901 |
-
def _declare_globals(self, routine):
|
| 1902 |
-
# global variables are not explicitly declared within C functions
|
| 1903 |
-
return []
|
| 1904 |
-
|
| 1905 |
-
def _declare_locals(self, routine):
|
| 1906 |
-
# loop variables are declared in loop statement
|
| 1907 |
-
return []
|
| 1908 |
-
|
| 1909 |
-
def _call_printer(self, routine):
|
| 1910 |
-
|
| 1911 |
-
code_lines = []
|
| 1912 |
-
declarations = []
|
| 1913 |
-
returns = []
|
| 1914 |
-
|
| 1915 |
-
# Compose a list of symbols to be dereferenced in the function
|
| 1916 |
-
# body. These are the arguments that were passed by a reference
|
| 1917 |
-
# pointer, excluding arrays.
|
| 1918 |
-
dereference = []
|
| 1919 |
-
for arg in routine.arguments:
|
| 1920 |
-
if isinstance(arg, ResultBase) and not arg.dimensions:
|
| 1921 |
-
dereference.append(arg.name)
|
| 1922 |
-
|
| 1923 |
-
for result in routine.results:
|
| 1924 |
-
if isinstance(result, Result):
|
| 1925 |
-
assign_to = result.result_var
|
| 1926 |
-
returns.append(str(result.result_var))
|
| 1927 |
-
else:
|
| 1928 |
-
raise CodeGenError("unexpected object in Routine results")
|
| 1929 |
-
|
| 1930 |
-
constants, not_supported, rs_expr = self._printer_method_with_settings(
|
| 1931 |
-
'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to)
|
| 1932 |
-
|
| 1933 |
-
for name, value in sorted(constants, key=str):
|
| 1934 |
-
declarations.append("const %s: f64 = %s;\n" % (name, value))
|
| 1935 |
-
|
| 1936 |
-
for obj in sorted(not_supported, key=str):
|
| 1937 |
-
if isinstance(obj, Function):
|
| 1938 |
-
name = obj.func
|
| 1939 |
-
else:
|
| 1940 |
-
name = obj
|
| 1941 |
-
declarations.append("// unsupported: %s\n" % (name))
|
| 1942 |
-
|
| 1943 |
-
code_lines.append("let %s\n" % rs_expr)
|
| 1944 |
-
|
| 1945 |
-
if len(returns) > 1:
|
| 1946 |
-
returns = ['(' + ', '.join(returns) + ')']
|
| 1947 |
-
|
| 1948 |
-
returns.append('\n')
|
| 1949 |
-
|
| 1950 |
-
return declarations + code_lines + returns
|
| 1951 |
-
|
| 1952 |
-
def _get_routine_ending(self, routine):
|
| 1953 |
-
return ["}\n"]
|
| 1954 |
-
|
| 1955 |
-
def dump_rs(self, routines, f, prefix, header=True, empty=True):
|
| 1956 |
-
self.dump_code(routines, f, prefix, header, empty)
|
| 1957 |
-
|
| 1958 |
-
dump_rs.extension = code_extension # type: ignore
|
| 1959 |
-
dump_rs.__doc__ = CodeGen.dump_code.__doc__
|
| 1960 |
-
|
| 1961 |
-
# This list of dump functions is used by CodeGen.write to know which dump
|
| 1962 |
-
# functions it has to call.
|
| 1963 |
-
dump_fns = [dump_rs]
|
| 1964 |
-
|
| 1965 |
-
|
| 1966 |
-
|
| 1967 |
-
|
| 1968 |
-
def get_code_generator(language, project=None, standard=None, printer = None):
|
| 1969 |
-
if language == 'C':
|
| 1970 |
-
if standard is None:
|
| 1971 |
-
pass
|
| 1972 |
-
elif standard.lower() == 'c89':
|
| 1973 |
-
language = 'C89'
|
| 1974 |
-
elif standard.lower() == 'c99':
|
| 1975 |
-
language = 'C99'
|
| 1976 |
-
CodeGenClass = {"C": CCodeGen, "C89": C89CodeGen, "C99": C99CodeGen,
|
| 1977 |
-
"F95": FCodeGen, "JULIA": JuliaCodeGen,
|
| 1978 |
-
"OCTAVE": OctaveCodeGen,
|
| 1979 |
-
"RUST": RustCodeGen}.get(language.upper())
|
| 1980 |
-
if CodeGenClass is None:
|
| 1981 |
-
raise ValueError("Language '%s' is not supported." % language)
|
| 1982 |
-
return CodeGenClass(project, printer)
|
| 1983 |
-
|
| 1984 |
-
|
| 1985 |
-
#
|
| 1986 |
-
# Friendly functions
|
| 1987 |
-
#
|
| 1988 |
-
|
| 1989 |
-
|
| 1990 |
-
def codegen(name_expr, language=None, prefix=None, project="project",
|
| 1991 |
-
to_files=False, header=True, empty=True, argument_sequence=None,
|
| 1992 |
-
global_vars=None, standard=None, code_gen=None, printer=None):
|
| 1993 |
-
"""Generate source code for expressions in a given language.
|
| 1994 |
-
|
| 1995 |
-
Parameters
|
| 1996 |
-
==========
|
| 1997 |
-
|
| 1998 |
-
name_expr : tuple, or list of tuples
|
| 1999 |
-
A single (name, expression) tuple or a list of (name, expression)
|
| 2000 |
-
tuples. Each tuple corresponds to a routine. If the expression is
|
| 2001 |
-
an equality (an instance of class Equality) the left hand side is
|
| 2002 |
-
considered an output argument. If expression is an iterable, then
|
| 2003 |
-
the routine will have multiple outputs.
|
| 2004 |
-
|
| 2005 |
-
language : string,
|
| 2006 |
-
A string that indicates the source code language. This is case
|
| 2007 |
-
insensitive. Currently, 'C', 'F95' and 'Octave' are supported.
|
| 2008 |
-
'Octave' generates code compatible with both Octave and Matlab.
|
| 2009 |
-
|
| 2010 |
-
prefix : string, optional
|
| 2011 |
-
A prefix for the names of the files that contain the source code.
|
| 2012 |
-
Language-dependent suffixes will be appended. If omitted, the name
|
| 2013 |
-
of the first name_expr tuple is used.
|
| 2014 |
-
|
| 2015 |
-
project : string, optional
|
| 2016 |
-
A project name, used for making unique preprocessor instructions.
|
| 2017 |
-
[default: "project"]
|
| 2018 |
-
|
| 2019 |
-
to_files : bool, optional
|
| 2020 |
-
When True, the code will be written to one or more files with the
|
| 2021 |
-
given prefix, otherwise strings with the names and contents of
|
| 2022 |
-
these files are returned. [default: False]
|
| 2023 |
-
|
| 2024 |
-
header : bool, optional
|
| 2025 |
-
When True, a header is written on top of each source file.
|
| 2026 |
-
[default: True]
|
| 2027 |
-
|
| 2028 |
-
empty : bool, optional
|
| 2029 |
-
When True, empty lines are used to structure the code.
|
| 2030 |
-
[default: True]
|
| 2031 |
-
|
| 2032 |
-
argument_sequence : iterable, optional
|
| 2033 |
-
Sequence of arguments for the routine in a preferred order. A
|
| 2034 |
-
CodeGenError is raised if required arguments are missing.
|
| 2035 |
-
Redundant arguments are used without warning. If omitted,
|
| 2036 |
-
arguments will be ordered alphabetically, but with all input
|
| 2037 |
-
arguments first, and then output or in-out arguments.
|
| 2038 |
-
|
| 2039 |
-
global_vars : iterable, optional
|
| 2040 |
-
Sequence of global variables used by the routine. Variables
|
| 2041 |
-
listed here will not show up as function arguments.
|
| 2042 |
-
|
| 2043 |
-
standard : string, optional
|
| 2044 |
-
|
| 2045 |
-
code_gen : CodeGen instance, optional
|
| 2046 |
-
An instance of a CodeGen subclass. Overrides ``language``.
|
| 2047 |
-
|
| 2048 |
-
printer : Printer instance, optional
|
| 2049 |
-
An instance of a Printer subclass.
|
| 2050 |
-
|
| 2051 |
-
Examples
|
| 2052 |
-
========
|
| 2053 |
-
|
| 2054 |
-
>>> from sympy.utilities.codegen import codegen
|
| 2055 |
-
>>> from sympy.abc import x, y, z
|
| 2056 |
-
>>> [(c_name, c_code), (h_name, c_header)] = codegen(
|
| 2057 |
-
... ("f", x+y*z), "C89", "test", header=False, empty=False)
|
| 2058 |
-
>>> print(c_name)
|
| 2059 |
-
test.c
|
| 2060 |
-
>>> print(c_code)
|
| 2061 |
-
#include "test.h"
|
| 2062 |
-
#include <math.h>
|
| 2063 |
-
double f(double x, double y, double z) {
|
| 2064 |
-
double f_result;
|
| 2065 |
-
f_result = x + y*z;
|
| 2066 |
-
return f_result;
|
| 2067 |
-
}
|
| 2068 |
-
<BLANKLINE>
|
| 2069 |
-
>>> print(h_name)
|
| 2070 |
-
test.h
|
| 2071 |
-
>>> print(c_header)
|
| 2072 |
-
#ifndef PROJECT__TEST__H
|
| 2073 |
-
#define PROJECT__TEST__H
|
| 2074 |
-
double f(double x, double y, double z);
|
| 2075 |
-
#endif
|
| 2076 |
-
<BLANKLINE>
|
| 2077 |
-
|
| 2078 |
-
Another example using Equality objects to give named outputs. Here the
|
| 2079 |
-
filename (prefix) is taken from the first (name, expr) pair.
|
| 2080 |
-
|
| 2081 |
-
>>> from sympy.abc import f, g
|
| 2082 |
-
>>> from sympy import Eq
|
| 2083 |
-
>>> [(c_name, c_code), (h_name, c_header)] = codegen(
|
| 2084 |
-
... [("myfcn", x + y), ("fcn2", [Eq(f, 2*x), Eq(g, y)])],
|
| 2085 |
-
... "C99", header=False, empty=False)
|
| 2086 |
-
>>> print(c_name)
|
| 2087 |
-
myfcn.c
|
| 2088 |
-
>>> print(c_code)
|
| 2089 |
-
#include "myfcn.h"
|
| 2090 |
-
#include <math.h>
|
| 2091 |
-
double myfcn(double x, double y) {
|
| 2092 |
-
double myfcn_result;
|
| 2093 |
-
myfcn_result = x + y;
|
| 2094 |
-
return myfcn_result;
|
| 2095 |
-
}
|
| 2096 |
-
void fcn2(double x, double y, double *f, double *g) {
|
| 2097 |
-
(*f) = 2*x;
|
| 2098 |
-
(*g) = y;
|
| 2099 |
-
}
|
| 2100 |
-
<BLANKLINE>
|
| 2101 |
-
|
| 2102 |
-
If the generated function(s) will be part of a larger project where various
|
| 2103 |
-
global variables have been defined, the 'global_vars' option can be used
|
| 2104 |
-
to remove the specified variables from the function signature
|
| 2105 |
-
|
| 2106 |
-
>>> from sympy.utilities.codegen import codegen
|
| 2107 |
-
>>> from sympy.abc import x, y, z
|
| 2108 |
-
>>> [(f_name, f_code), header] = codegen(
|
| 2109 |
-
... ("f", x+y*z), "F95", header=False, empty=False,
|
| 2110 |
-
... argument_sequence=(x, y), global_vars=(z,))
|
| 2111 |
-
>>> print(f_code)
|
| 2112 |
-
REAL*8 function f(x, y)
|
| 2113 |
-
implicit none
|
| 2114 |
-
REAL*8, intent(in) :: x
|
| 2115 |
-
REAL*8, intent(in) :: y
|
| 2116 |
-
f = x + y*z
|
| 2117 |
-
end function
|
| 2118 |
-
<BLANKLINE>
|
| 2119 |
-
|
| 2120 |
-
"""
|
| 2121 |
-
|
| 2122 |
-
# Initialize the code generator.
|
| 2123 |
-
if language is None:
|
| 2124 |
-
if code_gen is None:
|
| 2125 |
-
raise ValueError("Need either language or code_gen")
|
| 2126 |
-
else:
|
| 2127 |
-
if code_gen is not None:
|
| 2128 |
-
raise ValueError("You cannot specify both language and code_gen.")
|
| 2129 |
-
code_gen = get_code_generator(language, project, standard, printer)
|
| 2130 |
-
|
| 2131 |
-
if isinstance(name_expr[0], str):
|
| 2132 |
-
# single tuple is given, turn it into a singleton list with a tuple.
|
| 2133 |
-
name_expr = [name_expr]
|
| 2134 |
-
|
| 2135 |
-
if prefix is None:
|
| 2136 |
-
prefix = name_expr[0][0]
|
| 2137 |
-
|
| 2138 |
-
# Construct Routines appropriate for this code_gen from (name, expr) pairs.
|
| 2139 |
-
routines = []
|
| 2140 |
-
for name, expr in name_expr:
|
| 2141 |
-
routines.append(code_gen.routine(name, expr, argument_sequence,
|
| 2142 |
-
global_vars))
|
| 2143 |
-
|
| 2144 |
-
# Write the code.
|
| 2145 |
-
return code_gen.write(routines, prefix, to_files, header, empty)
|
| 2146 |
-
|
| 2147 |
-
|
| 2148 |
-
def make_routine(name, expr, argument_sequence=None,
|
| 2149 |
-
global_vars=None, language="F95"):
|
| 2150 |
-
"""A factory that makes an appropriate Routine from an expression.
|
| 2151 |
-
|
| 2152 |
-
Parameters
|
| 2153 |
-
==========
|
| 2154 |
-
|
| 2155 |
-
name : string
|
| 2156 |
-
The name of this routine in the generated code.
|
| 2157 |
-
|
| 2158 |
-
expr : expression or list/tuple of expressions
|
| 2159 |
-
A SymPy expression that the Routine instance will represent. If
|
| 2160 |
-
given a list or tuple of expressions, the routine will be
|
| 2161 |
-
considered to have multiple return values and/or output arguments.
|
| 2162 |
-
|
| 2163 |
-
argument_sequence : list or tuple, optional
|
| 2164 |
-
List arguments for the routine in a preferred order. If omitted,
|
| 2165 |
-
the results are language dependent, for example, alphabetical order
|
| 2166 |
-
or in the same order as the given expressions.
|
| 2167 |
-
|
| 2168 |
-
global_vars : iterable, optional
|
| 2169 |
-
Sequence of global variables used by the routine. Variables
|
| 2170 |
-
listed here will not show up as function arguments.
|
| 2171 |
-
|
| 2172 |
-
language : string, optional
|
| 2173 |
-
Specify a target language. The Routine itself should be
|
| 2174 |
-
language-agnostic but the precise way one is created, error
|
| 2175 |
-
checking, etc depend on the language. [default: "F95"].
|
| 2176 |
-
|
| 2177 |
-
Notes
|
| 2178 |
-
=====
|
| 2179 |
-
|
| 2180 |
-
A decision about whether to use output arguments or return values is made
|
| 2181 |
-
depending on both the language and the particular mathematical expressions.
|
| 2182 |
-
For an expression of type Equality, the left hand side is typically made
|
| 2183 |
-
into an OutputArgument (or perhaps an InOutArgument if appropriate).
|
| 2184 |
-
Otherwise, typically, the calculated expression is made a return values of
|
| 2185 |
-
the routine.
|
| 2186 |
-
|
| 2187 |
-
Examples
|
| 2188 |
-
========
|
| 2189 |
-
|
| 2190 |
-
>>> from sympy.utilities.codegen import make_routine
|
| 2191 |
-
>>> from sympy.abc import x, y, f, g
|
| 2192 |
-
>>> from sympy import Eq
|
| 2193 |
-
>>> r = make_routine('test', [Eq(f, 2*x), Eq(g, x + y)])
|
| 2194 |
-
>>> [arg.result_var for arg in r.results]
|
| 2195 |
-
[]
|
| 2196 |
-
>>> [arg.name for arg in r.arguments]
|
| 2197 |
-
[x, y, f, g]
|
| 2198 |
-
>>> [arg.name for arg in r.result_variables]
|
| 2199 |
-
[f, g]
|
| 2200 |
-
>>> r.local_vars
|
| 2201 |
-
set()
|
| 2202 |
-
|
| 2203 |
-
Another more complicated example with a mixture of specified and
|
| 2204 |
-
automatically-assigned names. Also has Matrix output.
|
| 2205 |
-
|
| 2206 |
-
>>> from sympy import Matrix
|
| 2207 |
-
>>> r = make_routine('fcn', [x*y, Eq(f, 1), Eq(g, x + g), Matrix([[x, 2]])])
|
| 2208 |
-
>>> [arg.result_var for arg in r.results] # doctest: +SKIP
|
| 2209 |
-
[result_5397460570204848505]
|
| 2210 |
-
>>> [arg.expr for arg in r.results]
|
| 2211 |
-
[x*y]
|
| 2212 |
-
>>> [arg.name for arg in r.arguments] # doctest: +SKIP
|
| 2213 |
-
[x, y, f, g, out_8598435338387848786]
|
| 2214 |
-
|
| 2215 |
-
We can examine the various arguments more closely:
|
| 2216 |
-
|
| 2217 |
-
>>> from sympy.utilities.codegen import (InputArgument, OutputArgument,
|
| 2218 |
-
... InOutArgument)
|
| 2219 |
-
>>> [a.name for a in r.arguments if isinstance(a, InputArgument)]
|
| 2220 |
-
[x, y]
|
| 2221 |
-
|
| 2222 |
-
>>> [a.name for a in r.arguments if isinstance(a, OutputArgument)] # doctest: +SKIP
|
| 2223 |
-
[f, out_8598435338387848786]
|
| 2224 |
-
>>> [a.expr for a in r.arguments if isinstance(a, OutputArgument)]
|
| 2225 |
-
[1, Matrix([[x, 2]])]
|
| 2226 |
-
|
| 2227 |
-
>>> [a.name for a in r.arguments if isinstance(a, InOutArgument)]
|
| 2228 |
-
[g]
|
| 2229 |
-
>>> [a.expr for a in r.arguments if isinstance(a, InOutArgument)]
|
| 2230 |
-
[g + x]
|
| 2231 |
-
|
| 2232 |
-
"""
|
| 2233 |
-
|
| 2234 |
-
# initialize a new code generator
|
| 2235 |
-
code_gen = get_code_generator(language)
|
| 2236 |
-
|
| 2237 |
-
return code_gen.routine(name, expr, argument_sequence, global_vars)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/decorator.py
DELETED
|
@@ -1,339 +0,0 @@
|
|
| 1 |
-
"""Useful utility decorators. """
|
| 2 |
-
|
| 3 |
-
from typing import TypeVar
|
| 4 |
-
import sys
|
| 5 |
-
import types
|
| 6 |
-
import inspect
|
| 7 |
-
from functools import wraps, update_wrapper
|
| 8 |
-
|
| 9 |
-
from sympy.utilities.exceptions import sympy_deprecation_warning
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
T = TypeVar('T')
|
| 13 |
-
"""A generic type"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def threaded_factory(func, use_add):
|
| 17 |
-
"""A factory for ``threaded`` decorators. """
|
| 18 |
-
from sympy.core import sympify
|
| 19 |
-
from sympy.matrices import MatrixBase
|
| 20 |
-
from sympy.utilities.iterables import iterable
|
| 21 |
-
|
| 22 |
-
@wraps(func)
|
| 23 |
-
def threaded_func(expr, *args, **kwargs):
|
| 24 |
-
if isinstance(expr, MatrixBase):
|
| 25 |
-
return expr.applyfunc(lambda f: func(f, *args, **kwargs))
|
| 26 |
-
elif iterable(expr):
|
| 27 |
-
try:
|
| 28 |
-
return expr.__class__([func(f, *args, **kwargs) for f in expr])
|
| 29 |
-
except TypeError:
|
| 30 |
-
return expr
|
| 31 |
-
else:
|
| 32 |
-
expr = sympify(expr)
|
| 33 |
-
|
| 34 |
-
if use_add and expr.is_Add:
|
| 35 |
-
return expr.__class__(*[ func(f, *args, **kwargs) for f in expr.args ])
|
| 36 |
-
elif expr.is_Relational:
|
| 37 |
-
return expr.__class__(func(expr.lhs, *args, **kwargs),
|
| 38 |
-
func(expr.rhs, *args, **kwargs))
|
| 39 |
-
else:
|
| 40 |
-
return func(expr, *args, **kwargs)
|
| 41 |
-
|
| 42 |
-
return threaded_func
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def threaded(func):
|
| 46 |
-
"""Apply ``func`` to sub--elements of an object, including :class:`~.Add`.
|
| 47 |
-
|
| 48 |
-
This decorator is intended to make it uniformly possible to apply a
|
| 49 |
-
function to all elements of composite objects, e.g. matrices, lists, tuples
|
| 50 |
-
and other iterable containers, or just expressions.
|
| 51 |
-
|
| 52 |
-
This version of :func:`threaded` decorator allows threading over
|
| 53 |
-
elements of :class:`~.Add` class. If this behavior is not desirable
|
| 54 |
-
use :func:`xthreaded` decorator.
|
| 55 |
-
|
| 56 |
-
Functions using this decorator must have the following signature::
|
| 57 |
-
|
| 58 |
-
@threaded
|
| 59 |
-
def function(expr, *args, **kwargs):
|
| 60 |
-
|
| 61 |
-
"""
|
| 62 |
-
return threaded_factory(func, True)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def xthreaded(func):
|
| 66 |
-
"""Apply ``func`` to sub--elements of an object, excluding :class:`~.Add`.
|
| 67 |
-
|
| 68 |
-
This decorator is intended to make it uniformly possible to apply a
|
| 69 |
-
function to all elements of composite objects, e.g. matrices, lists, tuples
|
| 70 |
-
and other iterable containers, or just expressions.
|
| 71 |
-
|
| 72 |
-
This version of :func:`threaded` decorator disallows threading over
|
| 73 |
-
elements of :class:`~.Add` class. If this behavior is not desirable
|
| 74 |
-
use :func:`threaded` decorator.
|
| 75 |
-
|
| 76 |
-
Functions using this decorator must have the following signature::
|
| 77 |
-
|
| 78 |
-
@xthreaded
|
| 79 |
-
def function(expr, *args, **kwargs):
|
| 80 |
-
|
| 81 |
-
"""
|
| 82 |
-
return threaded_factory(func, False)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def conserve_mpmath_dps(func):
|
| 86 |
-
"""After the function finishes, resets the value of ``mpmath.mp.dps`` to
|
| 87 |
-
the value it had before the function was run."""
|
| 88 |
-
import mpmath
|
| 89 |
-
|
| 90 |
-
def func_wrapper(*args, **kwargs):
|
| 91 |
-
dps = mpmath.mp.dps
|
| 92 |
-
try:
|
| 93 |
-
return func(*args, **kwargs)
|
| 94 |
-
finally:
|
| 95 |
-
mpmath.mp.dps = dps
|
| 96 |
-
|
| 97 |
-
func_wrapper = update_wrapper(func_wrapper, func)
|
| 98 |
-
return func_wrapper
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
class no_attrs_in_subclass:
|
| 102 |
-
"""Don't 'inherit' certain attributes from a base class
|
| 103 |
-
|
| 104 |
-
>>> from sympy.utilities.decorator import no_attrs_in_subclass
|
| 105 |
-
|
| 106 |
-
>>> class A(object):
|
| 107 |
-
... x = 'test'
|
| 108 |
-
|
| 109 |
-
>>> A.x = no_attrs_in_subclass(A, A.x)
|
| 110 |
-
|
| 111 |
-
>>> class B(A):
|
| 112 |
-
... pass
|
| 113 |
-
|
| 114 |
-
>>> hasattr(A, 'x')
|
| 115 |
-
True
|
| 116 |
-
>>> hasattr(B, 'x')
|
| 117 |
-
False
|
| 118 |
-
|
| 119 |
-
"""
|
| 120 |
-
def __init__(self, cls, f):
|
| 121 |
-
self.cls = cls
|
| 122 |
-
self.f = f
|
| 123 |
-
|
| 124 |
-
def __get__(self, instance, owner=None):
|
| 125 |
-
if owner == self.cls:
|
| 126 |
-
if hasattr(self.f, '__get__'):
|
| 127 |
-
return self.f.__get__(instance, owner)
|
| 128 |
-
return self.f
|
| 129 |
-
raise AttributeError
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def doctest_depends_on(exe=None, modules=None, disable_viewers=None,
|
| 133 |
-
python_version=None, ground_types=None):
|
| 134 |
-
"""
|
| 135 |
-
Adds metadata about the dependencies which need to be met for doctesting
|
| 136 |
-
the docstrings of the decorated objects.
|
| 137 |
-
|
| 138 |
-
``exe`` should be a list of executables
|
| 139 |
-
|
| 140 |
-
``modules`` should be a list of modules
|
| 141 |
-
|
| 142 |
-
``disable_viewers`` should be a list of viewers for :func:`~sympy.printing.preview.preview` to disable
|
| 143 |
-
|
| 144 |
-
``python_version`` should be the minimum Python version required, as a tuple
|
| 145 |
-
(like ``(3, 0)``)
|
| 146 |
-
"""
|
| 147 |
-
dependencies = {}
|
| 148 |
-
if exe is not None:
|
| 149 |
-
dependencies['executables'] = exe
|
| 150 |
-
if modules is not None:
|
| 151 |
-
dependencies['modules'] = modules
|
| 152 |
-
if disable_viewers is not None:
|
| 153 |
-
dependencies['disable_viewers'] = disable_viewers
|
| 154 |
-
if python_version is not None:
|
| 155 |
-
dependencies['python_version'] = python_version
|
| 156 |
-
if ground_types is not None:
|
| 157 |
-
dependencies['ground_types'] = ground_types
|
| 158 |
-
|
| 159 |
-
def skiptests():
|
| 160 |
-
from sympy.testing.runtests import DependencyError, SymPyDocTests, PyTestReporter # lazy import
|
| 161 |
-
r = PyTestReporter()
|
| 162 |
-
t = SymPyDocTests(r, None)
|
| 163 |
-
try:
|
| 164 |
-
t._check_dependencies(**dependencies)
|
| 165 |
-
except DependencyError:
|
| 166 |
-
return True # Skip doctests
|
| 167 |
-
else:
|
| 168 |
-
return False # Run doctests
|
| 169 |
-
|
| 170 |
-
def depends_on_deco(fn):
|
| 171 |
-
fn._doctest_depends_on = dependencies
|
| 172 |
-
fn.__doctest_skip__ = skiptests
|
| 173 |
-
|
| 174 |
-
if inspect.isclass(fn):
|
| 175 |
-
fn._doctest_depdends_on = no_attrs_in_subclass(
|
| 176 |
-
fn, fn._doctest_depends_on)
|
| 177 |
-
fn.__doctest_skip__ = no_attrs_in_subclass(
|
| 178 |
-
fn, fn.__doctest_skip__)
|
| 179 |
-
return fn
|
| 180 |
-
|
| 181 |
-
return depends_on_deco
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def public(obj: T) -> T:
|
| 185 |
-
"""
|
| 186 |
-
Append ``obj``'s name to global ``__all__`` variable (call site).
|
| 187 |
-
|
| 188 |
-
By using this decorator on functions or classes you achieve the same goal
|
| 189 |
-
as by filling ``__all__`` variables manually, you just do not have to repeat
|
| 190 |
-
yourself (object's name). You also know if object is public at definition
|
| 191 |
-
site, not at some random location (where ``__all__`` was set).
|
| 192 |
-
|
| 193 |
-
Note that in multiple decorator setup (in almost all cases) ``@public``
|
| 194 |
-
decorator must be applied before any other decorators, because it relies
|
| 195 |
-
on the pointer to object's global namespace. If you apply other decorators
|
| 196 |
-
first, ``@public`` may end up modifying the wrong namespace.
|
| 197 |
-
|
| 198 |
-
Examples
|
| 199 |
-
========
|
| 200 |
-
|
| 201 |
-
>>> from sympy.utilities.decorator import public
|
| 202 |
-
|
| 203 |
-
>>> __all__ # noqa: F821
|
| 204 |
-
Traceback (most recent call last):
|
| 205 |
-
...
|
| 206 |
-
NameError: name '__all__' is not defined
|
| 207 |
-
|
| 208 |
-
>>> @public
|
| 209 |
-
... def some_function():
|
| 210 |
-
... pass
|
| 211 |
-
|
| 212 |
-
>>> __all__ # noqa: F821
|
| 213 |
-
['some_function']
|
| 214 |
-
|
| 215 |
-
"""
|
| 216 |
-
if isinstance(obj, types.FunctionType):
|
| 217 |
-
ns = obj.__globals__
|
| 218 |
-
name = obj.__name__
|
| 219 |
-
elif isinstance(obj, (type(type), type)):
|
| 220 |
-
ns = sys.modules[obj.__module__].__dict__
|
| 221 |
-
name = obj.__name__
|
| 222 |
-
else:
|
| 223 |
-
raise TypeError("expected a function or a class, got %s" % obj)
|
| 224 |
-
|
| 225 |
-
if "__all__" not in ns:
|
| 226 |
-
ns["__all__"] = [name]
|
| 227 |
-
else:
|
| 228 |
-
ns["__all__"].append(name)
|
| 229 |
-
|
| 230 |
-
return obj
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
def memoize_property(propfunc):
|
| 234 |
-
"""Property decorator that caches the value of potentially expensive
|
| 235 |
-
``propfunc`` after the first evaluation. The cached value is stored in
|
| 236 |
-
the corresponding property name with an attached underscore."""
|
| 237 |
-
attrname = '_' + propfunc.__name__
|
| 238 |
-
sentinel = object()
|
| 239 |
-
|
| 240 |
-
@wraps(propfunc)
|
| 241 |
-
def accessor(self):
|
| 242 |
-
val = getattr(self, attrname, sentinel)
|
| 243 |
-
if val is sentinel:
|
| 244 |
-
val = propfunc(self)
|
| 245 |
-
setattr(self, attrname, val)
|
| 246 |
-
return val
|
| 247 |
-
|
| 248 |
-
return property(accessor)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
def deprecated(message, *, deprecated_since_version,
|
| 252 |
-
active_deprecations_target, stacklevel=3):
|
| 253 |
-
'''
|
| 254 |
-
Mark a function as deprecated.
|
| 255 |
-
|
| 256 |
-
This decorator should be used if an entire function or class is
|
| 257 |
-
deprecated. If only a certain functionality is deprecated, you should use
|
| 258 |
-
:func:`~.warns_deprecated_sympy` directly. This decorator is just a
|
| 259 |
-
convenience. There is no functional difference between using this
|
| 260 |
-
decorator and calling ``warns_deprecated_sympy()`` at the top of the
|
| 261 |
-
function.
|
| 262 |
-
|
| 263 |
-
The decorator takes the same arguments as
|
| 264 |
-
:func:`~.warns_deprecated_sympy`. See its
|
| 265 |
-
documentation for details on what the keywords to this decorator do.
|
| 266 |
-
|
| 267 |
-
See the :ref:`deprecation-policy` document for details on when and how
|
| 268 |
-
things should be deprecated in SymPy.
|
| 269 |
-
|
| 270 |
-
Examples
|
| 271 |
-
========
|
| 272 |
-
|
| 273 |
-
>>> from sympy.utilities.decorator import deprecated
|
| 274 |
-
>>> from sympy import simplify
|
| 275 |
-
>>> @deprecated("""\
|
| 276 |
-
... The simplify_this(expr) function is deprecated. Use simplify(expr)
|
| 277 |
-
... instead.""", deprecated_since_version="1.1",
|
| 278 |
-
... active_deprecations_target='simplify-this-deprecation')
|
| 279 |
-
... def simplify_this(expr):
|
| 280 |
-
... """
|
| 281 |
-
... Simplify ``expr``.
|
| 282 |
-
...
|
| 283 |
-
... .. deprecated:: 1.1
|
| 284 |
-
...
|
| 285 |
-
... The ``simplify_this`` function is deprecated. Use :func:`simplify`
|
| 286 |
-
... instead. See its documentation for more information. See
|
| 287 |
-
... :ref:`simplify-this-deprecation` for details.
|
| 288 |
-
...
|
| 289 |
-
... """
|
| 290 |
-
... return simplify(expr)
|
| 291 |
-
>>> from sympy.abc import x
|
| 292 |
-
>>> simplify_this(x*(x + 1) - x**2) # doctest: +SKIP
|
| 293 |
-
<stdin>:1: SymPyDeprecationWarning:
|
| 294 |
-
<BLANKLINE>
|
| 295 |
-
The simplify_this(expr) function is deprecated. Use simplify(expr)
|
| 296 |
-
instead.
|
| 297 |
-
<BLANKLINE>
|
| 298 |
-
See https://docs.sympy.org/latest/explanation/active-deprecations.html#simplify-this-deprecation
|
| 299 |
-
for details.
|
| 300 |
-
<BLANKLINE>
|
| 301 |
-
This has been deprecated since SymPy version 1.1. It
|
| 302 |
-
will be removed in a future version of SymPy.
|
| 303 |
-
<BLANKLINE>
|
| 304 |
-
simplify_this(x)
|
| 305 |
-
x
|
| 306 |
-
|
| 307 |
-
See Also
|
| 308 |
-
========
|
| 309 |
-
sympy.utilities.exceptions.SymPyDeprecationWarning
|
| 310 |
-
sympy.utilities.exceptions.sympy_deprecation_warning
|
| 311 |
-
sympy.utilities.exceptions.ignore_warnings
|
| 312 |
-
sympy.testing.pytest.warns_deprecated_sympy
|
| 313 |
-
|
| 314 |
-
'''
|
| 315 |
-
decorator_kwargs = {"deprecated_since_version": deprecated_since_version,
|
| 316 |
-
"active_deprecations_target": active_deprecations_target}
|
| 317 |
-
def deprecated_decorator(wrapped):
|
| 318 |
-
if hasattr(wrapped, '__mro__'): # wrapped is actually a class
|
| 319 |
-
class wrapper(wrapped):
|
| 320 |
-
__doc__ = wrapped.__doc__
|
| 321 |
-
__module__ = wrapped.__module__
|
| 322 |
-
_sympy_deprecated_func = wrapped
|
| 323 |
-
if '__new__' in wrapped.__dict__:
|
| 324 |
-
def __new__(cls, *args, **kwargs):
|
| 325 |
-
sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel)
|
| 326 |
-
return super().__new__(cls, *args, **kwargs)
|
| 327 |
-
else:
|
| 328 |
-
def __init__(self, *args, **kwargs):
|
| 329 |
-
sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel)
|
| 330 |
-
super().__init__(*args, **kwargs)
|
| 331 |
-
wrapper.__name__ = wrapped.__name__
|
| 332 |
-
else:
|
| 333 |
-
@wraps(wrapped)
|
| 334 |
-
def wrapper(*args, **kwargs):
|
| 335 |
-
sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel)
|
| 336 |
-
return wrapped(*args, **kwargs)
|
| 337 |
-
wrapper._sympy_deprecated_func = wrapped
|
| 338 |
-
return wrapper
|
| 339 |
-
return deprecated_decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/enumerative.py
DELETED
|
@@ -1,1155 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Algorithms and classes to support enumerative combinatorics.
|
| 3 |
-
|
| 4 |
-
Currently just multiset partitions, but more could be added.
|
| 5 |
-
|
| 6 |
-
Terminology (following Knuth, algorithm 7.1.2.5M TAOCP)
|
| 7 |
-
*multiset* aaabbcccc has a *partition* aaabc | bccc
|
| 8 |
-
|
| 9 |
-
The submultisets, aaabc and bccc of the partition are called
|
| 10 |
-
*parts*, or sometimes *vectors*. (Knuth notes that multiset
|
| 11 |
-
partitions can be thought of as partitions of vectors of integers,
|
| 12 |
-
where the ith element of the vector gives the multiplicity of
|
| 13 |
-
element i.)
|
| 14 |
-
|
| 15 |
-
The values a, b and c are *components* of the multiset. These
|
| 16 |
-
correspond to elements of a set, but in a multiset can be present
|
| 17 |
-
with a multiplicity greater than 1.
|
| 18 |
-
|
| 19 |
-
The algorithm deserves some explanation.
|
| 20 |
-
|
| 21 |
-
Think of the part aaabc from the multiset above. If we impose an
|
| 22 |
-
ordering on the components of the multiset, we can represent a part
|
| 23 |
-
with a vector, in which the value of the first element of the vector
|
| 24 |
-
corresponds to the multiplicity of the first component in that
|
| 25 |
-
part. Thus, aaabc can be represented by the vector [3, 1, 1]. We
|
| 26 |
-
can also define an ordering on parts, based on the lexicographic
|
| 27 |
-
ordering of the vector (leftmost vector element, i.e., the element
|
| 28 |
-
with the smallest component number, is the most significant), so
|
| 29 |
-
that [3, 1, 1] > [3, 1, 0] and [3, 1, 1] > [2, 1, 4]. The ordering
|
| 30 |
-
on parts can be extended to an ordering on partitions: First, sort
|
| 31 |
-
the parts in each partition, left-to-right in decreasing order. Then
|
| 32 |
-
partition A is greater than partition B if A's leftmost/greatest
|
| 33 |
-
part is greater than B's leftmost part. If the leftmost parts are
|
| 34 |
-
equal, compare the second parts, and so on.
|
| 35 |
-
|
| 36 |
-
In this ordering, the greatest partition of a given multiset has only
|
| 37 |
-
one part. The least partition is the one in which the components
|
| 38 |
-
are spread out, one per part.
|
| 39 |
-
|
| 40 |
-
The enumeration algorithms in this file yield the partitions of the
|
| 41 |
-
argument multiset in decreasing order. The main data structure is a
|
| 42 |
-
stack of parts, corresponding to the current partition. An
|
| 43 |
-
important invariant is that the parts on the stack are themselves in
|
| 44 |
-
decreasing order. This data structure is decremented to find the
|
| 45 |
-
next smaller partition. Most often, decrementing the partition will
|
| 46 |
-
only involve adjustments to the smallest parts at the top of the
|
| 47 |
-
stack, much as adjacent integers *usually* differ only in their last
|
| 48 |
-
few digits.
|
| 49 |
-
|
| 50 |
-
Knuth's algorithm uses two main operations on parts:
|
| 51 |
-
|
| 52 |
-
Decrement - change the part so that it is smaller in the
|
| 53 |
-
(vector) lexicographic order, but reduced by the smallest amount possible.
|
| 54 |
-
For example, if the multiset has vector [5,
|
| 55 |
-
3, 1], and the bottom/greatest part is [4, 2, 1], this part would
|
| 56 |
-
decrement to [4, 2, 0], while [4, 0, 0] would decrement to [3, 3,
|
| 57 |
-
1]. A singleton part is never decremented -- [1, 0, 0] is not
|
| 58 |
-
decremented to [0, 3, 1]. Instead, the decrement operator needs
|
| 59 |
-
to fail for this case. In Knuth's pseudocode, the decrement
|
| 60 |
-
operator is step m5.
|
| 61 |
-
|
| 62 |
-
Spread unallocated multiplicity - Once a part has been decremented,
|
| 63 |
-
it cannot be the rightmost part in the partition. There is some
|
| 64 |
-
multiplicity that has not been allocated, and new parts must be
|
| 65 |
-
created above it in the stack to use up this multiplicity. To
|
| 66 |
-
maintain the invariant that the parts on the stack are in
|
| 67 |
-
decreasing order, these new parts must be less than or equal to
|
| 68 |
-
the decremented part.
|
| 69 |
-
For example, if the multiset is [5, 3, 1], and its most
|
| 70 |
-
significant part has just been decremented to [5, 3, 0], the
|
| 71 |
-
spread operation will add a new part so that the stack becomes
|
| 72 |
-
[[5, 3, 0], [0, 0, 1]]. If the most significant part (for the
|
| 73 |
-
same multiset) has been decremented to [2, 0, 0] the stack becomes
|
| 74 |
-
[[2, 0, 0], [2, 0, 0], [1, 3, 1]]. In the pseudocode, the spread
|
| 75 |
-
operation for one part is step m2. The complete spread operation
|
| 76 |
-
is a loop of steps m2 and m3.
|
| 77 |
-
|
| 78 |
-
In order to facilitate the spread operation, Knuth stores, for each
|
| 79 |
-
component of each part, not just the multiplicity of that component
|
| 80 |
-
in the part, but also the total multiplicity available for this
|
| 81 |
-
component in this part or any lesser part above it on the stack.
|
| 82 |
-
|
| 83 |
-
One added twist is that Knuth does not represent the part vectors as
|
| 84 |
-
arrays. Instead, he uses a sparse representation, in which a
|
| 85 |
-
component of a part is represented as a component number (c), plus
|
| 86 |
-
the multiplicity of the component in that part (v) as well as the
|
| 87 |
-
total multiplicity available for that component (u). This saves
|
| 88 |
-
time that would be spent skipping over zeros.
|
| 89 |
-
|
| 90 |
-
"""
|
| 91 |
-
|
| 92 |
-
class PartComponent:
|
| 93 |
-
"""Internal class used in support of the multiset partitions
|
| 94 |
-
enumerators and the associated visitor functions.
|
| 95 |
-
|
| 96 |
-
Represents one component of one part of the current partition.
|
| 97 |
-
|
| 98 |
-
A stack of these, plus an auxiliary frame array, f, represents a
|
| 99 |
-
partition of the multiset.
|
| 100 |
-
|
| 101 |
-
Knuth's pseudocode makes c, u, and v separate arrays.
|
| 102 |
-
"""
|
| 103 |
-
|
| 104 |
-
__slots__ = ('c', 'u', 'v')
|
| 105 |
-
|
| 106 |
-
def __init__(self):
|
| 107 |
-
self.c = 0 # Component number
|
| 108 |
-
self.u = 0 # The as yet unpartitioned amount in component c
|
| 109 |
-
# *before* it is allocated by this triple
|
| 110 |
-
self.v = 0 # Amount of c component in the current part
|
| 111 |
-
# (v<=u). An invariant of the representation is
|
| 112 |
-
# that the next higher triple for this component
|
| 113 |
-
# (if there is one) will have a value of u-v in
|
| 114 |
-
# its u attribute.
|
| 115 |
-
|
| 116 |
-
def __repr__(self):
|
| 117 |
-
"for debug/algorithm animation purposes"
|
| 118 |
-
return 'c:%d u:%d v:%d' % (self.c, self.u, self.v)
|
| 119 |
-
|
| 120 |
-
def __eq__(self, other):
|
| 121 |
-
"""Define value oriented equality, which is useful for testers"""
|
| 122 |
-
return (isinstance(other, self.__class__) and
|
| 123 |
-
self.c == other.c and
|
| 124 |
-
self.u == other.u and
|
| 125 |
-
self.v == other.v)
|
| 126 |
-
|
| 127 |
-
def __ne__(self, other):
|
| 128 |
-
"""Defined for consistency with __eq__"""
|
| 129 |
-
return not self == other
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
# This function tries to be a faithful implementation of algorithm
|
| 133 |
-
# 7.1.2.5M in Volume 4A, Combinatoral Algorithms, Part 1, of The Art
|
| 134 |
-
# of Computer Programming, by Donald Knuth. This includes using
|
| 135 |
-
# (mostly) the same variable names, etc. This makes for rather
|
| 136 |
-
# low-level Python.
|
| 137 |
-
|
| 138 |
-
# Changes from Knuth's pseudocode include
|
| 139 |
-
# - use PartComponent struct/object instead of 3 arrays
|
| 140 |
-
# - make the function a generator
|
| 141 |
-
# - map (with some difficulty) the GOTOs to Python control structures.
|
| 142 |
-
# - Knuth uses 1-based numbering for components, this code is 0-based
|
| 143 |
-
# - renamed variable l to lpart.
|
| 144 |
-
# - flag variable x takes on values True/False instead of 1/0
|
| 145 |
-
#
|
| 146 |
-
def multiset_partitions_taocp(multiplicities):
|
| 147 |
-
"""Enumerates partitions of a multiset.
|
| 148 |
-
|
| 149 |
-
Parameters
|
| 150 |
-
==========
|
| 151 |
-
|
| 152 |
-
multiplicities
|
| 153 |
-
list of integer multiplicities of the components of the multiset.
|
| 154 |
-
|
| 155 |
-
Yields
|
| 156 |
-
======
|
| 157 |
-
|
| 158 |
-
state
|
| 159 |
-
Internal data structure which encodes a particular partition.
|
| 160 |
-
This output is then usually processed by a visitor function
|
| 161 |
-
which combines the information from this data structure with
|
| 162 |
-
the components themselves to produce an actual partition.
|
| 163 |
-
|
| 164 |
-
Unless they wish to create their own visitor function, users will
|
| 165 |
-
have little need to look inside this data structure. But, for
|
| 166 |
-
reference, it is a 3-element list with components:
|
| 167 |
-
|
| 168 |
-
f
|
| 169 |
-
is a frame array, which is used to divide pstack into parts.
|
| 170 |
-
|
| 171 |
-
lpart
|
| 172 |
-
points to the base of the topmost part.
|
| 173 |
-
|
| 174 |
-
pstack
|
| 175 |
-
is an array of PartComponent objects.
|
| 176 |
-
|
| 177 |
-
The ``state`` output offers a peek into the internal data
|
| 178 |
-
structures of the enumeration function. The client should
|
| 179 |
-
treat this as read-only; any modification of the data
|
| 180 |
-
structure will cause unpredictable (and almost certainly
|
| 181 |
-
incorrect) results. Also, the components of ``state`` are
|
| 182 |
-
modified in place at each iteration. Hence, the visitor must
|
| 183 |
-
be called at each loop iteration. Accumulating the ``state``
|
| 184 |
-
instances and processing them later will not work.
|
| 185 |
-
|
| 186 |
-
Examples
|
| 187 |
-
========
|
| 188 |
-
|
| 189 |
-
>>> from sympy.utilities.enumerative import list_visitor
|
| 190 |
-
>>> from sympy.utilities.enumerative import multiset_partitions_taocp
|
| 191 |
-
>>> # variables components and multiplicities represent the multiset 'abb'
|
| 192 |
-
>>> components = 'ab'
|
| 193 |
-
>>> multiplicities = [1, 2]
|
| 194 |
-
>>> states = multiset_partitions_taocp(multiplicities)
|
| 195 |
-
>>> list(list_visitor(state, components) for state in states)
|
| 196 |
-
[[['a', 'b', 'b']],
|
| 197 |
-
[['a', 'b'], ['b']],
|
| 198 |
-
[['a'], ['b', 'b']],
|
| 199 |
-
[['a'], ['b'], ['b']]]
|
| 200 |
-
|
| 201 |
-
See Also
|
| 202 |
-
========
|
| 203 |
-
|
| 204 |
-
sympy.utilities.iterables.multiset_partitions: Takes a multiset
|
| 205 |
-
as input and directly yields multiset partitions. It
|
| 206 |
-
dispatches to a number of functions, including this one, for
|
| 207 |
-
implementation. Most users will find it more convenient to
|
| 208 |
-
use than multiset_partitions_taocp.
|
| 209 |
-
|
| 210 |
-
"""
|
| 211 |
-
|
| 212 |
-
# Important variables.
|
| 213 |
-
# m is the number of components, i.e., number of distinct elements
|
| 214 |
-
m = len(multiplicities)
|
| 215 |
-
# n is the cardinality, total number of elements whether or not distinct
|
| 216 |
-
n = sum(multiplicities)
|
| 217 |
-
|
| 218 |
-
# The main data structure, f segments pstack into parts. See
|
| 219 |
-
# list_visitor() for example code indicating how this internal
|
| 220 |
-
# state corresponds to a partition.
|
| 221 |
-
|
| 222 |
-
# Note: allocation of space for stack is conservative. Knuth's
|
| 223 |
-
# exercise 7.2.1.5.68 gives some indication of how to tighten this
|
| 224 |
-
# bound, but this is not implemented.
|
| 225 |
-
pstack = [PartComponent() for i in range(n * m + 1)]
|
| 226 |
-
f = [0] * (n + 1)
|
| 227 |
-
|
| 228 |
-
# Step M1 in Knuth (Initialize)
|
| 229 |
-
# Initial state - entire multiset in one part.
|
| 230 |
-
for j in range(m):
|
| 231 |
-
ps = pstack[j]
|
| 232 |
-
ps.c = j
|
| 233 |
-
ps.u = multiplicities[j]
|
| 234 |
-
ps.v = multiplicities[j]
|
| 235 |
-
|
| 236 |
-
# Other variables
|
| 237 |
-
f[0] = 0
|
| 238 |
-
a = 0
|
| 239 |
-
lpart = 0
|
| 240 |
-
f[1] = m
|
| 241 |
-
b = m # in general, current stack frame is from a to b - 1
|
| 242 |
-
|
| 243 |
-
while True:
|
| 244 |
-
while True:
|
| 245 |
-
# Step M2 (Subtract v from u)
|
| 246 |
-
k = b
|
| 247 |
-
x = False
|
| 248 |
-
for j in range(a, b):
|
| 249 |
-
pstack[k].u = pstack[j].u - pstack[j].v
|
| 250 |
-
if pstack[k].u == 0:
|
| 251 |
-
x = True
|
| 252 |
-
elif not x:
|
| 253 |
-
pstack[k].c = pstack[j].c
|
| 254 |
-
pstack[k].v = min(pstack[j].v, pstack[k].u)
|
| 255 |
-
x = pstack[k].u < pstack[j].v
|
| 256 |
-
k = k + 1
|
| 257 |
-
else: # x is True
|
| 258 |
-
pstack[k].c = pstack[j].c
|
| 259 |
-
pstack[k].v = pstack[k].u
|
| 260 |
-
k = k + 1
|
| 261 |
-
# Note: x is True iff v has changed
|
| 262 |
-
|
| 263 |
-
# Step M3 (Push if nonzero.)
|
| 264 |
-
if k > b:
|
| 265 |
-
a = b
|
| 266 |
-
b = k
|
| 267 |
-
lpart = lpart + 1
|
| 268 |
-
f[lpart + 1] = b
|
| 269 |
-
# Return to M2
|
| 270 |
-
else:
|
| 271 |
-
break # Continue to M4
|
| 272 |
-
|
| 273 |
-
# M4 Visit a partition
|
| 274 |
-
state = [f, lpart, pstack]
|
| 275 |
-
yield state
|
| 276 |
-
|
| 277 |
-
# M5 (Decrease v)
|
| 278 |
-
while True:
|
| 279 |
-
j = b-1
|
| 280 |
-
while (pstack[j].v == 0):
|
| 281 |
-
j = j - 1
|
| 282 |
-
if j == a and pstack[j].v == 1:
|
| 283 |
-
# M6 (Backtrack)
|
| 284 |
-
if lpart == 0:
|
| 285 |
-
return
|
| 286 |
-
lpart = lpart - 1
|
| 287 |
-
b = a
|
| 288 |
-
a = f[lpart]
|
| 289 |
-
# Return to M5
|
| 290 |
-
else:
|
| 291 |
-
pstack[j].v = pstack[j].v - 1
|
| 292 |
-
for k in range(j + 1, b):
|
| 293 |
-
pstack[k].v = pstack[k].u
|
| 294 |
-
break # GOTO M2
|
| 295 |
-
|
| 296 |
-
# --------------- Visitor functions for multiset partitions ---------------
|
| 297 |
-
# A visitor takes the partition state generated by
|
| 298 |
-
# multiset_partitions_taocp or other enumerator, and produces useful
|
| 299 |
-
# output (such as the actual partition).
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
def factoring_visitor(state, primes):
|
| 303 |
-
"""Use with multiset_partitions_taocp to enumerate the ways a
|
| 304 |
-
number can be expressed as a product of factors. For this usage,
|
| 305 |
-
the exponents of the prime factors of a number are arguments to
|
| 306 |
-
the partition enumerator, while the corresponding prime factors
|
| 307 |
-
are input here.
|
| 308 |
-
|
| 309 |
-
Examples
|
| 310 |
-
========
|
| 311 |
-
|
| 312 |
-
To enumerate the factorings of a number we can think of the elements of the
|
| 313 |
-
partition as being the prime factors and the multiplicities as being their
|
| 314 |
-
exponents.
|
| 315 |
-
|
| 316 |
-
>>> from sympy.utilities.enumerative import factoring_visitor
|
| 317 |
-
>>> from sympy.utilities.enumerative import multiset_partitions_taocp
|
| 318 |
-
>>> from sympy import factorint
|
| 319 |
-
>>> primes, multiplicities = zip(*factorint(24).items())
|
| 320 |
-
>>> primes
|
| 321 |
-
(2, 3)
|
| 322 |
-
>>> multiplicities
|
| 323 |
-
(3, 1)
|
| 324 |
-
>>> states = multiset_partitions_taocp(multiplicities)
|
| 325 |
-
>>> list(factoring_visitor(state, primes) for state in states)
|
| 326 |
-
[[24], [8, 3], [12, 2], [4, 6], [4, 2, 3], [6, 2, 2], [2, 2, 2, 3]]
|
| 327 |
-
"""
|
| 328 |
-
f, lpart, pstack = state
|
| 329 |
-
factoring = []
|
| 330 |
-
for i in range(lpart + 1):
|
| 331 |
-
factor = 1
|
| 332 |
-
for ps in pstack[f[i]: f[i + 1]]:
|
| 333 |
-
if ps.v > 0:
|
| 334 |
-
factor *= primes[ps.c] ** ps.v
|
| 335 |
-
factoring.append(factor)
|
| 336 |
-
return factoring
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
def list_visitor(state, components):
|
| 340 |
-
"""Return a list of lists to represent the partition.
|
| 341 |
-
|
| 342 |
-
Examples
|
| 343 |
-
========
|
| 344 |
-
|
| 345 |
-
>>> from sympy.utilities.enumerative import list_visitor
|
| 346 |
-
>>> from sympy.utilities.enumerative import multiset_partitions_taocp
|
| 347 |
-
>>> states = multiset_partitions_taocp([1, 2, 1])
|
| 348 |
-
>>> s = next(states)
|
| 349 |
-
>>> list_visitor(s, 'abc') # for multiset 'a b b c'
|
| 350 |
-
[['a', 'b', 'b', 'c']]
|
| 351 |
-
>>> s = next(states)
|
| 352 |
-
>>> list_visitor(s, [1, 2, 3]) # for multiset '1 2 2 3
|
| 353 |
-
[[1, 2, 2], [3]]
|
| 354 |
-
"""
|
| 355 |
-
f, lpart, pstack = state
|
| 356 |
-
|
| 357 |
-
partition = []
|
| 358 |
-
for i in range(lpart+1):
|
| 359 |
-
part = []
|
| 360 |
-
for ps in pstack[f[i]:f[i+1]]:
|
| 361 |
-
if ps.v > 0:
|
| 362 |
-
part.extend([components[ps.c]] * ps.v)
|
| 363 |
-
partition.append(part)
|
| 364 |
-
|
| 365 |
-
return partition
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
class MultisetPartitionTraverser():
|
| 369 |
-
"""
|
| 370 |
-
Has methods to ``enumerate`` and ``count`` the partitions of a multiset.
|
| 371 |
-
|
| 372 |
-
This implements a refactored and extended version of Knuth's algorithm
|
| 373 |
-
7.1.2.5M [AOCP]_."
|
| 374 |
-
|
| 375 |
-
The enumeration methods of this class are generators and return
|
| 376 |
-
data structures which can be interpreted by the same visitor
|
| 377 |
-
functions used for the output of ``multiset_partitions_taocp``.
|
| 378 |
-
|
| 379 |
-
Examples
|
| 380 |
-
========
|
| 381 |
-
|
| 382 |
-
>>> from sympy.utilities.enumerative import MultisetPartitionTraverser
|
| 383 |
-
>>> m = MultisetPartitionTraverser()
|
| 384 |
-
>>> m.count_partitions([4,4,4,2])
|
| 385 |
-
127750
|
| 386 |
-
>>> m.count_partitions([3,3,3])
|
| 387 |
-
686
|
| 388 |
-
|
| 389 |
-
See Also
|
| 390 |
-
========
|
| 391 |
-
|
| 392 |
-
multiset_partitions_taocp
|
| 393 |
-
sympy.utilities.iterables.multiset_partitions
|
| 394 |
-
|
| 395 |
-
References
|
| 396 |
-
==========
|
| 397 |
-
|
| 398 |
-
.. [AOCP] Algorithm 7.1.2.5M in Volume 4A, Combinatoral Algorithms,
|
| 399 |
-
Part 1, of The Art of Computer Programming, by Donald Knuth.
|
| 400 |
-
|
| 401 |
-
.. [Factorisatio] On a Problem of Oppenheim concerning
|
| 402 |
-
"Factorisatio Numerorum" E. R. Canfield, Paul Erdos, Carl
|
| 403 |
-
Pomerance, JOURNAL OF NUMBER THEORY, Vol. 17, No. 1. August
|
| 404 |
-
1983. See section 7 for a description of an algorithm
|
| 405 |
-
similar to Knuth's.
|
| 406 |
-
|
| 407 |
-
.. [Yorgey] Generating Multiset Partitions, Brent Yorgey, The
|
| 408 |
-
Monad.Reader, Issue 8, September 2007.
|
| 409 |
-
|
| 410 |
-
"""
|
| 411 |
-
|
| 412 |
-
def __init__(self):
|
| 413 |
-
self.debug = False
|
| 414 |
-
# TRACING variables. These are useful for gathering
|
| 415 |
-
# statistics on the algorithm itself, but have no particular
|
| 416 |
-
# benefit to a user of the code.
|
| 417 |
-
self.k1 = 0
|
| 418 |
-
self.k2 = 0
|
| 419 |
-
self.p1 = 0
|
| 420 |
-
self.pstack = None
|
| 421 |
-
self.f = None
|
| 422 |
-
self.lpart = 0
|
| 423 |
-
self.discarded = 0
|
| 424 |
-
# dp_stack is list of lists of (part_key, start_count) pairs
|
| 425 |
-
self.dp_stack = []
|
| 426 |
-
|
| 427 |
-
# dp_map is map part_key-> count, where count represents the
|
| 428 |
-
# number of multiset which are descendants of a part with this
|
| 429 |
-
# key, **or any of its decrements**
|
| 430 |
-
|
| 431 |
-
# Thus, when we find a part in the map, we add its count
|
| 432 |
-
# value to the running total, cut off the enumeration, and
|
| 433 |
-
# backtrack
|
| 434 |
-
|
| 435 |
-
if not hasattr(self, 'dp_map'):
|
| 436 |
-
self.dp_map = {}
|
| 437 |
-
|
| 438 |
-
def db_trace(self, msg):
|
| 439 |
-
"""Useful for understanding/debugging the algorithms. Not
|
| 440 |
-
generally activated in end-user code."""
|
| 441 |
-
if self.debug:
|
| 442 |
-
# XXX: animation_visitor is undefined... Clearly this does not
|
| 443 |
-
# work and was not tested. Previous code in comments below.
|
| 444 |
-
raise RuntimeError
|
| 445 |
-
#letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 446 |
-
#state = [self.f, self.lpart, self.pstack]
|
| 447 |
-
#print("DBG:", msg,
|
| 448 |
-
# ["".join(part) for part in list_visitor(state, letters)],
|
| 449 |
-
# animation_visitor(state))
|
| 450 |
-
|
| 451 |
-
#
|
| 452 |
-
# Helper methods for enumeration
|
| 453 |
-
#
|
| 454 |
-
def _initialize_enumeration(self, multiplicities):
|
| 455 |
-
"""Allocates and initializes the partition stack.
|
| 456 |
-
|
| 457 |
-
This is called from the enumeration/counting routines, so
|
| 458 |
-
there is no need to call it separately."""
|
| 459 |
-
|
| 460 |
-
num_components = len(multiplicities)
|
| 461 |
-
# cardinality is the total number of elements, whether or not distinct
|
| 462 |
-
cardinality = sum(multiplicities)
|
| 463 |
-
|
| 464 |
-
# pstack is the partition stack, which is segmented by
|
| 465 |
-
# f into parts.
|
| 466 |
-
self.pstack = [PartComponent() for i in
|
| 467 |
-
range(num_components * cardinality + 1)]
|
| 468 |
-
self.f = [0] * (cardinality + 1)
|
| 469 |
-
|
| 470 |
-
# Initial state - entire multiset in one part.
|
| 471 |
-
for j in range(num_components):
|
| 472 |
-
ps = self.pstack[j]
|
| 473 |
-
ps.c = j
|
| 474 |
-
ps.u = multiplicities[j]
|
| 475 |
-
ps.v = multiplicities[j]
|
| 476 |
-
|
| 477 |
-
self.f[0] = 0
|
| 478 |
-
self.f[1] = num_components
|
| 479 |
-
self.lpart = 0
|
| 480 |
-
|
| 481 |
-
# The decrement_part() method corresponds to step M5 in Knuth's
|
| 482 |
-
# algorithm. This is the base version for enum_all(). Modified
|
| 483 |
-
# versions of this method are needed if we want to restrict
|
| 484 |
-
# sizes of the partitions produced.
|
| 485 |
-
def decrement_part(self, part):
|
| 486 |
-
"""Decrements part (a subrange of pstack), if possible, returning
|
| 487 |
-
True iff the part was successfully decremented.
|
| 488 |
-
|
| 489 |
-
If you think of the v values in the part as a multi-digit
|
| 490 |
-
integer (least significant digit on the right) this is
|
| 491 |
-
basically decrementing that integer, but with the extra
|
| 492 |
-
constraint that the leftmost digit cannot be decremented to 0.
|
| 493 |
-
|
| 494 |
-
Parameters
|
| 495 |
-
==========
|
| 496 |
-
|
| 497 |
-
part
|
| 498 |
-
The part, represented as a list of PartComponent objects,
|
| 499 |
-
which is to be decremented.
|
| 500 |
-
|
| 501 |
-
"""
|
| 502 |
-
plen = len(part)
|
| 503 |
-
for j in range(plen - 1, -1, -1):
|
| 504 |
-
if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0:
|
| 505 |
-
# found val to decrement
|
| 506 |
-
part[j].v -= 1
|
| 507 |
-
# Reset trailing parts back to maximum
|
| 508 |
-
for k in range(j + 1, plen):
|
| 509 |
-
part[k].v = part[k].u
|
| 510 |
-
return True
|
| 511 |
-
return False
|
| 512 |
-
|
| 513 |
-
# Version to allow number of parts to be bounded from above.
|
| 514 |
-
# Corresponds to (a modified) step M5.
|
| 515 |
-
def decrement_part_small(self, part, ub):
|
| 516 |
-
"""Decrements part (a subrange of pstack), if possible, returning
|
| 517 |
-
True iff the part was successfully decremented.
|
| 518 |
-
|
| 519 |
-
Parameters
|
| 520 |
-
==========
|
| 521 |
-
|
| 522 |
-
part
|
| 523 |
-
part to be decremented (topmost part on the stack)
|
| 524 |
-
|
| 525 |
-
ub
|
| 526 |
-
the maximum number of parts allowed in a partition
|
| 527 |
-
returned by the calling traversal.
|
| 528 |
-
|
| 529 |
-
Notes
|
| 530 |
-
=====
|
| 531 |
-
|
| 532 |
-
The goal of this modification of the ordinary decrement method
|
| 533 |
-
is to fail (meaning that the subtree rooted at this part is to
|
| 534 |
-
be skipped) when it can be proved that this part can only have
|
| 535 |
-
child partitions which are larger than allowed by ``ub``. If a
|
| 536 |
-
decision is made to fail, it must be accurate, otherwise the
|
| 537 |
-
enumeration will miss some partitions. But, it is OK not to
|
| 538 |
-
capture all the possible failures -- if a part is passed that
|
| 539 |
-
should not be, the resulting too-large partitions are filtered
|
| 540 |
-
by the enumeration one level up. However, as is usual in
|
| 541 |
-
constrained enumerations, failing early is advantageous.
|
| 542 |
-
|
| 543 |
-
The tests used by this method catch the most common cases,
|
| 544 |
-
although this implementation is by no means the last word on
|
| 545 |
-
this problem. The tests include:
|
| 546 |
-
|
| 547 |
-
1) ``lpart`` must be less than ``ub`` by at least 2. This is because
|
| 548 |
-
once a part has been decremented, the partition
|
| 549 |
-
will gain at least one child in the spread step.
|
| 550 |
-
|
| 551 |
-
2) If the leading component of the part is about to be
|
| 552 |
-
decremented, check for how many parts will be added in
|
| 553 |
-
order to use up the unallocated multiplicity in that
|
| 554 |
-
leading component, and fail if this number is greater than
|
| 555 |
-
allowed by ``ub``. (See code for the exact expression.) This
|
| 556 |
-
test is given in the answer to Knuth's problem 7.2.1.5.69.
|
| 557 |
-
|
| 558 |
-
3) If there is *exactly* enough room to expand the leading
|
| 559 |
-
component by the above test, check the next component (if
|
| 560 |
-
it exists) once decrementing has finished. If this has
|
| 561 |
-
``v == 0``, this next component will push the expansion over the
|
| 562 |
-
limit by 1, so fail.
|
| 563 |
-
"""
|
| 564 |
-
if self.lpart >= ub - 1:
|
| 565 |
-
self.p1 += 1 # increment to keep track of usefulness of tests
|
| 566 |
-
return False
|
| 567 |
-
plen = len(part)
|
| 568 |
-
for j in range(plen - 1, -1, -1):
|
| 569 |
-
# Knuth's mod, (answer to problem 7.2.1.5.69)
|
| 570 |
-
if j == 0 and (part[0].v - 1)*(ub - self.lpart) < part[0].u:
|
| 571 |
-
self.k1 += 1
|
| 572 |
-
return False
|
| 573 |
-
|
| 574 |
-
if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0:
|
| 575 |
-
# found val to decrement
|
| 576 |
-
part[j].v -= 1
|
| 577 |
-
# Reset trailing parts back to maximum
|
| 578 |
-
for k in range(j + 1, plen):
|
| 579 |
-
part[k].v = part[k].u
|
| 580 |
-
|
| 581 |
-
# Have now decremented part, but are we doomed to
|
| 582 |
-
# failure when it is expanded? Check one oddball case
|
| 583 |
-
# that turns out to be surprisingly common - exactly
|
| 584 |
-
# enough room to expand the leading component, but no
|
| 585 |
-
# room for the second component, which has v=0.
|
| 586 |
-
if (plen > 1 and part[1].v == 0 and
|
| 587 |
-
(part[0].u - part[0].v) ==
|
| 588 |
-
((ub - self.lpart - 1) * part[0].v)):
|
| 589 |
-
self.k2 += 1
|
| 590 |
-
self.db_trace("Decrement fails test 3")
|
| 591 |
-
return False
|
| 592 |
-
return True
|
| 593 |
-
return False
|
| 594 |
-
|
| 595 |
-
def decrement_part_large(self, part, amt, lb):
|
| 596 |
-
"""Decrements part, while respecting size constraint.
|
| 597 |
-
|
| 598 |
-
A part can have no children which are of sufficient size (as
|
| 599 |
-
indicated by ``lb``) unless that part has sufficient
|
| 600 |
-
unallocated multiplicity. When enforcing the size constraint,
|
| 601 |
-
this method will decrement the part (if necessary) by an
|
| 602 |
-
amount needed to ensure sufficient unallocated multiplicity.
|
| 603 |
-
|
| 604 |
-
Returns True iff the part was successfully decremented.
|
| 605 |
-
|
| 606 |
-
Parameters
|
| 607 |
-
==========
|
| 608 |
-
|
| 609 |
-
part
|
| 610 |
-
part to be decremented (topmost part on the stack)
|
| 611 |
-
|
| 612 |
-
amt
|
| 613 |
-
Can only take values 0 or 1. A value of 1 means that the
|
| 614 |
-
part must be decremented, and then the size constraint is
|
| 615 |
-
enforced. A value of 0 means just to enforce the ``lb``
|
| 616 |
-
size constraint.
|
| 617 |
-
|
| 618 |
-
lb
|
| 619 |
-
The partitions produced by the calling enumeration must
|
| 620 |
-
have more parts than this value.
|
| 621 |
-
|
| 622 |
-
"""
|
| 623 |
-
|
| 624 |
-
if amt == 1:
|
| 625 |
-
# In this case we always need to decrement, *before*
|
| 626 |
-
# enforcing the "sufficient unallocated multiplicity"
|
| 627 |
-
# constraint. Easiest for this is just to call the
|
| 628 |
-
# regular decrement method.
|
| 629 |
-
if not self.decrement_part(part):
|
| 630 |
-
return False
|
| 631 |
-
|
| 632 |
-
# Next, perform any needed additional decrementing to respect
|
| 633 |
-
# "sufficient unallocated multiplicity" (or fail if this is
|
| 634 |
-
# not possible).
|
| 635 |
-
min_unalloc = lb - self.lpart
|
| 636 |
-
if min_unalloc <= 0:
|
| 637 |
-
return True
|
| 638 |
-
total_mult = sum(pc.u for pc in part)
|
| 639 |
-
total_alloc = sum(pc.v for pc in part)
|
| 640 |
-
if total_mult <= min_unalloc:
|
| 641 |
-
return False
|
| 642 |
-
|
| 643 |
-
deficit = min_unalloc - (total_mult - total_alloc)
|
| 644 |
-
if deficit <= 0:
|
| 645 |
-
return True
|
| 646 |
-
|
| 647 |
-
for i in range(len(part) - 1, -1, -1):
|
| 648 |
-
if i == 0:
|
| 649 |
-
if part[0].v > deficit:
|
| 650 |
-
part[0].v -= deficit
|
| 651 |
-
return True
|
| 652 |
-
else:
|
| 653 |
-
return False # This shouldn't happen, due to above check
|
| 654 |
-
else:
|
| 655 |
-
if part[i].v >= deficit:
|
| 656 |
-
part[i].v -= deficit
|
| 657 |
-
return True
|
| 658 |
-
else:
|
| 659 |
-
deficit -= part[i].v
|
| 660 |
-
part[i].v = 0
|
| 661 |
-
|
| 662 |
-
def decrement_part_range(self, part, lb, ub):
|
| 663 |
-
"""Decrements part (a subrange of pstack), if possible, returning
|
| 664 |
-
True iff the part was successfully decremented.
|
| 665 |
-
|
| 666 |
-
Parameters
|
| 667 |
-
==========
|
| 668 |
-
|
| 669 |
-
part
|
| 670 |
-
part to be decremented (topmost part on the stack)
|
| 671 |
-
|
| 672 |
-
ub
|
| 673 |
-
the maximum number of parts allowed in a partition
|
| 674 |
-
returned by the calling traversal.
|
| 675 |
-
|
| 676 |
-
lb
|
| 677 |
-
The partitions produced by the calling enumeration must
|
| 678 |
-
have more parts than this value.
|
| 679 |
-
|
| 680 |
-
Notes
|
| 681 |
-
=====
|
| 682 |
-
|
| 683 |
-
Combines the constraints of _small and _large decrement
|
| 684 |
-
methods. If returns success, part has been decremented at
|
| 685 |
-
least once, but perhaps by quite a bit more if needed to meet
|
| 686 |
-
the lb constraint.
|
| 687 |
-
"""
|
| 688 |
-
|
| 689 |
-
# Constraint in the range case is just enforcing both the
|
| 690 |
-
# constraints from _small and _large cases. Note the 0 as the
|
| 691 |
-
# second argument to the _large call -- this is the signal to
|
| 692 |
-
# decrement only as needed to for constraint enforcement. The
|
| 693 |
-
# short circuiting and left-to-right order of the 'and'
|
| 694 |
-
# operator is important for this to work correctly.
|
| 695 |
-
return self.decrement_part_small(part, ub) and \
|
| 696 |
-
self.decrement_part_large(part, 0, lb)
|
| 697 |
-
|
| 698 |
-
def spread_part_multiplicity(self):
|
| 699 |
-
"""Returns True if a new part has been created, and
|
| 700 |
-
adjusts pstack, f and lpart as needed.
|
| 701 |
-
|
| 702 |
-
Notes
|
| 703 |
-
=====
|
| 704 |
-
|
| 705 |
-
Spreads unallocated multiplicity from the current top part
|
| 706 |
-
into a new part created above the current on the stack. This
|
| 707 |
-
new part is constrained to be less than or equal to the old in
|
| 708 |
-
terms of the part ordering.
|
| 709 |
-
|
| 710 |
-
This call does nothing (and returns False) if the current top
|
| 711 |
-
part has no unallocated multiplicity.
|
| 712 |
-
|
| 713 |
-
"""
|
| 714 |
-
j = self.f[self.lpart] # base of current top part
|
| 715 |
-
k = self.f[self.lpart + 1] # ub of current; potential base of next
|
| 716 |
-
base = k # save for later comparison
|
| 717 |
-
|
| 718 |
-
changed = False # Set to true when the new part (so far) is
|
| 719 |
-
# strictly less than (as opposed to less than
|
| 720 |
-
# or equal) to the old.
|
| 721 |
-
for j in range(self.f[self.lpart], self.f[self.lpart + 1]):
|
| 722 |
-
self.pstack[k].u = self.pstack[j].u - self.pstack[j].v
|
| 723 |
-
if self.pstack[k].u == 0:
|
| 724 |
-
changed = True
|
| 725 |
-
else:
|
| 726 |
-
self.pstack[k].c = self.pstack[j].c
|
| 727 |
-
if changed: # Put all available multiplicity in this part
|
| 728 |
-
self.pstack[k].v = self.pstack[k].u
|
| 729 |
-
else: # Still maintaining ordering constraint
|
| 730 |
-
if self.pstack[k].u < self.pstack[j].v:
|
| 731 |
-
self.pstack[k].v = self.pstack[k].u
|
| 732 |
-
changed = True
|
| 733 |
-
else:
|
| 734 |
-
self.pstack[k].v = self.pstack[j].v
|
| 735 |
-
k = k + 1
|
| 736 |
-
if k > base:
|
| 737 |
-
# Adjust for the new part on stack
|
| 738 |
-
self.lpart = self.lpart + 1
|
| 739 |
-
self.f[self.lpart + 1] = k
|
| 740 |
-
return True
|
| 741 |
-
return False
|
| 742 |
-
|
| 743 |
-
def top_part(self):
|
| 744 |
-
"""Return current top part on the stack, as a slice of pstack.
|
| 745 |
-
|
| 746 |
-
"""
|
| 747 |
-
return self.pstack[self.f[self.lpart]:self.f[self.lpart + 1]]
|
| 748 |
-
|
| 749 |
-
# Same interface and functionality as multiset_partitions_taocp(),
|
| 750 |
-
# but some might find this refactored version easier to follow.
|
| 751 |
-
def enum_all(self, multiplicities):
|
| 752 |
-
"""Enumerate the partitions of a multiset.
|
| 753 |
-
|
| 754 |
-
Examples
|
| 755 |
-
========
|
| 756 |
-
|
| 757 |
-
>>> from sympy.utilities.enumerative import list_visitor
|
| 758 |
-
>>> from sympy.utilities.enumerative import MultisetPartitionTraverser
|
| 759 |
-
>>> m = MultisetPartitionTraverser()
|
| 760 |
-
>>> states = m.enum_all([2,2])
|
| 761 |
-
>>> list(list_visitor(state, 'ab') for state in states)
|
| 762 |
-
[[['a', 'a', 'b', 'b']],
|
| 763 |
-
[['a', 'a', 'b'], ['b']],
|
| 764 |
-
[['a', 'a'], ['b', 'b']],
|
| 765 |
-
[['a', 'a'], ['b'], ['b']],
|
| 766 |
-
[['a', 'b', 'b'], ['a']],
|
| 767 |
-
[['a', 'b'], ['a', 'b']],
|
| 768 |
-
[['a', 'b'], ['a'], ['b']],
|
| 769 |
-
[['a'], ['a'], ['b', 'b']],
|
| 770 |
-
[['a'], ['a'], ['b'], ['b']]]
|
| 771 |
-
|
| 772 |
-
See Also
|
| 773 |
-
========
|
| 774 |
-
|
| 775 |
-
multiset_partitions_taocp:
|
| 776 |
-
which provides the same result as this method, but is
|
| 777 |
-
about twice as fast. Hence, enum_all is primarily useful
|
| 778 |
-
for testing. Also see the function for a discussion of
|
| 779 |
-
states and visitors.
|
| 780 |
-
|
| 781 |
-
"""
|
| 782 |
-
self._initialize_enumeration(multiplicities)
|
| 783 |
-
while True:
|
| 784 |
-
while self.spread_part_multiplicity():
|
| 785 |
-
pass
|
| 786 |
-
|
| 787 |
-
# M4 Visit a partition
|
| 788 |
-
state = [self.f, self.lpart, self.pstack]
|
| 789 |
-
yield state
|
| 790 |
-
|
| 791 |
-
# M5 (Decrease v)
|
| 792 |
-
while not self.decrement_part(self.top_part()):
|
| 793 |
-
# M6 (Backtrack)
|
| 794 |
-
if self.lpart == 0:
|
| 795 |
-
return
|
| 796 |
-
self.lpart -= 1
|
| 797 |
-
|
| 798 |
-
def enum_small(self, multiplicities, ub):
|
| 799 |
-
"""Enumerate multiset partitions with no more than ``ub`` parts.
|
| 800 |
-
|
| 801 |
-
Equivalent to enum_range(multiplicities, 0, ub)
|
| 802 |
-
|
| 803 |
-
Parameters
|
| 804 |
-
==========
|
| 805 |
-
|
| 806 |
-
multiplicities
|
| 807 |
-
list of multiplicities of the components of the multiset.
|
| 808 |
-
|
| 809 |
-
ub
|
| 810 |
-
Maximum number of parts
|
| 811 |
-
|
| 812 |
-
Examples
|
| 813 |
-
========
|
| 814 |
-
|
| 815 |
-
>>> from sympy.utilities.enumerative import list_visitor
|
| 816 |
-
>>> from sympy.utilities.enumerative import MultisetPartitionTraverser
|
| 817 |
-
>>> m = MultisetPartitionTraverser()
|
| 818 |
-
>>> states = m.enum_small([2,2], 2)
|
| 819 |
-
>>> list(list_visitor(state, 'ab') for state in states)
|
| 820 |
-
[[['a', 'a', 'b', 'b']],
|
| 821 |
-
[['a', 'a', 'b'], ['b']],
|
| 822 |
-
[['a', 'a'], ['b', 'b']],
|
| 823 |
-
[['a', 'b', 'b'], ['a']],
|
| 824 |
-
[['a', 'b'], ['a', 'b']]]
|
| 825 |
-
|
| 826 |
-
The implementation is based, in part, on the answer given to
|
| 827 |
-
exercise 69, in Knuth [AOCP]_.
|
| 828 |
-
|
| 829 |
-
See Also
|
| 830 |
-
========
|
| 831 |
-
|
| 832 |
-
enum_all, enum_large, enum_range
|
| 833 |
-
|
| 834 |
-
"""
|
| 835 |
-
|
| 836 |
-
# Keep track of iterations which do not yield a partition.
|
| 837 |
-
# Clearly, we would like to keep this number small.
|
| 838 |
-
self.discarded = 0
|
| 839 |
-
if ub <= 0:
|
| 840 |
-
return
|
| 841 |
-
self._initialize_enumeration(multiplicities)
|
| 842 |
-
while True:
|
| 843 |
-
while self.spread_part_multiplicity():
|
| 844 |
-
self.db_trace('spread 1')
|
| 845 |
-
if self.lpart >= ub:
|
| 846 |
-
self.discarded += 1
|
| 847 |
-
self.db_trace(' Discarding')
|
| 848 |
-
self.lpart = ub - 2
|
| 849 |
-
break
|
| 850 |
-
else:
|
| 851 |
-
# M4 Visit a partition
|
| 852 |
-
state = [self.f, self.lpart, self.pstack]
|
| 853 |
-
yield state
|
| 854 |
-
|
| 855 |
-
# M5 (Decrease v)
|
| 856 |
-
while not self.decrement_part_small(self.top_part(), ub):
|
| 857 |
-
self.db_trace("Failed decrement, going to backtrack")
|
| 858 |
-
# M6 (Backtrack)
|
| 859 |
-
if self.lpart == 0:
|
| 860 |
-
return
|
| 861 |
-
self.lpart -= 1
|
| 862 |
-
self.db_trace("Backtracked to")
|
| 863 |
-
self.db_trace("decrement ok, about to expand")
|
| 864 |
-
|
| 865 |
-
def enum_large(self, multiplicities, lb):
|
| 866 |
-
"""Enumerate the partitions of a multiset with lb < num(parts)
|
| 867 |
-
|
| 868 |
-
Equivalent to enum_range(multiplicities, lb, sum(multiplicities))
|
| 869 |
-
|
| 870 |
-
Parameters
|
| 871 |
-
==========
|
| 872 |
-
|
| 873 |
-
multiplicities
|
| 874 |
-
list of multiplicities of the components of the multiset.
|
| 875 |
-
|
| 876 |
-
lb
|
| 877 |
-
Number of parts in the partition must be greater than
|
| 878 |
-
this lower bound.
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
Examples
|
| 882 |
-
========
|
| 883 |
-
|
| 884 |
-
>>> from sympy.utilities.enumerative import list_visitor
|
| 885 |
-
>>> from sympy.utilities.enumerative import MultisetPartitionTraverser
|
| 886 |
-
>>> m = MultisetPartitionTraverser()
|
| 887 |
-
>>> states = m.enum_large([2,2], 2)
|
| 888 |
-
>>> list(list_visitor(state, 'ab') for state in states)
|
| 889 |
-
[[['a', 'a'], ['b'], ['b']],
|
| 890 |
-
[['a', 'b'], ['a'], ['b']],
|
| 891 |
-
[['a'], ['a'], ['b', 'b']],
|
| 892 |
-
[['a'], ['a'], ['b'], ['b']]]
|
| 893 |
-
|
| 894 |
-
See Also
|
| 895 |
-
========
|
| 896 |
-
|
| 897 |
-
enum_all, enum_small, enum_range
|
| 898 |
-
|
| 899 |
-
"""
|
| 900 |
-
self.discarded = 0
|
| 901 |
-
if lb >= sum(multiplicities):
|
| 902 |
-
return
|
| 903 |
-
self._initialize_enumeration(multiplicities)
|
| 904 |
-
self.decrement_part_large(self.top_part(), 0, lb)
|
| 905 |
-
while True:
|
| 906 |
-
good_partition = True
|
| 907 |
-
while self.spread_part_multiplicity():
|
| 908 |
-
if not self.decrement_part_large(self.top_part(), 0, lb):
|
| 909 |
-
# Failure here should be rare/impossible
|
| 910 |
-
self.discarded += 1
|
| 911 |
-
good_partition = False
|
| 912 |
-
break
|
| 913 |
-
|
| 914 |
-
# M4 Visit a partition
|
| 915 |
-
if good_partition:
|
| 916 |
-
state = [self.f, self.lpart, self.pstack]
|
| 917 |
-
yield state
|
| 918 |
-
|
| 919 |
-
# M5 (Decrease v)
|
| 920 |
-
while not self.decrement_part_large(self.top_part(), 1, lb):
|
| 921 |
-
# M6 (Backtrack)
|
| 922 |
-
if self.lpart == 0:
|
| 923 |
-
return
|
| 924 |
-
self.lpart -= 1
|
| 925 |
-
|
| 926 |
-
def enum_range(self, multiplicities, lb, ub):
|
| 927 |
-
|
| 928 |
-
"""Enumerate the partitions of a multiset with
|
| 929 |
-
``lb < num(parts) <= ub``.
|
| 930 |
-
|
| 931 |
-
In particular, if partitions with exactly ``k`` parts are
|
| 932 |
-
desired, call with ``(multiplicities, k - 1, k)``. This
|
| 933 |
-
method generalizes enum_all, enum_small, and enum_large.
|
| 934 |
-
|
| 935 |
-
Examples
|
| 936 |
-
========
|
| 937 |
-
|
| 938 |
-
>>> from sympy.utilities.enumerative import list_visitor
|
| 939 |
-
>>> from sympy.utilities.enumerative import MultisetPartitionTraverser
|
| 940 |
-
>>> m = MultisetPartitionTraverser()
|
| 941 |
-
>>> states = m.enum_range([2,2], 1, 2)
|
| 942 |
-
>>> list(list_visitor(state, 'ab') for state in states)
|
| 943 |
-
[[['a', 'a', 'b'], ['b']],
|
| 944 |
-
[['a', 'a'], ['b', 'b']],
|
| 945 |
-
[['a', 'b', 'b'], ['a']],
|
| 946 |
-
[['a', 'b'], ['a', 'b']]]
|
| 947 |
-
|
| 948 |
-
"""
|
| 949 |
-
# combine the constraints of the _large and _small
|
| 950 |
-
# enumerations.
|
| 951 |
-
self.discarded = 0
|
| 952 |
-
if ub <= 0 or lb >= sum(multiplicities):
|
| 953 |
-
return
|
| 954 |
-
self._initialize_enumeration(multiplicities)
|
| 955 |
-
self.decrement_part_large(self.top_part(), 0, lb)
|
| 956 |
-
while True:
|
| 957 |
-
good_partition = True
|
| 958 |
-
while self.spread_part_multiplicity():
|
| 959 |
-
self.db_trace("spread 1")
|
| 960 |
-
if not self.decrement_part_large(self.top_part(), 0, lb):
|
| 961 |
-
# Failure here - possible in range case?
|
| 962 |
-
self.db_trace(" Discarding (large cons)")
|
| 963 |
-
self.discarded += 1
|
| 964 |
-
good_partition = False
|
| 965 |
-
break
|
| 966 |
-
elif self.lpart >= ub:
|
| 967 |
-
self.discarded += 1
|
| 968 |
-
good_partition = False
|
| 969 |
-
self.db_trace(" Discarding small cons")
|
| 970 |
-
self.lpart = ub - 2
|
| 971 |
-
break
|
| 972 |
-
|
| 973 |
-
# M4 Visit a partition
|
| 974 |
-
if good_partition:
|
| 975 |
-
state = [self.f, self.lpart, self.pstack]
|
| 976 |
-
yield state
|
| 977 |
-
|
| 978 |
-
# M5 (Decrease v)
|
| 979 |
-
while not self.decrement_part_range(self.top_part(), lb, ub):
|
| 980 |
-
self.db_trace("Failed decrement, going to backtrack")
|
| 981 |
-
# M6 (Backtrack)
|
| 982 |
-
if self.lpart == 0:
|
| 983 |
-
return
|
| 984 |
-
self.lpart -= 1
|
| 985 |
-
self.db_trace("Backtracked to")
|
| 986 |
-
self.db_trace("decrement ok, about to expand")
|
| 987 |
-
|
| 988 |
-
def count_partitions_slow(self, multiplicities):
|
| 989 |
-
"""Returns the number of partitions of a multiset whose elements
|
| 990 |
-
have the multiplicities given in ``multiplicities``.
|
| 991 |
-
|
| 992 |
-
Primarily for comparison purposes. It follows the same path as
|
| 993 |
-
enumerate, and counts, rather than generates, the partitions.
|
| 994 |
-
|
| 995 |
-
See Also
|
| 996 |
-
========
|
| 997 |
-
|
| 998 |
-
count_partitions
|
| 999 |
-
Has the same calling interface, but is much faster.
|
| 1000 |
-
|
| 1001 |
-
"""
|
| 1002 |
-
# number of partitions so far in the enumeration
|
| 1003 |
-
self.pcount = 0
|
| 1004 |
-
self._initialize_enumeration(multiplicities)
|
| 1005 |
-
while True:
|
| 1006 |
-
while self.spread_part_multiplicity():
|
| 1007 |
-
pass
|
| 1008 |
-
|
| 1009 |
-
# M4 Visit (count) a partition
|
| 1010 |
-
self.pcount += 1
|
| 1011 |
-
|
| 1012 |
-
# M5 (Decrease v)
|
| 1013 |
-
while not self.decrement_part(self.top_part()):
|
| 1014 |
-
# M6 (Backtrack)
|
| 1015 |
-
if self.lpart == 0:
|
| 1016 |
-
return self.pcount
|
| 1017 |
-
self.lpart -= 1
|
| 1018 |
-
|
| 1019 |
-
def count_partitions(self, multiplicities):
|
| 1020 |
-
"""Returns the number of partitions of a multiset whose components
|
| 1021 |
-
have the multiplicities given in ``multiplicities``.
|
| 1022 |
-
|
| 1023 |
-
For larger counts, this method is much faster than calling one
|
| 1024 |
-
of the enumerators and counting the result. Uses dynamic
|
| 1025 |
-
programming to cut down on the number of nodes actually
|
| 1026 |
-
explored. The dictionary used in order to accelerate the
|
| 1027 |
-
counting process is stored in the ``MultisetPartitionTraverser``
|
| 1028 |
-
object and persists across calls. If the user does not
|
| 1029 |
-
expect to call ``count_partitions`` for any additional
|
| 1030 |
-
multisets, the object should be cleared to save memory. On
|
| 1031 |
-
the other hand, the cache built up from one count run can
|
| 1032 |
-
significantly speed up subsequent calls to ``count_partitions``,
|
| 1033 |
-
so it may be advantageous not to clear the object.
|
| 1034 |
-
|
| 1035 |
-
Examples
|
| 1036 |
-
========
|
| 1037 |
-
|
| 1038 |
-
>>> from sympy.utilities.enumerative import MultisetPartitionTraverser
|
| 1039 |
-
>>> m = MultisetPartitionTraverser()
|
| 1040 |
-
>>> m.count_partitions([9,8,2])
|
| 1041 |
-
288716
|
| 1042 |
-
>>> m.count_partitions([2,2])
|
| 1043 |
-
9
|
| 1044 |
-
>>> del m
|
| 1045 |
-
|
| 1046 |
-
Notes
|
| 1047 |
-
=====
|
| 1048 |
-
|
| 1049 |
-
If one looks at the workings of Knuth's algorithm M [AOCP]_, it
|
| 1050 |
-
can be viewed as a traversal of a binary tree of parts. A
|
| 1051 |
-
part has (up to) two children, the left child resulting from
|
| 1052 |
-
the spread operation, and the right child from the decrement
|
| 1053 |
-
operation. The ordinary enumeration of multiset partitions is
|
| 1054 |
-
an in-order traversal of this tree, and with the partitions
|
| 1055 |
-
corresponding to paths from the root to the leaves. The
|
| 1056 |
-
mapping from paths to partitions is a little complicated,
|
| 1057 |
-
since the partition would contain only those parts which are
|
| 1058 |
-
leaves or the parents of a spread link, not those which are
|
| 1059 |
-
parents of a decrement link.
|
| 1060 |
-
|
| 1061 |
-
For counting purposes, it is sufficient to count leaves, and
|
| 1062 |
-
this can be done with a recursive in-order traversal. The
|
| 1063 |
-
number of leaves of a subtree rooted at a particular part is a
|
| 1064 |
-
function only of that part itself, so memoizing has the
|
| 1065 |
-
potential to speed up the counting dramatically.
|
| 1066 |
-
|
| 1067 |
-
This method follows a computational approach which is similar
|
| 1068 |
-
to the hypothetical memoized recursive function, but with two
|
| 1069 |
-
differences:
|
| 1070 |
-
|
| 1071 |
-
1) This method is iterative, borrowing its structure from the
|
| 1072 |
-
other enumerations and maintaining an explicit stack of
|
| 1073 |
-
parts which are in the process of being counted. (There
|
| 1074 |
-
may be multisets which can be counted reasonably quickly by
|
| 1075 |
-
this implementation, but which would overflow the default
|
| 1076 |
-
Python recursion limit with a recursive implementation.)
|
| 1077 |
-
|
| 1078 |
-
2) Instead of using the part data structure directly, a more
|
| 1079 |
-
compact key is constructed. This saves space, but more
|
| 1080 |
-
importantly coalesces some parts which would remain
|
| 1081 |
-
separate with physical keys.
|
| 1082 |
-
|
| 1083 |
-
Unlike the enumeration functions, there is currently no _range
|
| 1084 |
-
version of count_partitions. If someone wants to stretch
|
| 1085 |
-
their brain, it should be possible to construct one by
|
| 1086 |
-
memoizing with a histogram of counts rather than a single
|
| 1087 |
-
count, and combining the histograms.
|
| 1088 |
-
"""
|
| 1089 |
-
# number of partitions so far in the enumeration
|
| 1090 |
-
self.pcount = 0
|
| 1091 |
-
|
| 1092 |
-
# dp_stack is list of lists of (part_key, start_count) pairs
|
| 1093 |
-
self.dp_stack = []
|
| 1094 |
-
|
| 1095 |
-
self._initialize_enumeration(multiplicities)
|
| 1096 |
-
pkey = part_key(self.top_part())
|
| 1097 |
-
self.dp_stack.append([(pkey, 0), ])
|
| 1098 |
-
while True:
|
| 1099 |
-
while self.spread_part_multiplicity():
|
| 1100 |
-
pkey = part_key(self.top_part())
|
| 1101 |
-
if pkey in self.dp_map:
|
| 1102 |
-
# Already have a cached value for the count of the
|
| 1103 |
-
# subtree rooted at this part. Add it to the
|
| 1104 |
-
# running counter, and break out of the spread
|
| 1105 |
-
# loop. The -1 below is to compensate for the
|
| 1106 |
-
# leaf that this code path would otherwise find,
|
| 1107 |
-
# and which gets incremented for below.
|
| 1108 |
-
|
| 1109 |
-
self.pcount += (self.dp_map[pkey] - 1)
|
| 1110 |
-
self.lpart -= 1
|
| 1111 |
-
break
|
| 1112 |
-
else:
|
| 1113 |
-
self.dp_stack.append([(pkey, self.pcount), ])
|
| 1114 |
-
|
| 1115 |
-
# M4 count a leaf partition
|
| 1116 |
-
self.pcount += 1
|
| 1117 |
-
|
| 1118 |
-
# M5 (Decrease v)
|
| 1119 |
-
while not self.decrement_part(self.top_part()):
|
| 1120 |
-
# M6 (Backtrack)
|
| 1121 |
-
for key, oldcount in self.dp_stack.pop():
|
| 1122 |
-
self.dp_map[key] = self.pcount - oldcount
|
| 1123 |
-
if self.lpart == 0:
|
| 1124 |
-
return self.pcount
|
| 1125 |
-
self.lpart -= 1
|
| 1126 |
-
|
| 1127 |
-
# At this point have successfully decremented the part on
|
| 1128 |
-
# the stack and it does not appear in the cache. It needs
|
| 1129 |
-
# to be added to the list at the top of dp_stack
|
| 1130 |
-
pkey = part_key(self.top_part())
|
| 1131 |
-
self.dp_stack[-1].append((pkey, self.pcount),)
|
| 1132 |
-
|
| 1133 |
-
|
| 1134 |
-
def part_key(part):
|
| 1135 |
-
"""Helper for MultisetPartitionTraverser.count_partitions that
|
| 1136 |
-
creates a key for ``part``, that only includes information which can
|
| 1137 |
-
affect the count for that part. (Any irrelevant information just
|
| 1138 |
-
reduces the effectiveness of dynamic programming.)
|
| 1139 |
-
|
| 1140 |
-
Notes
|
| 1141 |
-
=====
|
| 1142 |
-
|
| 1143 |
-
This member function is a candidate for future exploration. There
|
| 1144 |
-
are likely symmetries that can be exploited to coalesce some
|
| 1145 |
-
``part_key`` values, and thereby save space and improve
|
| 1146 |
-
performance.
|
| 1147 |
-
|
| 1148 |
-
"""
|
| 1149 |
-
# The component number is irrelevant for counting partitions, so
|
| 1150 |
-
# leave it out of the memo key.
|
| 1151 |
-
rval = []
|
| 1152 |
-
for ps in part:
|
| 1153 |
-
rval.append(ps.u)
|
| 1154 |
-
rval.append(ps.v)
|
| 1155 |
-
return tuple(rval)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/exceptions.py
DELETED
|
@@ -1,271 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
General SymPy exceptions and warnings.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import warnings
|
| 6 |
-
import contextlib
|
| 7 |
-
|
| 8 |
-
from textwrap import dedent
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class SymPyDeprecationWarning(DeprecationWarning):
|
| 12 |
-
r"""
|
| 13 |
-
A warning for deprecated features of SymPy.
|
| 14 |
-
|
| 15 |
-
See the :ref:`deprecation-policy` document for details on when and how
|
| 16 |
-
things should be deprecated in SymPy.
|
| 17 |
-
|
| 18 |
-
Note that simply constructing this class will not cause a warning to be
|
| 19 |
-
issued. To do that, you must call the :func`sympy_deprecation_warning`
|
| 20 |
-
function. For this reason, it is not recommended to ever construct this
|
| 21 |
-
class directly.
|
| 22 |
-
|
| 23 |
-
Explanation
|
| 24 |
-
===========
|
| 25 |
-
|
| 26 |
-
The ``SymPyDeprecationWarning`` class is a subclass of
|
| 27 |
-
``DeprecationWarning`` that is used for all deprecations in SymPy. A
|
| 28 |
-
special subclass is used so that we can automatically augment the warning
|
| 29 |
-
message with additional metadata about the version the deprecation was
|
| 30 |
-
introduced in and a link to the documentation. This also allows users to
|
| 31 |
-
explicitly filter deprecation warnings from SymPy using ``warnings``
|
| 32 |
-
filters (see :ref:`silencing-sympy-deprecation-warnings`).
|
| 33 |
-
|
| 34 |
-
Additionally, ``SymPyDeprecationWarning`` is enabled to be shown by
|
| 35 |
-
default, unlike normal ``DeprecationWarning``\s, which are only shown by
|
| 36 |
-
default in interactive sessions. This ensures that deprecation warnings in
|
| 37 |
-
SymPy will actually be seen by users.
|
| 38 |
-
|
| 39 |
-
See the documentation of :func:`sympy_deprecation_warning` for a
|
| 40 |
-
description of the parameters to this function.
|
| 41 |
-
|
| 42 |
-
To mark a function as deprecated, you can use the :func:`@deprecated
|
| 43 |
-
<sympy.utilities.decorator.deprecated>` decorator.
|
| 44 |
-
|
| 45 |
-
See Also
|
| 46 |
-
========
|
| 47 |
-
sympy.utilities.exceptions.sympy_deprecation_warning
|
| 48 |
-
sympy.utilities.exceptions.ignore_warnings
|
| 49 |
-
sympy.utilities.decorator.deprecated
|
| 50 |
-
sympy.testing.pytest.warns_deprecated_sympy
|
| 51 |
-
|
| 52 |
-
"""
|
| 53 |
-
def __init__(self, message, *, deprecated_since_version, active_deprecations_target):
|
| 54 |
-
|
| 55 |
-
super().__init__(message, deprecated_since_version,
|
| 56 |
-
active_deprecations_target)
|
| 57 |
-
self.message = message
|
| 58 |
-
if not isinstance(deprecated_since_version, str):
|
| 59 |
-
raise TypeError(f"'deprecated_since_version' should be a string, got {deprecated_since_version!r}")
|
| 60 |
-
self.deprecated_since_version = deprecated_since_version
|
| 61 |
-
self.active_deprecations_target = active_deprecations_target
|
| 62 |
-
if any(i in active_deprecations_target for i in '()='):
|
| 63 |
-
raise ValueError("active_deprecations_target be the part inside of the '(...)='")
|
| 64 |
-
|
| 65 |
-
self.full_message = f"""
|
| 66 |
-
|
| 67 |
-
{dedent(message).strip()}
|
| 68 |
-
|
| 69 |
-
See https://docs.sympy.org/latest/explanation/active-deprecations.html#{active_deprecations_target}
|
| 70 |
-
for details.
|
| 71 |
-
|
| 72 |
-
This has been deprecated since SymPy version {deprecated_since_version}. It
|
| 73 |
-
will be removed in a future version of SymPy.
|
| 74 |
-
"""
|
| 75 |
-
|
| 76 |
-
def __str__(self):
|
| 77 |
-
return self.full_message
|
| 78 |
-
|
| 79 |
-
def __repr__(self):
|
| 80 |
-
return f"{self.__class__.__name__}({self.message!r}, deprecated_since_version={self.deprecated_since_version!r}, active_deprecations_target={self.active_deprecations_target!r})"
|
| 81 |
-
|
| 82 |
-
def __eq__(self, other):
|
| 83 |
-
return isinstance(other, SymPyDeprecationWarning) and self.args == other.args
|
| 84 |
-
|
| 85 |
-
# Make pickling work. The by default, it tries to recreate the expression
|
| 86 |
-
# from its args, but this doesn't work because of our keyword-only
|
| 87 |
-
# arguments.
|
| 88 |
-
@classmethod
|
| 89 |
-
def _new(cls, message, deprecated_since_version,
|
| 90 |
-
active_deprecations_target):
|
| 91 |
-
return cls(message, deprecated_since_version=deprecated_since_version, active_deprecations_target=active_deprecations_target)
|
| 92 |
-
|
| 93 |
-
def __reduce__(self):
|
| 94 |
-
return (self._new, (self.message, self.deprecated_since_version, self.active_deprecations_target))
|
| 95 |
-
|
| 96 |
-
# Python by default hides DeprecationWarnings, which we do not want.
|
| 97 |
-
warnings.simplefilter("once", SymPyDeprecationWarning)
|
| 98 |
-
|
| 99 |
-
def sympy_deprecation_warning(message, *, deprecated_since_version,
|
| 100 |
-
active_deprecations_target, stacklevel=3):
|
| 101 |
-
r'''
|
| 102 |
-
Warn that a feature is deprecated in SymPy.
|
| 103 |
-
|
| 104 |
-
See the :ref:`deprecation-policy` document for details on when and how
|
| 105 |
-
things should be deprecated in SymPy.
|
| 106 |
-
|
| 107 |
-
To mark an entire function or class as deprecated, you can use the
|
| 108 |
-
:func:`@deprecated <sympy.utilities.decorator.deprecated>` decorator.
|
| 109 |
-
|
| 110 |
-
Parameters
|
| 111 |
-
==========
|
| 112 |
-
|
| 113 |
-
message : str
|
| 114 |
-
The deprecation message. This may span multiple lines and contain
|
| 115 |
-
code examples. Messages should be wrapped to 80 characters. The
|
| 116 |
-
message is automatically dedented and leading and trailing whitespace
|
| 117 |
-
stripped. Messages may include dynamic content based on the user
|
| 118 |
-
input, but avoid using ``str(expression)`` if an expression can be
|
| 119 |
-
arbitrary, as it might be huge and make the warning message
|
| 120 |
-
unreadable.
|
| 121 |
-
|
| 122 |
-
deprecated_since_version : str
|
| 123 |
-
The version of SymPy the feature has been deprecated since. For new
|
| 124 |
-
deprecations, this should be the version in `sympy/release.py
|
| 125 |
-
<https://github.com/sympy/sympy/blob/master/sympy/release.py>`_
|
| 126 |
-
without the ``.dev``. If the next SymPy version ends up being
|
| 127 |
-
different from this, the release manager will need to update any
|
| 128 |
-
``SymPyDeprecationWarning``\s using the incorrect version. This
|
| 129 |
-
argument is required and must be passed as a keyword argument.
|
| 130 |
-
(example: ``deprecated_since_version="1.10"``).
|
| 131 |
-
|
| 132 |
-
active_deprecations_target : str
|
| 133 |
-
The Sphinx target corresponding to the section for the deprecation in
|
| 134 |
-
the :ref:`active-deprecations` document (see
|
| 135 |
-
``doc/src/explanation/active-deprecations.md``). This is used to
|
| 136 |
-
automatically generate a URL to the page in the warning message. This
|
| 137 |
-
argument is required and must be passed as a keyword argument.
|
| 138 |
-
(example: ``active_deprecations_target="deprecated-feature-abc"``)
|
| 139 |
-
|
| 140 |
-
stacklevel : int, default: 3
|
| 141 |
-
The ``stacklevel`` parameter that is passed to ``warnings.warn``. If
|
| 142 |
-
you create a wrapper that calls this function, this should be
|
| 143 |
-
increased so that the warning message shows the user line of code that
|
| 144 |
-
produced the warning. Note that in some cases there will be multiple
|
| 145 |
-
possible different user code paths that could result in the warning.
|
| 146 |
-
In that case, just choose the smallest common stacklevel.
|
| 147 |
-
|
| 148 |
-
Examples
|
| 149 |
-
========
|
| 150 |
-
|
| 151 |
-
>>> from sympy.utilities.exceptions import sympy_deprecation_warning
|
| 152 |
-
>>> def is_this_zero(x, y=0):
|
| 153 |
-
... """
|
| 154 |
-
... Determine if x = 0.
|
| 155 |
-
...
|
| 156 |
-
... Parameters
|
| 157 |
-
... ==========
|
| 158 |
-
...
|
| 159 |
-
... x : Expr
|
| 160 |
-
... The expression to check.
|
| 161 |
-
...
|
| 162 |
-
... y : Expr, optional
|
| 163 |
-
... If provided, check if x = y.
|
| 164 |
-
...
|
| 165 |
-
... .. deprecated:: 1.1
|
| 166 |
-
...
|
| 167 |
-
... The ``y`` argument to ``is_this_zero`` is deprecated. Use
|
| 168 |
-
... ``is_this_zero(x - y)`` instead.
|
| 169 |
-
...
|
| 170 |
-
... """
|
| 171 |
-
... from sympy import simplify
|
| 172 |
-
...
|
| 173 |
-
... if y != 0:
|
| 174 |
-
... sympy_deprecation_warning("""
|
| 175 |
-
... The y argument to is_zero() is deprecated. Use is_zero(x - y) instead.""",
|
| 176 |
-
... deprecated_since_version="1.1",
|
| 177 |
-
... active_deprecations_target='is-this-zero-y-deprecation')
|
| 178 |
-
... return simplify(x - y) == 0
|
| 179 |
-
>>> is_this_zero(0)
|
| 180 |
-
True
|
| 181 |
-
>>> is_this_zero(1, 1) # doctest: +SKIP
|
| 182 |
-
<stdin>:1: SymPyDeprecationWarning:
|
| 183 |
-
<BLANKLINE>
|
| 184 |
-
The y argument to is_zero() is deprecated. Use is_zero(x - y) instead.
|
| 185 |
-
<BLANKLINE>
|
| 186 |
-
See https://docs.sympy.org/latest/explanation/active-deprecations.html#is-this-zero-y-deprecation
|
| 187 |
-
for details.
|
| 188 |
-
<BLANKLINE>
|
| 189 |
-
This has been deprecated since SymPy version 1.1. It
|
| 190 |
-
will be removed in a future version of SymPy.
|
| 191 |
-
<BLANKLINE>
|
| 192 |
-
is_this_zero(1, 1)
|
| 193 |
-
True
|
| 194 |
-
|
| 195 |
-
See Also
|
| 196 |
-
========
|
| 197 |
-
|
| 198 |
-
sympy.utilities.exceptions.SymPyDeprecationWarning
|
| 199 |
-
sympy.utilities.exceptions.ignore_warnings
|
| 200 |
-
sympy.utilities.decorator.deprecated
|
| 201 |
-
sympy.testing.pytest.warns_deprecated_sympy
|
| 202 |
-
|
| 203 |
-
'''
|
| 204 |
-
w = SymPyDeprecationWarning(message,
|
| 205 |
-
deprecated_since_version=deprecated_since_version,
|
| 206 |
-
active_deprecations_target=active_deprecations_target)
|
| 207 |
-
warnings.warn(w, stacklevel=stacklevel)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
@contextlib.contextmanager
|
| 211 |
-
def ignore_warnings(warningcls):
|
| 212 |
-
'''
|
| 213 |
-
Context manager to suppress warnings during tests.
|
| 214 |
-
|
| 215 |
-
.. note::
|
| 216 |
-
|
| 217 |
-
Do not use this with SymPyDeprecationWarning in the tests.
|
| 218 |
-
warns_deprecated_sympy() should be used instead.
|
| 219 |
-
|
| 220 |
-
This function is useful for suppressing warnings during tests. The warns
|
| 221 |
-
function should be used to assert that a warning is raised. The
|
| 222 |
-
ignore_warnings function is useful in situation when the warning is not
|
| 223 |
-
guaranteed to be raised (e.g. on importing a module) or if the warning
|
| 224 |
-
comes from third-party code.
|
| 225 |
-
|
| 226 |
-
This function is also useful to prevent the same or similar warnings from
|
| 227 |
-
being issue twice due to recursive calls.
|
| 228 |
-
|
| 229 |
-
When the warning is coming (reliably) from SymPy the warns function should
|
| 230 |
-
be preferred to ignore_warnings.
|
| 231 |
-
|
| 232 |
-
>>> from sympy.utilities.exceptions import ignore_warnings
|
| 233 |
-
>>> import warnings
|
| 234 |
-
|
| 235 |
-
Here's a warning:
|
| 236 |
-
|
| 237 |
-
>>> with warnings.catch_warnings(): # reset warnings in doctest
|
| 238 |
-
... warnings.simplefilter('error')
|
| 239 |
-
... warnings.warn('deprecated', UserWarning)
|
| 240 |
-
Traceback (most recent call last):
|
| 241 |
-
...
|
| 242 |
-
UserWarning: deprecated
|
| 243 |
-
|
| 244 |
-
Let's suppress it with ignore_warnings:
|
| 245 |
-
|
| 246 |
-
>>> with warnings.catch_warnings(): # reset warnings in doctest
|
| 247 |
-
... warnings.simplefilter('error')
|
| 248 |
-
... with ignore_warnings(UserWarning):
|
| 249 |
-
... warnings.warn('deprecated', UserWarning)
|
| 250 |
-
|
| 251 |
-
(No warning emitted)
|
| 252 |
-
|
| 253 |
-
See Also
|
| 254 |
-
========
|
| 255 |
-
sympy.utilities.exceptions.SymPyDeprecationWarning
|
| 256 |
-
sympy.utilities.exceptions.sympy_deprecation_warning
|
| 257 |
-
sympy.utilities.decorator.deprecated
|
| 258 |
-
sympy.testing.pytest.warns_deprecated_sympy
|
| 259 |
-
|
| 260 |
-
'''
|
| 261 |
-
# Absorbs all warnings in warnrec
|
| 262 |
-
with warnings.catch_warnings(record=True) as warnrec:
|
| 263 |
-
# Make sure our warning doesn't get filtered
|
| 264 |
-
warnings.simplefilter("always", warningcls)
|
| 265 |
-
# Now run the test
|
| 266 |
-
yield
|
| 267 |
-
|
| 268 |
-
# Reissue any warnings that we aren't testing for
|
| 269 |
-
for w in warnrec:
|
| 270 |
-
if not issubclass(w.category, warningcls):
|
| 271 |
-
warnings.warn_explicit(w.message, w.category, w.filename, w.lineno)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/iterables.py
DELETED
|
@@ -1,3179 +0,0 @@
|
|
| 1 |
-
from collections import Counter, defaultdict, OrderedDict
|
| 2 |
-
from itertools import (
|
| 3 |
-
chain, combinations, combinations_with_replacement, cycle, islice,
|
| 4 |
-
permutations, product, groupby
|
| 5 |
-
)
|
| 6 |
-
# For backwards compatibility
|
| 7 |
-
from itertools import product as cartes # noqa: F401
|
| 8 |
-
from operator import gt
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
# this is the logical location of these functions
|
| 13 |
-
from sympy.utilities.enumerative import (
|
| 14 |
-
multiset_partitions_taocp, list_visitor, MultisetPartitionTraverser)
|
| 15 |
-
|
| 16 |
-
from sympy.utilities.misc import as_int
|
| 17 |
-
from sympy.utilities.decorator import deprecated
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def is_palindromic(s, i=0, j=None):
|
| 21 |
-
"""
|
| 22 |
-
Return True if the sequence is the same from left to right as it
|
| 23 |
-
is from right to left in the whole sequence (default) or in the
|
| 24 |
-
Python slice ``s[i: j]``; else False.
|
| 25 |
-
|
| 26 |
-
Examples
|
| 27 |
-
========
|
| 28 |
-
|
| 29 |
-
>>> from sympy.utilities.iterables import is_palindromic
|
| 30 |
-
>>> is_palindromic([1, 0, 1])
|
| 31 |
-
True
|
| 32 |
-
>>> is_palindromic('abcbb')
|
| 33 |
-
False
|
| 34 |
-
>>> is_palindromic('abcbb', 1)
|
| 35 |
-
False
|
| 36 |
-
|
| 37 |
-
Normal Python slicing is performed in place so there is no need to
|
| 38 |
-
create a slice of the sequence for testing:
|
| 39 |
-
|
| 40 |
-
>>> is_palindromic('abcbb', 1, -1)
|
| 41 |
-
True
|
| 42 |
-
>>> is_palindromic('abcbb', -4, -1)
|
| 43 |
-
True
|
| 44 |
-
|
| 45 |
-
See Also
|
| 46 |
-
========
|
| 47 |
-
|
| 48 |
-
sympy.ntheory.digits.is_palindromic: tests integers
|
| 49 |
-
|
| 50 |
-
"""
|
| 51 |
-
i, j, _ = slice(i, j).indices(len(s))
|
| 52 |
-
m = (j - i)//2
|
| 53 |
-
# if length is odd, middle element will be ignored
|
| 54 |
-
return all(s[i + k] == s[j - 1 - k] for k in range(m))
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def flatten(iterable, levels=None, cls=None): # noqa: F811
|
| 58 |
-
"""
|
| 59 |
-
Recursively denest iterable containers.
|
| 60 |
-
|
| 61 |
-
>>> from sympy import flatten
|
| 62 |
-
|
| 63 |
-
>>> flatten([1, 2, 3])
|
| 64 |
-
[1, 2, 3]
|
| 65 |
-
>>> flatten([1, 2, [3]])
|
| 66 |
-
[1, 2, 3]
|
| 67 |
-
>>> flatten([1, [2, 3], [4, 5]])
|
| 68 |
-
[1, 2, 3, 4, 5]
|
| 69 |
-
>>> flatten([1.0, 2, (1, None)])
|
| 70 |
-
[1.0, 2, 1, None]
|
| 71 |
-
|
| 72 |
-
If you want to denest only a specified number of levels of
|
| 73 |
-
nested containers, then set ``levels`` flag to the desired
|
| 74 |
-
number of levels::
|
| 75 |
-
|
| 76 |
-
>>> ls = [[(-2, -1), (1, 2)], [(0, 0)]]
|
| 77 |
-
|
| 78 |
-
>>> flatten(ls, levels=1)
|
| 79 |
-
[(-2, -1), (1, 2), (0, 0)]
|
| 80 |
-
|
| 81 |
-
If cls argument is specified, it will only flatten instances of that
|
| 82 |
-
class, for example:
|
| 83 |
-
|
| 84 |
-
>>> from sympy import Basic, S
|
| 85 |
-
>>> class MyOp(Basic):
|
| 86 |
-
... pass
|
| 87 |
-
...
|
| 88 |
-
>>> flatten([MyOp(S(1), MyOp(S(2), S(3)))], cls=MyOp)
|
| 89 |
-
[1, 2, 3]
|
| 90 |
-
|
| 91 |
-
adapted from https://kogs-www.informatik.uni-hamburg.de/~meine/python_tricks
|
| 92 |
-
"""
|
| 93 |
-
from sympy.tensor.array import NDimArray
|
| 94 |
-
if levels is not None:
|
| 95 |
-
if not levels:
|
| 96 |
-
return iterable
|
| 97 |
-
elif levels > 0:
|
| 98 |
-
levels -= 1
|
| 99 |
-
else:
|
| 100 |
-
raise ValueError(
|
| 101 |
-
"expected non-negative number of levels, got %s" % levels)
|
| 102 |
-
|
| 103 |
-
if cls is None:
|
| 104 |
-
def reducible(x):
|
| 105 |
-
return is_sequence(x, set)
|
| 106 |
-
else:
|
| 107 |
-
def reducible(x):
|
| 108 |
-
return isinstance(x, cls)
|
| 109 |
-
|
| 110 |
-
result = []
|
| 111 |
-
|
| 112 |
-
for el in iterable:
|
| 113 |
-
if reducible(el):
|
| 114 |
-
if hasattr(el, 'args') and not isinstance(el, NDimArray):
|
| 115 |
-
el = el.args
|
| 116 |
-
result.extend(flatten(el, levels=levels, cls=cls))
|
| 117 |
-
else:
|
| 118 |
-
result.append(el)
|
| 119 |
-
|
| 120 |
-
return result
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def unflatten(iter, n=2):
|
| 124 |
-
"""Group ``iter`` into tuples of length ``n``. Raise an error if
|
| 125 |
-
the length of ``iter`` is not a multiple of ``n``.
|
| 126 |
-
"""
|
| 127 |
-
if n < 1 or len(iter) % n:
|
| 128 |
-
raise ValueError('iter length is not a multiple of %i' % n)
|
| 129 |
-
return list(zip(*(iter[i::n] for i in range(n))))
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def reshape(seq, how):
|
| 133 |
-
"""Reshape the sequence according to the template in ``how``.
|
| 134 |
-
|
| 135 |
-
Examples
|
| 136 |
-
========
|
| 137 |
-
|
| 138 |
-
>>> from sympy.utilities import reshape
|
| 139 |
-
>>> seq = list(range(1, 9))
|
| 140 |
-
|
| 141 |
-
>>> reshape(seq, [4]) # lists of 4
|
| 142 |
-
[[1, 2, 3, 4], [5, 6, 7, 8]]
|
| 143 |
-
|
| 144 |
-
>>> reshape(seq, (4,)) # tuples of 4
|
| 145 |
-
[(1, 2, 3, 4), (5, 6, 7, 8)]
|
| 146 |
-
|
| 147 |
-
>>> reshape(seq, (2, 2)) # tuples of 4
|
| 148 |
-
[(1, 2, 3, 4), (5, 6, 7, 8)]
|
| 149 |
-
|
| 150 |
-
>>> reshape(seq, (2, [2])) # (i, i, [i, i])
|
| 151 |
-
[(1, 2, [3, 4]), (5, 6, [7, 8])]
|
| 152 |
-
|
| 153 |
-
>>> reshape(seq, ((2,), [2])) # etc....
|
| 154 |
-
[((1, 2), [3, 4]), ((5, 6), [7, 8])]
|
| 155 |
-
|
| 156 |
-
>>> reshape(seq, (1, [2], 1))
|
| 157 |
-
[(1, [2, 3], 4), (5, [6, 7], 8)]
|
| 158 |
-
|
| 159 |
-
>>> reshape(tuple(seq), ([[1], 1, (2,)],))
|
| 160 |
-
(([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],))
|
| 161 |
-
|
| 162 |
-
>>> reshape(tuple(seq), ([1], 1, (2,)))
|
| 163 |
-
(([1], 2, (3, 4)), ([5], 6, (7, 8)))
|
| 164 |
-
|
| 165 |
-
>>> reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)])
|
| 166 |
-
[[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]]
|
| 167 |
-
|
| 168 |
-
"""
|
| 169 |
-
m = sum(flatten(how))
|
| 170 |
-
n, rem = divmod(len(seq), m)
|
| 171 |
-
if m < 0 or rem:
|
| 172 |
-
raise ValueError('template must sum to positive number '
|
| 173 |
-
'that divides the length of the sequence')
|
| 174 |
-
i = 0
|
| 175 |
-
container = type(how)
|
| 176 |
-
rv = [None]*n
|
| 177 |
-
for k in range(len(rv)):
|
| 178 |
-
_rv = []
|
| 179 |
-
for hi in how:
|
| 180 |
-
if isinstance(hi, int):
|
| 181 |
-
_rv.extend(seq[i: i + hi])
|
| 182 |
-
i += hi
|
| 183 |
-
else:
|
| 184 |
-
n = sum(flatten(hi))
|
| 185 |
-
hi_type = type(hi)
|
| 186 |
-
_rv.append(hi_type(reshape(seq[i: i + n], hi)[0]))
|
| 187 |
-
i += n
|
| 188 |
-
rv[k] = container(_rv)
|
| 189 |
-
return type(seq)(rv)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
def group(seq, multiple=True):
|
| 193 |
-
"""
|
| 194 |
-
Splits a sequence into a list of lists of equal, adjacent elements.
|
| 195 |
-
|
| 196 |
-
Examples
|
| 197 |
-
========
|
| 198 |
-
|
| 199 |
-
>>> from sympy import group
|
| 200 |
-
|
| 201 |
-
>>> group([1, 1, 1, 2, 2, 3])
|
| 202 |
-
[[1, 1, 1], [2, 2], [3]]
|
| 203 |
-
>>> group([1, 1, 1, 2, 2, 3], multiple=False)
|
| 204 |
-
[(1, 3), (2, 2), (3, 1)]
|
| 205 |
-
>>> group([1, 1, 3, 2, 2, 1], multiple=False)
|
| 206 |
-
[(1, 2), (3, 1), (2, 2), (1, 1)]
|
| 207 |
-
|
| 208 |
-
See Also
|
| 209 |
-
========
|
| 210 |
-
|
| 211 |
-
multiset
|
| 212 |
-
|
| 213 |
-
"""
|
| 214 |
-
if multiple:
|
| 215 |
-
return [(list(g)) for _, g in groupby(seq)]
|
| 216 |
-
return [(k, len(list(g))) for k, g in groupby(seq)]
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def _iproduct2(iterable1, iterable2):
|
| 220 |
-
'''Cartesian product of two possibly infinite iterables'''
|
| 221 |
-
|
| 222 |
-
it1 = iter(iterable1)
|
| 223 |
-
it2 = iter(iterable2)
|
| 224 |
-
|
| 225 |
-
elems1 = []
|
| 226 |
-
elems2 = []
|
| 227 |
-
|
| 228 |
-
sentinel = object()
|
| 229 |
-
def append(it, elems):
|
| 230 |
-
e = next(it, sentinel)
|
| 231 |
-
if e is not sentinel:
|
| 232 |
-
elems.append(e)
|
| 233 |
-
|
| 234 |
-
n = 0
|
| 235 |
-
append(it1, elems1)
|
| 236 |
-
append(it2, elems2)
|
| 237 |
-
|
| 238 |
-
while n <= len(elems1) + len(elems2):
|
| 239 |
-
for m in range(n-len(elems1)+1, len(elems2)):
|
| 240 |
-
yield (elems1[n-m], elems2[m])
|
| 241 |
-
n += 1
|
| 242 |
-
append(it1, elems1)
|
| 243 |
-
append(it2, elems2)
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def iproduct(*iterables):
|
| 247 |
-
'''
|
| 248 |
-
Cartesian product of iterables.
|
| 249 |
-
|
| 250 |
-
Generator of the Cartesian product of iterables. This is analogous to
|
| 251 |
-
itertools.product except that it works with infinite iterables and will
|
| 252 |
-
yield any item from the infinite product eventually.
|
| 253 |
-
|
| 254 |
-
Examples
|
| 255 |
-
========
|
| 256 |
-
|
| 257 |
-
>>> from sympy.utilities.iterables import iproduct
|
| 258 |
-
>>> sorted(iproduct([1,2], [3,4]))
|
| 259 |
-
[(1, 3), (1, 4), (2, 3), (2, 4)]
|
| 260 |
-
|
| 261 |
-
With an infinite iterator:
|
| 262 |
-
|
| 263 |
-
>>> from sympy import S
|
| 264 |
-
>>> (3,) in iproduct(S.Integers)
|
| 265 |
-
True
|
| 266 |
-
>>> (3, 4) in iproduct(S.Integers, S.Integers)
|
| 267 |
-
True
|
| 268 |
-
|
| 269 |
-
.. seealso::
|
| 270 |
-
|
| 271 |
-
`itertools.product
|
| 272 |
-
<https://docs.python.org/3/library/itertools.html#itertools.product>`_
|
| 273 |
-
'''
|
| 274 |
-
if len(iterables) == 0:
|
| 275 |
-
yield ()
|
| 276 |
-
return
|
| 277 |
-
elif len(iterables) == 1:
|
| 278 |
-
for e in iterables[0]:
|
| 279 |
-
yield (e,)
|
| 280 |
-
elif len(iterables) == 2:
|
| 281 |
-
yield from _iproduct2(*iterables)
|
| 282 |
-
else:
|
| 283 |
-
first, others = iterables[0], iterables[1:]
|
| 284 |
-
for ef, eo in _iproduct2(first, iproduct(*others)):
|
| 285 |
-
yield (ef,) + eo
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def multiset(seq):
|
| 289 |
-
"""Return the hashable sequence in multiset form with values being the
|
| 290 |
-
multiplicity of the item in the sequence.
|
| 291 |
-
|
| 292 |
-
Examples
|
| 293 |
-
========
|
| 294 |
-
|
| 295 |
-
>>> from sympy.utilities.iterables import multiset
|
| 296 |
-
>>> multiset('mississippi')
|
| 297 |
-
{'i': 4, 'm': 1, 'p': 2, 's': 4}
|
| 298 |
-
|
| 299 |
-
See Also
|
| 300 |
-
========
|
| 301 |
-
|
| 302 |
-
group
|
| 303 |
-
|
| 304 |
-
"""
|
| 305 |
-
return dict(Counter(seq).items())
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
def ibin(n, bits=None, str=False):
|
| 311 |
-
"""Return a list of length ``bits`` corresponding to the binary value
|
| 312 |
-
of ``n`` with small bits to the right (last). If bits is omitted, the
|
| 313 |
-
length will be the number required to represent ``n``. If the bits are
|
| 314 |
-
desired in reversed order, use the ``[::-1]`` slice of the returned list.
|
| 315 |
-
|
| 316 |
-
If a sequence of all bits-length lists starting from ``[0, 0,..., 0]``
|
| 317 |
-
through ``[1, 1, ..., 1]`` are desired, pass a non-integer for bits, e.g.
|
| 318 |
-
``'all'``.
|
| 319 |
-
|
| 320 |
-
If the bit *string* is desired pass ``str=True``.
|
| 321 |
-
|
| 322 |
-
Examples
|
| 323 |
-
========
|
| 324 |
-
|
| 325 |
-
>>> from sympy.utilities.iterables import ibin
|
| 326 |
-
>>> ibin(2)
|
| 327 |
-
[1, 0]
|
| 328 |
-
>>> ibin(2, 4)
|
| 329 |
-
[0, 0, 1, 0]
|
| 330 |
-
|
| 331 |
-
If all lists corresponding to 0 to 2**n - 1, pass a non-integer
|
| 332 |
-
for bits:
|
| 333 |
-
|
| 334 |
-
>>> bits = 2
|
| 335 |
-
>>> for i in ibin(2, 'all'):
|
| 336 |
-
... print(i)
|
| 337 |
-
(0, 0)
|
| 338 |
-
(0, 1)
|
| 339 |
-
(1, 0)
|
| 340 |
-
(1, 1)
|
| 341 |
-
|
| 342 |
-
If a bit string is desired of a given length, use str=True:
|
| 343 |
-
|
| 344 |
-
>>> n = 123
|
| 345 |
-
>>> bits = 10
|
| 346 |
-
>>> ibin(n, bits, str=True)
|
| 347 |
-
'0001111011'
|
| 348 |
-
>>> ibin(n, bits, str=True)[::-1] # small bits left
|
| 349 |
-
'1101111000'
|
| 350 |
-
>>> list(ibin(3, 'all', str=True))
|
| 351 |
-
['000', '001', '010', '011', '100', '101', '110', '111']
|
| 352 |
-
|
| 353 |
-
"""
|
| 354 |
-
if n < 0:
|
| 355 |
-
raise ValueError("negative numbers are not allowed")
|
| 356 |
-
n = as_int(n)
|
| 357 |
-
|
| 358 |
-
if bits is None:
|
| 359 |
-
bits = 0
|
| 360 |
-
else:
|
| 361 |
-
try:
|
| 362 |
-
bits = as_int(bits)
|
| 363 |
-
except ValueError:
|
| 364 |
-
bits = -1
|
| 365 |
-
else:
|
| 366 |
-
if n.bit_length() > bits:
|
| 367 |
-
raise ValueError(
|
| 368 |
-
"`bits` must be >= {}".format(n.bit_length()))
|
| 369 |
-
|
| 370 |
-
if not str:
|
| 371 |
-
if bits >= 0:
|
| 372 |
-
return [1 if i == "1" else 0 for i in bin(n)[2:].rjust(bits, "0")]
|
| 373 |
-
else:
|
| 374 |
-
return variations(range(2), n, repetition=True)
|
| 375 |
-
else:
|
| 376 |
-
if bits >= 0:
|
| 377 |
-
return bin(n)[2:].rjust(bits, "0")
|
| 378 |
-
else:
|
| 379 |
-
return (bin(i)[2:].rjust(n, "0") for i in range(2**n))
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
def variations(seq, n, repetition=False):
|
| 383 |
-
r"""Returns an iterator over the n-sized variations of ``seq`` (size N).
|
| 384 |
-
``repetition`` controls whether items in ``seq`` can appear more than once;
|
| 385 |
-
|
| 386 |
-
Examples
|
| 387 |
-
========
|
| 388 |
-
|
| 389 |
-
``variations(seq, n)`` will return `\frac{N!}{(N - n)!}` permutations without
|
| 390 |
-
repetition of ``seq``'s elements:
|
| 391 |
-
|
| 392 |
-
>>> from sympy import variations
|
| 393 |
-
>>> list(variations([1, 2], 2))
|
| 394 |
-
[(1, 2), (2, 1)]
|
| 395 |
-
|
| 396 |
-
``variations(seq, n, True)`` will return the `N^n` permutations obtained
|
| 397 |
-
by allowing repetition of elements:
|
| 398 |
-
|
| 399 |
-
>>> list(variations([1, 2], 2, repetition=True))
|
| 400 |
-
[(1, 1), (1, 2), (2, 1), (2, 2)]
|
| 401 |
-
|
| 402 |
-
If you ask for more items than are in the set you get the empty set unless
|
| 403 |
-
you allow repetitions:
|
| 404 |
-
|
| 405 |
-
>>> list(variations([0, 1], 3, repetition=False))
|
| 406 |
-
[]
|
| 407 |
-
>>> list(variations([0, 1], 3, repetition=True))[:4]
|
| 408 |
-
[(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1)]
|
| 409 |
-
|
| 410 |
-
.. seealso::
|
| 411 |
-
|
| 412 |
-
`itertools.permutations
|
| 413 |
-
<https://docs.python.org/3/library/itertools.html#itertools.permutations>`_,
|
| 414 |
-
`itertools.product
|
| 415 |
-
<https://docs.python.org/3/library/itertools.html#itertools.product>`_
|
| 416 |
-
"""
|
| 417 |
-
if not repetition:
|
| 418 |
-
seq = tuple(seq)
|
| 419 |
-
if len(seq) < n:
|
| 420 |
-
return iter(()) # 0 length iterator
|
| 421 |
-
return permutations(seq, n)
|
| 422 |
-
else:
|
| 423 |
-
if n == 0:
|
| 424 |
-
return iter(((),)) # yields 1 empty tuple
|
| 425 |
-
else:
|
| 426 |
-
return product(seq, repeat=n)
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
def subsets(seq, k=None, repetition=False):
|
| 430 |
-
r"""Generates all `k`-subsets (combinations) from an `n`-element set, ``seq``.
|
| 431 |
-
|
| 432 |
-
A `k`-subset of an `n`-element set is any subset of length exactly `k`. The
|
| 433 |
-
number of `k`-subsets of an `n`-element set is given by ``binomial(n, k)``,
|
| 434 |
-
whereas there are `2^n` subsets all together. If `k` is ``None`` then all
|
| 435 |
-
`2^n` subsets will be returned from shortest to longest.
|
| 436 |
-
|
| 437 |
-
Examples
|
| 438 |
-
========
|
| 439 |
-
|
| 440 |
-
>>> from sympy import subsets
|
| 441 |
-
|
| 442 |
-
``subsets(seq, k)`` will return the
|
| 443 |
-
`\frac{n!}{k!(n - k)!}` `k`-subsets (combinations)
|
| 444 |
-
without repetition, i.e. once an item has been removed, it can no
|
| 445 |
-
longer be "taken":
|
| 446 |
-
|
| 447 |
-
>>> list(subsets([1, 2], 2))
|
| 448 |
-
[(1, 2)]
|
| 449 |
-
>>> list(subsets([1, 2]))
|
| 450 |
-
[(), (1,), (2,), (1, 2)]
|
| 451 |
-
>>> list(subsets([1, 2, 3], 2))
|
| 452 |
-
[(1, 2), (1, 3), (2, 3)]
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
``subsets(seq, k, repetition=True)`` will return the
|
| 456 |
-
`\frac{(n - 1 + k)!}{k!(n - 1)!}`
|
| 457 |
-
combinations *with* repetition:
|
| 458 |
-
|
| 459 |
-
>>> list(subsets([1, 2], 2, repetition=True))
|
| 460 |
-
[(1, 1), (1, 2), (2, 2)]
|
| 461 |
-
|
| 462 |
-
If you ask for more items than are in the set you get the empty set unless
|
| 463 |
-
you allow repetitions:
|
| 464 |
-
|
| 465 |
-
>>> list(subsets([0, 1], 3, repetition=False))
|
| 466 |
-
[]
|
| 467 |
-
>>> list(subsets([0, 1], 3, repetition=True))
|
| 468 |
-
[(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)]
|
| 469 |
-
|
| 470 |
-
"""
|
| 471 |
-
if k is None:
|
| 472 |
-
if not repetition:
|
| 473 |
-
return chain.from_iterable((combinations(seq, k)
|
| 474 |
-
for k in range(len(seq) + 1)))
|
| 475 |
-
else:
|
| 476 |
-
return chain.from_iterable((combinations_with_replacement(seq, k)
|
| 477 |
-
for k in range(len(seq) + 1)))
|
| 478 |
-
else:
|
| 479 |
-
if not repetition:
|
| 480 |
-
return combinations(seq, k)
|
| 481 |
-
else:
|
| 482 |
-
return combinations_with_replacement(seq, k)
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
def filter_symbols(iterator, exclude):
|
| 486 |
-
"""
|
| 487 |
-
Only yield elements from `iterator` that do not occur in `exclude`.
|
| 488 |
-
|
| 489 |
-
Parameters
|
| 490 |
-
==========
|
| 491 |
-
|
| 492 |
-
iterator : iterable
|
| 493 |
-
iterator to take elements from
|
| 494 |
-
|
| 495 |
-
exclude : iterable
|
| 496 |
-
elements to exclude
|
| 497 |
-
|
| 498 |
-
Returns
|
| 499 |
-
=======
|
| 500 |
-
|
| 501 |
-
iterator : iterator
|
| 502 |
-
filtered iterator
|
| 503 |
-
"""
|
| 504 |
-
exclude = set(exclude)
|
| 505 |
-
for s in iterator:
|
| 506 |
-
if s not in exclude:
|
| 507 |
-
yield s
|
| 508 |
-
|
| 509 |
-
def numbered_symbols(prefix='x', cls=None, start=0, exclude=(), *args, **assumptions):
|
| 510 |
-
"""
|
| 511 |
-
Generate an infinite stream of Symbols consisting of a prefix and
|
| 512 |
-
increasing subscripts provided that they do not occur in ``exclude``.
|
| 513 |
-
|
| 514 |
-
Parameters
|
| 515 |
-
==========
|
| 516 |
-
|
| 517 |
-
prefix : str, optional
|
| 518 |
-
The prefix to use. By default, this function will generate symbols of
|
| 519 |
-
the form "x0", "x1", etc.
|
| 520 |
-
|
| 521 |
-
cls : class, optional
|
| 522 |
-
The class to use. By default, it uses ``Symbol``, but you can also use ``Wild``
|
| 523 |
-
or ``Dummy``.
|
| 524 |
-
|
| 525 |
-
start : int, optional
|
| 526 |
-
The start number. By default, it is 0.
|
| 527 |
-
|
| 528 |
-
exclude : list, tuple, set of cls, optional
|
| 529 |
-
Symbols to be excluded.
|
| 530 |
-
|
| 531 |
-
*args, **kwargs
|
| 532 |
-
Additional positional and keyword arguments are passed to the *cls* class.
|
| 533 |
-
|
| 534 |
-
Returns
|
| 535 |
-
=======
|
| 536 |
-
|
| 537 |
-
sym : Symbol
|
| 538 |
-
The subscripted symbols.
|
| 539 |
-
"""
|
| 540 |
-
exclude = set(exclude or [])
|
| 541 |
-
if cls is None:
|
| 542 |
-
# We can't just make the default cls=Symbol because it isn't
|
| 543 |
-
# imported yet.
|
| 544 |
-
from sympy.core import Symbol
|
| 545 |
-
cls = Symbol
|
| 546 |
-
|
| 547 |
-
while True:
|
| 548 |
-
name = '%s%s' % (prefix, start)
|
| 549 |
-
s = cls(name, *args, **assumptions)
|
| 550 |
-
if s not in exclude:
|
| 551 |
-
yield s
|
| 552 |
-
start += 1
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
def capture(func):
|
| 556 |
-
"""Return the printed output of func().
|
| 557 |
-
|
| 558 |
-
``func`` should be a function without arguments that produces output with
|
| 559 |
-
print statements.
|
| 560 |
-
|
| 561 |
-
>>> from sympy.utilities.iterables import capture
|
| 562 |
-
>>> from sympy import pprint
|
| 563 |
-
>>> from sympy.abc import x
|
| 564 |
-
>>> def foo():
|
| 565 |
-
... print('hello world!')
|
| 566 |
-
...
|
| 567 |
-
>>> 'hello' in capture(foo) # foo, not foo()
|
| 568 |
-
True
|
| 569 |
-
>>> capture(lambda: pprint(2/x))
|
| 570 |
-
'2\\n-\\nx\\n'
|
| 571 |
-
|
| 572 |
-
"""
|
| 573 |
-
from io import StringIO
|
| 574 |
-
import sys
|
| 575 |
-
|
| 576 |
-
stdout = sys.stdout
|
| 577 |
-
sys.stdout = file = StringIO()
|
| 578 |
-
try:
|
| 579 |
-
func()
|
| 580 |
-
finally:
|
| 581 |
-
sys.stdout = stdout
|
| 582 |
-
return file.getvalue()
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
def sift(seq, keyfunc, binary=False):
|
| 586 |
-
"""
|
| 587 |
-
Sift the sequence, ``seq`` according to ``keyfunc``.
|
| 588 |
-
|
| 589 |
-
Returns
|
| 590 |
-
=======
|
| 591 |
-
|
| 592 |
-
When ``binary`` is ``False`` (default), the output is a dictionary
|
| 593 |
-
where elements of ``seq`` are stored in a list keyed to the value
|
| 594 |
-
of keyfunc for that element. If ``binary`` is True then a tuple
|
| 595 |
-
with lists ``T`` and ``F`` are returned where ``T`` is a list
|
| 596 |
-
containing elements of seq for which ``keyfunc`` was ``True`` and
|
| 597 |
-
``F`` containing those elements for which ``keyfunc`` was ``False``;
|
| 598 |
-
a ValueError is raised if the ``keyfunc`` is not binary.
|
| 599 |
-
|
| 600 |
-
Examples
|
| 601 |
-
========
|
| 602 |
-
|
| 603 |
-
>>> from sympy.utilities import sift
|
| 604 |
-
>>> from sympy.abc import x, y
|
| 605 |
-
>>> from sympy import sqrt, exp, pi, Tuple
|
| 606 |
-
|
| 607 |
-
>>> sift(range(5), lambda x: x % 2)
|
| 608 |
-
{0: [0, 2, 4], 1: [1, 3]}
|
| 609 |
-
|
| 610 |
-
sift() returns a defaultdict() object, so any key that has no matches will
|
| 611 |
-
give [].
|
| 612 |
-
|
| 613 |
-
>>> sift([x], lambda x: x.is_commutative)
|
| 614 |
-
{True: [x]}
|
| 615 |
-
>>> _[False]
|
| 616 |
-
[]
|
| 617 |
-
|
| 618 |
-
Sometimes you will not know how many keys you will get:
|
| 619 |
-
|
| 620 |
-
>>> sift([sqrt(x), exp(x), (y**x)**2],
|
| 621 |
-
... lambda x: x.as_base_exp()[0])
|
| 622 |
-
{E: [exp(x)], x: [sqrt(x)], y: [y**(2*x)]}
|
| 623 |
-
|
| 624 |
-
Sometimes you expect the results to be binary; the
|
| 625 |
-
results can be unpacked by setting ``binary`` to True:
|
| 626 |
-
|
| 627 |
-
>>> sift(range(4), lambda x: x % 2, binary=True)
|
| 628 |
-
([1, 3], [0, 2])
|
| 629 |
-
>>> sift(Tuple(1, pi), lambda x: x.is_rational, binary=True)
|
| 630 |
-
([1], [pi])
|
| 631 |
-
|
| 632 |
-
A ValueError is raised if the predicate was not actually binary
|
| 633 |
-
(which is a good test for the logic where sifting is used and
|
| 634 |
-
binary results were expected):
|
| 635 |
-
|
| 636 |
-
>>> unknown = exp(1) - pi # the rationality of this is unknown
|
| 637 |
-
>>> args = Tuple(1, pi, unknown)
|
| 638 |
-
>>> sift(args, lambda x: x.is_rational, binary=True)
|
| 639 |
-
Traceback (most recent call last):
|
| 640 |
-
...
|
| 641 |
-
ValueError: keyfunc gave non-binary output
|
| 642 |
-
|
| 643 |
-
The non-binary sifting shows that there were 3 keys generated:
|
| 644 |
-
|
| 645 |
-
>>> set(sift(args, lambda x: x.is_rational).keys())
|
| 646 |
-
{None, False, True}
|
| 647 |
-
|
| 648 |
-
If you need to sort the sifted items it might be better to use
|
| 649 |
-
``ordered`` which can economically apply multiple sort keys
|
| 650 |
-
to a sequence while sorting.
|
| 651 |
-
|
| 652 |
-
See Also
|
| 653 |
-
========
|
| 654 |
-
|
| 655 |
-
ordered
|
| 656 |
-
|
| 657 |
-
"""
|
| 658 |
-
if not binary:
|
| 659 |
-
m = defaultdict(list)
|
| 660 |
-
for i in seq:
|
| 661 |
-
m[keyfunc(i)].append(i)
|
| 662 |
-
return m
|
| 663 |
-
sift = F, T = [], []
|
| 664 |
-
for i in seq:
|
| 665 |
-
try:
|
| 666 |
-
sift[keyfunc(i)].append(i)
|
| 667 |
-
except (IndexError, TypeError):
|
| 668 |
-
raise ValueError('keyfunc gave non-binary output')
|
| 669 |
-
return T, F
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
def take(iter, n):
|
| 673 |
-
"""Return ``n`` items from ``iter`` iterator. """
|
| 674 |
-
return [ value for _, value in zip(range(n), iter) ]
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
def dict_merge(*dicts):
|
| 678 |
-
"""Merge dictionaries into a single dictionary. """
|
| 679 |
-
merged = {}
|
| 680 |
-
|
| 681 |
-
for dict in dicts:
|
| 682 |
-
merged.update(dict)
|
| 683 |
-
|
| 684 |
-
return merged
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
def common_prefix(*seqs):
|
| 688 |
-
"""Return the subsequence that is a common start of sequences in ``seqs``.
|
| 689 |
-
|
| 690 |
-
>>> from sympy.utilities.iterables import common_prefix
|
| 691 |
-
>>> common_prefix(list(range(3)))
|
| 692 |
-
[0, 1, 2]
|
| 693 |
-
>>> common_prefix(list(range(3)), list(range(4)))
|
| 694 |
-
[0, 1, 2]
|
| 695 |
-
>>> common_prefix([1, 2, 3], [1, 2, 5])
|
| 696 |
-
[1, 2]
|
| 697 |
-
>>> common_prefix([1, 2, 3], [1, 3, 5])
|
| 698 |
-
[1]
|
| 699 |
-
"""
|
| 700 |
-
if not all(seqs):
|
| 701 |
-
return []
|
| 702 |
-
elif len(seqs) == 1:
|
| 703 |
-
return seqs[0]
|
| 704 |
-
i = 0
|
| 705 |
-
for i in range(min(len(s) for s in seqs)):
|
| 706 |
-
if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))):
|
| 707 |
-
break
|
| 708 |
-
else:
|
| 709 |
-
i += 1
|
| 710 |
-
return seqs[0][:i]
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
def common_suffix(*seqs):
|
| 714 |
-
"""Return the subsequence that is a common ending of sequences in ``seqs``.
|
| 715 |
-
|
| 716 |
-
>>> from sympy.utilities.iterables import common_suffix
|
| 717 |
-
>>> common_suffix(list(range(3)))
|
| 718 |
-
[0, 1, 2]
|
| 719 |
-
>>> common_suffix(list(range(3)), list(range(4)))
|
| 720 |
-
[]
|
| 721 |
-
>>> common_suffix([1, 2, 3], [9, 2, 3])
|
| 722 |
-
[2, 3]
|
| 723 |
-
>>> common_suffix([1, 2, 3], [9, 7, 3])
|
| 724 |
-
[3]
|
| 725 |
-
"""
|
| 726 |
-
|
| 727 |
-
if not all(seqs):
|
| 728 |
-
return []
|
| 729 |
-
elif len(seqs) == 1:
|
| 730 |
-
return seqs[0]
|
| 731 |
-
i = 0
|
| 732 |
-
for i in range(-1, -min(len(s) for s in seqs) - 1, -1):
|
| 733 |
-
if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))):
|
| 734 |
-
break
|
| 735 |
-
else:
|
| 736 |
-
i -= 1
|
| 737 |
-
if i == -1:
|
| 738 |
-
return []
|
| 739 |
-
else:
|
| 740 |
-
return seqs[0][i + 1:]
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
def prefixes(seq):
|
| 744 |
-
"""
|
| 745 |
-
Generate all prefixes of a sequence.
|
| 746 |
-
|
| 747 |
-
Examples
|
| 748 |
-
========
|
| 749 |
-
|
| 750 |
-
>>> from sympy.utilities.iterables import prefixes
|
| 751 |
-
|
| 752 |
-
>>> list(prefixes([1,2,3,4]))
|
| 753 |
-
[[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]]
|
| 754 |
-
|
| 755 |
-
"""
|
| 756 |
-
n = len(seq)
|
| 757 |
-
|
| 758 |
-
for i in range(n):
|
| 759 |
-
yield seq[:i + 1]
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
def postfixes(seq):
|
| 763 |
-
"""
|
| 764 |
-
Generate all postfixes of a sequence.
|
| 765 |
-
|
| 766 |
-
Examples
|
| 767 |
-
========
|
| 768 |
-
|
| 769 |
-
>>> from sympy.utilities.iterables import postfixes
|
| 770 |
-
|
| 771 |
-
>>> list(postfixes([1,2,3,4]))
|
| 772 |
-
[[4], [3, 4], [2, 3, 4], [1, 2, 3, 4]]
|
| 773 |
-
|
| 774 |
-
"""
|
| 775 |
-
n = len(seq)
|
| 776 |
-
|
| 777 |
-
for i in range(n):
|
| 778 |
-
yield seq[n - i - 1:]
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
def topological_sort(graph, key=None):
|
| 782 |
-
r"""
|
| 783 |
-
Topological sort of graph's vertices.
|
| 784 |
-
|
| 785 |
-
Parameters
|
| 786 |
-
==========
|
| 787 |
-
|
| 788 |
-
graph : tuple[list, list[tuple[T, T]]
|
| 789 |
-
A tuple consisting of a list of vertices and a list of edges of
|
| 790 |
-
a graph to be sorted topologically.
|
| 791 |
-
|
| 792 |
-
key : callable[T] (optional)
|
| 793 |
-
Ordering key for vertices on the same level. By default the natural
|
| 794 |
-
(e.g. lexicographic) ordering is used (in this case the base type
|
| 795 |
-
must implement ordering relations).
|
| 796 |
-
|
| 797 |
-
Examples
|
| 798 |
-
========
|
| 799 |
-
|
| 800 |
-
Consider a graph::
|
| 801 |
-
|
| 802 |
-
+---+ +---+ +---+
|
| 803 |
-
| 7 |\ | 5 | | 3 |
|
| 804 |
-
+---+ \ +---+ +---+
|
| 805 |
-
| _\___/ ____ _/ |
|
| 806 |
-
| / \___/ \ / |
|
| 807 |
-
V V V V |
|
| 808 |
-
+----+ +---+ |
|
| 809 |
-
| 11 | | 8 | |
|
| 810 |
-
+----+ +---+ |
|
| 811 |
-
| | \____ ___/ _ |
|
| 812 |
-
| \ \ / / \ |
|
| 813 |
-
V \ V V / V V
|
| 814 |
-
+---+ \ +---+ | +----+
|
| 815 |
-
| 2 | | | 9 | | | 10 |
|
| 816 |
-
+---+ | +---+ | +----+
|
| 817 |
-
\________/
|
| 818 |
-
|
| 819 |
-
where vertices are integers. This graph can be encoded using
|
| 820 |
-
elementary Python's data structures as follows::
|
| 821 |
-
|
| 822 |
-
>>> V = [2, 3, 5, 7, 8, 9, 10, 11]
|
| 823 |
-
>>> E = [(7, 11), (7, 8), (5, 11), (3, 8), (3, 10),
|
| 824 |
-
... (11, 2), (11, 9), (11, 10), (8, 9)]
|
| 825 |
-
|
| 826 |
-
To compute a topological sort for graph ``(V, E)`` issue::
|
| 827 |
-
|
| 828 |
-
>>> from sympy.utilities.iterables import topological_sort
|
| 829 |
-
|
| 830 |
-
>>> topological_sort((V, E))
|
| 831 |
-
[3, 5, 7, 8, 11, 2, 9, 10]
|
| 832 |
-
|
| 833 |
-
If specific tie breaking approach is needed, use ``key`` parameter::
|
| 834 |
-
|
| 835 |
-
>>> topological_sort((V, E), key=lambda v: -v)
|
| 836 |
-
[7, 5, 11, 3, 10, 8, 9, 2]
|
| 837 |
-
|
| 838 |
-
Only acyclic graphs can be sorted. If the input graph has a cycle,
|
| 839 |
-
then ``ValueError`` will be raised::
|
| 840 |
-
|
| 841 |
-
>>> topological_sort((V, E + [(10, 7)]))
|
| 842 |
-
Traceback (most recent call last):
|
| 843 |
-
...
|
| 844 |
-
ValueError: cycle detected
|
| 845 |
-
|
| 846 |
-
References
|
| 847 |
-
==========
|
| 848 |
-
|
| 849 |
-
.. [1] https://en.wikipedia.org/wiki/Topological_sorting
|
| 850 |
-
|
| 851 |
-
"""
|
| 852 |
-
V, E = graph
|
| 853 |
-
|
| 854 |
-
L = []
|
| 855 |
-
S = set(V)
|
| 856 |
-
E = list(E)
|
| 857 |
-
|
| 858 |
-
S.difference_update(u for v, u in E)
|
| 859 |
-
|
| 860 |
-
if key is None:
|
| 861 |
-
def key(value):
|
| 862 |
-
return value
|
| 863 |
-
|
| 864 |
-
S = sorted(S, key=key, reverse=True)
|
| 865 |
-
|
| 866 |
-
while S:
|
| 867 |
-
node = S.pop()
|
| 868 |
-
L.append(node)
|
| 869 |
-
|
| 870 |
-
for u, v in list(E):
|
| 871 |
-
if u == node:
|
| 872 |
-
E.remove((u, v))
|
| 873 |
-
|
| 874 |
-
for _u, _v in E:
|
| 875 |
-
if v == _v:
|
| 876 |
-
break
|
| 877 |
-
else:
|
| 878 |
-
kv = key(v)
|
| 879 |
-
|
| 880 |
-
for i, s in enumerate(S):
|
| 881 |
-
ks = key(s)
|
| 882 |
-
|
| 883 |
-
if kv > ks:
|
| 884 |
-
S.insert(i, v)
|
| 885 |
-
break
|
| 886 |
-
else:
|
| 887 |
-
S.append(v)
|
| 888 |
-
|
| 889 |
-
if E:
|
| 890 |
-
raise ValueError("cycle detected")
|
| 891 |
-
else:
|
| 892 |
-
return L
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
def strongly_connected_components(G):
|
| 896 |
-
r"""
|
| 897 |
-
Strongly connected components of a directed graph in reverse topological
|
| 898 |
-
order.
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
Parameters
|
| 902 |
-
==========
|
| 903 |
-
|
| 904 |
-
G : tuple[list, list[tuple[T, T]]
|
| 905 |
-
A tuple consisting of a list of vertices and a list of edges of
|
| 906 |
-
a graph whose strongly connected components are to be found.
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
Examples
|
| 910 |
-
========
|
| 911 |
-
|
| 912 |
-
Consider a directed graph (in dot notation)::
|
| 913 |
-
|
| 914 |
-
digraph {
|
| 915 |
-
A -> B
|
| 916 |
-
A -> C
|
| 917 |
-
B -> C
|
| 918 |
-
C -> B
|
| 919 |
-
B -> D
|
| 920 |
-
}
|
| 921 |
-
|
| 922 |
-
.. graphviz::
|
| 923 |
-
|
| 924 |
-
digraph {
|
| 925 |
-
A -> B
|
| 926 |
-
A -> C
|
| 927 |
-
B -> C
|
| 928 |
-
C -> B
|
| 929 |
-
B -> D
|
| 930 |
-
}
|
| 931 |
-
|
| 932 |
-
where vertices are the letters A, B, C and D. This graph can be encoded
|
| 933 |
-
using Python's elementary data structures as follows::
|
| 934 |
-
|
| 935 |
-
>>> V = ['A', 'B', 'C', 'D']
|
| 936 |
-
>>> E = [('A', 'B'), ('A', 'C'), ('B', 'C'), ('C', 'B'), ('B', 'D')]
|
| 937 |
-
|
| 938 |
-
The strongly connected components of this graph can be computed as
|
| 939 |
-
|
| 940 |
-
>>> from sympy.utilities.iterables import strongly_connected_components
|
| 941 |
-
|
| 942 |
-
>>> strongly_connected_components((V, E))
|
| 943 |
-
[['D'], ['B', 'C'], ['A']]
|
| 944 |
-
|
| 945 |
-
This also gives the components in reverse topological order.
|
| 946 |
-
|
| 947 |
-
Since the subgraph containing B and C has a cycle they must be together in
|
| 948 |
-
a strongly connected component. A and D are connected to the rest of the
|
| 949 |
-
graph but not in a cyclic manner so they appear as their own strongly
|
| 950 |
-
connected components.
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
Notes
|
| 954 |
-
=====
|
| 955 |
-
|
| 956 |
-
The vertices of the graph must be hashable for the data structures used.
|
| 957 |
-
If the vertices are unhashable replace them with integer indices.
|
| 958 |
-
|
| 959 |
-
This function uses Tarjan's algorithm to compute the strongly connected
|
| 960 |
-
components in `O(|V|+|E|)` (linear) time.
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
References
|
| 964 |
-
==========
|
| 965 |
-
|
| 966 |
-
.. [1] https://en.wikipedia.org/wiki/Strongly_connected_component
|
| 967 |
-
.. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
See Also
|
| 971 |
-
========
|
| 972 |
-
|
| 973 |
-
sympy.utilities.iterables.connected_components
|
| 974 |
-
|
| 975 |
-
"""
|
| 976 |
-
# Map from a vertex to its neighbours
|
| 977 |
-
V, E = G
|
| 978 |
-
Gmap = {vi: [] for vi in V}
|
| 979 |
-
for v1, v2 in E:
|
| 980 |
-
Gmap[v1].append(v2)
|
| 981 |
-
return _strongly_connected_components(V, Gmap)
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
def _strongly_connected_components(V, Gmap):
|
| 985 |
-
"""More efficient internal routine for strongly_connected_components"""
|
| 986 |
-
#
|
| 987 |
-
# Here V is an iterable of vertices and Gmap is a dict mapping each vertex
|
| 988 |
-
# to a list of neighbours e.g.:
|
| 989 |
-
#
|
| 990 |
-
# V = [0, 1, 2, 3]
|
| 991 |
-
# Gmap = {0: [2, 3], 1: [0]}
|
| 992 |
-
#
|
| 993 |
-
# For a large graph these data structures can often be created more
|
| 994 |
-
# efficiently then those expected by strongly_connected_components() which
|
| 995 |
-
# in this case would be
|
| 996 |
-
#
|
| 997 |
-
# V = [0, 1, 2, 3]
|
| 998 |
-
# Gmap = [(0, 2), (0, 3), (1, 0)]
|
| 999 |
-
#
|
| 1000 |
-
# XXX: Maybe this should be the recommended function to use instead...
|
| 1001 |
-
#
|
| 1002 |
-
|
| 1003 |
-
# Non-recursive Tarjan's algorithm:
|
| 1004 |
-
lowlink = {}
|
| 1005 |
-
indices = {}
|
| 1006 |
-
stack = OrderedDict()
|
| 1007 |
-
callstack = []
|
| 1008 |
-
components = []
|
| 1009 |
-
nomore = object()
|
| 1010 |
-
|
| 1011 |
-
def start(v):
|
| 1012 |
-
index = len(stack)
|
| 1013 |
-
indices[v] = lowlink[v] = index
|
| 1014 |
-
stack[v] = None
|
| 1015 |
-
callstack.append((v, iter(Gmap[v])))
|
| 1016 |
-
|
| 1017 |
-
def finish(v1):
|
| 1018 |
-
# Finished a component?
|
| 1019 |
-
if lowlink[v1] == indices[v1]:
|
| 1020 |
-
component = [stack.popitem()[0]]
|
| 1021 |
-
while component[-1] is not v1:
|
| 1022 |
-
component.append(stack.popitem()[0])
|
| 1023 |
-
components.append(component[::-1])
|
| 1024 |
-
v2, _ = callstack.pop()
|
| 1025 |
-
if callstack:
|
| 1026 |
-
v1, _ = callstack[-1]
|
| 1027 |
-
lowlink[v1] = min(lowlink[v1], lowlink[v2])
|
| 1028 |
-
|
| 1029 |
-
for v in V:
|
| 1030 |
-
if v in indices:
|
| 1031 |
-
continue
|
| 1032 |
-
start(v)
|
| 1033 |
-
while callstack:
|
| 1034 |
-
v1, it1 = callstack[-1]
|
| 1035 |
-
v2 = next(it1, nomore)
|
| 1036 |
-
# Finished children of v1?
|
| 1037 |
-
if v2 is nomore:
|
| 1038 |
-
finish(v1)
|
| 1039 |
-
# Recurse on v2
|
| 1040 |
-
elif v2 not in indices:
|
| 1041 |
-
start(v2)
|
| 1042 |
-
elif v2 in stack:
|
| 1043 |
-
lowlink[v1] = min(lowlink[v1], indices[v2])
|
| 1044 |
-
|
| 1045 |
-
# Reverse topological sort order:
|
| 1046 |
-
return components
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
def connected_components(G):
|
| 1050 |
-
r"""
|
| 1051 |
-
Connected components of an undirected graph or weakly connected components
|
| 1052 |
-
of a directed graph.
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
Parameters
|
| 1056 |
-
==========
|
| 1057 |
-
|
| 1058 |
-
G : tuple[list, list[tuple[T, T]]
|
| 1059 |
-
A tuple consisting of a list of vertices and a list of edges of
|
| 1060 |
-
a graph whose connected components are to be found.
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
Examples
|
| 1064 |
-
========
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
Given an undirected graph::
|
| 1068 |
-
|
| 1069 |
-
graph {
|
| 1070 |
-
A -- B
|
| 1071 |
-
C -- D
|
| 1072 |
-
}
|
| 1073 |
-
|
| 1074 |
-
.. graphviz::
|
| 1075 |
-
|
| 1076 |
-
graph {
|
| 1077 |
-
A -- B
|
| 1078 |
-
C -- D
|
| 1079 |
-
}
|
| 1080 |
-
|
| 1081 |
-
We can find the connected components using this function if we include
|
| 1082 |
-
each edge in both directions::
|
| 1083 |
-
|
| 1084 |
-
>>> from sympy.utilities.iterables import connected_components
|
| 1085 |
-
|
| 1086 |
-
>>> V = ['A', 'B', 'C', 'D']
|
| 1087 |
-
>>> E = [('A', 'B'), ('B', 'A'), ('C', 'D'), ('D', 'C')]
|
| 1088 |
-
>>> connected_components((V, E))
|
| 1089 |
-
[['A', 'B'], ['C', 'D']]
|
| 1090 |
-
|
| 1091 |
-
The weakly connected components of a directed graph can found the same
|
| 1092 |
-
way.
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
Notes
|
| 1096 |
-
=====
|
| 1097 |
-
|
| 1098 |
-
The vertices of the graph must be hashable for the data structures used.
|
| 1099 |
-
If the vertices are unhashable replace them with integer indices.
|
| 1100 |
-
|
| 1101 |
-
This function uses Tarjan's algorithm to compute the connected components
|
| 1102 |
-
in `O(|V|+|E|)` (linear) time.
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
References
|
| 1106 |
-
==========
|
| 1107 |
-
|
| 1108 |
-
.. [1] https://en.wikipedia.org/wiki/Component_%28graph_theory%29
|
| 1109 |
-
.. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
See Also
|
| 1113 |
-
========
|
| 1114 |
-
|
| 1115 |
-
sympy.utilities.iterables.strongly_connected_components
|
| 1116 |
-
|
| 1117 |
-
"""
|
| 1118 |
-
# Duplicate edges both ways so that the graph is effectively undirected
|
| 1119 |
-
# and return the strongly connected components:
|
| 1120 |
-
V, E = G
|
| 1121 |
-
E_undirected = []
|
| 1122 |
-
for v1, v2 in E:
|
| 1123 |
-
E_undirected.extend([(v1, v2), (v2, v1)])
|
| 1124 |
-
return strongly_connected_components((V, E_undirected))
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
def rotate_left(x, y):
|
| 1128 |
-
"""
|
| 1129 |
-
Left rotates a list x by the number of steps specified
|
| 1130 |
-
in y.
|
| 1131 |
-
|
| 1132 |
-
Examples
|
| 1133 |
-
========
|
| 1134 |
-
|
| 1135 |
-
>>> from sympy.utilities.iterables import rotate_left
|
| 1136 |
-
>>> a = [0, 1, 2]
|
| 1137 |
-
>>> rotate_left(a, 1)
|
| 1138 |
-
[1, 2, 0]
|
| 1139 |
-
"""
|
| 1140 |
-
if len(x) == 0:
|
| 1141 |
-
return []
|
| 1142 |
-
y = y % len(x)
|
| 1143 |
-
return x[y:] + x[:y]
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
def rotate_right(x, y):
|
| 1147 |
-
"""
|
| 1148 |
-
Right rotates a list x by the number of steps specified
|
| 1149 |
-
in y.
|
| 1150 |
-
|
| 1151 |
-
Examples
|
| 1152 |
-
========
|
| 1153 |
-
|
| 1154 |
-
>>> from sympy.utilities.iterables import rotate_right
|
| 1155 |
-
>>> a = [0, 1, 2]
|
| 1156 |
-
>>> rotate_right(a, 1)
|
| 1157 |
-
[2, 0, 1]
|
| 1158 |
-
"""
|
| 1159 |
-
if len(x) == 0:
|
| 1160 |
-
return []
|
| 1161 |
-
y = len(x) - y % len(x)
|
| 1162 |
-
return x[y:] + x[:y]
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
def least_rotation(x, key=None):
|
| 1166 |
-
'''
|
| 1167 |
-
Returns the number of steps of left rotation required to
|
| 1168 |
-
obtain lexicographically minimal string/list/tuple, etc.
|
| 1169 |
-
|
| 1170 |
-
Examples
|
| 1171 |
-
========
|
| 1172 |
-
|
| 1173 |
-
>>> from sympy.utilities.iterables import least_rotation, rotate_left
|
| 1174 |
-
>>> a = [3, 1, 5, 1, 2]
|
| 1175 |
-
>>> least_rotation(a)
|
| 1176 |
-
3
|
| 1177 |
-
>>> rotate_left(a, _)
|
| 1178 |
-
[1, 2, 3, 1, 5]
|
| 1179 |
-
|
| 1180 |
-
References
|
| 1181 |
-
==========
|
| 1182 |
-
|
| 1183 |
-
.. [1] https://en.wikipedia.org/wiki/Lexicographically_minimal_string_rotation
|
| 1184 |
-
|
| 1185 |
-
'''
|
| 1186 |
-
from sympy.functions.elementary.miscellaneous import Id
|
| 1187 |
-
if key is None: key = Id
|
| 1188 |
-
S = x + x # Concatenate string to it self to avoid modular arithmetic
|
| 1189 |
-
f = [-1] * len(S) # Failure function
|
| 1190 |
-
k = 0 # Least rotation of string found so far
|
| 1191 |
-
for j in range(1,len(S)):
|
| 1192 |
-
sj = S[j]
|
| 1193 |
-
i = f[j-k-1]
|
| 1194 |
-
while i != -1 and sj != S[k+i+1]:
|
| 1195 |
-
if key(sj) < key(S[k+i+1]):
|
| 1196 |
-
k = j-i-1
|
| 1197 |
-
i = f[i]
|
| 1198 |
-
if sj != S[k+i+1]:
|
| 1199 |
-
if key(sj) < key(S[k]):
|
| 1200 |
-
k = j
|
| 1201 |
-
f[j-k] = -1
|
| 1202 |
-
else:
|
| 1203 |
-
f[j-k] = i+1
|
| 1204 |
-
return k
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
def multiset_combinations(m, n, g=None):
|
| 1208 |
-
"""
|
| 1209 |
-
Return the unique combinations of size ``n`` from multiset ``m``.
|
| 1210 |
-
|
| 1211 |
-
Examples
|
| 1212 |
-
========
|
| 1213 |
-
|
| 1214 |
-
>>> from sympy.utilities.iterables import multiset_combinations
|
| 1215 |
-
>>> from itertools import combinations
|
| 1216 |
-
>>> [''.join(i) for i in multiset_combinations('baby', 3)]
|
| 1217 |
-
['abb', 'aby', 'bby']
|
| 1218 |
-
|
| 1219 |
-
>>> def count(f, s): return len(list(f(s, 3)))
|
| 1220 |
-
|
| 1221 |
-
The number of combinations depends on the number of letters; the
|
| 1222 |
-
number of unique combinations depends on how the letters are
|
| 1223 |
-
repeated.
|
| 1224 |
-
|
| 1225 |
-
>>> s1 = 'abracadabra'
|
| 1226 |
-
>>> s2 = 'banana tree'
|
| 1227 |
-
>>> count(combinations, s1), count(multiset_combinations, s1)
|
| 1228 |
-
(165, 23)
|
| 1229 |
-
>>> count(combinations, s2), count(multiset_combinations, s2)
|
| 1230 |
-
(165, 54)
|
| 1231 |
-
|
| 1232 |
-
"""
|
| 1233 |
-
from sympy.core.sorting import ordered
|
| 1234 |
-
if g is None:
|
| 1235 |
-
if isinstance(m, dict):
|
| 1236 |
-
if any(as_int(v) < 0 for v in m.values()):
|
| 1237 |
-
raise ValueError('counts cannot be negative')
|
| 1238 |
-
N = sum(m.values())
|
| 1239 |
-
if n > N:
|
| 1240 |
-
return
|
| 1241 |
-
g = [[k, m[k]] for k in ordered(m)]
|
| 1242 |
-
else:
|
| 1243 |
-
m = list(m)
|
| 1244 |
-
N = len(m)
|
| 1245 |
-
if n > N:
|
| 1246 |
-
return
|
| 1247 |
-
try:
|
| 1248 |
-
m = multiset(m)
|
| 1249 |
-
g = [(k, m[k]) for k in ordered(m)]
|
| 1250 |
-
except TypeError:
|
| 1251 |
-
m = list(ordered(m))
|
| 1252 |
-
g = [list(i) for i in group(m, multiple=False)]
|
| 1253 |
-
del m
|
| 1254 |
-
else:
|
| 1255 |
-
# not checking counts since g is intended for internal use
|
| 1256 |
-
N = sum(v for k, v in g)
|
| 1257 |
-
if n > N or not n:
|
| 1258 |
-
yield []
|
| 1259 |
-
else:
|
| 1260 |
-
for i, (k, v) in enumerate(g):
|
| 1261 |
-
if v >= n:
|
| 1262 |
-
yield [k]*n
|
| 1263 |
-
v = n - 1
|
| 1264 |
-
for v in range(min(n, v), 0, -1):
|
| 1265 |
-
for j in multiset_combinations(None, n - v, g[i + 1:]):
|
| 1266 |
-
rv = [k]*v + j
|
| 1267 |
-
if len(rv) == n:
|
| 1268 |
-
yield rv
|
| 1269 |
-
|
| 1270 |
-
def multiset_permutations(m, size=None, g=None):
|
| 1271 |
-
"""
|
| 1272 |
-
Return the unique permutations of multiset ``m``.
|
| 1273 |
-
|
| 1274 |
-
Examples
|
| 1275 |
-
========
|
| 1276 |
-
|
| 1277 |
-
>>> from sympy.utilities.iterables import multiset_permutations
|
| 1278 |
-
>>> from sympy import factorial
|
| 1279 |
-
>>> [''.join(i) for i in multiset_permutations('aab')]
|
| 1280 |
-
['aab', 'aba', 'baa']
|
| 1281 |
-
>>> factorial(len('banana'))
|
| 1282 |
-
720
|
| 1283 |
-
>>> len(list(multiset_permutations('banana')))
|
| 1284 |
-
60
|
| 1285 |
-
"""
|
| 1286 |
-
from sympy.core.sorting import ordered
|
| 1287 |
-
if g is None:
|
| 1288 |
-
if isinstance(m, dict):
|
| 1289 |
-
if any(as_int(v) < 0 for v in m.values()):
|
| 1290 |
-
raise ValueError('counts cannot be negative')
|
| 1291 |
-
g = [[k, m[k]] for k in ordered(m)]
|
| 1292 |
-
else:
|
| 1293 |
-
m = list(ordered(m))
|
| 1294 |
-
g = [list(i) for i in group(m, multiple=False)]
|
| 1295 |
-
del m
|
| 1296 |
-
do = [gi for gi in g if gi[1] > 0]
|
| 1297 |
-
SUM = sum(gi[1] for gi in do)
|
| 1298 |
-
if not do or size is not None and (size > SUM or size < 1):
|
| 1299 |
-
if not do and size is None or size == 0:
|
| 1300 |
-
yield []
|
| 1301 |
-
return
|
| 1302 |
-
elif size == 1:
|
| 1303 |
-
for k, v in do:
|
| 1304 |
-
yield [k]
|
| 1305 |
-
elif len(do) == 1:
|
| 1306 |
-
k, v = do[0]
|
| 1307 |
-
v = v if size is None else (size if size <= v else 0)
|
| 1308 |
-
yield [k for i in range(v)]
|
| 1309 |
-
elif all(v == 1 for k, v in do):
|
| 1310 |
-
for p in permutations([k for k, v in do], size):
|
| 1311 |
-
yield list(p)
|
| 1312 |
-
else:
|
| 1313 |
-
size = size if size is not None else SUM
|
| 1314 |
-
for i, (k, v) in enumerate(do):
|
| 1315 |
-
do[i][1] -= 1
|
| 1316 |
-
for j in multiset_permutations(None, size - 1, do):
|
| 1317 |
-
if j:
|
| 1318 |
-
yield [k] + j
|
| 1319 |
-
do[i][1] += 1
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
def _partition(seq, vector, m=None):
|
| 1323 |
-
"""
|
| 1324 |
-
Return the partition of seq as specified by the partition vector.
|
| 1325 |
-
|
| 1326 |
-
Examples
|
| 1327 |
-
========
|
| 1328 |
-
|
| 1329 |
-
>>> from sympy.utilities.iterables import _partition
|
| 1330 |
-
>>> _partition('abcde', [1, 0, 1, 2, 0])
|
| 1331 |
-
[['b', 'e'], ['a', 'c'], ['d']]
|
| 1332 |
-
|
| 1333 |
-
Specifying the number of bins in the partition is optional:
|
| 1334 |
-
|
| 1335 |
-
>>> _partition('abcde', [1, 0, 1, 2, 0], 3)
|
| 1336 |
-
[['b', 'e'], ['a', 'c'], ['d']]
|
| 1337 |
-
|
| 1338 |
-
The output of _set_partitions can be passed as follows:
|
| 1339 |
-
|
| 1340 |
-
>>> output = (3, [1, 0, 1, 2, 0])
|
| 1341 |
-
>>> _partition('abcde', *output)
|
| 1342 |
-
[['b', 'e'], ['a', 'c'], ['d']]
|
| 1343 |
-
|
| 1344 |
-
See Also
|
| 1345 |
-
========
|
| 1346 |
-
|
| 1347 |
-
combinatorics.partitions.Partition.from_rgs
|
| 1348 |
-
|
| 1349 |
-
"""
|
| 1350 |
-
if m is None:
|
| 1351 |
-
m = max(vector) + 1
|
| 1352 |
-
elif isinstance(vector, int): # entered as m, vector
|
| 1353 |
-
vector, m = m, vector
|
| 1354 |
-
p = [[] for i in range(m)]
|
| 1355 |
-
for i, v in enumerate(vector):
|
| 1356 |
-
p[v].append(seq[i])
|
| 1357 |
-
return p
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
def _set_partitions(n):
|
| 1361 |
-
"""Cycle through all partitions of n elements, yielding the
|
| 1362 |
-
current number of partitions, ``m``, and a mutable list, ``q``
|
| 1363 |
-
such that ``element[i]`` is in part ``q[i]`` of the partition.
|
| 1364 |
-
|
| 1365 |
-
NOTE: ``q`` is modified in place and generally should not be changed
|
| 1366 |
-
between function calls.
|
| 1367 |
-
|
| 1368 |
-
Examples
|
| 1369 |
-
========
|
| 1370 |
-
|
| 1371 |
-
>>> from sympy.utilities.iterables import _set_partitions, _partition
|
| 1372 |
-
>>> for m, q in _set_partitions(3):
|
| 1373 |
-
... print('%s %s %s' % (m, q, _partition('abc', q, m)))
|
| 1374 |
-
1 [0, 0, 0] [['a', 'b', 'c']]
|
| 1375 |
-
2 [0, 0, 1] [['a', 'b'], ['c']]
|
| 1376 |
-
2 [0, 1, 0] [['a', 'c'], ['b']]
|
| 1377 |
-
2 [0, 1, 1] [['a'], ['b', 'c']]
|
| 1378 |
-
3 [0, 1, 2] [['a'], ['b'], ['c']]
|
| 1379 |
-
|
| 1380 |
-
Notes
|
| 1381 |
-
=====
|
| 1382 |
-
|
| 1383 |
-
This algorithm is similar to, and solves the same problem as,
|
| 1384 |
-
Algorithm 7.2.1.5H, from volume 4A of Knuth's The Art of Computer
|
| 1385 |
-
Programming. Knuth uses the term "restricted growth string" where
|
| 1386 |
-
this code refers to a "partition vector". In each case, the meaning is
|
| 1387 |
-
the same: the value in the ith element of the vector specifies to
|
| 1388 |
-
which part the ith set element is to be assigned.
|
| 1389 |
-
|
| 1390 |
-
At the lowest level, this code implements an n-digit big-endian
|
| 1391 |
-
counter (stored in the array q) which is incremented (with carries) to
|
| 1392 |
-
get the next partition in the sequence. A special twist is that a
|
| 1393 |
-
digit is constrained to be at most one greater than the maximum of all
|
| 1394 |
-
the digits to the left of it. The array p maintains this maximum, so
|
| 1395 |
-
that the code can efficiently decide when a digit can be incremented
|
| 1396 |
-
in place or whether it needs to be reset to 0 and trigger a carry to
|
| 1397 |
-
the next digit. The enumeration starts with all the digits 0 (which
|
| 1398 |
-
corresponds to all the set elements being assigned to the same 0th
|
| 1399 |
-
part), and ends with 0123...n, which corresponds to each set element
|
| 1400 |
-
being assigned to a different, singleton, part.
|
| 1401 |
-
|
| 1402 |
-
This routine was rewritten to use 0-based lists while trying to
|
| 1403 |
-
preserve the beauty and efficiency of the original algorithm.
|
| 1404 |
-
|
| 1405 |
-
References
|
| 1406 |
-
==========
|
| 1407 |
-
|
| 1408 |
-
.. [1] Nijenhuis, Albert and Wilf, Herbert. (1978) Combinatorial Algorithms,
|
| 1409 |
-
2nd Ed, p 91, algorithm "nexequ". Available online from
|
| 1410 |
-
https://www.math.upenn.edu/~wilf/website/CombAlgDownld.html (viewed
|
| 1411 |
-
November 17, 2012).
|
| 1412 |
-
|
| 1413 |
-
"""
|
| 1414 |
-
p = [0]*n
|
| 1415 |
-
q = [0]*n
|
| 1416 |
-
nc = 1
|
| 1417 |
-
yield nc, q
|
| 1418 |
-
while nc != n:
|
| 1419 |
-
m = n
|
| 1420 |
-
while 1:
|
| 1421 |
-
m -= 1
|
| 1422 |
-
i = q[m]
|
| 1423 |
-
if p[i] != 1:
|
| 1424 |
-
break
|
| 1425 |
-
q[m] = 0
|
| 1426 |
-
i += 1
|
| 1427 |
-
q[m] = i
|
| 1428 |
-
m += 1
|
| 1429 |
-
nc += m - n
|
| 1430 |
-
p[0] += n - m
|
| 1431 |
-
if i == nc:
|
| 1432 |
-
p[nc] = 0
|
| 1433 |
-
nc += 1
|
| 1434 |
-
p[i - 1] -= 1
|
| 1435 |
-
p[i] += 1
|
| 1436 |
-
yield nc, q
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
def multiset_partitions(multiset, m=None):
|
| 1440 |
-
"""
|
| 1441 |
-
Return unique partitions of the given multiset (in list form).
|
| 1442 |
-
If ``m`` is None, all multisets will be returned, otherwise only
|
| 1443 |
-
partitions with ``m`` parts will be returned.
|
| 1444 |
-
|
| 1445 |
-
If ``multiset`` is an integer, a range [0, 1, ..., multiset - 1]
|
| 1446 |
-
will be supplied.
|
| 1447 |
-
|
| 1448 |
-
Examples
|
| 1449 |
-
========
|
| 1450 |
-
|
| 1451 |
-
>>> from sympy.utilities.iterables import multiset_partitions
|
| 1452 |
-
>>> list(multiset_partitions([1, 2, 3, 4], 2))
|
| 1453 |
-
[[[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]],
|
| 1454 |
-
[[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]],
|
| 1455 |
-
[[1], [2, 3, 4]]]
|
| 1456 |
-
>>> list(multiset_partitions([1, 2, 3, 4], 1))
|
| 1457 |
-
[[[1, 2, 3, 4]]]
|
| 1458 |
-
|
| 1459 |
-
Only unique partitions are returned and these will be returned in a
|
| 1460 |
-
canonical order regardless of the order of the input:
|
| 1461 |
-
|
| 1462 |
-
>>> a = [1, 2, 2, 1]
|
| 1463 |
-
>>> ans = list(multiset_partitions(a, 2))
|
| 1464 |
-
>>> a.sort()
|
| 1465 |
-
>>> list(multiset_partitions(a, 2)) == ans
|
| 1466 |
-
True
|
| 1467 |
-
>>> a = range(3, 1, -1)
|
| 1468 |
-
>>> (list(multiset_partitions(a)) ==
|
| 1469 |
-
... list(multiset_partitions(sorted(a))))
|
| 1470 |
-
True
|
| 1471 |
-
|
| 1472 |
-
If m is omitted then all partitions will be returned:
|
| 1473 |
-
|
| 1474 |
-
>>> list(multiset_partitions([1, 1, 2]))
|
| 1475 |
-
[[[1, 1, 2]], [[1, 1], [2]], [[1, 2], [1]], [[1], [1], [2]]]
|
| 1476 |
-
>>> list(multiset_partitions([1]*3))
|
| 1477 |
-
[[[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]]
|
| 1478 |
-
|
| 1479 |
-
Counting
|
| 1480 |
-
========
|
| 1481 |
-
|
| 1482 |
-
The number of partitions of a set is given by the bell number:
|
| 1483 |
-
|
| 1484 |
-
>>> from sympy import bell
|
| 1485 |
-
>>> len(list(multiset_partitions(5))) == bell(5) == 52
|
| 1486 |
-
True
|
| 1487 |
-
|
| 1488 |
-
The number of partitions of length k from a set of size n is given by the
|
| 1489 |
-
Stirling Number of the 2nd kind:
|
| 1490 |
-
|
| 1491 |
-
>>> from sympy.functions.combinatorial.numbers import stirling
|
| 1492 |
-
>>> stirling(5, 2) == len(list(multiset_partitions(5, 2))) == 15
|
| 1493 |
-
True
|
| 1494 |
-
|
| 1495 |
-
These comments on counting apply to *sets*, not multisets.
|
| 1496 |
-
|
| 1497 |
-
Notes
|
| 1498 |
-
=====
|
| 1499 |
-
|
| 1500 |
-
When all the elements are the same in the multiset, the order
|
| 1501 |
-
of the returned partitions is determined by the ``partitions``
|
| 1502 |
-
routine. If one is counting partitions then it is better to use
|
| 1503 |
-
the ``nT`` function.
|
| 1504 |
-
|
| 1505 |
-
See Also
|
| 1506 |
-
========
|
| 1507 |
-
|
| 1508 |
-
partitions
|
| 1509 |
-
sympy.combinatorics.partitions.Partition
|
| 1510 |
-
sympy.combinatorics.partitions.IntegerPartition
|
| 1511 |
-
sympy.functions.combinatorial.numbers.nT
|
| 1512 |
-
|
| 1513 |
-
"""
|
| 1514 |
-
# This function looks at the supplied input and dispatches to
|
| 1515 |
-
# several special-case routines as they apply.
|
| 1516 |
-
if isinstance(multiset, int):
|
| 1517 |
-
n = multiset
|
| 1518 |
-
if m and m > n:
|
| 1519 |
-
return
|
| 1520 |
-
multiset = list(range(n))
|
| 1521 |
-
if m == 1:
|
| 1522 |
-
yield [multiset[:]]
|
| 1523 |
-
return
|
| 1524 |
-
|
| 1525 |
-
# If m is not None, it can sometimes be faster to use
|
| 1526 |
-
# MultisetPartitionTraverser.enum_range() even for inputs
|
| 1527 |
-
# which are sets. Since the _set_partitions code is quite
|
| 1528 |
-
# fast, this is only advantageous when the overall set
|
| 1529 |
-
# partitions outnumber those with the desired number of parts
|
| 1530 |
-
# by a large factor. (At least 60.) Such a switch is not
|
| 1531 |
-
# currently implemented.
|
| 1532 |
-
for nc, q in _set_partitions(n):
|
| 1533 |
-
if m is None or nc == m:
|
| 1534 |
-
rv = [[] for i in range(nc)]
|
| 1535 |
-
for i in range(n):
|
| 1536 |
-
rv[q[i]].append(multiset[i])
|
| 1537 |
-
yield rv
|
| 1538 |
-
return
|
| 1539 |
-
|
| 1540 |
-
if len(multiset) == 1 and isinstance(multiset, str):
|
| 1541 |
-
multiset = [multiset]
|
| 1542 |
-
|
| 1543 |
-
if not has_variety(multiset):
|
| 1544 |
-
# Only one component, repeated n times. The resulting
|
| 1545 |
-
# partitions correspond to partitions of integer n.
|
| 1546 |
-
n = len(multiset)
|
| 1547 |
-
if m and m > n:
|
| 1548 |
-
return
|
| 1549 |
-
if m == 1:
|
| 1550 |
-
yield [multiset[:]]
|
| 1551 |
-
return
|
| 1552 |
-
x = multiset[:1]
|
| 1553 |
-
for size, p in partitions(n, m, size=True):
|
| 1554 |
-
if m is None or size == m:
|
| 1555 |
-
rv = []
|
| 1556 |
-
for k in sorted(p):
|
| 1557 |
-
rv.extend([x*k]*p[k])
|
| 1558 |
-
yield rv
|
| 1559 |
-
else:
|
| 1560 |
-
from sympy.core.sorting import ordered
|
| 1561 |
-
multiset = list(ordered(multiset))
|
| 1562 |
-
n = len(multiset)
|
| 1563 |
-
if m and m > n:
|
| 1564 |
-
return
|
| 1565 |
-
if m == 1:
|
| 1566 |
-
yield [multiset[:]]
|
| 1567 |
-
return
|
| 1568 |
-
|
| 1569 |
-
# Split the information of the multiset into two lists -
|
| 1570 |
-
# one of the elements themselves, and one (of the same length)
|
| 1571 |
-
# giving the number of repeats for the corresponding element.
|
| 1572 |
-
elements, multiplicities = zip(*group(multiset, False))
|
| 1573 |
-
|
| 1574 |
-
if len(elements) < len(multiset):
|
| 1575 |
-
# General case - multiset with more than one distinct element
|
| 1576 |
-
# and at least one element repeated more than once.
|
| 1577 |
-
if m:
|
| 1578 |
-
mpt = MultisetPartitionTraverser()
|
| 1579 |
-
for state in mpt.enum_range(multiplicities, m-1, m):
|
| 1580 |
-
yield list_visitor(state, elements)
|
| 1581 |
-
else:
|
| 1582 |
-
for state in multiset_partitions_taocp(multiplicities):
|
| 1583 |
-
yield list_visitor(state, elements)
|
| 1584 |
-
else:
|
| 1585 |
-
# Set partitions case - no repeated elements. Pretty much
|
| 1586 |
-
# same as int argument case above, with same possible, but
|
| 1587 |
-
# currently unimplemented optimization for some cases when
|
| 1588 |
-
# m is not None
|
| 1589 |
-
for nc, q in _set_partitions(n):
|
| 1590 |
-
if m is None or nc == m:
|
| 1591 |
-
rv = [[] for i in range(nc)]
|
| 1592 |
-
for i in range(n):
|
| 1593 |
-
rv[q[i]].append(i)
|
| 1594 |
-
yield [[multiset[j] for j in i] for i in rv]
|
| 1595 |
-
|
| 1596 |
-
|
| 1597 |
-
def partitions(n, m=None, k=None, size=False):
|
| 1598 |
-
"""Generate all partitions of positive integer, n.
|
| 1599 |
-
|
| 1600 |
-
Each partition is represented as a dictionary, mapping an integer
|
| 1601 |
-
to the number of copies of that integer in the partition. For example,
|
| 1602 |
-
the first partition of 4 returned is {4: 1}, "4: one of them".
|
| 1603 |
-
|
| 1604 |
-
Parameters
|
| 1605 |
-
==========
|
| 1606 |
-
n : int
|
| 1607 |
-
m : int, optional
|
| 1608 |
-
limits number of parts in partition (mnemonic: m, maximum parts)
|
| 1609 |
-
k : int, optional
|
| 1610 |
-
limits the numbers that are kept in the partition (mnemonic: k, keys)
|
| 1611 |
-
size : bool, default: False
|
| 1612 |
-
If ``True``, (M, P) is returned where M is the sum of the
|
| 1613 |
-
multiplicities and P is the generated partition.
|
| 1614 |
-
If ``False``, only the generated partition is returned.
|
| 1615 |
-
|
| 1616 |
-
Examples
|
| 1617 |
-
========
|
| 1618 |
-
|
| 1619 |
-
>>> from sympy.utilities.iterables import partitions
|
| 1620 |
-
|
| 1621 |
-
The numbers appearing in the partition (the key of the returned dict)
|
| 1622 |
-
are limited with k:
|
| 1623 |
-
|
| 1624 |
-
>>> for p in partitions(6, k=2): # doctest: +SKIP
|
| 1625 |
-
... print(p)
|
| 1626 |
-
{2: 3}
|
| 1627 |
-
{1: 2, 2: 2}
|
| 1628 |
-
{1: 4, 2: 1}
|
| 1629 |
-
{1: 6}
|
| 1630 |
-
|
| 1631 |
-
The maximum number of parts in the partition (the sum of the values in
|
| 1632 |
-
the returned dict) are limited with m (default value, None, gives
|
| 1633 |
-
partitions from 1 through n):
|
| 1634 |
-
|
| 1635 |
-
>>> for p in partitions(6, m=2): # doctest: +SKIP
|
| 1636 |
-
... print(p)
|
| 1637 |
-
...
|
| 1638 |
-
{6: 1}
|
| 1639 |
-
{1: 1, 5: 1}
|
| 1640 |
-
{2: 1, 4: 1}
|
| 1641 |
-
{3: 2}
|
| 1642 |
-
|
| 1643 |
-
References
|
| 1644 |
-
==========
|
| 1645 |
-
|
| 1646 |
-
.. [1] modified from Tim Peter's version to allow for k and m values:
|
| 1647 |
-
https://code.activestate.com/recipes/218332-generator-for-integer-partitions/
|
| 1648 |
-
|
| 1649 |
-
See Also
|
| 1650 |
-
========
|
| 1651 |
-
|
| 1652 |
-
sympy.combinatorics.partitions.Partition
|
| 1653 |
-
sympy.combinatorics.partitions.IntegerPartition
|
| 1654 |
-
|
| 1655 |
-
"""
|
| 1656 |
-
if (n <= 0 or
|
| 1657 |
-
m is not None and m < 1 or
|
| 1658 |
-
k is not None and k < 1 or
|
| 1659 |
-
m and k and m*k < n):
|
| 1660 |
-
# the empty set is the only way to handle these inputs
|
| 1661 |
-
# and returning {} to represent it is consistent with
|
| 1662 |
-
# the counting convention, e.g. nT(0) == 1.
|
| 1663 |
-
if size:
|
| 1664 |
-
yield 0, {}
|
| 1665 |
-
else:
|
| 1666 |
-
yield {}
|
| 1667 |
-
return
|
| 1668 |
-
|
| 1669 |
-
if m is None:
|
| 1670 |
-
m = n
|
| 1671 |
-
else:
|
| 1672 |
-
m = min(m, n)
|
| 1673 |
-
k = min(k or n, n)
|
| 1674 |
-
|
| 1675 |
-
n, m, k = as_int(n), as_int(m), as_int(k)
|
| 1676 |
-
q, r = divmod(n, k)
|
| 1677 |
-
ms = {k: q}
|
| 1678 |
-
keys = [k] # ms.keys(), from largest to smallest
|
| 1679 |
-
if r:
|
| 1680 |
-
ms[r] = 1
|
| 1681 |
-
keys.append(r)
|
| 1682 |
-
room = m - q - bool(r)
|
| 1683 |
-
if size:
|
| 1684 |
-
yield sum(ms.values()), ms.copy()
|
| 1685 |
-
else:
|
| 1686 |
-
yield ms.copy()
|
| 1687 |
-
|
| 1688 |
-
while keys != [1]:
|
| 1689 |
-
# Reuse any 1's.
|
| 1690 |
-
if keys[-1] == 1:
|
| 1691 |
-
del keys[-1]
|
| 1692 |
-
reuse = ms.pop(1)
|
| 1693 |
-
room += reuse
|
| 1694 |
-
else:
|
| 1695 |
-
reuse = 0
|
| 1696 |
-
|
| 1697 |
-
while 1:
|
| 1698 |
-
# Let i be the smallest key larger than 1. Reuse one
|
| 1699 |
-
# instance of i.
|
| 1700 |
-
i = keys[-1]
|
| 1701 |
-
newcount = ms[i] = ms[i] - 1
|
| 1702 |
-
reuse += i
|
| 1703 |
-
if newcount == 0:
|
| 1704 |
-
del keys[-1], ms[i]
|
| 1705 |
-
room += 1
|
| 1706 |
-
|
| 1707 |
-
# Break the remainder into pieces of size i-1.
|
| 1708 |
-
i -= 1
|
| 1709 |
-
q, r = divmod(reuse, i)
|
| 1710 |
-
need = q + bool(r)
|
| 1711 |
-
if need > room:
|
| 1712 |
-
if not keys:
|
| 1713 |
-
return
|
| 1714 |
-
continue
|
| 1715 |
-
|
| 1716 |
-
ms[i] = q
|
| 1717 |
-
keys.append(i)
|
| 1718 |
-
if r:
|
| 1719 |
-
ms[r] = 1
|
| 1720 |
-
keys.append(r)
|
| 1721 |
-
break
|
| 1722 |
-
room -= need
|
| 1723 |
-
if size:
|
| 1724 |
-
yield sum(ms.values()), ms.copy()
|
| 1725 |
-
else:
|
| 1726 |
-
yield ms.copy()
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
def ordered_partitions(n, m=None, sort=True):
|
| 1730 |
-
"""Generates ordered partitions of integer *n*.
|
| 1731 |
-
|
| 1732 |
-
Parameters
|
| 1733 |
-
==========
|
| 1734 |
-
n : int
|
| 1735 |
-
m : int, optional
|
| 1736 |
-
The default value gives partitions of all sizes else only
|
| 1737 |
-
those with size m. In addition, if *m* is not None then
|
| 1738 |
-
partitions are generated *in place* (see examples).
|
| 1739 |
-
sort : bool, default: True
|
| 1740 |
-
Controls whether partitions are
|
| 1741 |
-
returned in sorted order when *m* is not None; when False,
|
| 1742 |
-
the partitions are returned as fast as possible with elements
|
| 1743 |
-
sorted, but when m|n the partitions will not be in
|
| 1744 |
-
ascending lexicographical order.
|
| 1745 |
-
|
| 1746 |
-
Examples
|
| 1747 |
-
========
|
| 1748 |
-
|
| 1749 |
-
>>> from sympy.utilities.iterables import ordered_partitions
|
| 1750 |
-
|
| 1751 |
-
All partitions of 5 in ascending lexicographical:
|
| 1752 |
-
|
| 1753 |
-
>>> for p in ordered_partitions(5):
|
| 1754 |
-
... print(p)
|
| 1755 |
-
[1, 1, 1, 1, 1]
|
| 1756 |
-
[1, 1, 1, 2]
|
| 1757 |
-
[1, 1, 3]
|
| 1758 |
-
[1, 2, 2]
|
| 1759 |
-
[1, 4]
|
| 1760 |
-
[2, 3]
|
| 1761 |
-
[5]
|
| 1762 |
-
|
| 1763 |
-
Only partitions of 5 with two parts:
|
| 1764 |
-
|
| 1765 |
-
>>> for p in ordered_partitions(5, 2):
|
| 1766 |
-
... print(p)
|
| 1767 |
-
[1, 4]
|
| 1768 |
-
[2, 3]
|
| 1769 |
-
|
| 1770 |
-
When ``m`` is given, a given list objects will be used more than
|
| 1771 |
-
once for speed reasons so you will not see the correct partitions
|
| 1772 |
-
unless you make a copy of each as it is generated:
|
| 1773 |
-
|
| 1774 |
-
>>> [p for p in ordered_partitions(7, 3)]
|
| 1775 |
-
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [2, 2, 2]]
|
| 1776 |
-
>>> [list(p) for p in ordered_partitions(7, 3)]
|
| 1777 |
-
[[1, 1, 5], [1, 2, 4], [1, 3, 3], [2, 2, 3]]
|
| 1778 |
-
|
| 1779 |
-
When ``n`` is a multiple of ``m``, the elements are still sorted
|
| 1780 |
-
but the partitions themselves will be *unordered* if sort is False;
|
| 1781 |
-
the default is to return them in ascending lexicographical order.
|
| 1782 |
-
|
| 1783 |
-
>>> for p in ordered_partitions(6, 2):
|
| 1784 |
-
... print(p)
|
| 1785 |
-
[1, 5]
|
| 1786 |
-
[2, 4]
|
| 1787 |
-
[3, 3]
|
| 1788 |
-
|
| 1789 |
-
But if speed is more important than ordering, sort can be set to
|
| 1790 |
-
False:
|
| 1791 |
-
|
| 1792 |
-
>>> for p in ordered_partitions(6, 2, sort=False):
|
| 1793 |
-
... print(p)
|
| 1794 |
-
[1, 5]
|
| 1795 |
-
[3, 3]
|
| 1796 |
-
[2, 4]
|
| 1797 |
-
|
| 1798 |
-
References
|
| 1799 |
-
==========
|
| 1800 |
-
|
| 1801 |
-
.. [1] Generating Integer Partitions, [online],
|
| 1802 |
-
Available: https://jeromekelleher.net/generating-integer-partitions.html
|
| 1803 |
-
.. [2] Jerome Kelleher and Barry O'Sullivan, "Generating All
|
| 1804 |
-
Partitions: A Comparison Of Two Encodings", [online],
|
| 1805 |
-
Available: https://arxiv.org/pdf/0909.2331v2.pdf
|
| 1806 |
-
"""
|
| 1807 |
-
if n < 1 or m is not None and m < 1:
|
| 1808 |
-
# the empty set is the only way to handle these inputs
|
| 1809 |
-
# and returning {} to represent it is consistent with
|
| 1810 |
-
# the counting convention, e.g. nT(0) == 1.
|
| 1811 |
-
yield []
|
| 1812 |
-
return
|
| 1813 |
-
|
| 1814 |
-
if m is None:
|
| 1815 |
-
# The list `a`'s leading elements contain the partition in which
|
| 1816 |
-
# y is the biggest element and x is either the same as y or the
|
| 1817 |
-
# 2nd largest element; v and w are adjacent element indices
|
| 1818 |
-
# to which x and y are being assigned, respectively.
|
| 1819 |
-
a = [1]*n
|
| 1820 |
-
y = -1
|
| 1821 |
-
v = n
|
| 1822 |
-
while v > 0:
|
| 1823 |
-
v -= 1
|
| 1824 |
-
x = a[v] + 1
|
| 1825 |
-
while y >= 2 * x:
|
| 1826 |
-
a[v] = x
|
| 1827 |
-
y -= x
|
| 1828 |
-
v += 1
|
| 1829 |
-
w = v + 1
|
| 1830 |
-
while x <= y:
|
| 1831 |
-
a[v] = x
|
| 1832 |
-
a[w] = y
|
| 1833 |
-
yield a[:w + 1]
|
| 1834 |
-
x += 1
|
| 1835 |
-
y -= 1
|
| 1836 |
-
a[v] = x + y
|
| 1837 |
-
y = a[v] - 1
|
| 1838 |
-
yield a[:w]
|
| 1839 |
-
elif m == 1:
|
| 1840 |
-
yield [n]
|
| 1841 |
-
elif n == m:
|
| 1842 |
-
yield [1]*n
|
| 1843 |
-
else:
|
| 1844 |
-
# recursively generate partitions of size m
|
| 1845 |
-
for b in range(1, n//m + 1):
|
| 1846 |
-
a = [b]*m
|
| 1847 |
-
x = n - b*m
|
| 1848 |
-
if not x:
|
| 1849 |
-
if sort:
|
| 1850 |
-
yield a
|
| 1851 |
-
elif not sort and x <= m:
|
| 1852 |
-
for ax in ordered_partitions(x, sort=False):
|
| 1853 |
-
mi = len(ax)
|
| 1854 |
-
a[-mi:] = [i + b for i in ax]
|
| 1855 |
-
yield a
|
| 1856 |
-
a[-mi:] = [b]*mi
|
| 1857 |
-
else:
|
| 1858 |
-
for mi in range(1, m):
|
| 1859 |
-
for ax in ordered_partitions(x, mi, sort=True):
|
| 1860 |
-
a[-mi:] = [i + b for i in ax]
|
| 1861 |
-
yield a
|
| 1862 |
-
a[-mi:] = [b]*mi
|
| 1863 |
-
|
| 1864 |
-
|
| 1865 |
-
def binary_partitions(n):
|
| 1866 |
-
"""
|
| 1867 |
-
Generates the binary partition of *n*.
|
| 1868 |
-
|
| 1869 |
-
A binary partition consists only of numbers that are
|
| 1870 |
-
powers of two. Each step reduces a `2^{k+1}` to `2^k` and
|
| 1871 |
-
`2^k`. Thus 16 is converted to 8 and 8.
|
| 1872 |
-
|
| 1873 |
-
Examples
|
| 1874 |
-
========
|
| 1875 |
-
|
| 1876 |
-
>>> from sympy.utilities.iterables import binary_partitions
|
| 1877 |
-
>>> for i in binary_partitions(5):
|
| 1878 |
-
... print(i)
|
| 1879 |
-
...
|
| 1880 |
-
[4, 1]
|
| 1881 |
-
[2, 2, 1]
|
| 1882 |
-
[2, 1, 1, 1]
|
| 1883 |
-
[1, 1, 1, 1, 1]
|
| 1884 |
-
|
| 1885 |
-
References
|
| 1886 |
-
==========
|
| 1887 |
-
|
| 1888 |
-
.. [1] TAOCP 4, section 7.2.1.5, problem 64
|
| 1889 |
-
|
| 1890 |
-
"""
|
| 1891 |
-
from math import ceil, log2
|
| 1892 |
-
power = int(2**(ceil(log2(n))))
|
| 1893 |
-
acc = 0
|
| 1894 |
-
partition = []
|
| 1895 |
-
while power:
|
| 1896 |
-
if acc + power <= n:
|
| 1897 |
-
partition.append(power)
|
| 1898 |
-
acc += power
|
| 1899 |
-
power >>= 1
|
| 1900 |
-
|
| 1901 |
-
last_num = len(partition) - 1 - (n & 1)
|
| 1902 |
-
while last_num >= 0:
|
| 1903 |
-
yield partition
|
| 1904 |
-
if partition[last_num] == 2:
|
| 1905 |
-
partition[last_num] = 1
|
| 1906 |
-
partition.append(1)
|
| 1907 |
-
last_num -= 1
|
| 1908 |
-
continue
|
| 1909 |
-
partition.append(1)
|
| 1910 |
-
partition[last_num] >>= 1
|
| 1911 |
-
x = partition[last_num + 1] = partition[last_num]
|
| 1912 |
-
last_num += 1
|
| 1913 |
-
while x > 1:
|
| 1914 |
-
if x <= len(partition) - last_num - 1:
|
| 1915 |
-
del partition[-x + 1:]
|
| 1916 |
-
last_num += 1
|
| 1917 |
-
partition[last_num] = x
|
| 1918 |
-
else:
|
| 1919 |
-
x >>= 1
|
| 1920 |
-
yield [1]*n
|
| 1921 |
-
|
| 1922 |
-
|
| 1923 |
-
def has_dups(seq):
|
| 1924 |
-
"""Return True if there are any duplicate elements in ``seq``.
|
| 1925 |
-
|
| 1926 |
-
Examples
|
| 1927 |
-
========
|
| 1928 |
-
|
| 1929 |
-
>>> from sympy import has_dups, Dict, Set
|
| 1930 |
-
>>> has_dups((1, 2, 1))
|
| 1931 |
-
True
|
| 1932 |
-
>>> has_dups(range(3))
|
| 1933 |
-
False
|
| 1934 |
-
>>> all(has_dups(c) is False for c in (set(), Set(), dict(), Dict()))
|
| 1935 |
-
True
|
| 1936 |
-
"""
|
| 1937 |
-
from sympy.core.containers import Dict
|
| 1938 |
-
from sympy.sets.sets import Set
|
| 1939 |
-
if isinstance(seq, (dict, set, Dict, Set)):
|
| 1940 |
-
return False
|
| 1941 |
-
unique = set()
|
| 1942 |
-
try:
|
| 1943 |
-
return any(True for s in seq if s in unique or unique.add(s))
|
| 1944 |
-
except TypeError:
|
| 1945 |
-
return len(seq) != len(list(uniq(seq)))
|
| 1946 |
-
|
| 1947 |
-
|
| 1948 |
-
def has_variety(seq):
|
| 1949 |
-
"""Return True if there are any different elements in ``seq``.
|
| 1950 |
-
|
| 1951 |
-
Examples
|
| 1952 |
-
========
|
| 1953 |
-
|
| 1954 |
-
>>> from sympy import has_variety
|
| 1955 |
-
|
| 1956 |
-
>>> has_variety((1, 2, 1))
|
| 1957 |
-
True
|
| 1958 |
-
>>> has_variety((1, 1, 1))
|
| 1959 |
-
False
|
| 1960 |
-
"""
|
| 1961 |
-
for i, s in enumerate(seq):
|
| 1962 |
-
if i == 0:
|
| 1963 |
-
sentinel = s
|
| 1964 |
-
else:
|
| 1965 |
-
if s != sentinel:
|
| 1966 |
-
return True
|
| 1967 |
-
return False
|
| 1968 |
-
|
| 1969 |
-
|
| 1970 |
-
def uniq(seq, result=None):
|
| 1971 |
-
"""
|
| 1972 |
-
Yield unique elements from ``seq`` as an iterator. The second
|
| 1973 |
-
parameter ``result`` is used internally; it is not necessary
|
| 1974 |
-
to pass anything for this.
|
| 1975 |
-
|
| 1976 |
-
Note: changing the sequence during iteration will raise a
|
| 1977 |
-
RuntimeError if the size of the sequence is known; if you pass
|
| 1978 |
-
an iterator and advance the iterator you will change the
|
| 1979 |
-
output of this routine but there will be no warning.
|
| 1980 |
-
|
| 1981 |
-
Examples
|
| 1982 |
-
========
|
| 1983 |
-
|
| 1984 |
-
>>> from sympy.utilities.iterables import uniq
|
| 1985 |
-
>>> dat = [1, 4, 1, 5, 4, 2, 1, 2]
|
| 1986 |
-
>>> type(uniq(dat)) in (list, tuple)
|
| 1987 |
-
False
|
| 1988 |
-
|
| 1989 |
-
>>> list(uniq(dat))
|
| 1990 |
-
[1, 4, 5, 2]
|
| 1991 |
-
>>> list(uniq(x for x in dat))
|
| 1992 |
-
[1, 4, 5, 2]
|
| 1993 |
-
>>> list(uniq([[1], [2, 1], [1]]))
|
| 1994 |
-
[[1], [2, 1]]
|
| 1995 |
-
"""
|
| 1996 |
-
try:
|
| 1997 |
-
n = len(seq)
|
| 1998 |
-
except TypeError:
|
| 1999 |
-
n = None
|
| 2000 |
-
def check():
|
| 2001 |
-
# check that size of seq did not change during iteration;
|
| 2002 |
-
# if n == None the object won't support size changing, e.g.
|
| 2003 |
-
# an iterator can't be changed
|
| 2004 |
-
if n is not None and len(seq) != n:
|
| 2005 |
-
raise RuntimeError('sequence changed size during iteration')
|
| 2006 |
-
try:
|
| 2007 |
-
seen = set()
|
| 2008 |
-
result = result or []
|
| 2009 |
-
for i, s in enumerate(seq):
|
| 2010 |
-
if not (s in seen or seen.add(s)):
|
| 2011 |
-
yield s
|
| 2012 |
-
check()
|
| 2013 |
-
except TypeError:
|
| 2014 |
-
if s not in result:
|
| 2015 |
-
yield s
|
| 2016 |
-
check()
|
| 2017 |
-
result.append(s)
|
| 2018 |
-
if hasattr(seq, '__getitem__'):
|
| 2019 |
-
yield from uniq(seq[i + 1:], result)
|
| 2020 |
-
else:
|
| 2021 |
-
yield from uniq(seq, result)
|
| 2022 |
-
|
| 2023 |
-
|
| 2024 |
-
def generate_bell(n):
|
| 2025 |
-
"""Return permutations of [0, 1, ..., n - 1] such that each permutation
|
| 2026 |
-
differs from the last by the exchange of a single pair of neighbors.
|
| 2027 |
-
The ``n!`` permutations are returned as an iterator. In order to obtain
|
| 2028 |
-
the next permutation from a random starting permutation, use the
|
| 2029 |
-
``next_trotterjohnson`` method of the Permutation class (which generates
|
| 2030 |
-
the same sequence in a different manner).
|
| 2031 |
-
|
| 2032 |
-
Examples
|
| 2033 |
-
========
|
| 2034 |
-
|
| 2035 |
-
>>> from itertools import permutations
|
| 2036 |
-
>>> from sympy.utilities.iterables import generate_bell
|
| 2037 |
-
>>> from sympy import zeros, Matrix
|
| 2038 |
-
|
| 2039 |
-
This is the sort of permutation used in the ringing of physical bells,
|
| 2040 |
-
and does not produce permutations in lexicographical order. Rather, the
|
| 2041 |
-
permutations differ from each other by exactly one inversion, and the
|
| 2042 |
-
position at which the swapping occurs varies periodically in a simple
|
| 2043 |
-
fashion. Consider the first few permutations of 4 elements generated
|
| 2044 |
-
by ``permutations`` and ``generate_bell``:
|
| 2045 |
-
|
| 2046 |
-
>>> list(permutations(range(4)))[:5]
|
| 2047 |
-
[(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), (0, 3, 1, 2)]
|
| 2048 |
-
>>> list(generate_bell(4))[:5]
|
| 2049 |
-
[(0, 1, 2, 3), (0, 1, 3, 2), (0, 3, 1, 2), (3, 0, 1, 2), (3, 0, 2, 1)]
|
| 2050 |
-
|
| 2051 |
-
Notice how the 2nd and 3rd lexicographical permutations have 3 elements
|
| 2052 |
-
out of place whereas each "bell" permutation always has only two
|
| 2053 |
-
elements out of place relative to the previous permutation (and so the
|
| 2054 |
-
signature (+/-1) of a permutation is opposite of the signature of the
|
| 2055 |
-
previous permutation).
|
| 2056 |
-
|
| 2057 |
-
How the position of inversion varies across the elements can be seen
|
| 2058 |
-
by tracing out where the largest number appears in the permutations:
|
| 2059 |
-
|
| 2060 |
-
>>> m = zeros(4, 24)
|
| 2061 |
-
>>> for i, p in enumerate(generate_bell(4)):
|
| 2062 |
-
... m[:, i] = Matrix([j - 3 for j in list(p)]) # make largest zero
|
| 2063 |
-
>>> m.print_nonzero('X')
|
| 2064 |
-
[XXX XXXXXX XXXXXX XXX]
|
| 2065 |
-
[XX XX XXXX XX XXXX XX XX]
|
| 2066 |
-
[X XXXX XX XXXX XX XXXX X]
|
| 2067 |
-
[ XXXXXX XXXXXX XXXXXX ]
|
| 2068 |
-
|
| 2069 |
-
See Also
|
| 2070 |
-
========
|
| 2071 |
-
|
| 2072 |
-
sympy.combinatorics.permutations.Permutation.next_trotterjohnson
|
| 2073 |
-
|
| 2074 |
-
References
|
| 2075 |
-
==========
|
| 2076 |
-
|
| 2077 |
-
.. [1] https://en.wikipedia.org/wiki/Method_ringing
|
| 2078 |
-
|
| 2079 |
-
.. [2] https://stackoverflow.com/questions/4856615/recursive-permutation/4857018
|
| 2080 |
-
|
| 2081 |
-
.. [3] https://web.archive.org/web/20160313023044/http://programminggeeks.com/bell-algorithm-for-permutation/
|
| 2082 |
-
|
| 2083 |
-
.. [4] https://en.wikipedia.org/wiki/Steinhaus%E2%80%93Johnson%E2%80%93Trotter_algorithm
|
| 2084 |
-
|
| 2085 |
-
.. [5] Generating involutions, derangements, and relatives by ECO
|
| 2086 |
-
Vincent Vajnovszki, DMTCS vol 1 issue 12, 2010
|
| 2087 |
-
|
| 2088 |
-
"""
|
| 2089 |
-
n = as_int(n)
|
| 2090 |
-
if n < 1:
|
| 2091 |
-
raise ValueError('n must be a positive integer')
|
| 2092 |
-
if n == 1:
|
| 2093 |
-
yield (0,)
|
| 2094 |
-
elif n == 2:
|
| 2095 |
-
yield (0, 1)
|
| 2096 |
-
yield (1, 0)
|
| 2097 |
-
elif n == 3:
|
| 2098 |
-
yield from [(0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)]
|
| 2099 |
-
else:
|
| 2100 |
-
m = n - 1
|
| 2101 |
-
op = [0] + [-1]*m
|
| 2102 |
-
l = list(range(n))
|
| 2103 |
-
while True:
|
| 2104 |
-
yield tuple(l)
|
| 2105 |
-
# find biggest element with op
|
| 2106 |
-
big = None, -1 # idx, value
|
| 2107 |
-
for i in range(n):
|
| 2108 |
-
if op[i] and l[i] > big[1]:
|
| 2109 |
-
big = i, l[i]
|
| 2110 |
-
i, _ = big
|
| 2111 |
-
if i is None:
|
| 2112 |
-
break # there are no ops left
|
| 2113 |
-
# swap it with neighbor in the indicated direction
|
| 2114 |
-
j = i + op[i]
|
| 2115 |
-
l[i], l[j] = l[j], l[i]
|
| 2116 |
-
op[i], op[j] = op[j], op[i]
|
| 2117 |
-
# if it landed at the end or if the neighbor in the same
|
| 2118 |
-
# direction is bigger then turn off op
|
| 2119 |
-
if j == 0 or j == m or l[j + op[j]] > l[j]:
|
| 2120 |
-
op[j] = 0
|
| 2121 |
-
# any element bigger to the left gets +1 op
|
| 2122 |
-
for i in range(j):
|
| 2123 |
-
if l[i] > l[j]:
|
| 2124 |
-
op[i] = 1
|
| 2125 |
-
# any element bigger to the right gets -1 op
|
| 2126 |
-
for i in range(j + 1, n):
|
| 2127 |
-
if l[i] > l[j]:
|
| 2128 |
-
op[i] = -1
|
| 2129 |
-
|
| 2130 |
-
|
| 2131 |
-
def generate_involutions(n):
|
| 2132 |
-
"""
|
| 2133 |
-
Generates involutions.
|
| 2134 |
-
|
| 2135 |
-
An involution is a permutation that when multiplied
|
| 2136 |
-
by itself equals the identity permutation. In this
|
| 2137 |
-
implementation the involutions are generated using
|
| 2138 |
-
Fixed Points.
|
| 2139 |
-
|
| 2140 |
-
Alternatively, an involution can be considered as
|
| 2141 |
-
a permutation that does not contain any cycles with
|
| 2142 |
-
a length that is greater than two.
|
| 2143 |
-
|
| 2144 |
-
Examples
|
| 2145 |
-
========
|
| 2146 |
-
|
| 2147 |
-
>>> from sympy.utilities.iterables import generate_involutions
|
| 2148 |
-
>>> list(generate_involutions(3))
|
| 2149 |
-
[(0, 1, 2), (0, 2, 1), (1, 0, 2), (2, 1, 0)]
|
| 2150 |
-
>>> len(list(generate_involutions(4)))
|
| 2151 |
-
10
|
| 2152 |
-
|
| 2153 |
-
References
|
| 2154 |
-
==========
|
| 2155 |
-
|
| 2156 |
-
.. [1] https://mathworld.wolfram.com/PermutationInvolution.html
|
| 2157 |
-
|
| 2158 |
-
"""
|
| 2159 |
-
idx = list(range(n))
|
| 2160 |
-
for p in permutations(idx):
|
| 2161 |
-
for i in idx:
|
| 2162 |
-
if p[p[i]] != i:
|
| 2163 |
-
break
|
| 2164 |
-
else:
|
| 2165 |
-
yield p
|
| 2166 |
-
|
| 2167 |
-
|
| 2168 |
-
def multiset_derangements(s):
|
| 2169 |
-
"""Generate derangements of the elements of s *in place*.
|
| 2170 |
-
|
| 2171 |
-
Examples
|
| 2172 |
-
========
|
| 2173 |
-
|
| 2174 |
-
>>> from sympy.utilities.iterables import multiset_derangements, uniq
|
| 2175 |
-
|
| 2176 |
-
Because the derangements of multisets (not sets) are generated
|
| 2177 |
-
in place, copies of the return value must be made if a collection
|
| 2178 |
-
of derangements is desired or else all values will be the same:
|
| 2179 |
-
|
| 2180 |
-
>>> list(uniq([i for i in multiset_derangements('1233')]))
|
| 2181 |
-
[[None, None, None, None]]
|
| 2182 |
-
>>> [i.copy() for i in multiset_derangements('1233')]
|
| 2183 |
-
[['3', '3', '1', '2'], ['3', '3', '2', '1']]
|
| 2184 |
-
>>> [''.join(i) for i in multiset_derangements('1233')]
|
| 2185 |
-
['3312', '3321']
|
| 2186 |
-
"""
|
| 2187 |
-
from sympy.core.sorting import ordered
|
| 2188 |
-
# create multiset dictionary of hashable elements or else
|
| 2189 |
-
# remap elements to integers
|
| 2190 |
-
try:
|
| 2191 |
-
ms = multiset(s)
|
| 2192 |
-
except TypeError:
|
| 2193 |
-
# give each element a canonical integer value
|
| 2194 |
-
key = dict(enumerate(ordered(uniq(s))))
|
| 2195 |
-
h = []
|
| 2196 |
-
for si in s:
|
| 2197 |
-
for k in key:
|
| 2198 |
-
if key[k] == si:
|
| 2199 |
-
h.append(k)
|
| 2200 |
-
break
|
| 2201 |
-
for i in multiset_derangements(h):
|
| 2202 |
-
yield [key[j] for j in i]
|
| 2203 |
-
return
|
| 2204 |
-
|
| 2205 |
-
mx = max(ms.values()) # max repetition of any element
|
| 2206 |
-
n = len(s) # the number of elements
|
| 2207 |
-
|
| 2208 |
-
## special cases
|
| 2209 |
-
|
| 2210 |
-
# 1) one element has more than half the total cardinality of s: no
|
| 2211 |
-
# derangements are possible.
|
| 2212 |
-
if mx*2 > n:
|
| 2213 |
-
return
|
| 2214 |
-
|
| 2215 |
-
# 2) all elements appear once: singletons
|
| 2216 |
-
if len(ms) == n:
|
| 2217 |
-
yield from _set_derangements(s)
|
| 2218 |
-
return
|
| 2219 |
-
|
| 2220 |
-
# find the first element that is repeated the most to place
|
| 2221 |
-
# in the following two special cases where the selection
|
| 2222 |
-
# is unambiguous: either there are two elements with multiplicity
|
| 2223 |
-
# of mx or else there is only one with multiplicity mx
|
| 2224 |
-
for M in ms:
|
| 2225 |
-
if ms[M] == mx:
|
| 2226 |
-
break
|
| 2227 |
-
|
| 2228 |
-
inonM = [i for i in range(n) if s[i] != M] # location of non-M
|
| 2229 |
-
iM = [i for i in range(n) if s[i] == M] # locations of M
|
| 2230 |
-
rv = [None]*n
|
| 2231 |
-
|
| 2232 |
-
# 3) half are the same
|
| 2233 |
-
if 2*mx == n:
|
| 2234 |
-
# M goes into non-M locations
|
| 2235 |
-
for i in inonM:
|
| 2236 |
-
rv[i] = M
|
| 2237 |
-
# permutations of non-M go to M locations
|
| 2238 |
-
for p in multiset_permutations([s[i] for i in inonM]):
|
| 2239 |
-
for i, pi in zip(iM, p):
|
| 2240 |
-
rv[i] = pi
|
| 2241 |
-
yield rv
|
| 2242 |
-
# clean-up (and encourages proper use of routine)
|
| 2243 |
-
rv[:] = [None]*n
|
| 2244 |
-
return
|
| 2245 |
-
|
| 2246 |
-
# 4) single repeat covers all but 1 of the non-repeats:
|
| 2247 |
-
# if there is one repeat then the multiset of the values
|
| 2248 |
-
# of ms would be {mx: 1, 1: n - mx}, i.e. there would
|
| 2249 |
-
# be n - mx + 1 values with the condition that n - 2*mx = 1
|
| 2250 |
-
if n - 2*mx == 1 and len(ms.values()) == n - mx + 1:
|
| 2251 |
-
for i, i1 in enumerate(inonM):
|
| 2252 |
-
ifill = inonM[:i] + inonM[i+1:]
|
| 2253 |
-
for j in ifill:
|
| 2254 |
-
rv[j] = M
|
| 2255 |
-
for p in permutations([s[j] for j in ifill]):
|
| 2256 |
-
rv[i1] = s[i1]
|
| 2257 |
-
for j, pi in zip(iM, p):
|
| 2258 |
-
rv[j] = pi
|
| 2259 |
-
k = i1
|
| 2260 |
-
for j in iM:
|
| 2261 |
-
rv[j], rv[k] = rv[k], rv[j]
|
| 2262 |
-
yield rv
|
| 2263 |
-
k = j
|
| 2264 |
-
# clean-up (and encourages proper use of routine)
|
| 2265 |
-
rv[:] = [None]*n
|
| 2266 |
-
return
|
| 2267 |
-
|
| 2268 |
-
## general case is handled with 3 helpers:
|
| 2269 |
-
# 1) `finish_derangements` will place the last two elements
|
| 2270 |
-
# which have arbitrary multiplicities, e.g. for multiset
|
| 2271 |
-
# {c: 3, a: 2, b: 2}, the last two elements are a and b
|
| 2272 |
-
# 2) `iopen` will tell where a given element can be placed
|
| 2273 |
-
# 3) `do` will recursively place elements into subsets of
|
| 2274 |
-
# valid locations
|
| 2275 |
-
|
| 2276 |
-
def finish_derangements():
|
| 2277 |
-
"""Place the last two elements into the partially completed
|
| 2278 |
-
derangement, and yield the results.
|
| 2279 |
-
"""
|
| 2280 |
-
|
| 2281 |
-
a = take[1][0] # penultimate element
|
| 2282 |
-
a_ct = take[1][1]
|
| 2283 |
-
b = take[0][0] # last element to be placed
|
| 2284 |
-
b_ct = take[0][1]
|
| 2285 |
-
|
| 2286 |
-
# split the indexes of the not-already-assigned elements of rv into
|
| 2287 |
-
# three categories
|
| 2288 |
-
forced_a = [] # positions which must have an a
|
| 2289 |
-
forced_b = [] # positions which must have a b
|
| 2290 |
-
open_free = [] # positions which could take either
|
| 2291 |
-
for i in range(len(s)):
|
| 2292 |
-
if rv[i] is None:
|
| 2293 |
-
if s[i] == a:
|
| 2294 |
-
forced_b.append(i)
|
| 2295 |
-
elif s[i] == b:
|
| 2296 |
-
forced_a.append(i)
|
| 2297 |
-
else:
|
| 2298 |
-
open_free.append(i)
|
| 2299 |
-
|
| 2300 |
-
if len(forced_a) > a_ct or len(forced_b) > b_ct:
|
| 2301 |
-
# No derangement possible
|
| 2302 |
-
return
|
| 2303 |
-
|
| 2304 |
-
for i in forced_a:
|
| 2305 |
-
rv[i] = a
|
| 2306 |
-
for i in forced_b:
|
| 2307 |
-
rv[i] = b
|
| 2308 |
-
for a_place in combinations(open_free, a_ct - len(forced_a)):
|
| 2309 |
-
for a_pos in a_place:
|
| 2310 |
-
rv[a_pos] = a
|
| 2311 |
-
for i in open_free:
|
| 2312 |
-
if rv[i] is None: # anything not in the subset is set to b
|
| 2313 |
-
rv[i] = b
|
| 2314 |
-
yield rv
|
| 2315 |
-
# Clean up/undo the final placements
|
| 2316 |
-
for i in open_free:
|
| 2317 |
-
rv[i] = None
|
| 2318 |
-
|
| 2319 |
-
# additional cleanup - clear forced_a, forced_b
|
| 2320 |
-
for i in forced_a:
|
| 2321 |
-
rv[i] = None
|
| 2322 |
-
for i in forced_b:
|
| 2323 |
-
rv[i] = None
|
| 2324 |
-
|
| 2325 |
-
def iopen(v):
|
| 2326 |
-
# return indices at which element v can be placed in rv:
|
| 2327 |
-
# locations which are not already occupied if that location
|
| 2328 |
-
# does not already contain v in the same location of s
|
| 2329 |
-
return [i for i in range(n) if rv[i] is None and s[i] != v]
|
| 2330 |
-
|
| 2331 |
-
def do(j):
|
| 2332 |
-
if j == 1:
|
| 2333 |
-
# handle the last two elements (regardless of multiplicity)
|
| 2334 |
-
# with a special method
|
| 2335 |
-
yield from finish_derangements()
|
| 2336 |
-
else:
|
| 2337 |
-
# place the mx elements of M into a subset of places
|
| 2338 |
-
# into which it can be replaced
|
| 2339 |
-
M, mx = take[j]
|
| 2340 |
-
for i in combinations(iopen(M), mx):
|
| 2341 |
-
# place M
|
| 2342 |
-
for ii in i:
|
| 2343 |
-
rv[ii] = M
|
| 2344 |
-
# recursively place the next element
|
| 2345 |
-
yield from do(j - 1)
|
| 2346 |
-
# mark positions where M was placed as once again
|
| 2347 |
-
# open for placement of other elements
|
| 2348 |
-
for ii in i:
|
| 2349 |
-
rv[ii] = None
|
| 2350 |
-
|
| 2351 |
-
# process elements in order of canonically decreasing multiplicity
|
| 2352 |
-
take = sorted(ms.items(), key=lambda x:(x[1], x[0]))
|
| 2353 |
-
yield from do(len(take) - 1)
|
| 2354 |
-
rv[:] = [None]*n
|
| 2355 |
-
|
| 2356 |
-
|
| 2357 |
-
def random_derangement(t, choice=None, strict=True):
|
| 2358 |
-
"""Return a list of elements in which none are in the same positions
|
| 2359 |
-
as they were originally. If an element fills more than half of the positions
|
| 2360 |
-
then an error will be raised since no derangement is possible. To obtain
|
| 2361 |
-
a derangement of as many items as possible--with some of the most numerous
|
| 2362 |
-
remaining in their original positions--pass `strict=False`. To produce a
|
| 2363 |
-
pseudorandom derangment, pass a pseudorandom selector like `choice` (see
|
| 2364 |
-
below).
|
| 2365 |
-
|
| 2366 |
-
Examples
|
| 2367 |
-
========
|
| 2368 |
-
|
| 2369 |
-
>>> from sympy.utilities.iterables import random_derangement
|
| 2370 |
-
>>> t = 'SymPy: a CAS in pure Python'
|
| 2371 |
-
>>> d = random_derangement(t)
|
| 2372 |
-
>>> all(i != j for i, j in zip(d, t))
|
| 2373 |
-
True
|
| 2374 |
-
|
| 2375 |
-
A predictable result can be obtained by using a pseudorandom
|
| 2376 |
-
generator for the choice:
|
| 2377 |
-
|
| 2378 |
-
>>> from sympy.core.random import seed, choice as c
|
| 2379 |
-
>>> seed(1)
|
| 2380 |
-
>>> d = [''.join(random_derangement(t, c)) for i in range(5)]
|
| 2381 |
-
>>> assert len(set(d)) != 1 # we got different values
|
| 2382 |
-
|
| 2383 |
-
By reseeding, the same sequence can be obtained:
|
| 2384 |
-
|
| 2385 |
-
>>> seed(1)
|
| 2386 |
-
>>> d2 = [''.join(random_derangement(t, c)) for i in range(5)]
|
| 2387 |
-
>>> assert d == d2
|
| 2388 |
-
"""
|
| 2389 |
-
if choice is None:
|
| 2390 |
-
import secrets
|
| 2391 |
-
choice = secrets.choice
|
| 2392 |
-
def shuffle(rv):
|
| 2393 |
-
'''Knuth shuffle'''
|
| 2394 |
-
for i in range(len(rv) - 1, 0, -1):
|
| 2395 |
-
x = choice(rv[:i + 1])
|
| 2396 |
-
j = rv.index(x)
|
| 2397 |
-
rv[i], rv[j] = rv[j], rv[i]
|
| 2398 |
-
def pick(rv, n):
|
| 2399 |
-
'''shuffle rv and return the first n values
|
| 2400 |
-
'''
|
| 2401 |
-
shuffle(rv)
|
| 2402 |
-
return rv[:n]
|
| 2403 |
-
ms = multiset(t)
|
| 2404 |
-
tot = len(t)
|
| 2405 |
-
ms = sorted(ms.items(), key=lambda x: x[1])
|
| 2406 |
-
# if there are not enough spaces for the most
|
| 2407 |
-
# plentiful element to move to then some of them
|
| 2408 |
-
# will have to stay in place
|
| 2409 |
-
M, mx = ms[-1]
|
| 2410 |
-
n = len(t)
|
| 2411 |
-
xs = 2*mx - tot
|
| 2412 |
-
if xs > 0:
|
| 2413 |
-
if strict:
|
| 2414 |
-
raise ValueError('no derangement possible')
|
| 2415 |
-
opts = [i for (i, c) in enumerate(t) if c == ms[-1][0]]
|
| 2416 |
-
pick(opts, xs)
|
| 2417 |
-
stay = sorted(opts[:xs])
|
| 2418 |
-
rv = list(t)
|
| 2419 |
-
for i in reversed(stay):
|
| 2420 |
-
rv.pop(i)
|
| 2421 |
-
rv = random_derangement(rv, choice)
|
| 2422 |
-
for i in stay:
|
| 2423 |
-
rv.insert(i, ms[-1][0])
|
| 2424 |
-
return ''.join(rv) if type(t) is str else rv
|
| 2425 |
-
# the normal derangement calculated from here
|
| 2426 |
-
if n == len(ms):
|
| 2427 |
-
# approx 1/3 will succeed
|
| 2428 |
-
rv = list(t)
|
| 2429 |
-
while True:
|
| 2430 |
-
shuffle(rv)
|
| 2431 |
-
if all(i != j for i,j in zip(rv, t)):
|
| 2432 |
-
break
|
| 2433 |
-
else:
|
| 2434 |
-
# general case
|
| 2435 |
-
rv = [None]*n
|
| 2436 |
-
while True:
|
| 2437 |
-
j = 0
|
| 2438 |
-
while j > -len(ms): # do most numerous first
|
| 2439 |
-
j -= 1
|
| 2440 |
-
e, c = ms[j]
|
| 2441 |
-
opts = [i for i in range(n) if rv[i] is None and t[i] != e]
|
| 2442 |
-
if len(opts) < c:
|
| 2443 |
-
for i in range(n):
|
| 2444 |
-
rv[i] = None
|
| 2445 |
-
break # try again
|
| 2446 |
-
pick(opts, c)
|
| 2447 |
-
for i in range(c):
|
| 2448 |
-
rv[opts[i]] = e
|
| 2449 |
-
else:
|
| 2450 |
-
return rv
|
| 2451 |
-
return rv
|
| 2452 |
-
|
| 2453 |
-
|
| 2454 |
-
def _set_derangements(s):
|
| 2455 |
-
"""
|
| 2456 |
-
yield derangements of items in ``s`` which are assumed to contain
|
| 2457 |
-
no repeated elements
|
| 2458 |
-
"""
|
| 2459 |
-
if len(s) < 2:
|
| 2460 |
-
return
|
| 2461 |
-
if len(s) == 2:
|
| 2462 |
-
yield [s[1], s[0]]
|
| 2463 |
-
return
|
| 2464 |
-
if len(s) == 3:
|
| 2465 |
-
yield [s[1], s[2], s[0]]
|
| 2466 |
-
yield [s[2], s[0], s[1]]
|
| 2467 |
-
return
|
| 2468 |
-
for p in permutations(s):
|
| 2469 |
-
if not any(i == j for i, j in zip(p, s)):
|
| 2470 |
-
yield list(p)
|
| 2471 |
-
|
| 2472 |
-
|
| 2473 |
-
def generate_derangements(s):
|
| 2474 |
-
"""
|
| 2475 |
-
Return unique derangements of the elements of iterable ``s``.
|
| 2476 |
-
|
| 2477 |
-
Examples
|
| 2478 |
-
========
|
| 2479 |
-
|
| 2480 |
-
>>> from sympy.utilities.iterables import generate_derangements
|
| 2481 |
-
>>> list(generate_derangements([0, 1, 2]))
|
| 2482 |
-
[[1, 2, 0], [2, 0, 1]]
|
| 2483 |
-
>>> list(generate_derangements([0, 1, 2, 2]))
|
| 2484 |
-
[[2, 2, 0, 1], [2, 2, 1, 0]]
|
| 2485 |
-
>>> list(generate_derangements([0, 1, 1]))
|
| 2486 |
-
[]
|
| 2487 |
-
|
| 2488 |
-
See Also
|
| 2489 |
-
========
|
| 2490 |
-
|
| 2491 |
-
sympy.functions.combinatorial.factorials.subfactorial
|
| 2492 |
-
|
| 2493 |
-
"""
|
| 2494 |
-
if not has_dups(s):
|
| 2495 |
-
yield from _set_derangements(s)
|
| 2496 |
-
else:
|
| 2497 |
-
for p in multiset_derangements(s):
|
| 2498 |
-
yield list(p)
|
| 2499 |
-
|
| 2500 |
-
|
| 2501 |
-
def necklaces(n, k, free=False):
|
| 2502 |
-
"""
|
| 2503 |
-
A routine to generate necklaces that may (free=True) or may not
|
| 2504 |
-
(free=False) be turned over to be viewed. The "necklaces" returned
|
| 2505 |
-
are comprised of ``n`` integers (beads) with ``k`` different
|
| 2506 |
-
values (colors). Only unique necklaces are returned.
|
| 2507 |
-
|
| 2508 |
-
Examples
|
| 2509 |
-
========
|
| 2510 |
-
|
| 2511 |
-
>>> from sympy.utilities.iterables import necklaces, bracelets
|
| 2512 |
-
>>> def show(s, i):
|
| 2513 |
-
... return ''.join(s[j] for j in i)
|
| 2514 |
-
|
| 2515 |
-
The "unrestricted necklace" is sometimes also referred to as a
|
| 2516 |
-
"bracelet" (an object that can be turned over, a sequence that can
|
| 2517 |
-
be reversed) and the term "necklace" is used to imply a sequence
|
| 2518 |
-
that cannot be reversed. So ACB == ABC for a bracelet (rotate and
|
| 2519 |
-
reverse) while the two are different for a necklace since rotation
|
| 2520 |
-
alone cannot make the two sequences the same.
|
| 2521 |
-
|
| 2522 |
-
(mnemonic: Bracelets can be viewed Backwards, but Not Necklaces.)
|
| 2523 |
-
|
| 2524 |
-
>>> B = [show('ABC', i) for i in bracelets(3, 3)]
|
| 2525 |
-
>>> N = [show('ABC', i) for i in necklaces(3, 3)]
|
| 2526 |
-
>>> set(N) - set(B)
|
| 2527 |
-
{'ACB'}
|
| 2528 |
-
|
| 2529 |
-
>>> list(necklaces(4, 2))
|
| 2530 |
-
[(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 1),
|
| 2531 |
-
(0, 1, 0, 1), (0, 1, 1, 1), (1, 1, 1, 1)]
|
| 2532 |
-
|
| 2533 |
-
>>> [show('.o', i) for i in bracelets(4, 2)]
|
| 2534 |
-
['....', '...o', '..oo', '.o.o', '.ooo', 'oooo']
|
| 2535 |
-
|
| 2536 |
-
References
|
| 2537 |
-
==========
|
| 2538 |
-
|
| 2539 |
-
.. [1] https://mathworld.wolfram.com/Necklace.html
|
| 2540 |
-
|
| 2541 |
-
.. [2] Frank Ruskey, Carla Savage, and Terry Min Yih Wang,
|
| 2542 |
-
Generating necklaces, Journal of Algorithms 13 (1992), 414-430;
|
| 2543 |
-
https://doi.org/10.1016/0196-6774(92)90047-G
|
| 2544 |
-
|
| 2545 |
-
"""
|
| 2546 |
-
# The FKM algorithm
|
| 2547 |
-
if k == 0 and n > 0:
|
| 2548 |
-
return
|
| 2549 |
-
a = [0]*n
|
| 2550 |
-
yield tuple(a)
|
| 2551 |
-
if n == 0:
|
| 2552 |
-
return
|
| 2553 |
-
while True:
|
| 2554 |
-
i = n - 1
|
| 2555 |
-
while a[i] == k - 1:
|
| 2556 |
-
i -= 1
|
| 2557 |
-
if i == -1:
|
| 2558 |
-
return
|
| 2559 |
-
a[i] += 1
|
| 2560 |
-
for j in range(n - i - 1):
|
| 2561 |
-
a[j + i + 1] = a[j]
|
| 2562 |
-
if n % (i + 1) == 0 and (not free or all(a <= a[j::-1] + a[-1:j:-1] for j in range(n - 1))):
|
| 2563 |
-
# No need to test j = n - 1.
|
| 2564 |
-
yield tuple(a)
|
| 2565 |
-
|
| 2566 |
-
|
| 2567 |
-
def bracelets(n, k):
|
| 2568 |
-
"""Wrapper to necklaces to return a free (unrestricted) necklace."""
|
| 2569 |
-
return necklaces(n, k, free=True)
|
| 2570 |
-
|
| 2571 |
-
|
| 2572 |
-
def generate_oriented_forest(n):
|
| 2573 |
-
"""
|
| 2574 |
-
This algorithm generates oriented forests.
|
| 2575 |
-
|
| 2576 |
-
An oriented graph is a directed graph having no symmetric pair of directed
|
| 2577 |
-
edges. A forest is an acyclic graph, i.e., it has no cycles. A forest can
|
| 2578 |
-
also be described as a disjoint union of trees, which are graphs in which
|
| 2579 |
-
any two vertices are connected by exactly one simple path.
|
| 2580 |
-
|
| 2581 |
-
Examples
|
| 2582 |
-
========
|
| 2583 |
-
|
| 2584 |
-
>>> from sympy.utilities.iterables import generate_oriented_forest
|
| 2585 |
-
>>> list(generate_oriented_forest(4))
|
| 2586 |
-
[[0, 1, 2, 3], [0, 1, 2, 2], [0, 1, 2, 1], [0, 1, 2, 0], \
|
| 2587 |
-
[0, 1, 1, 1], [0, 1, 1, 0], [0, 1, 0, 1], [0, 1, 0, 0], [0, 0, 0, 0]]
|
| 2588 |
-
|
| 2589 |
-
References
|
| 2590 |
-
==========
|
| 2591 |
-
|
| 2592 |
-
.. [1] T. Beyer and S.M. Hedetniemi: constant time generation of
|
| 2593 |
-
rooted trees, SIAM J. Computing Vol. 9, No. 4, November 1980
|
| 2594 |
-
|
| 2595 |
-
.. [2] https://stackoverflow.com/questions/1633833/oriented-forest-taocp-algorithm-in-python
|
| 2596 |
-
|
| 2597 |
-
"""
|
| 2598 |
-
P = list(range(-1, n))
|
| 2599 |
-
while True:
|
| 2600 |
-
yield P[1:]
|
| 2601 |
-
if P[n] > 0:
|
| 2602 |
-
P[n] = P[P[n]]
|
| 2603 |
-
else:
|
| 2604 |
-
for p in range(n - 1, 0, -1):
|
| 2605 |
-
if P[p] != 0:
|
| 2606 |
-
target = P[p] - 1
|
| 2607 |
-
for q in range(p - 1, 0, -1):
|
| 2608 |
-
if P[q] == target:
|
| 2609 |
-
break
|
| 2610 |
-
offset = p - q
|
| 2611 |
-
for i in range(p, n + 1):
|
| 2612 |
-
P[i] = P[i - offset]
|
| 2613 |
-
break
|
| 2614 |
-
else:
|
| 2615 |
-
break
|
| 2616 |
-
|
| 2617 |
-
|
| 2618 |
-
def minlex(seq, directed=True, key=None):
|
| 2619 |
-
r"""
|
| 2620 |
-
Return the rotation of the sequence in which the lexically smallest
|
| 2621 |
-
elements appear first, e.g. `cba \rightarrow acb`.
|
| 2622 |
-
|
| 2623 |
-
The sequence returned is a tuple, unless the input sequence is a string
|
| 2624 |
-
in which case a string is returned.
|
| 2625 |
-
|
| 2626 |
-
If ``directed`` is False then the smaller of the sequence and the
|
| 2627 |
-
reversed sequence is returned, e.g. `cba \rightarrow abc`.
|
| 2628 |
-
|
| 2629 |
-
If ``key`` is not None then it is used to extract a comparison key from each element in iterable.
|
| 2630 |
-
|
| 2631 |
-
Examples
|
| 2632 |
-
========
|
| 2633 |
-
|
| 2634 |
-
>>> from sympy.combinatorics.polyhedron import minlex
|
| 2635 |
-
>>> minlex((1, 2, 0))
|
| 2636 |
-
(0, 1, 2)
|
| 2637 |
-
>>> minlex((1, 0, 2))
|
| 2638 |
-
(0, 2, 1)
|
| 2639 |
-
>>> minlex((1, 0, 2), directed=False)
|
| 2640 |
-
(0, 1, 2)
|
| 2641 |
-
|
| 2642 |
-
>>> minlex('11010011000', directed=True)
|
| 2643 |
-
'00011010011'
|
| 2644 |
-
>>> minlex('11010011000', directed=False)
|
| 2645 |
-
'00011001011'
|
| 2646 |
-
|
| 2647 |
-
>>> minlex(('bb', 'aaa', 'c', 'a'))
|
| 2648 |
-
('a', 'bb', 'aaa', 'c')
|
| 2649 |
-
>>> minlex(('bb', 'aaa', 'c', 'a'), key=len)
|
| 2650 |
-
('c', 'a', 'bb', 'aaa')
|
| 2651 |
-
|
| 2652 |
-
"""
|
| 2653 |
-
from sympy.functions.elementary.miscellaneous import Id
|
| 2654 |
-
if key is None: key = Id
|
| 2655 |
-
best = rotate_left(seq, least_rotation(seq, key=key))
|
| 2656 |
-
if not directed:
|
| 2657 |
-
rseq = seq[::-1]
|
| 2658 |
-
rbest = rotate_left(rseq, least_rotation(rseq, key=key))
|
| 2659 |
-
best = min(best, rbest, key=key)
|
| 2660 |
-
|
| 2661 |
-
# Convert to tuple, unless we started with a string.
|
| 2662 |
-
return tuple(best) if not isinstance(seq, str) else best
|
| 2663 |
-
|
| 2664 |
-
|
| 2665 |
-
def runs(seq, op=gt):
|
| 2666 |
-
"""Group the sequence into lists in which successive elements
|
| 2667 |
-
all compare the same with the comparison operator, ``op``:
|
| 2668 |
-
op(seq[i + 1], seq[i]) is True from all elements in a run.
|
| 2669 |
-
|
| 2670 |
-
Examples
|
| 2671 |
-
========
|
| 2672 |
-
|
| 2673 |
-
>>> from sympy.utilities.iterables import runs
|
| 2674 |
-
>>> from operator import ge
|
| 2675 |
-
>>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2])
|
| 2676 |
-
[[0, 1, 2], [2], [1, 4], [3], [2], [2]]
|
| 2677 |
-
>>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2], op=ge)
|
| 2678 |
-
[[0, 1, 2, 2], [1, 4], [3], [2, 2]]
|
| 2679 |
-
"""
|
| 2680 |
-
cycles = []
|
| 2681 |
-
seq = iter(seq)
|
| 2682 |
-
try:
|
| 2683 |
-
run = [next(seq)]
|
| 2684 |
-
except StopIteration:
|
| 2685 |
-
return []
|
| 2686 |
-
while True:
|
| 2687 |
-
try:
|
| 2688 |
-
ei = next(seq)
|
| 2689 |
-
except StopIteration:
|
| 2690 |
-
break
|
| 2691 |
-
if op(ei, run[-1]):
|
| 2692 |
-
run.append(ei)
|
| 2693 |
-
continue
|
| 2694 |
-
else:
|
| 2695 |
-
cycles.append(run)
|
| 2696 |
-
run = [ei]
|
| 2697 |
-
if run:
|
| 2698 |
-
cycles.append(run)
|
| 2699 |
-
return cycles
|
| 2700 |
-
|
| 2701 |
-
|
| 2702 |
-
def sequence_partitions(l, n, /):
|
| 2703 |
-
r"""Returns the partition of sequence $l$ into $n$ bins
|
| 2704 |
-
|
| 2705 |
-
Explanation
|
| 2706 |
-
===========
|
| 2707 |
-
|
| 2708 |
-
Given the sequence $l_1 \cdots l_m \in V^+$ where
|
| 2709 |
-
$V^+$ is the Kleene plus of $V$
|
| 2710 |
-
|
| 2711 |
-
The set of $n$ partitions of $l$ is defined as:
|
| 2712 |
-
|
| 2713 |
-
.. math::
|
| 2714 |
-
\{(s_1, \cdots, s_n) | s_1 \in V^+, \cdots, s_n \in V^+,
|
| 2715 |
-
s_1 \cdots s_n = l_1 \cdots l_m\}
|
| 2716 |
-
|
| 2717 |
-
Parameters
|
| 2718 |
-
==========
|
| 2719 |
-
|
| 2720 |
-
l : Sequence[T]
|
| 2721 |
-
A nonempty sequence of any Python objects
|
| 2722 |
-
|
| 2723 |
-
n : int
|
| 2724 |
-
A positive integer
|
| 2725 |
-
|
| 2726 |
-
Yields
|
| 2727 |
-
======
|
| 2728 |
-
|
| 2729 |
-
out : list[Sequence[T]]
|
| 2730 |
-
A list of sequences with concatenation equals $l$.
|
| 2731 |
-
This should conform with the type of $l$.
|
| 2732 |
-
|
| 2733 |
-
Examples
|
| 2734 |
-
========
|
| 2735 |
-
|
| 2736 |
-
>>> from sympy.utilities.iterables import sequence_partitions
|
| 2737 |
-
>>> for out in sequence_partitions([1, 2, 3, 4], 2):
|
| 2738 |
-
... print(out)
|
| 2739 |
-
[[1], [2, 3, 4]]
|
| 2740 |
-
[[1, 2], [3, 4]]
|
| 2741 |
-
[[1, 2, 3], [4]]
|
| 2742 |
-
|
| 2743 |
-
Notes
|
| 2744 |
-
=====
|
| 2745 |
-
|
| 2746 |
-
This is modified version of EnricoGiampieri's partition generator
|
| 2747 |
-
from https://stackoverflow.com/questions/13131491/partition-n-items-into-k-bins-in-python-lazily
|
| 2748 |
-
|
| 2749 |
-
See Also
|
| 2750 |
-
========
|
| 2751 |
-
|
| 2752 |
-
sequence_partitions_empty
|
| 2753 |
-
"""
|
| 2754 |
-
# Asserting l is nonempty is done only for sanity check
|
| 2755 |
-
if n == 1 and l:
|
| 2756 |
-
yield [l]
|
| 2757 |
-
return
|
| 2758 |
-
for i in range(1, len(l)):
|
| 2759 |
-
for part in sequence_partitions(l[i:], n - 1):
|
| 2760 |
-
yield [l[:i]] + part
|
| 2761 |
-
|
| 2762 |
-
|
| 2763 |
-
def sequence_partitions_empty(l, n, /):
|
| 2764 |
-
r"""Returns the partition of sequence $l$ into $n$ bins with
|
| 2765 |
-
empty sequence
|
| 2766 |
-
|
| 2767 |
-
Explanation
|
| 2768 |
-
===========
|
| 2769 |
-
|
| 2770 |
-
Given the sequence $l_1 \cdots l_m \in V^*$ where
|
| 2771 |
-
$V^*$ is the Kleene star of $V$
|
| 2772 |
-
|
| 2773 |
-
The set of $n$ partitions of $l$ is defined as:
|
| 2774 |
-
|
| 2775 |
-
.. math::
|
| 2776 |
-
\{(s_1, \cdots, s_n) | s_1 \in V^*, \cdots, s_n \in V^*,
|
| 2777 |
-
s_1 \cdots s_n = l_1 \cdots l_m\}
|
| 2778 |
-
|
| 2779 |
-
There are more combinations than :func:`sequence_partitions` because
|
| 2780 |
-
empty sequence can fill everywhere, so we try to provide different
|
| 2781 |
-
utility for this.
|
| 2782 |
-
|
| 2783 |
-
Parameters
|
| 2784 |
-
==========
|
| 2785 |
-
|
| 2786 |
-
l : Sequence[T]
|
| 2787 |
-
A sequence of any Python objects (can be possibly empty)
|
| 2788 |
-
|
| 2789 |
-
n : int
|
| 2790 |
-
A positive integer
|
| 2791 |
-
|
| 2792 |
-
Yields
|
| 2793 |
-
======
|
| 2794 |
-
|
| 2795 |
-
out : list[Sequence[T]]
|
| 2796 |
-
A list of sequences with concatenation equals $l$.
|
| 2797 |
-
This should conform with the type of $l$.
|
| 2798 |
-
|
| 2799 |
-
Examples
|
| 2800 |
-
========
|
| 2801 |
-
|
| 2802 |
-
>>> from sympy.utilities.iterables import sequence_partitions_empty
|
| 2803 |
-
>>> for out in sequence_partitions_empty([1, 2, 3, 4], 2):
|
| 2804 |
-
... print(out)
|
| 2805 |
-
[[], [1, 2, 3, 4]]
|
| 2806 |
-
[[1], [2, 3, 4]]
|
| 2807 |
-
[[1, 2], [3, 4]]
|
| 2808 |
-
[[1, 2, 3], [4]]
|
| 2809 |
-
[[1, 2, 3, 4], []]
|
| 2810 |
-
|
| 2811 |
-
See Also
|
| 2812 |
-
========
|
| 2813 |
-
|
| 2814 |
-
sequence_partitions
|
| 2815 |
-
"""
|
| 2816 |
-
if n < 1:
|
| 2817 |
-
return
|
| 2818 |
-
if n == 1:
|
| 2819 |
-
yield [l]
|
| 2820 |
-
return
|
| 2821 |
-
for i in range(0, len(l) + 1):
|
| 2822 |
-
for part in sequence_partitions_empty(l[i:], n - 1):
|
| 2823 |
-
yield [l[:i]] + part
|
| 2824 |
-
|
| 2825 |
-
|
| 2826 |
-
def kbins(l, k, ordered=None):
|
| 2827 |
-
"""
|
| 2828 |
-
Return sequence ``l`` partitioned into ``k`` bins.
|
| 2829 |
-
|
| 2830 |
-
Examples
|
| 2831 |
-
========
|
| 2832 |
-
|
| 2833 |
-
The default is to give the items in the same order, but grouped
|
| 2834 |
-
into k partitions without any reordering:
|
| 2835 |
-
|
| 2836 |
-
>>> from sympy.utilities.iterables import kbins
|
| 2837 |
-
>>> for p in kbins(list(range(5)), 2):
|
| 2838 |
-
... print(p)
|
| 2839 |
-
...
|
| 2840 |
-
[[0], [1, 2, 3, 4]]
|
| 2841 |
-
[[0, 1], [2, 3, 4]]
|
| 2842 |
-
[[0, 1, 2], [3, 4]]
|
| 2843 |
-
[[0, 1, 2, 3], [4]]
|
| 2844 |
-
|
| 2845 |
-
The ``ordered`` flag is either None (to give the simple partition
|
| 2846 |
-
of the elements) or is a 2 digit integer indicating whether the order of
|
| 2847 |
-
the bins and the order of the items in the bins matters. Given::
|
| 2848 |
-
|
| 2849 |
-
A = [[0], [1, 2]]
|
| 2850 |
-
B = [[1, 2], [0]]
|
| 2851 |
-
C = [[2, 1], [0]]
|
| 2852 |
-
D = [[0], [2, 1]]
|
| 2853 |
-
|
| 2854 |
-
the following values for ``ordered`` have the shown meanings::
|
| 2855 |
-
|
| 2856 |
-
00 means A == B == C == D
|
| 2857 |
-
01 means A == B
|
| 2858 |
-
10 means A == D
|
| 2859 |
-
11 means A == A
|
| 2860 |
-
|
| 2861 |
-
>>> for ordered_flag in [None, 0, 1, 10, 11]:
|
| 2862 |
-
... print('ordered = %s' % ordered_flag)
|
| 2863 |
-
... for p in kbins(list(range(3)), 2, ordered=ordered_flag):
|
| 2864 |
-
... print(' %s' % p)
|
| 2865 |
-
...
|
| 2866 |
-
ordered = None
|
| 2867 |
-
[[0], [1, 2]]
|
| 2868 |
-
[[0, 1], [2]]
|
| 2869 |
-
ordered = 0
|
| 2870 |
-
[[0, 1], [2]]
|
| 2871 |
-
[[0, 2], [1]]
|
| 2872 |
-
[[0], [1, 2]]
|
| 2873 |
-
ordered = 1
|
| 2874 |
-
[[0], [1, 2]]
|
| 2875 |
-
[[0], [2, 1]]
|
| 2876 |
-
[[1], [0, 2]]
|
| 2877 |
-
[[1], [2, 0]]
|
| 2878 |
-
[[2], [0, 1]]
|
| 2879 |
-
[[2], [1, 0]]
|
| 2880 |
-
ordered = 10
|
| 2881 |
-
[[0, 1], [2]]
|
| 2882 |
-
[[2], [0, 1]]
|
| 2883 |
-
[[0, 2], [1]]
|
| 2884 |
-
[[1], [0, 2]]
|
| 2885 |
-
[[0], [1, 2]]
|
| 2886 |
-
[[1, 2], [0]]
|
| 2887 |
-
ordered = 11
|
| 2888 |
-
[[0], [1, 2]]
|
| 2889 |
-
[[0, 1], [2]]
|
| 2890 |
-
[[0], [2, 1]]
|
| 2891 |
-
[[0, 2], [1]]
|
| 2892 |
-
[[1], [0, 2]]
|
| 2893 |
-
[[1, 0], [2]]
|
| 2894 |
-
[[1], [2, 0]]
|
| 2895 |
-
[[1, 2], [0]]
|
| 2896 |
-
[[2], [0, 1]]
|
| 2897 |
-
[[2, 0], [1]]
|
| 2898 |
-
[[2], [1, 0]]
|
| 2899 |
-
[[2, 1], [0]]
|
| 2900 |
-
|
| 2901 |
-
See Also
|
| 2902 |
-
========
|
| 2903 |
-
|
| 2904 |
-
partitions, multiset_partitions
|
| 2905 |
-
|
| 2906 |
-
"""
|
| 2907 |
-
if ordered is None:
|
| 2908 |
-
yield from sequence_partitions(l, k)
|
| 2909 |
-
elif ordered == 11:
|
| 2910 |
-
for pl in multiset_permutations(l):
|
| 2911 |
-
pl = list(pl)
|
| 2912 |
-
yield from sequence_partitions(pl, k)
|
| 2913 |
-
elif ordered == 00:
|
| 2914 |
-
yield from multiset_partitions(l, k)
|
| 2915 |
-
elif ordered == 10:
|
| 2916 |
-
for p in multiset_partitions(l, k):
|
| 2917 |
-
for perm in permutations(p):
|
| 2918 |
-
yield list(perm)
|
| 2919 |
-
elif ordered == 1:
|
| 2920 |
-
for kgot, p in partitions(len(l), k, size=True):
|
| 2921 |
-
if kgot != k:
|
| 2922 |
-
continue
|
| 2923 |
-
for li in multiset_permutations(l):
|
| 2924 |
-
rv = []
|
| 2925 |
-
i = j = 0
|
| 2926 |
-
li = list(li)
|
| 2927 |
-
for size, multiplicity in sorted(p.items()):
|
| 2928 |
-
for m in range(multiplicity):
|
| 2929 |
-
j = i + size
|
| 2930 |
-
rv.append(li[i: j])
|
| 2931 |
-
i = j
|
| 2932 |
-
yield rv
|
| 2933 |
-
else:
|
| 2934 |
-
raise ValueError(
|
| 2935 |
-
'ordered must be one of 00, 01, 10 or 11, not %s' % ordered)
|
| 2936 |
-
|
| 2937 |
-
|
| 2938 |
-
def permute_signs(t):
|
| 2939 |
-
"""Return iterator in which the signs of non-zero elements
|
| 2940 |
-
of t are permuted.
|
| 2941 |
-
|
| 2942 |
-
Examples
|
| 2943 |
-
========
|
| 2944 |
-
|
| 2945 |
-
>>> from sympy.utilities.iterables import permute_signs
|
| 2946 |
-
>>> list(permute_signs((0, 1, 2)))
|
| 2947 |
-
[(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2)]
|
| 2948 |
-
"""
|
| 2949 |
-
for signs in product(*[(1, -1)]*(len(t) - t.count(0))):
|
| 2950 |
-
signs = list(signs)
|
| 2951 |
-
yield type(t)([i*signs.pop() if i else i for i in t])
|
| 2952 |
-
|
| 2953 |
-
|
| 2954 |
-
def signed_permutations(t):
|
| 2955 |
-
"""Return iterator in which the signs of non-zero elements
|
| 2956 |
-
of t and the order of the elements are permuted and all
|
| 2957 |
-
returned values are unique.
|
| 2958 |
-
|
| 2959 |
-
Examples
|
| 2960 |
-
========
|
| 2961 |
-
|
| 2962 |
-
>>> from sympy.utilities.iterables import signed_permutations
|
| 2963 |
-
>>> list(signed_permutations((0, 1, 2)))
|
| 2964 |
-
[(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2), (0, 2, 1),
|
| 2965 |
-
(0, -2, 1), (0, 2, -1), (0, -2, -1), (1, 0, 2), (-1, 0, 2),
|
| 2966 |
-
(1, 0, -2), (-1, 0, -2), (1, 2, 0), (-1, 2, 0), (1, -2, 0),
|
| 2967 |
-
(-1, -2, 0), (2, 0, 1), (-2, 0, 1), (2, 0, -1), (-2, 0, -1),
|
| 2968 |
-
(2, 1, 0), (-2, 1, 0), (2, -1, 0), (-2, -1, 0)]
|
| 2969 |
-
"""
|
| 2970 |
-
return (type(t)(i) for j in multiset_permutations(t)
|
| 2971 |
-
for i in permute_signs(j))
|
| 2972 |
-
|
| 2973 |
-
|
| 2974 |
-
def rotations(s, dir=1):
|
| 2975 |
-
"""Return a generator giving the items in s as list where
|
| 2976 |
-
each subsequent list has the items rotated to the left (default)
|
| 2977 |
-
or right (``dir=-1``) relative to the previous list.
|
| 2978 |
-
|
| 2979 |
-
Examples
|
| 2980 |
-
========
|
| 2981 |
-
|
| 2982 |
-
>>> from sympy import rotations
|
| 2983 |
-
>>> list(rotations([1,2,3]))
|
| 2984 |
-
[[1, 2, 3], [2, 3, 1], [3, 1, 2]]
|
| 2985 |
-
>>> list(rotations([1,2,3], -1))
|
| 2986 |
-
[[1, 2, 3], [3, 1, 2], [2, 3, 1]]
|
| 2987 |
-
"""
|
| 2988 |
-
seq = list(s)
|
| 2989 |
-
for i in range(len(seq)):
|
| 2990 |
-
yield seq
|
| 2991 |
-
seq = rotate_left(seq, dir)
|
| 2992 |
-
|
| 2993 |
-
|
| 2994 |
-
def roundrobin(*iterables):
|
| 2995 |
-
"""roundrobin recipe taken from itertools documentation:
|
| 2996 |
-
https://docs.python.org/3/library/itertools.html#itertools-recipes
|
| 2997 |
-
|
| 2998 |
-
roundrobin('ABC', 'D', 'EF') --> A D E B F C
|
| 2999 |
-
|
| 3000 |
-
Recipe credited to George Sakkis
|
| 3001 |
-
"""
|
| 3002 |
-
nexts = cycle(iter(it).__next__ for it in iterables)
|
| 3003 |
-
|
| 3004 |
-
pending = len(iterables)
|
| 3005 |
-
while pending:
|
| 3006 |
-
try:
|
| 3007 |
-
for nxt in nexts:
|
| 3008 |
-
yield nxt()
|
| 3009 |
-
except StopIteration:
|
| 3010 |
-
pending -= 1
|
| 3011 |
-
nexts = cycle(islice(nexts, pending))
|
| 3012 |
-
|
| 3013 |
-
|
| 3014 |
-
|
| 3015 |
-
class NotIterable:
|
| 3016 |
-
"""
|
| 3017 |
-
Use this as mixin when creating a class which is not supposed to
|
| 3018 |
-
return true when iterable() is called on its instances because
|
| 3019 |
-
calling list() on the instance, for example, would result in
|
| 3020 |
-
an infinite loop.
|
| 3021 |
-
"""
|
| 3022 |
-
pass
|
| 3023 |
-
|
| 3024 |
-
|
| 3025 |
-
def iterable(i, exclude=(str, dict, NotIterable)):
|
| 3026 |
-
"""
|
| 3027 |
-
Return a boolean indicating whether ``i`` is SymPy iterable.
|
| 3028 |
-
True also indicates that the iterator is finite, e.g. you can
|
| 3029 |
-
call list(...) on the instance.
|
| 3030 |
-
|
| 3031 |
-
When SymPy is working with iterables, it is almost always assuming
|
| 3032 |
-
that the iterable is not a string or a mapping, so those are excluded
|
| 3033 |
-
by default. If you want a pure Python definition, make exclude=None. To
|
| 3034 |
-
exclude multiple items, pass them as a tuple.
|
| 3035 |
-
|
| 3036 |
-
You can also set the _iterable attribute to True or False on your class,
|
| 3037 |
-
which will override the checks here, including the exclude test.
|
| 3038 |
-
|
| 3039 |
-
As a rule of thumb, some SymPy functions use this to check if they should
|
| 3040 |
-
recursively map over an object. If an object is technically iterable in
|
| 3041 |
-
the Python sense but does not desire this behavior (e.g., because its
|
| 3042 |
-
iteration is not finite, or because iteration might induce an unwanted
|
| 3043 |
-
computation), it should disable it by setting the _iterable attribute to False.
|
| 3044 |
-
|
| 3045 |
-
See also: is_sequence
|
| 3046 |
-
|
| 3047 |
-
Examples
|
| 3048 |
-
========
|
| 3049 |
-
|
| 3050 |
-
>>> from sympy.utilities.iterables import iterable
|
| 3051 |
-
>>> from sympy import Tuple
|
| 3052 |
-
>>> things = [[1], (1,), set([1]), Tuple(1), (j for j in [1, 2]), {1:2}, '1', 1]
|
| 3053 |
-
>>> for i in things:
|
| 3054 |
-
... print('%s %s' % (iterable(i), type(i)))
|
| 3055 |
-
True <... 'list'>
|
| 3056 |
-
True <... 'tuple'>
|
| 3057 |
-
True <... 'set'>
|
| 3058 |
-
True <class 'sympy.core.containers.Tuple'>
|
| 3059 |
-
True <... 'generator'>
|
| 3060 |
-
False <... 'dict'>
|
| 3061 |
-
False <... 'str'>
|
| 3062 |
-
False <... 'int'>
|
| 3063 |
-
|
| 3064 |
-
>>> iterable({}, exclude=None)
|
| 3065 |
-
True
|
| 3066 |
-
>>> iterable({}, exclude=str)
|
| 3067 |
-
True
|
| 3068 |
-
>>> iterable("no", exclude=str)
|
| 3069 |
-
False
|
| 3070 |
-
|
| 3071 |
-
"""
|
| 3072 |
-
if hasattr(i, '_iterable'):
|
| 3073 |
-
return i._iterable
|
| 3074 |
-
try:
|
| 3075 |
-
iter(i)
|
| 3076 |
-
except TypeError:
|
| 3077 |
-
return False
|
| 3078 |
-
if exclude:
|
| 3079 |
-
return not isinstance(i, exclude)
|
| 3080 |
-
return True
|
| 3081 |
-
|
| 3082 |
-
|
| 3083 |
-
def is_sequence(i, include=None):
|
| 3084 |
-
"""
|
| 3085 |
-
Return a boolean indicating whether ``i`` is a sequence in the SymPy
|
| 3086 |
-
sense. If anything that fails the test below should be included as
|
| 3087 |
-
being a sequence for your application, set 'include' to that object's
|
| 3088 |
-
type; multiple types should be passed as a tuple of types.
|
| 3089 |
-
|
| 3090 |
-
Note: although generators can generate a sequence, they often need special
|
| 3091 |
-
handling to make sure their elements are captured before the generator is
|
| 3092 |
-
exhausted, so these are not included by default in the definition of a
|
| 3093 |
-
sequence.
|
| 3094 |
-
|
| 3095 |
-
See also: iterable
|
| 3096 |
-
|
| 3097 |
-
Examples
|
| 3098 |
-
========
|
| 3099 |
-
|
| 3100 |
-
>>> from sympy.utilities.iterables import is_sequence
|
| 3101 |
-
>>> from types import GeneratorType
|
| 3102 |
-
>>> is_sequence([])
|
| 3103 |
-
True
|
| 3104 |
-
>>> is_sequence(set())
|
| 3105 |
-
False
|
| 3106 |
-
>>> is_sequence('abc')
|
| 3107 |
-
False
|
| 3108 |
-
>>> is_sequence('abc', include=str)
|
| 3109 |
-
True
|
| 3110 |
-
>>> generator = (c for c in 'abc')
|
| 3111 |
-
>>> is_sequence(generator)
|
| 3112 |
-
False
|
| 3113 |
-
>>> is_sequence(generator, include=(str, GeneratorType))
|
| 3114 |
-
True
|
| 3115 |
-
|
| 3116 |
-
"""
|
| 3117 |
-
return (hasattr(i, '__getitem__') and
|
| 3118 |
-
iterable(i) or
|
| 3119 |
-
bool(include) and
|
| 3120 |
-
isinstance(i, include))
|
| 3121 |
-
|
| 3122 |
-
|
| 3123 |
-
@deprecated(
|
| 3124 |
-
"""
|
| 3125 |
-
Using postorder_traversal from the sympy.utilities.iterables submodule is
|
| 3126 |
-
deprecated.
|
| 3127 |
-
|
| 3128 |
-
Instead, use postorder_traversal from the top-level sympy namespace, like
|
| 3129 |
-
|
| 3130 |
-
sympy.postorder_traversal
|
| 3131 |
-
""",
|
| 3132 |
-
deprecated_since_version="1.10",
|
| 3133 |
-
active_deprecations_target="deprecated-traversal-functions-moved")
|
| 3134 |
-
def postorder_traversal(node, keys=None):
|
| 3135 |
-
from sympy.core.traversal import postorder_traversal as _postorder_traversal
|
| 3136 |
-
return _postorder_traversal(node, keys=keys)
|
| 3137 |
-
|
| 3138 |
-
|
| 3139 |
-
@deprecated(
|
| 3140 |
-
"""
|
| 3141 |
-
Using interactive_traversal from the sympy.utilities.iterables submodule
|
| 3142 |
-
is deprecated.
|
| 3143 |
-
|
| 3144 |
-
Instead, use interactive_traversal from the top-level sympy namespace,
|
| 3145 |
-
like
|
| 3146 |
-
|
| 3147 |
-
sympy.interactive_traversal
|
| 3148 |
-
""",
|
| 3149 |
-
deprecated_since_version="1.10",
|
| 3150 |
-
active_deprecations_target="deprecated-traversal-functions-moved")
|
| 3151 |
-
def interactive_traversal(expr):
|
| 3152 |
-
from sympy.interactive.traversal import interactive_traversal as _interactive_traversal
|
| 3153 |
-
return _interactive_traversal(expr)
|
| 3154 |
-
|
| 3155 |
-
|
| 3156 |
-
@deprecated(
|
| 3157 |
-
"""
|
| 3158 |
-
Importing default_sort_key from sympy.utilities.iterables is deprecated.
|
| 3159 |
-
Use from sympy import default_sort_key instead.
|
| 3160 |
-
""",
|
| 3161 |
-
deprecated_since_version="1.10",
|
| 3162 |
-
active_deprecations_target="deprecated-sympy-core-compatibility",
|
| 3163 |
-
)
|
| 3164 |
-
def default_sort_key(*args, **kwargs):
|
| 3165 |
-
from sympy import default_sort_key as _default_sort_key
|
| 3166 |
-
return _default_sort_key(*args, **kwargs)
|
| 3167 |
-
|
| 3168 |
-
|
| 3169 |
-
@deprecated(
|
| 3170 |
-
"""
|
| 3171 |
-
Importing default_sort_key from sympy.utilities.iterables is deprecated.
|
| 3172 |
-
Use from sympy import default_sort_key instead.
|
| 3173 |
-
""",
|
| 3174 |
-
deprecated_since_version="1.10",
|
| 3175 |
-
active_deprecations_target="deprecated-sympy-core-compatibility",
|
| 3176 |
-
)
|
| 3177 |
-
def ordered(*args, **kwargs):
|
| 3178 |
-
from sympy import ordered as _ordered
|
| 3179 |
-
return _ordered(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/lambdify.py
DELETED
|
@@ -1,1592 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This module provides convenient functions to transform SymPy expressions to
|
| 3 |
-
lambda functions which can be used to calculate numerical values very fast.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
import builtins
|
| 10 |
-
import inspect
|
| 11 |
-
import keyword
|
| 12 |
-
import textwrap
|
| 13 |
-
import linecache
|
| 14 |
-
import weakref
|
| 15 |
-
|
| 16 |
-
# Required despite static analysis claiming it is not used
|
| 17 |
-
from sympy.external import import_module # noqa:F401
|
| 18 |
-
from sympy.utilities.exceptions import sympy_deprecation_warning
|
| 19 |
-
from sympy.utilities.decorator import doctest_depends_on
|
| 20 |
-
from sympy.utilities.iterables import (is_sequence, iterable,
|
| 21 |
-
NotIterable, flatten)
|
| 22 |
-
from sympy.utilities.misc import filldedent
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
__doctest_requires__ = {('lambdify',): ['numpy', 'tensorflow']}
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# Default namespaces, letting us define translations that can't be defined
|
| 29 |
-
# by simple variable maps, like I => 1j
|
| 30 |
-
MATH_DEFAULT: dict[str, Any] = {}
|
| 31 |
-
CMATH_DEFAULT: dict[str,Any] = {}
|
| 32 |
-
MPMATH_DEFAULT: dict[str, Any] = {}
|
| 33 |
-
NUMPY_DEFAULT: dict[str, Any] = {"I": 1j}
|
| 34 |
-
SCIPY_DEFAULT: dict[str, Any] = {"I": 1j}
|
| 35 |
-
CUPY_DEFAULT: dict[str, Any] = {"I": 1j}
|
| 36 |
-
JAX_DEFAULT: dict[str, Any] = {"I": 1j}
|
| 37 |
-
TENSORFLOW_DEFAULT: dict[str, Any] = {}
|
| 38 |
-
TORCH_DEFAULT: dict[str, Any] = {"I": 1j}
|
| 39 |
-
SYMPY_DEFAULT: dict[str, Any] = {}
|
| 40 |
-
NUMEXPR_DEFAULT: dict[str, Any] = {}
|
| 41 |
-
|
| 42 |
-
# These are the namespaces the lambda functions will use.
|
| 43 |
-
# These are separate from the names above because they are modified
|
| 44 |
-
# throughout this file, whereas the defaults should remain unmodified.
|
| 45 |
-
|
| 46 |
-
MATH = MATH_DEFAULT.copy()
|
| 47 |
-
CMATH = CMATH_DEFAULT.copy()
|
| 48 |
-
MPMATH = MPMATH_DEFAULT.copy()
|
| 49 |
-
NUMPY = NUMPY_DEFAULT.copy()
|
| 50 |
-
SCIPY = SCIPY_DEFAULT.copy()
|
| 51 |
-
CUPY = CUPY_DEFAULT.copy()
|
| 52 |
-
JAX = JAX_DEFAULT.copy()
|
| 53 |
-
TENSORFLOW = TENSORFLOW_DEFAULT.copy()
|
| 54 |
-
TORCH = TORCH_DEFAULT.copy()
|
| 55 |
-
SYMPY = SYMPY_DEFAULT.copy()
|
| 56 |
-
NUMEXPR = NUMEXPR_DEFAULT.copy()
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# Mappings between SymPy and other modules function names.
|
| 60 |
-
MATH_TRANSLATIONS = {
|
| 61 |
-
"ceiling": "ceil",
|
| 62 |
-
"E": "e",
|
| 63 |
-
"ln": "log",
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
CMATH_TRANSLATIONS: dict[str, str] = {}
|
| 67 |
-
|
| 68 |
-
# NOTE: This dictionary is reused in Function._eval_evalf to allow subclasses
|
| 69 |
-
# of Function to automatically evalf.
|
| 70 |
-
MPMATH_TRANSLATIONS = {
|
| 71 |
-
"Abs": "fabs",
|
| 72 |
-
"elliptic_k": "ellipk",
|
| 73 |
-
"elliptic_f": "ellipf",
|
| 74 |
-
"elliptic_e": "ellipe",
|
| 75 |
-
"elliptic_pi": "ellippi",
|
| 76 |
-
"ceiling": "ceil",
|
| 77 |
-
"chebyshevt": "chebyt",
|
| 78 |
-
"chebyshevu": "chebyu",
|
| 79 |
-
"assoc_legendre": "legenp",
|
| 80 |
-
"E": "e",
|
| 81 |
-
"I": "j",
|
| 82 |
-
"ln": "log",
|
| 83 |
-
#"lowergamma":"lower_gamma",
|
| 84 |
-
"oo": "inf",
|
| 85 |
-
#"uppergamma":"upper_gamma",
|
| 86 |
-
"LambertW": "lambertw",
|
| 87 |
-
"MutableDenseMatrix": "matrix",
|
| 88 |
-
"ImmutableDenseMatrix": "matrix",
|
| 89 |
-
"conjugate": "conj",
|
| 90 |
-
"dirichlet_eta": "altzeta",
|
| 91 |
-
"Ei": "ei",
|
| 92 |
-
"Shi": "shi",
|
| 93 |
-
"Chi": "chi",
|
| 94 |
-
"Si": "si",
|
| 95 |
-
"Ci": "ci",
|
| 96 |
-
"RisingFactorial": "rf",
|
| 97 |
-
"FallingFactorial": "ff",
|
| 98 |
-
"betainc_regularized": "betainc",
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
NUMPY_TRANSLATIONS: dict[str, str] = {
|
| 102 |
-
"Heaviside": "heaviside",
|
| 103 |
-
}
|
| 104 |
-
SCIPY_TRANSLATIONS: dict[str, str] = {
|
| 105 |
-
"jn" : "spherical_jn",
|
| 106 |
-
"yn" : "spherical_yn"
|
| 107 |
-
}
|
| 108 |
-
CUPY_TRANSLATIONS: dict[str, str] = {}
|
| 109 |
-
JAX_TRANSLATIONS: dict[str, str] = {}
|
| 110 |
-
|
| 111 |
-
TENSORFLOW_TRANSLATIONS: dict[str, str] = {}
|
| 112 |
-
TORCH_TRANSLATIONS: dict[str, str] = {}
|
| 113 |
-
|
| 114 |
-
NUMEXPR_TRANSLATIONS: dict[str, str] = {}
|
| 115 |
-
|
| 116 |
-
# Available modules:
|
| 117 |
-
MODULES = {
|
| 118 |
-
"math": (MATH, MATH_DEFAULT, MATH_TRANSLATIONS, ("from math import *",)),
|
| 119 |
-
"cmath": (CMATH, CMATH_DEFAULT, CMATH_TRANSLATIONS, ("import cmath; from cmath import *",)),
|
| 120 |
-
"mpmath": (MPMATH, MPMATH_DEFAULT, MPMATH_TRANSLATIONS, ("from mpmath import *",)),
|
| 121 |
-
"numpy": (NUMPY, NUMPY_DEFAULT, NUMPY_TRANSLATIONS, ("import numpy; from numpy import *; from numpy.linalg import *",)),
|
| 122 |
-
"scipy": (SCIPY, SCIPY_DEFAULT, SCIPY_TRANSLATIONS, ("import scipy; import numpy; from scipy.special import *",)),
|
| 123 |
-
"cupy": (CUPY, CUPY_DEFAULT, CUPY_TRANSLATIONS, ("import cupy",)),
|
| 124 |
-
"jax": (JAX, JAX_DEFAULT, JAX_TRANSLATIONS, ("import jax",)),
|
| 125 |
-
"tensorflow": (TENSORFLOW, TENSORFLOW_DEFAULT, TENSORFLOW_TRANSLATIONS, ("import tensorflow",)),
|
| 126 |
-
"torch": (TORCH, TORCH_DEFAULT, TORCH_TRANSLATIONS, ("import torch",)),
|
| 127 |
-
"sympy": (SYMPY, SYMPY_DEFAULT, {}, (
|
| 128 |
-
"from sympy.functions import *",
|
| 129 |
-
"from sympy.matrices import *",
|
| 130 |
-
"from sympy import Integral, pi, oo, nan, zoo, E, I",)),
|
| 131 |
-
"numexpr" : (NUMEXPR, NUMEXPR_DEFAULT, NUMEXPR_TRANSLATIONS,
|
| 132 |
-
("import_module('numexpr')", )),
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def _import(module, reload=False):
|
| 137 |
-
"""
|
| 138 |
-
Creates a global translation dictionary for module.
|
| 139 |
-
|
| 140 |
-
The argument module has to be one of the following strings: "math","cmath"
|
| 141 |
-
"mpmath", "numpy", "sympy", "tensorflow", "jax".
|
| 142 |
-
These dictionaries map names of Python functions to their equivalent in
|
| 143 |
-
other modules.
|
| 144 |
-
"""
|
| 145 |
-
try:
|
| 146 |
-
namespace, namespace_default, translations, import_commands = MODULES[
|
| 147 |
-
module]
|
| 148 |
-
except KeyError:
|
| 149 |
-
raise NameError(
|
| 150 |
-
"'%s' module cannot be used for lambdification" % module)
|
| 151 |
-
|
| 152 |
-
# Clear namespace or exit
|
| 153 |
-
if namespace != namespace_default:
|
| 154 |
-
# The namespace was already generated, don't do it again if not forced.
|
| 155 |
-
if reload:
|
| 156 |
-
namespace.clear()
|
| 157 |
-
namespace.update(namespace_default)
|
| 158 |
-
else:
|
| 159 |
-
return
|
| 160 |
-
|
| 161 |
-
for import_command in import_commands:
|
| 162 |
-
if import_command.startswith('import_module'):
|
| 163 |
-
module = eval(import_command)
|
| 164 |
-
|
| 165 |
-
if module is not None:
|
| 166 |
-
namespace.update(module.__dict__)
|
| 167 |
-
continue
|
| 168 |
-
else:
|
| 169 |
-
try:
|
| 170 |
-
exec(import_command, {}, namespace)
|
| 171 |
-
continue
|
| 172 |
-
except ImportError:
|
| 173 |
-
pass
|
| 174 |
-
|
| 175 |
-
raise ImportError(
|
| 176 |
-
"Cannot import '%s' with '%s' command" % (module, import_command))
|
| 177 |
-
|
| 178 |
-
# Add translated names to namespace
|
| 179 |
-
for sympyname, translation in translations.items():
|
| 180 |
-
namespace[sympyname] = namespace[translation]
|
| 181 |
-
|
| 182 |
-
# For computing the modulus of a SymPy expression we use the builtin abs
|
| 183 |
-
# function, instead of the previously used fabs function for all
|
| 184 |
-
# translation modules. This is because the fabs function in the math
|
| 185 |
-
# module does not accept complex valued arguments. (see issue 9474). The
|
| 186 |
-
# only exception, where we don't use the builtin abs function is the
|
| 187 |
-
# mpmath translation module, because mpmath.fabs returns mpf objects in
|
| 188 |
-
# contrast to abs().
|
| 189 |
-
if 'Abs' not in namespace:
|
| 190 |
-
namespace['Abs'] = abs
|
| 191 |
-
|
| 192 |
-
# Used for dynamically generated filenames that are inserted into the
|
| 193 |
-
# linecache.
|
| 194 |
-
_lambdify_generated_counter = 1
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
@doctest_depends_on(modules=('numpy', 'scipy', 'tensorflow',), python_version=(3,))
|
| 198 |
-
def lambdify(args, expr, modules=None, printer=None, use_imps=True,
|
| 199 |
-
dummify=False, cse=False, docstring_limit=1000):
|
| 200 |
-
"""Convert a SymPy expression into a function that allows for fast
|
| 201 |
-
numeric evaluation.
|
| 202 |
-
|
| 203 |
-
.. warning::
|
| 204 |
-
This function uses ``exec``, and thus should not be used on
|
| 205 |
-
unsanitized input.
|
| 206 |
-
|
| 207 |
-
.. deprecated:: 1.7
|
| 208 |
-
Passing a set for the *args* parameter is deprecated as sets are
|
| 209 |
-
unordered. Use an ordered iterable such as a list or tuple.
|
| 210 |
-
|
| 211 |
-
Explanation
|
| 212 |
-
===========
|
| 213 |
-
|
| 214 |
-
For example, to convert the SymPy expression ``sin(x) + cos(x)`` to an
|
| 215 |
-
equivalent NumPy function that numerically evaluates it:
|
| 216 |
-
|
| 217 |
-
>>> from sympy import sin, cos, symbols, lambdify
|
| 218 |
-
>>> import numpy as np
|
| 219 |
-
>>> x = symbols('x')
|
| 220 |
-
>>> expr = sin(x) + cos(x)
|
| 221 |
-
>>> expr
|
| 222 |
-
sin(x) + cos(x)
|
| 223 |
-
>>> f = lambdify(x, expr, 'numpy')
|
| 224 |
-
>>> a = np.array([1, 2])
|
| 225 |
-
>>> f(a)
|
| 226 |
-
[1.38177329 0.49315059]
|
| 227 |
-
|
| 228 |
-
The primary purpose of this function is to provide a bridge from SymPy
|
| 229 |
-
expressions to numerical libraries such as NumPy, SciPy, NumExpr, mpmath,
|
| 230 |
-
and tensorflow. In general, SymPy functions do not work with objects from
|
| 231 |
-
other libraries, such as NumPy arrays, and functions from numeric
|
| 232 |
-
libraries like NumPy or mpmath do not work on SymPy expressions.
|
| 233 |
-
``lambdify`` bridges the two by converting a SymPy expression to an
|
| 234 |
-
equivalent numeric function.
|
| 235 |
-
|
| 236 |
-
The basic workflow with ``lambdify`` is to first create a SymPy expression
|
| 237 |
-
representing whatever mathematical function you wish to evaluate. This
|
| 238 |
-
should be done using only SymPy functions and expressions. Then, use
|
| 239 |
-
``lambdify`` to convert this to an equivalent function for numerical
|
| 240 |
-
evaluation. For instance, above we created ``expr`` using the SymPy symbol
|
| 241 |
-
``x`` and SymPy functions ``sin`` and ``cos``, then converted it to an
|
| 242 |
-
equivalent NumPy function ``f``, and called it on a NumPy array ``a``.
|
| 243 |
-
|
| 244 |
-
Parameters
|
| 245 |
-
==========
|
| 246 |
-
|
| 247 |
-
args : List[Symbol]
|
| 248 |
-
A variable or a list of variables whose nesting represents the
|
| 249 |
-
nesting of the arguments that will be passed to the function.
|
| 250 |
-
|
| 251 |
-
Variables can be symbols, undefined functions, or matrix symbols.
|
| 252 |
-
|
| 253 |
-
>>> from sympy import Eq
|
| 254 |
-
>>> from sympy.abc import x, y, z
|
| 255 |
-
|
| 256 |
-
The list of variables should match the structure of how the
|
| 257 |
-
arguments will be passed to the function. Simply enclose the
|
| 258 |
-
parameters as they will be passed in a list.
|
| 259 |
-
|
| 260 |
-
To call a function like ``f(x)`` then ``[x]``
|
| 261 |
-
should be the first argument to ``lambdify``; for this
|
| 262 |
-
case a single ``x`` can also be used:
|
| 263 |
-
|
| 264 |
-
>>> f = lambdify(x, x + 1)
|
| 265 |
-
>>> f(1)
|
| 266 |
-
2
|
| 267 |
-
>>> f = lambdify([x], x + 1)
|
| 268 |
-
>>> f(1)
|
| 269 |
-
2
|
| 270 |
-
|
| 271 |
-
To call a function like ``f(x, y)`` then ``[x, y]`` will
|
| 272 |
-
be the first argument of the ``lambdify``:
|
| 273 |
-
|
| 274 |
-
>>> f = lambdify([x, y], x + y)
|
| 275 |
-
>>> f(1, 1)
|
| 276 |
-
2
|
| 277 |
-
|
| 278 |
-
To call a function with a single 3-element tuple like
|
| 279 |
-
``f((x, y, z))`` then ``[(x, y, z)]`` will be the first
|
| 280 |
-
argument of the ``lambdify``:
|
| 281 |
-
|
| 282 |
-
>>> f = lambdify([(x, y, z)], Eq(z**2, x**2 + y**2))
|
| 283 |
-
>>> f((3, 4, 5))
|
| 284 |
-
True
|
| 285 |
-
|
| 286 |
-
If two args will be passed and the first is a scalar but
|
| 287 |
-
the second is a tuple with two arguments then the items
|
| 288 |
-
in the list should match that structure:
|
| 289 |
-
|
| 290 |
-
>>> f = lambdify([x, (y, z)], x + y + z)
|
| 291 |
-
>>> f(1, (2, 3))
|
| 292 |
-
6
|
| 293 |
-
|
| 294 |
-
expr : Expr
|
| 295 |
-
An expression, list of expressions, or matrix to be evaluated.
|
| 296 |
-
|
| 297 |
-
Lists may be nested.
|
| 298 |
-
If the expression is a list, the output will also be a list.
|
| 299 |
-
|
| 300 |
-
>>> f = lambdify(x, [x, [x + 1, x + 2]])
|
| 301 |
-
>>> f(1)
|
| 302 |
-
[1, [2, 3]]
|
| 303 |
-
|
| 304 |
-
If it is a matrix, an array will be returned (for the NumPy module).
|
| 305 |
-
|
| 306 |
-
>>> from sympy import Matrix
|
| 307 |
-
>>> f = lambdify(x, Matrix([x, x + 1]))
|
| 308 |
-
>>> f(1)
|
| 309 |
-
[[1]
|
| 310 |
-
[2]]
|
| 311 |
-
|
| 312 |
-
Note that the argument order here (variables then expression) is used
|
| 313 |
-
to emulate the Python ``lambda`` keyword. ``lambdify(x, expr)`` works
|
| 314 |
-
(roughly) like ``lambda x: expr``
|
| 315 |
-
(see :ref:`lambdify-how-it-works` below).
|
| 316 |
-
|
| 317 |
-
modules : str, optional
|
| 318 |
-
Specifies the numeric library to use.
|
| 319 |
-
|
| 320 |
-
If not specified, *modules* defaults to:
|
| 321 |
-
|
| 322 |
-
- ``["scipy", "numpy"]`` if SciPy is installed
|
| 323 |
-
- ``["numpy"]`` if only NumPy is installed
|
| 324 |
-
- ``["math","cmath", "mpmath", "sympy"]`` if neither is installed.
|
| 325 |
-
|
| 326 |
-
That is, SymPy functions are replaced as far as possible by
|
| 327 |
-
either ``scipy`` or ``numpy`` functions if available, and Python's
|
| 328 |
-
standard library ``math`` and ``cmath``, or ``mpmath`` functions otherwise.
|
| 329 |
-
|
| 330 |
-
*modules* can be one of the following types:
|
| 331 |
-
|
| 332 |
-
- The strings ``"math"``, ``"cmath"``, ``"mpmath"``, ``"numpy"``, ``"numexpr"``,
|
| 333 |
-
``"scipy"``, ``"sympy"``, or ``"tensorflow"`` or ``"jax"``. This uses the
|
| 334 |
-
corresponding printer and namespace mapping for that module.
|
| 335 |
-
- A module (e.g., ``math``). This uses the global namespace of the
|
| 336 |
-
module. If the module is one of the above known modules, it will
|
| 337 |
-
also use the corresponding printer and namespace mapping
|
| 338 |
-
(i.e., ``modules=numpy`` is equivalent to ``modules="numpy"``).
|
| 339 |
-
- A dictionary that maps names of SymPy functions to arbitrary
|
| 340 |
-
functions
|
| 341 |
-
(e.g., ``{'sin': custom_sin}``).
|
| 342 |
-
- A list that contains a mix of the arguments above, with higher
|
| 343 |
-
priority given to entries appearing first
|
| 344 |
-
(e.g., to use the NumPy module but override the ``sin`` function
|
| 345 |
-
with a custom version, you can use
|
| 346 |
-
``[{'sin': custom_sin}, 'numpy']``).
|
| 347 |
-
|
| 348 |
-
dummify : bool, optional
|
| 349 |
-
Whether or not the variables in the provided expression that are not
|
| 350 |
-
valid Python identifiers are substituted with dummy symbols.
|
| 351 |
-
|
| 352 |
-
This allows for undefined functions like ``Function('f')(t)`` to be
|
| 353 |
-
supplied as arguments. By default, the variables are only dummified
|
| 354 |
-
if they are not valid Python identifiers.
|
| 355 |
-
|
| 356 |
-
Set ``dummify=True`` to replace all arguments with dummy symbols
|
| 357 |
-
(if ``args`` is not a string) - for example, to ensure that the
|
| 358 |
-
arguments do not redefine any built-in names.
|
| 359 |
-
|
| 360 |
-
cse : bool, or callable, optional
|
| 361 |
-
Large expressions can be computed more efficiently when
|
| 362 |
-
common subexpressions are identified and precomputed before
|
| 363 |
-
being used multiple time. Finding the subexpressions will make
|
| 364 |
-
creation of the 'lambdify' function slower, however.
|
| 365 |
-
|
| 366 |
-
When ``True``, ``sympy.simplify.cse`` is used, otherwise (the default)
|
| 367 |
-
the user may pass a function matching the ``cse`` signature.
|
| 368 |
-
|
| 369 |
-
docstring_limit : int or None
|
| 370 |
-
When lambdifying large expressions, a significant proportion of the time
|
| 371 |
-
spent inside ``lambdify`` is spent producing a string representation of
|
| 372 |
-
the expression for use in the automatically generated docstring of the
|
| 373 |
-
returned function. For expressions containing hundreds or more nodes the
|
| 374 |
-
resulting docstring often becomes so long and dense that it is difficult
|
| 375 |
-
to read. To reduce the runtime of lambdify, the rendering of the full
|
| 376 |
-
expression inside the docstring can be disabled.
|
| 377 |
-
|
| 378 |
-
When ``None``, the full expression is rendered in the docstring. When
|
| 379 |
-
``0`` or a negative ``int``, an ellipsis is rendering in the docstring
|
| 380 |
-
instead of the expression. When a strictly positive ``int``, if the
|
| 381 |
-
number of nodes in the expression exceeds ``docstring_limit`` an
|
| 382 |
-
ellipsis is rendered in the docstring, otherwise a string representation
|
| 383 |
-
of the expression is rendered as normal. The default is ``1000``.
|
| 384 |
-
|
| 385 |
-
Examples
|
| 386 |
-
========
|
| 387 |
-
|
| 388 |
-
>>> from sympy.utilities.lambdify import implemented_function
|
| 389 |
-
>>> from sympy import sqrt, sin, Matrix
|
| 390 |
-
>>> from sympy import Function
|
| 391 |
-
>>> from sympy.abc import w, x, y, z
|
| 392 |
-
|
| 393 |
-
>>> f = lambdify(x, x**2)
|
| 394 |
-
>>> f(2)
|
| 395 |
-
4
|
| 396 |
-
>>> f = lambdify((x, y, z), [z, y, x])
|
| 397 |
-
>>> f(1,2,3)
|
| 398 |
-
[3, 2, 1]
|
| 399 |
-
>>> f = lambdify(x, sqrt(x))
|
| 400 |
-
>>> f(4)
|
| 401 |
-
2.0
|
| 402 |
-
>>> f = lambdify((x, y), sin(x*y)**2)
|
| 403 |
-
>>> f(0, 5)
|
| 404 |
-
0.0
|
| 405 |
-
>>> row = lambdify((x, y), Matrix((x, x + y)).T, modules='sympy')
|
| 406 |
-
>>> row(1, 2)
|
| 407 |
-
Matrix([[1, 3]])
|
| 408 |
-
|
| 409 |
-
``lambdify`` can be used to translate SymPy expressions into mpmath
|
| 410 |
-
functions. This may be preferable to using ``evalf`` (which uses mpmath on
|
| 411 |
-
the backend) in some cases.
|
| 412 |
-
|
| 413 |
-
>>> f = lambdify(x, sin(x), 'mpmath')
|
| 414 |
-
>>> f(1)
|
| 415 |
-
0.8414709848078965
|
| 416 |
-
|
| 417 |
-
Tuple arguments are handled and the lambdified function should
|
| 418 |
-
be called with the same type of arguments as were used to create
|
| 419 |
-
the function:
|
| 420 |
-
|
| 421 |
-
>>> f = lambdify((x, (y, z)), x + y)
|
| 422 |
-
>>> f(1, (2, 4))
|
| 423 |
-
3
|
| 424 |
-
|
| 425 |
-
The ``flatten`` function can be used to always work with flattened
|
| 426 |
-
arguments:
|
| 427 |
-
|
| 428 |
-
>>> from sympy.utilities.iterables import flatten
|
| 429 |
-
>>> args = w, (x, (y, z))
|
| 430 |
-
>>> vals = 1, (2, (3, 4))
|
| 431 |
-
>>> f = lambdify(flatten(args), w + x + y + z)
|
| 432 |
-
>>> f(*flatten(vals))
|
| 433 |
-
10
|
| 434 |
-
|
| 435 |
-
Functions present in ``expr`` can also carry their own numerical
|
| 436 |
-
implementations, in a callable attached to the ``_imp_`` attribute. This
|
| 437 |
-
can be used with undefined functions using the ``implemented_function``
|
| 438 |
-
factory:
|
| 439 |
-
|
| 440 |
-
>>> f = implemented_function(Function('f'), lambda x: x+1)
|
| 441 |
-
>>> func = lambdify(x, f(x))
|
| 442 |
-
>>> func(4)
|
| 443 |
-
5
|
| 444 |
-
|
| 445 |
-
``lambdify`` always prefers ``_imp_`` implementations to implementations
|
| 446 |
-
in other namespaces, unless the ``use_imps`` input parameter is False.
|
| 447 |
-
|
| 448 |
-
Usage with Tensorflow:
|
| 449 |
-
|
| 450 |
-
>>> import tensorflow as tf
|
| 451 |
-
>>> from sympy import Max, sin, lambdify
|
| 452 |
-
>>> from sympy.abc import x
|
| 453 |
-
|
| 454 |
-
>>> f = Max(x, sin(x))
|
| 455 |
-
>>> func = lambdify(x, f, 'tensorflow')
|
| 456 |
-
|
| 457 |
-
After tensorflow v2, eager execution is enabled by default.
|
| 458 |
-
If you want to get the compatible result across tensorflow v1 and v2
|
| 459 |
-
as same as this tutorial, run this line.
|
| 460 |
-
|
| 461 |
-
>>> tf.compat.v1.enable_eager_execution()
|
| 462 |
-
|
| 463 |
-
If you have eager execution enabled, you can get the result out
|
| 464 |
-
immediately as you can use numpy.
|
| 465 |
-
|
| 466 |
-
If you pass tensorflow objects, you may get an ``EagerTensor``
|
| 467 |
-
object instead of value.
|
| 468 |
-
|
| 469 |
-
>>> result = func(tf.constant(1.0))
|
| 470 |
-
>>> print(result)
|
| 471 |
-
tf.Tensor(1.0, shape=(), dtype=float32)
|
| 472 |
-
>>> print(result.__class__)
|
| 473 |
-
<class 'tensorflow.python.framework.ops.EagerTensor'>
|
| 474 |
-
|
| 475 |
-
You can use ``.numpy()`` to get the numpy value of the tensor.
|
| 476 |
-
|
| 477 |
-
>>> result.numpy()
|
| 478 |
-
1.0
|
| 479 |
-
|
| 480 |
-
>>> var = tf.Variable(2.0)
|
| 481 |
-
>>> result = func(var) # also works for tf.Variable and tf.Placeholder
|
| 482 |
-
>>> result.numpy()
|
| 483 |
-
2.0
|
| 484 |
-
|
| 485 |
-
And it works with any shape array.
|
| 486 |
-
|
| 487 |
-
>>> tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])
|
| 488 |
-
>>> result = func(tensor)
|
| 489 |
-
>>> result.numpy()
|
| 490 |
-
[[1. 2.]
|
| 491 |
-
[3. 4.]]
|
| 492 |
-
|
| 493 |
-
Notes
|
| 494 |
-
=====
|
| 495 |
-
|
| 496 |
-
- For functions involving large array calculations, numexpr can provide a
|
| 497 |
-
significant speedup over numpy. Please note that the available functions
|
| 498 |
-
for numexpr are more limited than numpy but can be expanded with
|
| 499 |
-
``implemented_function`` and user defined subclasses of Function. If
|
| 500 |
-
specified, numexpr may be the only option in modules. The official list
|
| 501 |
-
of numexpr functions can be found at:
|
| 502 |
-
https://numexpr.readthedocs.io/en/latest/user_guide.html#supported-functions
|
| 503 |
-
|
| 504 |
-
- In the above examples, the generated functions can accept scalar
|
| 505 |
-
values or numpy arrays as arguments. However, in some cases
|
| 506 |
-
the generated function relies on the input being a numpy array:
|
| 507 |
-
|
| 508 |
-
>>> import numpy
|
| 509 |
-
>>> from sympy import Piecewise
|
| 510 |
-
>>> from sympy.testing.pytest import ignore_warnings
|
| 511 |
-
>>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "numpy")
|
| 512 |
-
|
| 513 |
-
>>> with ignore_warnings(RuntimeWarning):
|
| 514 |
-
... f(numpy.array([-1, 0, 1, 2]))
|
| 515 |
-
[-1. 0. 1. 0.5]
|
| 516 |
-
|
| 517 |
-
>>> f(0)
|
| 518 |
-
Traceback (most recent call last):
|
| 519 |
-
...
|
| 520 |
-
ZeroDivisionError: division by zero
|
| 521 |
-
|
| 522 |
-
In such cases, the input should be wrapped in a numpy array:
|
| 523 |
-
|
| 524 |
-
>>> with ignore_warnings(RuntimeWarning):
|
| 525 |
-
... float(f(numpy.array([0])))
|
| 526 |
-
0.0
|
| 527 |
-
|
| 528 |
-
Or if numpy functionality is not required another module can be used:
|
| 529 |
-
|
| 530 |
-
>>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "math")
|
| 531 |
-
>>> f(0)
|
| 532 |
-
0
|
| 533 |
-
|
| 534 |
-
.. _lambdify-how-it-works:
|
| 535 |
-
|
| 536 |
-
How it works
|
| 537 |
-
============
|
| 538 |
-
|
| 539 |
-
When using this function, it helps a great deal to have an idea of what it
|
| 540 |
-
is doing. At its core, lambdify is nothing more than a namespace
|
| 541 |
-
translation, on top of a special printer that makes some corner cases work
|
| 542 |
-
properly.
|
| 543 |
-
|
| 544 |
-
To understand lambdify, first we must properly understand how Python
|
| 545 |
-
namespaces work. Say we had two files. One called ``sin_cos_sympy.py``,
|
| 546 |
-
with
|
| 547 |
-
|
| 548 |
-
.. code:: python
|
| 549 |
-
|
| 550 |
-
# sin_cos_sympy.py
|
| 551 |
-
|
| 552 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 553 |
-
|
| 554 |
-
def sin_cos(x):
|
| 555 |
-
return sin(x) + cos(x)
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
and one called ``sin_cos_numpy.py`` with
|
| 559 |
-
|
| 560 |
-
.. code:: python
|
| 561 |
-
|
| 562 |
-
# sin_cos_numpy.py
|
| 563 |
-
|
| 564 |
-
from numpy import sin, cos
|
| 565 |
-
|
| 566 |
-
def sin_cos(x):
|
| 567 |
-
return sin(x) + cos(x)
|
| 568 |
-
|
| 569 |
-
The two files define an identical function ``sin_cos``. However, in the
|
| 570 |
-
first file, ``sin`` and ``cos`` are defined as the SymPy ``sin`` and
|
| 571 |
-
``cos``. In the second, they are defined as the NumPy versions.
|
| 572 |
-
|
| 573 |
-
If we were to import the first file and use the ``sin_cos`` function, we
|
| 574 |
-
would get something like
|
| 575 |
-
|
| 576 |
-
>>> from sin_cos_sympy import sin_cos # doctest: +SKIP
|
| 577 |
-
>>> sin_cos(1) # doctest: +SKIP
|
| 578 |
-
cos(1) + sin(1)
|
| 579 |
-
|
| 580 |
-
On the other hand, if we imported ``sin_cos`` from the second file, we
|
| 581 |
-
would get
|
| 582 |
-
|
| 583 |
-
>>> from sin_cos_numpy import sin_cos # doctest: +SKIP
|
| 584 |
-
>>> sin_cos(1) # doctest: +SKIP
|
| 585 |
-
1.38177329068
|
| 586 |
-
|
| 587 |
-
In the first case we got a symbolic output, because it used the symbolic
|
| 588 |
-
``sin`` and ``cos`` functions from SymPy. In the second, we got a numeric
|
| 589 |
-
result, because ``sin_cos`` used the numeric ``sin`` and ``cos`` functions
|
| 590 |
-
from NumPy. But notice that the versions of ``sin`` and ``cos`` that were
|
| 591 |
-
used was not inherent to the ``sin_cos`` function definition. Both
|
| 592 |
-
``sin_cos`` definitions are exactly the same. Rather, it was based on the
|
| 593 |
-
names defined at the module where the ``sin_cos`` function was defined.
|
| 594 |
-
|
| 595 |
-
The key point here is that when function in Python references a name that
|
| 596 |
-
is not defined in the function, that name is looked up in the "global"
|
| 597 |
-
namespace of the module where that function is defined.
|
| 598 |
-
|
| 599 |
-
Now, in Python, we can emulate this behavior without actually writing a
|
| 600 |
-
file to disk using the ``exec`` function. ``exec`` takes a string
|
| 601 |
-
containing a block of Python code, and a dictionary that should contain
|
| 602 |
-
the global variables of the module. It then executes the code "in" that
|
| 603 |
-
dictionary, as if it were the module globals. The following is equivalent
|
| 604 |
-
to the ``sin_cos`` defined in ``sin_cos_sympy.py``:
|
| 605 |
-
|
| 606 |
-
>>> import sympy
|
| 607 |
-
>>> module_dictionary = {'sin': sympy.sin, 'cos': sympy.cos}
|
| 608 |
-
>>> exec('''
|
| 609 |
-
... def sin_cos(x):
|
| 610 |
-
... return sin(x) + cos(x)
|
| 611 |
-
... ''', module_dictionary)
|
| 612 |
-
>>> sin_cos = module_dictionary['sin_cos']
|
| 613 |
-
>>> sin_cos(1)
|
| 614 |
-
cos(1) + sin(1)
|
| 615 |
-
|
| 616 |
-
and similarly with ``sin_cos_numpy``:
|
| 617 |
-
|
| 618 |
-
>>> import numpy
|
| 619 |
-
>>> module_dictionary = {'sin': numpy.sin, 'cos': numpy.cos}
|
| 620 |
-
>>> exec('''
|
| 621 |
-
... def sin_cos(x):
|
| 622 |
-
... return sin(x) + cos(x)
|
| 623 |
-
... ''', module_dictionary)
|
| 624 |
-
>>> sin_cos = module_dictionary['sin_cos']
|
| 625 |
-
>>> sin_cos(1)
|
| 626 |
-
1.38177329068
|
| 627 |
-
|
| 628 |
-
So now we can get an idea of how ``lambdify`` works. The name "lambdify"
|
| 629 |
-
comes from the fact that we can think of something like ``lambdify(x,
|
| 630 |
-
sin(x) + cos(x), 'numpy')`` as ``lambda x: sin(x) + cos(x)``, where
|
| 631 |
-
``sin`` and ``cos`` come from the ``numpy`` namespace. This is also why
|
| 632 |
-
the symbols argument is first in ``lambdify``, as opposed to most SymPy
|
| 633 |
-
functions where it comes after the expression: to better mimic the
|
| 634 |
-
``lambda`` keyword.
|
| 635 |
-
|
| 636 |
-
``lambdify`` takes the input expression (like ``sin(x) + cos(x)``) and
|
| 637 |
-
|
| 638 |
-
1. Converts it to a string
|
| 639 |
-
2. Creates a module globals dictionary based on the modules that are
|
| 640 |
-
passed in (by default, it uses the NumPy module)
|
| 641 |
-
3. Creates the string ``"def func({vars}): return {expr}"``, where ``{vars}`` is the
|
| 642 |
-
list of variables separated by commas, and ``{expr}`` is the string
|
| 643 |
-
created in step 1., then ``exec``s that string with the module globals
|
| 644 |
-
namespace and returns ``func``.
|
| 645 |
-
|
| 646 |
-
In fact, functions returned by ``lambdify`` support inspection. So you can
|
| 647 |
-
see exactly how they are defined by using ``inspect.getsource``, or ``??`` if you
|
| 648 |
-
are using IPython or the Jupyter notebook.
|
| 649 |
-
|
| 650 |
-
>>> f = lambdify(x, sin(x) + cos(x))
|
| 651 |
-
>>> import inspect
|
| 652 |
-
>>> print(inspect.getsource(f))
|
| 653 |
-
def _lambdifygenerated(x):
|
| 654 |
-
return sin(x) + cos(x)
|
| 655 |
-
|
| 656 |
-
This shows us the source code of the function, but not the namespace it
|
| 657 |
-
was defined in. We can inspect that by looking at the ``__globals__``
|
| 658 |
-
attribute of ``f``:
|
| 659 |
-
|
| 660 |
-
>>> f.__globals__['sin']
|
| 661 |
-
<ufunc 'sin'>
|
| 662 |
-
>>> f.__globals__['cos']
|
| 663 |
-
<ufunc 'cos'>
|
| 664 |
-
>>> f.__globals__['sin'] is numpy.sin
|
| 665 |
-
True
|
| 666 |
-
|
| 667 |
-
This shows us that ``sin`` and ``cos`` in the namespace of ``f`` will be
|
| 668 |
-
``numpy.sin`` and ``numpy.cos``.
|
| 669 |
-
|
| 670 |
-
Note that there are some convenience layers in each of these steps, but at
|
| 671 |
-
the core, this is how ``lambdify`` works. Step 1 is done using the
|
| 672 |
-
``LambdaPrinter`` printers defined in the printing module (see
|
| 673 |
-
:mod:`sympy.printing.lambdarepr`). This allows different SymPy expressions
|
| 674 |
-
to define how they should be converted to a string for different modules.
|
| 675 |
-
You can change which printer ``lambdify`` uses by passing a custom printer
|
| 676 |
-
in to the ``printer`` argument.
|
| 677 |
-
|
| 678 |
-
Step 2 is augmented by certain translations. There are default
|
| 679 |
-
translations for each module, but you can provide your own by passing a
|
| 680 |
-
list to the ``modules`` argument. For instance,
|
| 681 |
-
|
| 682 |
-
>>> def mysin(x):
|
| 683 |
-
... print('taking the sin of', x)
|
| 684 |
-
... return numpy.sin(x)
|
| 685 |
-
...
|
| 686 |
-
>>> f = lambdify(x, sin(x), [{'sin': mysin}, 'numpy'])
|
| 687 |
-
>>> f(1)
|
| 688 |
-
taking the sin of 1
|
| 689 |
-
0.8414709848078965
|
| 690 |
-
|
| 691 |
-
The globals dictionary is generated from the list by merging the
|
| 692 |
-
dictionary ``{'sin': mysin}`` and the module dictionary for NumPy. The
|
| 693 |
-
merging is done so that earlier items take precedence, which is why
|
| 694 |
-
``mysin`` is used above instead of ``numpy.sin``.
|
| 695 |
-
|
| 696 |
-
If you want to modify the way ``lambdify`` works for a given function, it
|
| 697 |
-
is usually easiest to do so by modifying the globals dictionary as such.
|
| 698 |
-
In more complicated cases, it may be necessary to create and pass in a
|
| 699 |
-
custom printer.
|
| 700 |
-
|
| 701 |
-
Finally, step 3 is augmented with certain convenience operations, such as
|
| 702 |
-
the addition of a docstring.
|
| 703 |
-
|
| 704 |
-
Understanding how ``lambdify`` works can make it easier to avoid certain
|
| 705 |
-
gotchas when using it. For instance, a common mistake is to create a
|
| 706 |
-
lambdified function for one module (say, NumPy), and pass it objects from
|
| 707 |
-
another (say, a SymPy expression).
|
| 708 |
-
|
| 709 |
-
For instance, say we create
|
| 710 |
-
|
| 711 |
-
>>> from sympy.abc import x
|
| 712 |
-
>>> f = lambdify(x, x + 1, 'numpy')
|
| 713 |
-
|
| 714 |
-
Now if we pass in a NumPy array, we get that array plus 1
|
| 715 |
-
|
| 716 |
-
>>> import numpy
|
| 717 |
-
>>> a = numpy.array([1, 2])
|
| 718 |
-
>>> f(a)
|
| 719 |
-
[2 3]
|
| 720 |
-
|
| 721 |
-
But what happens if you make the mistake of passing in a SymPy expression
|
| 722 |
-
instead of a NumPy array:
|
| 723 |
-
|
| 724 |
-
>>> f(x + 1)
|
| 725 |
-
x + 2
|
| 726 |
-
|
| 727 |
-
This worked, but it was only by accident. Now take a different lambdified
|
| 728 |
-
function:
|
| 729 |
-
|
| 730 |
-
>>> from sympy import sin
|
| 731 |
-
>>> g = lambdify(x, x + sin(x), 'numpy')
|
| 732 |
-
|
| 733 |
-
This works as expected on NumPy arrays:
|
| 734 |
-
|
| 735 |
-
>>> g(a)
|
| 736 |
-
[1.84147098 2.90929743]
|
| 737 |
-
|
| 738 |
-
But if we try to pass in a SymPy expression, it fails
|
| 739 |
-
|
| 740 |
-
>>> g(x + 1)
|
| 741 |
-
Traceback (most recent call last):
|
| 742 |
-
...
|
| 743 |
-
TypeError: loop of ufunc does not support argument 0 of type Add which has
|
| 744 |
-
no callable sin method
|
| 745 |
-
|
| 746 |
-
Now, let's look at what happened. The reason this fails is that ``g``
|
| 747 |
-
calls ``numpy.sin`` on the input expression, and ``numpy.sin`` does not
|
| 748 |
-
know how to operate on a SymPy object. **As a general rule, NumPy
|
| 749 |
-
functions do not know how to operate on SymPy expressions, and SymPy
|
| 750 |
-
functions do not know how to operate on NumPy arrays. This is why lambdify
|
| 751 |
-
exists: to provide a bridge between SymPy and NumPy.**
|
| 752 |
-
|
| 753 |
-
However, why is it that ``f`` did work? That's because ``f`` does not call
|
| 754 |
-
any functions, it only adds 1. So the resulting function that is created,
|
| 755 |
-
``def _lambdifygenerated(x): return x + 1`` does not depend on the globals
|
| 756 |
-
namespace it is defined in. Thus it works, but only by accident. A future
|
| 757 |
-
version of ``lambdify`` may remove this behavior.
|
| 758 |
-
|
| 759 |
-
Be aware that certain implementation details described here may change in
|
| 760 |
-
future versions of SymPy. The API of passing in custom modules and
|
| 761 |
-
printers will not change, but the details of how a lambda function is
|
| 762 |
-
created may change. However, the basic idea will remain the same, and
|
| 763 |
-
understanding it will be helpful to understanding the behavior of
|
| 764 |
-
lambdify.
|
| 765 |
-
|
| 766 |
-
**In general: you should create lambdified functions for one module (say,
|
| 767 |
-
NumPy), and only pass it input types that are compatible with that module
|
| 768 |
-
(say, NumPy arrays).** Remember that by default, if the ``module``
|
| 769 |
-
argument is not provided, ``lambdify`` creates functions using the NumPy
|
| 770 |
-
and SciPy namespaces.
|
| 771 |
-
"""
|
| 772 |
-
from sympy.core.symbol import Symbol
|
| 773 |
-
from sympy.core.expr import Expr
|
| 774 |
-
|
| 775 |
-
# If the user hasn't specified any modules, use what is available.
|
| 776 |
-
if modules is None:
|
| 777 |
-
try:
|
| 778 |
-
_import("scipy")
|
| 779 |
-
except ImportError:
|
| 780 |
-
try:
|
| 781 |
-
_import("numpy")
|
| 782 |
-
except ImportError:
|
| 783 |
-
# Use either numpy (if available) or python.math where possible.
|
| 784 |
-
# XXX: This leads to different behaviour on different systems and
|
| 785 |
-
# might be the reason for irreproducible errors.
|
| 786 |
-
modules = ["math", "mpmath", "sympy"]
|
| 787 |
-
else:
|
| 788 |
-
modules = ["numpy"]
|
| 789 |
-
else:
|
| 790 |
-
modules = ["numpy", "scipy"]
|
| 791 |
-
|
| 792 |
-
# Get the needed namespaces.
|
| 793 |
-
namespaces = []
|
| 794 |
-
# First find any function implementations
|
| 795 |
-
if use_imps:
|
| 796 |
-
namespaces.append(_imp_namespace(expr))
|
| 797 |
-
# Check for dict before iterating
|
| 798 |
-
if isinstance(modules, (dict, str)) or not hasattr(modules, '__iter__'):
|
| 799 |
-
namespaces.append(modules)
|
| 800 |
-
else:
|
| 801 |
-
# consistency check
|
| 802 |
-
if _module_present('numexpr', modules) and len(modules) > 1:
|
| 803 |
-
raise TypeError("numexpr must be the only item in 'modules'")
|
| 804 |
-
namespaces += list(modules)
|
| 805 |
-
# fill namespace with first having highest priority
|
| 806 |
-
namespace = {}
|
| 807 |
-
for m in namespaces[::-1]:
|
| 808 |
-
buf = _get_namespace(m)
|
| 809 |
-
namespace.update(buf)
|
| 810 |
-
|
| 811 |
-
if hasattr(expr, "atoms"):
|
| 812 |
-
#Try if you can extract symbols from the expression.
|
| 813 |
-
#Move on if expr.atoms in not implemented.
|
| 814 |
-
syms = expr.atoms(Symbol)
|
| 815 |
-
for term in syms:
|
| 816 |
-
namespace.update({str(term): term})
|
| 817 |
-
|
| 818 |
-
if printer is None:
|
| 819 |
-
if _module_present('mpmath', namespaces):
|
| 820 |
-
from sympy.printing.pycode import MpmathPrinter as Printer # type: ignore
|
| 821 |
-
elif _module_present('scipy', namespaces):
|
| 822 |
-
from sympy.printing.numpy import SciPyPrinter as Printer # type: ignore
|
| 823 |
-
elif _module_present('numpy', namespaces):
|
| 824 |
-
from sympy.printing.numpy import NumPyPrinter as Printer # type: ignore
|
| 825 |
-
elif _module_present('cupy', namespaces):
|
| 826 |
-
from sympy.printing.numpy import CuPyPrinter as Printer # type: ignore
|
| 827 |
-
elif _module_present('jax', namespaces):
|
| 828 |
-
from sympy.printing.numpy import JaxPrinter as Printer # type: ignore
|
| 829 |
-
elif _module_present('numexpr', namespaces):
|
| 830 |
-
from sympy.printing.lambdarepr import NumExprPrinter as Printer # type: ignore
|
| 831 |
-
elif _module_present('tensorflow', namespaces):
|
| 832 |
-
from sympy.printing.tensorflow import TensorflowPrinter as Printer # type: ignore
|
| 833 |
-
elif _module_present('torch', namespaces):
|
| 834 |
-
from sympy.printing.pytorch import TorchPrinter as Printer # type: ignore
|
| 835 |
-
elif _module_present('sympy', namespaces):
|
| 836 |
-
from sympy.printing.pycode import SymPyPrinter as Printer # type: ignore
|
| 837 |
-
elif _module_present('cmath', namespaces):
|
| 838 |
-
from sympy.printing.pycode import CmathPrinter as Printer # type: ignore
|
| 839 |
-
else:
|
| 840 |
-
from sympy.printing.pycode import PythonCodePrinter as Printer # type: ignore
|
| 841 |
-
user_functions = {}
|
| 842 |
-
for m in namespaces[::-1]:
|
| 843 |
-
if isinstance(m, dict):
|
| 844 |
-
for k in m:
|
| 845 |
-
user_functions[k] = k
|
| 846 |
-
printer = Printer({'fully_qualified_modules': False, 'inline': True,
|
| 847 |
-
'allow_unknown_functions': True,
|
| 848 |
-
'user_functions': user_functions})
|
| 849 |
-
|
| 850 |
-
if isinstance(args, set):
|
| 851 |
-
sympy_deprecation_warning(
|
| 852 |
-
"""
|
| 853 |
-
Passing the function arguments to lambdify() as a set is deprecated. This
|
| 854 |
-
leads to unpredictable results since sets are unordered. Instead, use a list
|
| 855 |
-
or tuple for the function arguments.
|
| 856 |
-
""",
|
| 857 |
-
deprecated_since_version="1.6.3",
|
| 858 |
-
active_deprecations_target="deprecated-lambdify-arguments-set",
|
| 859 |
-
)
|
| 860 |
-
|
| 861 |
-
# Get the names of the args, for creating a docstring
|
| 862 |
-
iterable_args = (args,) if isinstance(args, Expr) else args
|
| 863 |
-
names = []
|
| 864 |
-
|
| 865 |
-
# Grab the callers frame, for getting the names by inspection (if needed)
|
| 866 |
-
callers_local_vars = inspect.currentframe().f_back.f_locals.items() # type: ignore
|
| 867 |
-
for n, var in enumerate(iterable_args):
|
| 868 |
-
if hasattr(var, 'name'):
|
| 869 |
-
names.append(var.name)
|
| 870 |
-
else:
|
| 871 |
-
# It's an iterable. Try to get name by inspection of calling frame.
|
| 872 |
-
name_list = [var_name for var_name, var_val in callers_local_vars
|
| 873 |
-
if var_val is var]
|
| 874 |
-
if len(name_list) == 1:
|
| 875 |
-
names.append(name_list[0])
|
| 876 |
-
else:
|
| 877 |
-
# Cannot infer name with certainty. arg_# will have to do.
|
| 878 |
-
names.append('arg_' + str(n))
|
| 879 |
-
|
| 880 |
-
# Create the function definition code and execute it
|
| 881 |
-
funcname = '_lambdifygenerated'
|
| 882 |
-
if _module_present('tensorflow', namespaces):
|
| 883 |
-
funcprinter = _TensorflowEvaluatorPrinter(printer, dummify)
|
| 884 |
-
else:
|
| 885 |
-
funcprinter = _EvaluatorPrinter(printer, dummify)
|
| 886 |
-
|
| 887 |
-
if cse == True:
|
| 888 |
-
from sympy.simplify.cse_main import cse as _cse
|
| 889 |
-
cses, _expr = _cse(expr, list=False)
|
| 890 |
-
elif callable(cse):
|
| 891 |
-
cses, _expr = cse(expr)
|
| 892 |
-
else:
|
| 893 |
-
cses, _expr = (), expr
|
| 894 |
-
funcstr = funcprinter.doprint(funcname, iterable_args, _expr, cses=cses)
|
| 895 |
-
|
| 896 |
-
# Collect the module imports from the code printers.
|
| 897 |
-
imp_mod_lines = []
|
| 898 |
-
for mod, keys in (getattr(printer, 'module_imports', None) or {}).items():
|
| 899 |
-
for k in keys:
|
| 900 |
-
if k not in namespace:
|
| 901 |
-
ln = "from %s import %s" % (mod, k)
|
| 902 |
-
try:
|
| 903 |
-
exec(ln, {}, namespace)
|
| 904 |
-
except ImportError:
|
| 905 |
-
# Tensorflow 2.0 has issues with importing a specific
|
| 906 |
-
# function from its submodule.
|
| 907 |
-
# https://github.com/tensorflow/tensorflow/issues/33022
|
| 908 |
-
ln = "%s = %s.%s" % (k, mod, k)
|
| 909 |
-
exec(ln, {}, namespace)
|
| 910 |
-
imp_mod_lines.append(ln)
|
| 911 |
-
|
| 912 |
-
# Provide lambda expression with builtins, and compatible implementation of range
|
| 913 |
-
namespace.update({'builtins':builtins, 'range':range})
|
| 914 |
-
|
| 915 |
-
funclocals = {}
|
| 916 |
-
global _lambdify_generated_counter
|
| 917 |
-
filename = '<lambdifygenerated-%s>' % _lambdify_generated_counter
|
| 918 |
-
_lambdify_generated_counter += 1
|
| 919 |
-
c = compile(funcstr, filename, 'exec')
|
| 920 |
-
exec(c, namespace, funclocals)
|
| 921 |
-
# mtime has to be None or else linecache.checkcache will remove it
|
| 922 |
-
linecache.cache[filename] = (len(funcstr), None, funcstr.splitlines(True), filename) # type: ignore
|
| 923 |
-
|
| 924 |
-
# Remove the entry from the linecache when the object is garbage collected
|
| 925 |
-
def cleanup_linecache(filename):
|
| 926 |
-
def _cleanup():
|
| 927 |
-
if filename in linecache.cache:
|
| 928 |
-
del linecache.cache[filename]
|
| 929 |
-
return _cleanup
|
| 930 |
-
|
| 931 |
-
func = funclocals[funcname]
|
| 932 |
-
|
| 933 |
-
weakref.finalize(func, cleanup_linecache(filename))
|
| 934 |
-
|
| 935 |
-
# Apply the docstring
|
| 936 |
-
sig = "func({})".format(", ".join(str(i) for i in names))
|
| 937 |
-
sig = textwrap.fill(sig, subsequent_indent=' '*8)
|
| 938 |
-
if _too_large_for_docstring(expr, docstring_limit):
|
| 939 |
-
expr_str = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
|
| 940 |
-
src_str = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
|
| 941 |
-
else:
|
| 942 |
-
expr_str = str(expr)
|
| 943 |
-
if len(expr_str) > 78:
|
| 944 |
-
expr_str = textwrap.wrap(expr_str, 75)[0] + '...'
|
| 945 |
-
src_str = funcstr
|
| 946 |
-
func.__doc__ = (
|
| 947 |
-
"Created with lambdify. Signature:\n\n"
|
| 948 |
-
"{sig}\n\n"
|
| 949 |
-
"Expression:\n\n"
|
| 950 |
-
"{expr}\n\n"
|
| 951 |
-
"Source code:\n\n"
|
| 952 |
-
"{src}\n\n"
|
| 953 |
-
"Imported modules:\n\n"
|
| 954 |
-
"{imp_mods}"
|
| 955 |
-
).format(sig=sig, expr=expr_str, src=src_str, imp_mods='\n'.join(imp_mod_lines))
|
| 956 |
-
return func
|
| 957 |
-
|
| 958 |
-
def _module_present(modname, modlist):
|
| 959 |
-
if modname in modlist:
|
| 960 |
-
return True
|
| 961 |
-
for m in modlist:
|
| 962 |
-
if hasattr(m, '__name__') and m.__name__ == modname:
|
| 963 |
-
return True
|
| 964 |
-
return False
|
| 965 |
-
|
| 966 |
-
def _get_namespace(m):
|
| 967 |
-
"""
|
| 968 |
-
This is used by _lambdify to parse its arguments.
|
| 969 |
-
"""
|
| 970 |
-
if isinstance(m, str):
|
| 971 |
-
_import(m)
|
| 972 |
-
return MODULES[m][0]
|
| 973 |
-
elif isinstance(m, dict):
|
| 974 |
-
return m
|
| 975 |
-
elif hasattr(m, "__dict__"):
|
| 976 |
-
return m.__dict__
|
| 977 |
-
else:
|
| 978 |
-
raise TypeError("Argument must be either a string, dict or module but it is: %s" % m)
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
def _recursive_to_string(doprint, arg):
|
| 982 |
-
"""Functions in lambdify accept both SymPy types and non-SymPy types such as python
|
| 983 |
-
lists and tuples. This method ensures that we only call the doprint method of the
|
| 984 |
-
printer with SymPy types (so that the printer safely can use SymPy-methods)."""
|
| 985 |
-
from sympy.matrices.matrixbase import MatrixBase
|
| 986 |
-
from sympy.core.basic import Basic
|
| 987 |
-
|
| 988 |
-
if isinstance(arg, (Basic, MatrixBase)):
|
| 989 |
-
return doprint(arg)
|
| 990 |
-
elif iterable(arg):
|
| 991 |
-
if isinstance(arg, list):
|
| 992 |
-
left, right = "[", "]"
|
| 993 |
-
elif isinstance(arg, tuple):
|
| 994 |
-
left, right = "(", ",)"
|
| 995 |
-
if not arg:
|
| 996 |
-
return "()"
|
| 997 |
-
else:
|
| 998 |
-
raise NotImplementedError("unhandled type: %s, %s" % (type(arg), arg))
|
| 999 |
-
return left +', '.join(_recursive_to_string(doprint, e) for e in arg) + right
|
| 1000 |
-
elif isinstance(arg, str):
|
| 1001 |
-
return arg
|
| 1002 |
-
else:
|
| 1003 |
-
return doprint(arg)
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
def lambdastr(args, expr, printer=None, dummify=None):
|
| 1007 |
-
"""
|
| 1008 |
-
Returns a string that can be evaluated to a lambda function.
|
| 1009 |
-
|
| 1010 |
-
Examples
|
| 1011 |
-
========
|
| 1012 |
-
|
| 1013 |
-
>>> from sympy.abc import x, y, z
|
| 1014 |
-
>>> from sympy.utilities.lambdify import lambdastr
|
| 1015 |
-
>>> lambdastr(x, x**2)
|
| 1016 |
-
'lambda x: (x**2)'
|
| 1017 |
-
>>> lambdastr((x,y,z), [z,y,x])
|
| 1018 |
-
'lambda x,y,z: ([z, y, x])'
|
| 1019 |
-
|
| 1020 |
-
Although tuples may not appear as arguments to lambda in Python 3,
|
| 1021 |
-
lambdastr will create a lambda function that will unpack the original
|
| 1022 |
-
arguments so that nested arguments can be handled:
|
| 1023 |
-
|
| 1024 |
-
>>> lambdastr((x, (y, z)), x + y)
|
| 1025 |
-
'lambda _0,_1: (lambda x,y,z: (x + y))(_0,_1[0],_1[1])'
|
| 1026 |
-
"""
|
| 1027 |
-
# Transforming everything to strings.
|
| 1028 |
-
from sympy.matrices import DeferredVector
|
| 1029 |
-
from sympy.core.basic import Basic
|
| 1030 |
-
from sympy.core.function import (Derivative, Function)
|
| 1031 |
-
from sympy.core.symbol import (Dummy, Symbol)
|
| 1032 |
-
from sympy.core.sympify import sympify
|
| 1033 |
-
|
| 1034 |
-
if printer is not None:
|
| 1035 |
-
if inspect.isfunction(printer):
|
| 1036 |
-
lambdarepr = printer
|
| 1037 |
-
else:
|
| 1038 |
-
if inspect.isclass(printer):
|
| 1039 |
-
lambdarepr = lambda expr: printer().doprint(expr)
|
| 1040 |
-
else:
|
| 1041 |
-
lambdarepr = lambda expr: printer.doprint(expr)
|
| 1042 |
-
else:
|
| 1043 |
-
#XXX: This has to be done here because of circular imports
|
| 1044 |
-
from sympy.printing.lambdarepr import lambdarepr
|
| 1045 |
-
|
| 1046 |
-
def sub_args(args, dummies_dict):
|
| 1047 |
-
if isinstance(args, str):
|
| 1048 |
-
return args
|
| 1049 |
-
elif isinstance(args, DeferredVector):
|
| 1050 |
-
return str(args)
|
| 1051 |
-
elif iterable(args):
|
| 1052 |
-
dummies = flatten([sub_args(a, dummies_dict) for a in args])
|
| 1053 |
-
return ",".join(str(a) for a in dummies)
|
| 1054 |
-
else:
|
| 1055 |
-
# replace these with Dummy symbols
|
| 1056 |
-
if isinstance(args, (Function, Symbol, Derivative)):
|
| 1057 |
-
dummies = Dummy()
|
| 1058 |
-
dummies_dict.update({args : dummies})
|
| 1059 |
-
return str(dummies)
|
| 1060 |
-
else:
|
| 1061 |
-
return str(args)
|
| 1062 |
-
|
| 1063 |
-
def sub_expr(expr, dummies_dict):
|
| 1064 |
-
expr = sympify(expr)
|
| 1065 |
-
# dict/tuple are sympified to Basic
|
| 1066 |
-
if isinstance(expr, Basic):
|
| 1067 |
-
expr = expr.xreplace(dummies_dict)
|
| 1068 |
-
# list is not sympified to Basic
|
| 1069 |
-
elif isinstance(expr, list):
|
| 1070 |
-
expr = [sub_expr(a, dummies_dict) for a in expr]
|
| 1071 |
-
return expr
|
| 1072 |
-
|
| 1073 |
-
# Transform args
|
| 1074 |
-
def isiter(l):
|
| 1075 |
-
return iterable(l, exclude=(str, DeferredVector, NotIterable))
|
| 1076 |
-
|
| 1077 |
-
def flat_indexes(iterable):
|
| 1078 |
-
n = 0
|
| 1079 |
-
|
| 1080 |
-
for el in iterable:
|
| 1081 |
-
if isiter(el):
|
| 1082 |
-
for ndeep in flat_indexes(el):
|
| 1083 |
-
yield (n,) + ndeep
|
| 1084 |
-
else:
|
| 1085 |
-
yield (n,)
|
| 1086 |
-
|
| 1087 |
-
n += 1
|
| 1088 |
-
|
| 1089 |
-
if dummify is None:
|
| 1090 |
-
dummify = any(isinstance(a, Basic) and
|
| 1091 |
-
a.atoms(Function, Derivative) for a in (
|
| 1092 |
-
args if isiter(args) else [args]))
|
| 1093 |
-
|
| 1094 |
-
if isiter(args) and any(isiter(i) for i in args):
|
| 1095 |
-
dum_args = [str(Dummy(str(i))) for i in range(len(args))]
|
| 1096 |
-
|
| 1097 |
-
indexed_args = ','.join([
|
| 1098 |
-
dum_args[ind[0]] + ''.join(["[%s]" % k for k in ind[1:]])
|
| 1099 |
-
for ind in flat_indexes(args)])
|
| 1100 |
-
|
| 1101 |
-
lstr = lambdastr(flatten(args), expr, printer=printer, dummify=dummify)
|
| 1102 |
-
|
| 1103 |
-
return 'lambda %s: (%s)(%s)' % (','.join(dum_args), lstr, indexed_args)
|
| 1104 |
-
|
| 1105 |
-
dummies_dict = {}
|
| 1106 |
-
if dummify:
|
| 1107 |
-
args = sub_args(args, dummies_dict)
|
| 1108 |
-
else:
|
| 1109 |
-
if isinstance(args, str):
|
| 1110 |
-
pass
|
| 1111 |
-
elif iterable(args, exclude=DeferredVector):
|
| 1112 |
-
args = ",".join(str(a) for a in args)
|
| 1113 |
-
|
| 1114 |
-
# Transform expr
|
| 1115 |
-
if dummify:
|
| 1116 |
-
if isinstance(expr, str):
|
| 1117 |
-
pass
|
| 1118 |
-
else:
|
| 1119 |
-
expr = sub_expr(expr, dummies_dict)
|
| 1120 |
-
expr = _recursive_to_string(lambdarepr, expr)
|
| 1121 |
-
return "lambda %s: (%s)" % (args, expr)
|
| 1122 |
-
|
| 1123 |
-
class _EvaluatorPrinter:
|
| 1124 |
-
def __init__(self, printer=None, dummify=False):
|
| 1125 |
-
self._dummify = dummify
|
| 1126 |
-
|
| 1127 |
-
#XXX: This has to be done here because of circular imports
|
| 1128 |
-
from sympy.printing.lambdarepr import LambdaPrinter
|
| 1129 |
-
|
| 1130 |
-
if printer is None:
|
| 1131 |
-
printer = LambdaPrinter()
|
| 1132 |
-
|
| 1133 |
-
if inspect.isfunction(printer):
|
| 1134 |
-
self._exprrepr = printer
|
| 1135 |
-
else:
|
| 1136 |
-
if inspect.isclass(printer):
|
| 1137 |
-
printer = printer()
|
| 1138 |
-
|
| 1139 |
-
self._exprrepr = printer.doprint
|
| 1140 |
-
|
| 1141 |
-
#if hasattr(printer, '_print_Symbol'):
|
| 1142 |
-
# symbolrepr = printer._print_Symbol
|
| 1143 |
-
|
| 1144 |
-
#if hasattr(printer, '_print_Dummy'):
|
| 1145 |
-
# dummyrepr = printer._print_Dummy
|
| 1146 |
-
|
| 1147 |
-
# Used to print the generated function arguments in a standard way
|
| 1148 |
-
self._argrepr = LambdaPrinter().doprint
|
| 1149 |
-
|
| 1150 |
-
def doprint(self, funcname, args, expr, *, cses=()):
|
| 1151 |
-
"""
|
| 1152 |
-
Returns the function definition code as a string.
|
| 1153 |
-
"""
|
| 1154 |
-
from sympy.core.symbol import Dummy
|
| 1155 |
-
|
| 1156 |
-
funcbody = []
|
| 1157 |
-
|
| 1158 |
-
if not iterable(args):
|
| 1159 |
-
args = [args]
|
| 1160 |
-
|
| 1161 |
-
if cses:
|
| 1162 |
-
cses = list(cses)
|
| 1163 |
-
subvars, subexprs = zip(*cses)
|
| 1164 |
-
exprs = [expr] + list(subexprs)
|
| 1165 |
-
argstrs, exprs = self._preprocess(args, exprs, cses=cses)
|
| 1166 |
-
expr, subexprs = exprs[0], exprs[1:]
|
| 1167 |
-
cses = zip(subvars, subexprs)
|
| 1168 |
-
else:
|
| 1169 |
-
argstrs, expr = self._preprocess(args, expr)
|
| 1170 |
-
|
| 1171 |
-
# Generate argument unpacking and final argument list
|
| 1172 |
-
funcargs = []
|
| 1173 |
-
unpackings = []
|
| 1174 |
-
|
| 1175 |
-
for argstr in argstrs:
|
| 1176 |
-
if iterable(argstr):
|
| 1177 |
-
funcargs.append(self._argrepr(Dummy()))
|
| 1178 |
-
unpackings.extend(self._print_unpacking(argstr, funcargs[-1]))
|
| 1179 |
-
else:
|
| 1180 |
-
funcargs.append(argstr)
|
| 1181 |
-
|
| 1182 |
-
funcsig = 'def {}({}):'.format(funcname, ', '.join(funcargs))
|
| 1183 |
-
|
| 1184 |
-
# Wrap input arguments before unpacking
|
| 1185 |
-
funcbody.extend(self._print_funcargwrapping(funcargs))
|
| 1186 |
-
|
| 1187 |
-
funcbody.extend(unpackings)
|
| 1188 |
-
|
| 1189 |
-
for s, e in cses:
|
| 1190 |
-
if e is None:
|
| 1191 |
-
funcbody.append('del {}'.format(self._exprrepr(s)))
|
| 1192 |
-
else:
|
| 1193 |
-
funcbody.append('{} = {}'.format(self._exprrepr(s), self._exprrepr(e)))
|
| 1194 |
-
|
| 1195 |
-
# Subs may appear in expressions generated by .diff()
|
| 1196 |
-
subs_assignments = []
|
| 1197 |
-
expr = self._handle_Subs(expr, out=subs_assignments)
|
| 1198 |
-
for lhs, rhs in subs_assignments:
|
| 1199 |
-
funcbody.append('{} = {}'.format(self._exprrepr(lhs), self._exprrepr(rhs)))
|
| 1200 |
-
|
| 1201 |
-
str_expr = _recursive_to_string(self._exprrepr, expr)
|
| 1202 |
-
|
| 1203 |
-
if '\n' in str_expr:
|
| 1204 |
-
str_expr = '({})'.format(str_expr)
|
| 1205 |
-
funcbody.append('return {}'.format(str_expr))
|
| 1206 |
-
|
| 1207 |
-
funclines = [funcsig]
|
| 1208 |
-
funclines.extend([' ' + line for line in funcbody])
|
| 1209 |
-
|
| 1210 |
-
return '\n'.join(funclines) + '\n'
|
| 1211 |
-
|
| 1212 |
-
@classmethod
|
| 1213 |
-
def _is_safe_ident(cls, ident):
|
| 1214 |
-
return isinstance(ident, str) and ident.isidentifier() \
|
| 1215 |
-
and not keyword.iskeyword(ident)
|
| 1216 |
-
|
| 1217 |
-
def _preprocess(self, args, expr, cses=(), _dummies_dict=None):
|
| 1218 |
-
"""Preprocess args, expr to replace arguments that do not map
|
| 1219 |
-
to valid Python identifiers.
|
| 1220 |
-
|
| 1221 |
-
Returns string form of args, and updated expr.
|
| 1222 |
-
"""
|
| 1223 |
-
from sympy.core.basic import Basic
|
| 1224 |
-
from sympy.core.sorting import ordered
|
| 1225 |
-
from sympy.core.function import (Derivative, Function)
|
| 1226 |
-
from sympy.core.symbol import Dummy, uniquely_named_symbol
|
| 1227 |
-
from sympy.matrices import DeferredVector
|
| 1228 |
-
from sympy.core.expr import Expr
|
| 1229 |
-
|
| 1230 |
-
# Args of type Dummy can cause name collisions with args
|
| 1231 |
-
# of type Symbol. Force dummify of everything in this
|
| 1232 |
-
# situation.
|
| 1233 |
-
dummify = self._dummify or any(
|
| 1234 |
-
isinstance(arg, Dummy) for arg in flatten(args))
|
| 1235 |
-
|
| 1236 |
-
argstrs = [None]*len(args)
|
| 1237 |
-
if _dummies_dict is None:
|
| 1238 |
-
_dummies_dict = {}
|
| 1239 |
-
|
| 1240 |
-
def update_dummies(arg, dummy):
|
| 1241 |
-
_dummies_dict[arg] = dummy
|
| 1242 |
-
for repl, sub in cses:
|
| 1243 |
-
arg = arg.xreplace({sub: repl})
|
| 1244 |
-
_dummies_dict[arg] = dummy
|
| 1245 |
-
|
| 1246 |
-
for arg, i in reversed(list(ordered(zip(args, range(len(args)))))):
|
| 1247 |
-
if iterable(arg):
|
| 1248 |
-
s, expr = self._preprocess(arg, expr, cses=cses, _dummies_dict=_dummies_dict)
|
| 1249 |
-
elif isinstance(arg, DeferredVector):
|
| 1250 |
-
s = str(arg)
|
| 1251 |
-
elif isinstance(arg, Basic) and arg.is_symbol:
|
| 1252 |
-
s = str(arg)
|
| 1253 |
-
if dummify or not self._is_safe_ident(s):
|
| 1254 |
-
dummy = Dummy()
|
| 1255 |
-
if isinstance(expr, Expr):
|
| 1256 |
-
dummy = uniquely_named_symbol(
|
| 1257 |
-
dummy.name, expr, modify=lambda s: '_' + s)
|
| 1258 |
-
s = self._argrepr(dummy)
|
| 1259 |
-
update_dummies(arg, dummy)
|
| 1260 |
-
expr = self._subexpr(expr, _dummies_dict)
|
| 1261 |
-
elif dummify or isinstance(arg, (Function, Derivative)):
|
| 1262 |
-
dummy = Dummy()
|
| 1263 |
-
s = self._argrepr(dummy)
|
| 1264 |
-
update_dummies(arg, dummy)
|
| 1265 |
-
expr = self._subexpr(expr, _dummies_dict)
|
| 1266 |
-
else:
|
| 1267 |
-
s = str(arg)
|
| 1268 |
-
argstrs[i] = s
|
| 1269 |
-
return argstrs, expr
|
| 1270 |
-
|
| 1271 |
-
def _subexpr(self, expr, dummies_dict):
|
| 1272 |
-
from sympy.matrices import DeferredVector
|
| 1273 |
-
from sympy.core.sympify import sympify
|
| 1274 |
-
|
| 1275 |
-
expr = sympify(expr)
|
| 1276 |
-
xreplace = getattr(expr, 'xreplace', None)
|
| 1277 |
-
if xreplace is not None:
|
| 1278 |
-
expr = xreplace(dummies_dict)
|
| 1279 |
-
else:
|
| 1280 |
-
if isinstance(expr, DeferredVector):
|
| 1281 |
-
pass
|
| 1282 |
-
elif isinstance(expr, dict):
|
| 1283 |
-
k = [self._subexpr(sympify(a), dummies_dict) for a in expr.keys()]
|
| 1284 |
-
v = [self._subexpr(sympify(a), dummies_dict) for a in expr.values()]
|
| 1285 |
-
expr = dict(zip(k, v))
|
| 1286 |
-
elif isinstance(expr, tuple):
|
| 1287 |
-
expr = tuple(self._subexpr(sympify(a), dummies_dict) for a in expr)
|
| 1288 |
-
elif isinstance(expr, list):
|
| 1289 |
-
expr = [self._subexpr(sympify(a), dummies_dict) for a in expr]
|
| 1290 |
-
return expr
|
| 1291 |
-
|
| 1292 |
-
def _print_funcargwrapping(self, args):
|
| 1293 |
-
"""Generate argument wrapping code.
|
| 1294 |
-
|
| 1295 |
-
args is the argument list of the generated function (strings).
|
| 1296 |
-
|
| 1297 |
-
Return value is a list of lines of code that will be inserted at
|
| 1298 |
-
the beginning of the function definition.
|
| 1299 |
-
"""
|
| 1300 |
-
return []
|
| 1301 |
-
|
| 1302 |
-
def _print_unpacking(self, unpackto, arg):
|
| 1303 |
-
"""Generate argument unpacking code.
|
| 1304 |
-
|
| 1305 |
-
arg is the function argument to be unpacked (a string), and
|
| 1306 |
-
unpackto is a list or nested lists of the variable names (strings) to
|
| 1307 |
-
unpack to.
|
| 1308 |
-
"""
|
| 1309 |
-
def unpack_lhs(lvalues):
|
| 1310 |
-
return '[{}]'.format(', '.join(
|
| 1311 |
-
unpack_lhs(val) if iterable(val) else val for val in lvalues))
|
| 1312 |
-
|
| 1313 |
-
return ['{} = {}'.format(unpack_lhs(unpackto), arg)]
|
| 1314 |
-
|
| 1315 |
-
def _handle_Subs(self, expr, out):
|
| 1316 |
-
"""Any instance of Subs is extracted and returned as assignment pairs."""
|
| 1317 |
-
from sympy.core.basic import Basic
|
| 1318 |
-
from sympy.core.function import Subs
|
| 1319 |
-
from sympy.core.symbol import Dummy
|
| 1320 |
-
from sympy.matrices.matrixbase import MatrixBase
|
| 1321 |
-
|
| 1322 |
-
def _replace(ex, variables, point):
|
| 1323 |
-
safe = {}
|
| 1324 |
-
for lhs, rhs in zip(variables, point):
|
| 1325 |
-
dummy = Dummy()
|
| 1326 |
-
safe[lhs] = dummy
|
| 1327 |
-
out.append((dummy, rhs))
|
| 1328 |
-
return ex.xreplace(safe)
|
| 1329 |
-
|
| 1330 |
-
if isinstance(expr, (Basic, MatrixBase)):
|
| 1331 |
-
expr = expr.replace(Subs, _replace)
|
| 1332 |
-
elif iterable(expr):
|
| 1333 |
-
expr = type(expr)([self._handle_Subs(e, out) for e in expr])
|
| 1334 |
-
return expr
|
| 1335 |
-
|
| 1336 |
-
class _TensorflowEvaluatorPrinter(_EvaluatorPrinter):
|
| 1337 |
-
def _print_unpacking(self, lvalues, rvalue):
|
| 1338 |
-
"""Generate argument unpacking code.
|
| 1339 |
-
|
| 1340 |
-
This method is used when the input value is not iterable,
|
| 1341 |
-
but can be indexed (see issue #14655).
|
| 1342 |
-
"""
|
| 1343 |
-
|
| 1344 |
-
def flat_indexes(elems):
|
| 1345 |
-
n = 0
|
| 1346 |
-
|
| 1347 |
-
for el in elems:
|
| 1348 |
-
if iterable(el):
|
| 1349 |
-
for ndeep in flat_indexes(el):
|
| 1350 |
-
yield (n,) + ndeep
|
| 1351 |
-
else:
|
| 1352 |
-
yield (n,)
|
| 1353 |
-
|
| 1354 |
-
n += 1
|
| 1355 |
-
|
| 1356 |
-
indexed = ', '.join('{}[{}]'.format(rvalue, ']['.join(map(str, ind)))
|
| 1357 |
-
for ind in flat_indexes(lvalues))
|
| 1358 |
-
|
| 1359 |
-
return ['[{}] = [{}]'.format(', '.join(flatten(lvalues)), indexed)]
|
| 1360 |
-
|
| 1361 |
-
def _imp_namespace(expr, namespace=None):
|
| 1362 |
-
""" Return namespace dict with function implementations
|
| 1363 |
-
|
| 1364 |
-
We need to search for functions in anything that can be thrown at
|
| 1365 |
-
us - that is - anything that could be passed as ``expr``. Examples
|
| 1366 |
-
include SymPy expressions, as well as tuples, lists and dicts that may
|
| 1367 |
-
contain SymPy expressions.
|
| 1368 |
-
|
| 1369 |
-
Parameters
|
| 1370 |
-
----------
|
| 1371 |
-
expr : object
|
| 1372 |
-
Something passed to lambdify, that will generate valid code from
|
| 1373 |
-
``str(expr)``.
|
| 1374 |
-
namespace : None or mapping
|
| 1375 |
-
Namespace to fill. None results in new empty dict
|
| 1376 |
-
|
| 1377 |
-
Returns
|
| 1378 |
-
-------
|
| 1379 |
-
namespace : dict
|
| 1380 |
-
dict with keys of implemented function names within ``expr`` and
|
| 1381 |
-
corresponding values being the numerical implementation of
|
| 1382 |
-
function
|
| 1383 |
-
|
| 1384 |
-
Examples
|
| 1385 |
-
========
|
| 1386 |
-
|
| 1387 |
-
>>> from sympy.abc import x
|
| 1388 |
-
>>> from sympy.utilities.lambdify import implemented_function, _imp_namespace
|
| 1389 |
-
>>> from sympy import Function
|
| 1390 |
-
>>> f = implemented_function(Function('f'), lambda x: x+1)
|
| 1391 |
-
>>> g = implemented_function(Function('g'), lambda x: x*10)
|
| 1392 |
-
>>> namespace = _imp_namespace(f(g(x)))
|
| 1393 |
-
>>> sorted(namespace.keys())
|
| 1394 |
-
['f', 'g']
|
| 1395 |
-
"""
|
| 1396 |
-
# Delayed import to avoid circular imports
|
| 1397 |
-
from sympy.core.function import FunctionClass
|
| 1398 |
-
if namespace is None:
|
| 1399 |
-
namespace = {}
|
| 1400 |
-
# tuples, lists, dicts are valid expressions
|
| 1401 |
-
if is_sequence(expr):
|
| 1402 |
-
for arg in expr:
|
| 1403 |
-
_imp_namespace(arg, namespace)
|
| 1404 |
-
return namespace
|
| 1405 |
-
elif isinstance(expr, dict):
|
| 1406 |
-
for key, val in expr.items():
|
| 1407 |
-
# functions can be in dictionary keys
|
| 1408 |
-
_imp_namespace(key, namespace)
|
| 1409 |
-
_imp_namespace(val, namespace)
|
| 1410 |
-
return namespace
|
| 1411 |
-
# SymPy expressions may be Functions themselves
|
| 1412 |
-
func = getattr(expr, 'func', None)
|
| 1413 |
-
if isinstance(func, FunctionClass):
|
| 1414 |
-
imp = getattr(func, '_imp_', None)
|
| 1415 |
-
if imp is not None:
|
| 1416 |
-
name = expr.func.__name__
|
| 1417 |
-
if name in namespace and namespace[name] != imp:
|
| 1418 |
-
raise ValueError('We found more than one '
|
| 1419 |
-
'implementation with name '
|
| 1420 |
-
'"%s"' % name)
|
| 1421 |
-
namespace[name] = imp
|
| 1422 |
-
# and / or they may take Functions as arguments
|
| 1423 |
-
if hasattr(expr, 'args'):
|
| 1424 |
-
for arg in expr.args:
|
| 1425 |
-
_imp_namespace(arg, namespace)
|
| 1426 |
-
return namespace
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
-
def implemented_function(symfunc, implementation):
|
| 1430 |
-
""" Add numerical ``implementation`` to function ``symfunc``.
|
| 1431 |
-
|
| 1432 |
-
``symfunc`` can be an ``UndefinedFunction`` instance, or a name string.
|
| 1433 |
-
In the latter case we create an ``UndefinedFunction`` instance with that
|
| 1434 |
-
name.
|
| 1435 |
-
|
| 1436 |
-
Be aware that this is a quick workaround, not a general method to create
|
| 1437 |
-
special symbolic functions. If you want to create a symbolic function to be
|
| 1438 |
-
used by all the machinery of SymPy you should subclass the ``Function``
|
| 1439 |
-
class.
|
| 1440 |
-
|
| 1441 |
-
Parameters
|
| 1442 |
-
----------
|
| 1443 |
-
symfunc : ``str`` or ``UndefinedFunction`` instance
|
| 1444 |
-
If ``str``, then create new ``UndefinedFunction`` with this as
|
| 1445 |
-
name. If ``symfunc`` is an Undefined function, create a new function
|
| 1446 |
-
with the same name and the implemented function attached.
|
| 1447 |
-
implementation : callable
|
| 1448 |
-
numerical implementation to be called by ``evalf()`` or ``lambdify``
|
| 1449 |
-
|
| 1450 |
-
Returns
|
| 1451 |
-
-------
|
| 1452 |
-
afunc : sympy.FunctionClass instance
|
| 1453 |
-
function with attached implementation
|
| 1454 |
-
|
| 1455 |
-
Examples
|
| 1456 |
-
========
|
| 1457 |
-
|
| 1458 |
-
>>> from sympy.abc import x
|
| 1459 |
-
>>> from sympy.utilities.lambdify import implemented_function
|
| 1460 |
-
>>> from sympy import lambdify
|
| 1461 |
-
>>> f = implemented_function('f', lambda x: x+1)
|
| 1462 |
-
>>> lam_f = lambdify(x, f(x))
|
| 1463 |
-
>>> lam_f(4)
|
| 1464 |
-
5
|
| 1465 |
-
"""
|
| 1466 |
-
# Delayed import to avoid circular imports
|
| 1467 |
-
from sympy.core.function import UndefinedFunction
|
| 1468 |
-
# if name, create function to hold implementation
|
| 1469 |
-
kwargs = {}
|
| 1470 |
-
if isinstance(symfunc, UndefinedFunction):
|
| 1471 |
-
kwargs = symfunc._kwargs
|
| 1472 |
-
symfunc = symfunc.__name__
|
| 1473 |
-
if isinstance(symfunc, str):
|
| 1474 |
-
# Keyword arguments to UndefinedFunction are added as attributes to
|
| 1475 |
-
# the created class.
|
| 1476 |
-
symfunc = UndefinedFunction(
|
| 1477 |
-
symfunc, _imp_=staticmethod(implementation), **kwargs)
|
| 1478 |
-
elif not isinstance(symfunc, UndefinedFunction):
|
| 1479 |
-
raise ValueError(filldedent('''
|
| 1480 |
-
symfunc should be either a string or
|
| 1481 |
-
an UndefinedFunction instance.'''))
|
| 1482 |
-
return symfunc
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
def _too_large_for_docstring(expr, limit):
|
| 1486 |
-
"""Decide whether an ``Expr`` is too large to be fully rendered in a
|
| 1487 |
-
``lambdify`` docstring.
|
| 1488 |
-
|
| 1489 |
-
This is a fast alternative to ``count_ops``, which can become prohibitively
|
| 1490 |
-
slow for large expressions, because in this instance we only care whether
|
| 1491 |
-
``limit`` is exceeded rather than counting the exact number of nodes in the
|
| 1492 |
-
expression.
|
| 1493 |
-
|
| 1494 |
-
Parameters
|
| 1495 |
-
==========
|
| 1496 |
-
expr : ``Expr``, (nested) ``list`` of ``Expr``, or ``Matrix``
|
| 1497 |
-
The same objects that can be passed to the ``expr`` argument of
|
| 1498 |
-
``lambdify``.
|
| 1499 |
-
limit : ``int`` or ``None``
|
| 1500 |
-
The threshold above which an expression contains too many nodes to be
|
| 1501 |
-
usefully rendered in the docstring. If ``None`` then there is no limit.
|
| 1502 |
-
|
| 1503 |
-
Returns
|
| 1504 |
-
=======
|
| 1505 |
-
bool
|
| 1506 |
-
``True`` if the number of nodes in the expression exceeds the limit,
|
| 1507 |
-
``False`` otherwise.
|
| 1508 |
-
|
| 1509 |
-
Examples
|
| 1510 |
-
========
|
| 1511 |
-
|
| 1512 |
-
>>> from sympy.abc import x, y, z
|
| 1513 |
-
>>> from sympy.utilities.lambdify import _too_large_for_docstring
|
| 1514 |
-
>>> expr = x
|
| 1515 |
-
>>> _too_large_for_docstring(expr, None)
|
| 1516 |
-
False
|
| 1517 |
-
>>> _too_large_for_docstring(expr, 100)
|
| 1518 |
-
False
|
| 1519 |
-
>>> _too_large_for_docstring(expr, 1)
|
| 1520 |
-
False
|
| 1521 |
-
>>> _too_large_for_docstring(expr, 0)
|
| 1522 |
-
True
|
| 1523 |
-
>>> _too_large_for_docstring(expr, -1)
|
| 1524 |
-
True
|
| 1525 |
-
|
| 1526 |
-
Does this split it?
|
| 1527 |
-
|
| 1528 |
-
>>> expr = [x, y, z]
|
| 1529 |
-
>>> _too_large_for_docstring(expr, None)
|
| 1530 |
-
False
|
| 1531 |
-
>>> _too_large_for_docstring(expr, 100)
|
| 1532 |
-
False
|
| 1533 |
-
>>> _too_large_for_docstring(expr, 1)
|
| 1534 |
-
True
|
| 1535 |
-
>>> _too_large_for_docstring(expr, 0)
|
| 1536 |
-
True
|
| 1537 |
-
>>> _too_large_for_docstring(expr, -1)
|
| 1538 |
-
True
|
| 1539 |
-
|
| 1540 |
-
>>> expr = [x, [y], z, [[x+y], [x*y*z, [x+y+z]]]]
|
| 1541 |
-
>>> _too_large_for_docstring(expr, None)
|
| 1542 |
-
False
|
| 1543 |
-
>>> _too_large_for_docstring(expr, 100)
|
| 1544 |
-
False
|
| 1545 |
-
>>> _too_large_for_docstring(expr, 1)
|
| 1546 |
-
True
|
| 1547 |
-
>>> _too_large_for_docstring(expr, 0)
|
| 1548 |
-
True
|
| 1549 |
-
>>> _too_large_for_docstring(expr, -1)
|
| 1550 |
-
True
|
| 1551 |
-
|
| 1552 |
-
>>> expr = ((x + y + z)**5).expand()
|
| 1553 |
-
>>> _too_large_for_docstring(expr, None)
|
| 1554 |
-
False
|
| 1555 |
-
>>> _too_large_for_docstring(expr, 100)
|
| 1556 |
-
True
|
| 1557 |
-
>>> _too_large_for_docstring(expr, 1)
|
| 1558 |
-
True
|
| 1559 |
-
>>> _too_large_for_docstring(expr, 0)
|
| 1560 |
-
True
|
| 1561 |
-
>>> _too_large_for_docstring(expr, -1)
|
| 1562 |
-
True
|
| 1563 |
-
|
| 1564 |
-
>>> from sympy import Matrix
|
| 1565 |
-
>>> expr = Matrix([[(x + y + z), ((x + y + z)**2).expand(),
|
| 1566 |
-
... ((x + y + z)**3).expand(), ((x + y + z)**4).expand()]])
|
| 1567 |
-
>>> _too_large_for_docstring(expr, None)
|
| 1568 |
-
False
|
| 1569 |
-
>>> _too_large_for_docstring(expr, 1000)
|
| 1570 |
-
False
|
| 1571 |
-
>>> _too_large_for_docstring(expr, 100)
|
| 1572 |
-
True
|
| 1573 |
-
>>> _too_large_for_docstring(expr, 1)
|
| 1574 |
-
True
|
| 1575 |
-
>>> _too_large_for_docstring(expr, 0)
|
| 1576 |
-
True
|
| 1577 |
-
>>> _too_large_for_docstring(expr, -1)
|
| 1578 |
-
True
|
| 1579 |
-
|
| 1580 |
-
"""
|
| 1581 |
-
# Must be imported here to avoid a circular import error
|
| 1582 |
-
from sympy.core.traversal import postorder_traversal
|
| 1583 |
-
|
| 1584 |
-
if limit is None:
|
| 1585 |
-
return False
|
| 1586 |
-
|
| 1587 |
-
i = 0
|
| 1588 |
-
for _ in postorder_traversal(expr):
|
| 1589 |
-
i += 1
|
| 1590 |
-
if i > limit:
|
| 1591 |
-
return True
|
| 1592 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/magic.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
"""Functions that involve magic. """
|
| 2 |
-
|
| 3 |
-
def pollute(names, objects):
|
| 4 |
-
"""Pollute the global namespace with symbols -> objects mapping. """
|
| 5 |
-
from inspect import currentframe
|
| 6 |
-
frame = currentframe().f_back.f_back
|
| 7 |
-
|
| 8 |
-
try:
|
| 9 |
-
for name, obj in zip(names, objects):
|
| 10 |
-
frame.f_globals[name] = obj
|
| 11 |
-
finally:
|
| 12 |
-
del frame # break cyclic dependencies as stated in inspect docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/matchpy_connector.py
DELETED
|
@@ -1,340 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The objects in this module allow the usage of the MatchPy pattern matching
|
| 3 |
-
library on SymPy expressions.
|
| 4 |
-
"""
|
| 5 |
-
import re
|
| 6 |
-
from typing import List, Callable, NamedTuple, Any, Dict
|
| 7 |
-
|
| 8 |
-
from sympy.core.sympify import _sympify
|
| 9 |
-
from sympy.external import import_module
|
| 10 |
-
from sympy.functions import (log, sin, cos, tan, cot, csc, sec, erf, gamma, uppergamma)
|
| 11 |
-
from sympy.functions.elementary.hyperbolic import acosh, asinh, atanh, acoth, acsch, asech, cosh, sinh, tanh, coth, sech, csch
|
| 12 |
-
from sympy.functions.elementary.trigonometric import atan, acsc, asin, acot, acos, asec
|
| 13 |
-
from sympy.functions.special.error_functions import fresnelc, fresnels, erfc, erfi, Ei
|
| 14 |
-
from sympy.core.add import Add
|
| 15 |
-
from sympy.core.basic import Basic
|
| 16 |
-
from sympy.core.expr import Expr
|
| 17 |
-
from sympy.core.mul import Mul
|
| 18 |
-
from sympy.core.power import Pow
|
| 19 |
-
from sympy.core.relational import (Equality, Unequality)
|
| 20 |
-
from sympy.core.symbol import Symbol
|
| 21 |
-
from sympy.functions.elementary.exponential import exp
|
| 22 |
-
from sympy.integrals.integrals import Integral
|
| 23 |
-
from sympy.printing.repr import srepr
|
| 24 |
-
from sympy.utilities.decorator import doctest_depends_on
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
matchpy = import_module("matchpy")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
__doctest_requires__ = {('*',): ['matchpy']}
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
if matchpy:
|
| 34 |
-
from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation
|
| 35 |
-
from matchpy.expressions.functions import op_iter, create_operation_expression, op_len
|
| 36 |
-
|
| 37 |
-
Operation.register(Integral)
|
| 38 |
-
Operation.register(Pow)
|
| 39 |
-
OneIdentityOperation.register(Pow)
|
| 40 |
-
|
| 41 |
-
Operation.register(Add)
|
| 42 |
-
OneIdentityOperation.register(Add)
|
| 43 |
-
CommutativeOperation.register(Add)
|
| 44 |
-
AssociativeOperation.register(Add)
|
| 45 |
-
|
| 46 |
-
Operation.register(Mul)
|
| 47 |
-
OneIdentityOperation.register(Mul)
|
| 48 |
-
CommutativeOperation.register(Mul)
|
| 49 |
-
AssociativeOperation.register(Mul)
|
| 50 |
-
|
| 51 |
-
Operation.register(Equality)
|
| 52 |
-
CommutativeOperation.register(Equality)
|
| 53 |
-
Operation.register(Unequality)
|
| 54 |
-
CommutativeOperation.register(Unequality)
|
| 55 |
-
|
| 56 |
-
Operation.register(exp)
|
| 57 |
-
Operation.register(log)
|
| 58 |
-
Operation.register(gamma)
|
| 59 |
-
Operation.register(uppergamma)
|
| 60 |
-
Operation.register(fresnels)
|
| 61 |
-
Operation.register(fresnelc)
|
| 62 |
-
Operation.register(erf)
|
| 63 |
-
Operation.register(Ei)
|
| 64 |
-
Operation.register(erfc)
|
| 65 |
-
Operation.register(erfi)
|
| 66 |
-
Operation.register(sin)
|
| 67 |
-
Operation.register(cos)
|
| 68 |
-
Operation.register(tan)
|
| 69 |
-
Operation.register(cot)
|
| 70 |
-
Operation.register(csc)
|
| 71 |
-
Operation.register(sec)
|
| 72 |
-
Operation.register(sinh)
|
| 73 |
-
Operation.register(cosh)
|
| 74 |
-
Operation.register(tanh)
|
| 75 |
-
Operation.register(coth)
|
| 76 |
-
Operation.register(csch)
|
| 77 |
-
Operation.register(sech)
|
| 78 |
-
Operation.register(asin)
|
| 79 |
-
Operation.register(acos)
|
| 80 |
-
Operation.register(atan)
|
| 81 |
-
Operation.register(acot)
|
| 82 |
-
Operation.register(acsc)
|
| 83 |
-
Operation.register(asec)
|
| 84 |
-
Operation.register(asinh)
|
| 85 |
-
Operation.register(acosh)
|
| 86 |
-
Operation.register(atanh)
|
| 87 |
-
Operation.register(acoth)
|
| 88 |
-
Operation.register(acsch)
|
| 89 |
-
Operation.register(asech)
|
| 90 |
-
|
| 91 |
-
@op_iter.register(Integral) # type: ignore
|
| 92 |
-
def _(operation):
|
| 93 |
-
return iter((operation._args[0],) + operation._args[1])
|
| 94 |
-
|
| 95 |
-
@op_iter.register(Basic) # type: ignore
|
| 96 |
-
def _(operation):
|
| 97 |
-
return iter(operation._args)
|
| 98 |
-
|
| 99 |
-
@op_len.register(Integral) # type: ignore
|
| 100 |
-
def _(operation):
|
| 101 |
-
return 1 + len(operation._args[1])
|
| 102 |
-
|
| 103 |
-
@op_len.register(Basic) # type: ignore
|
| 104 |
-
def _(operation):
|
| 105 |
-
return len(operation._args)
|
| 106 |
-
|
| 107 |
-
@create_operation_expression.register(Basic)
|
| 108 |
-
def sympy_op_factory(old_operation, new_operands, variable_name=True):
|
| 109 |
-
return type(old_operation)(*new_operands)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
if matchpy:
|
| 113 |
-
from matchpy import Wildcard
|
| 114 |
-
else:
|
| 115 |
-
class Wildcard: # type: ignore
|
| 116 |
-
def __init__(self, min_length, fixed_size, variable_name, optional):
|
| 117 |
-
self.min_count = min_length
|
| 118 |
-
self.fixed_size = fixed_size
|
| 119 |
-
self.variable_name = variable_name
|
| 120 |
-
self.optional = optional
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
@doctest_depends_on(modules=('matchpy',))
|
| 124 |
-
class _WildAbstract(Wildcard, Symbol):
|
| 125 |
-
min_length: int # abstract field required in subclasses
|
| 126 |
-
fixed_size: bool # abstract field required in subclasses
|
| 127 |
-
|
| 128 |
-
def __init__(self, variable_name=None, optional=None, **assumptions):
|
| 129 |
-
min_length = self.min_length
|
| 130 |
-
fixed_size = self.fixed_size
|
| 131 |
-
if optional is not None:
|
| 132 |
-
optional = _sympify(optional)
|
| 133 |
-
Wildcard.__init__(self, min_length, fixed_size, str(variable_name), optional)
|
| 134 |
-
|
| 135 |
-
def __getstate__(self):
|
| 136 |
-
return {
|
| 137 |
-
"min_length": self.min_length,
|
| 138 |
-
"fixed_size": self.fixed_size,
|
| 139 |
-
"min_count": self.min_count,
|
| 140 |
-
"variable_name": self.variable_name,
|
| 141 |
-
"optional": self.optional,
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
def __new__(cls, variable_name=None, optional=None, **assumptions):
|
| 145 |
-
cls._sanitize(assumptions, cls)
|
| 146 |
-
return _WildAbstract.__xnew__(cls, variable_name, optional, **assumptions)
|
| 147 |
-
|
| 148 |
-
def __getnewargs__(self):
|
| 149 |
-
return self.variable_name, self.optional
|
| 150 |
-
|
| 151 |
-
@staticmethod
|
| 152 |
-
def __xnew__(cls, variable_name=None, optional=None, **assumptions):
|
| 153 |
-
obj = Symbol.__xnew__(cls, variable_name, **assumptions)
|
| 154 |
-
return obj
|
| 155 |
-
|
| 156 |
-
def _hashable_content(self):
|
| 157 |
-
if self.optional:
|
| 158 |
-
return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional)
|
| 159 |
-
else:
|
| 160 |
-
return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name)
|
| 161 |
-
|
| 162 |
-
def __copy__(self) -> '_WildAbstract':
|
| 163 |
-
return type(self)(variable_name=self.variable_name, optional=self.optional)
|
| 164 |
-
|
| 165 |
-
def __repr__(self):
|
| 166 |
-
return str(self)
|
| 167 |
-
|
| 168 |
-
def __str__(self):
|
| 169 |
-
return self.name
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
@doctest_depends_on(modules=('matchpy',))
|
| 173 |
-
class WildDot(_WildAbstract):
|
| 174 |
-
min_length = 1
|
| 175 |
-
fixed_size = True
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
@doctest_depends_on(modules=('matchpy',))
|
| 179 |
-
class WildPlus(_WildAbstract):
|
| 180 |
-
min_length = 1
|
| 181 |
-
fixed_size = False
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
@doctest_depends_on(modules=('matchpy',))
|
| 185 |
-
class WildStar(_WildAbstract):
|
| 186 |
-
min_length = 0
|
| 187 |
-
fixed_size = False
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def _get_srepr(expr):
|
| 191 |
-
s = srepr(expr)
|
| 192 |
-
s = re.sub(r"WildDot\('(\w+)'\)", r"\1", s)
|
| 193 |
-
s = re.sub(r"WildPlus\('(\w+)'\)", r"*\1", s)
|
| 194 |
-
s = re.sub(r"WildStar\('(\w+)'\)", r"*\1", s)
|
| 195 |
-
return s
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
class ReplacementInfo(NamedTuple):
|
| 199 |
-
replacement: Any
|
| 200 |
-
info: Any
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
@doctest_depends_on(modules=('matchpy',))
|
| 204 |
-
class Replacer:
|
| 205 |
-
"""
|
| 206 |
-
Replacer object to perform multiple pattern matching and subexpression
|
| 207 |
-
replacements in SymPy expressions.
|
| 208 |
-
|
| 209 |
-
Examples
|
| 210 |
-
========
|
| 211 |
-
|
| 212 |
-
Example to construct a simple first degree equation solver:
|
| 213 |
-
|
| 214 |
-
>>> from sympy.utilities.matchpy_connector import WildDot, Replacer
|
| 215 |
-
>>> from sympy import Equality, Symbol
|
| 216 |
-
>>> x = Symbol("x")
|
| 217 |
-
>>> a_ = WildDot("a_", optional=1)
|
| 218 |
-
>>> b_ = WildDot("b_", optional=0)
|
| 219 |
-
|
| 220 |
-
The lines above have defined two wildcards, ``a_`` and ``b_``, the
|
| 221 |
-
coefficients of the equation `a x + b = 0`. The optional values specified
|
| 222 |
-
indicate which expression to return in case no match is found, they are
|
| 223 |
-
necessary in equations like `a x = 0` and `x + b = 0`.
|
| 224 |
-
|
| 225 |
-
Create two constraints to make sure that ``a_`` and ``b_`` will not match
|
| 226 |
-
any expression containing ``x``:
|
| 227 |
-
|
| 228 |
-
>>> from matchpy import CustomConstraint
|
| 229 |
-
>>> free_x_a = CustomConstraint(lambda a_: not a_.has(x))
|
| 230 |
-
>>> free_x_b = CustomConstraint(lambda b_: not b_.has(x))
|
| 231 |
-
|
| 232 |
-
Now create the rule replacer with the constraints:
|
| 233 |
-
|
| 234 |
-
>>> replacer = Replacer(common_constraints=[free_x_a, free_x_b])
|
| 235 |
-
|
| 236 |
-
Add the matching rule:
|
| 237 |
-
|
| 238 |
-
>>> replacer.add(Equality(a_*x + b_, 0), -b_/a_)
|
| 239 |
-
|
| 240 |
-
Let's try it:
|
| 241 |
-
|
| 242 |
-
>>> replacer.replace(Equality(3*x + 4, 0))
|
| 243 |
-
-4/3
|
| 244 |
-
|
| 245 |
-
Notice that it will not match equations expressed with other patterns:
|
| 246 |
-
|
| 247 |
-
>>> eq = Equality(3*x, 4)
|
| 248 |
-
>>> replacer.replace(eq)
|
| 249 |
-
Eq(3*x, 4)
|
| 250 |
-
|
| 251 |
-
In order to extend the matching patterns, define another one (we also need
|
| 252 |
-
to clear the cache, because the previous result has already been memorized
|
| 253 |
-
and the pattern matcher will not iterate again if given the same expression)
|
| 254 |
-
|
| 255 |
-
>>> replacer.add(Equality(a_*x, b_), b_/a_)
|
| 256 |
-
>>> replacer._matcher.clear()
|
| 257 |
-
>>> replacer.replace(eq)
|
| 258 |
-
4/3
|
| 259 |
-
"""
|
| 260 |
-
|
| 261 |
-
def __init__(self, common_constraints: list = [], lambdify: bool = False, info: bool = False):
|
| 262 |
-
self._matcher = matchpy.ManyToOneMatcher()
|
| 263 |
-
self._common_constraint = common_constraints
|
| 264 |
-
self._lambdify = lambdify
|
| 265 |
-
self._info = info
|
| 266 |
-
self._wildcards: Dict[str, Wildcard] = {}
|
| 267 |
-
|
| 268 |
-
def _get_lambda(self, lambda_str: str) -> Callable[..., Expr]:
|
| 269 |
-
exec("from sympy import *")
|
| 270 |
-
return eval(lambda_str, locals())
|
| 271 |
-
|
| 272 |
-
def _get_custom_constraint(self, constraint_expr: Expr, condition_template: str) -> Callable[..., Expr]:
|
| 273 |
-
wilds = [x.name for x in constraint_expr.atoms(_WildAbstract)]
|
| 274 |
-
lambdaargs = ', '.join(wilds)
|
| 275 |
-
fullexpr = _get_srepr(constraint_expr)
|
| 276 |
-
condition = condition_template.format(fullexpr)
|
| 277 |
-
return matchpy.CustomConstraint(
|
| 278 |
-
self._get_lambda(f"lambda {lambdaargs}: ({condition})"))
|
| 279 |
-
|
| 280 |
-
def _get_custom_constraint_nonfalse(self, constraint_expr: Expr) -> Callable[..., Expr]:
|
| 281 |
-
return self._get_custom_constraint(constraint_expr, "({}) != False")
|
| 282 |
-
|
| 283 |
-
def _get_custom_constraint_true(self, constraint_expr: Expr) -> Callable[..., Expr]:
|
| 284 |
-
return self._get_custom_constraint(constraint_expr, "({}) == True")
|
| 285 |
-
|
| 286 |
-
def add(self, expr: Expr, replacement, conditions_true: List[Expr] = [],
|
| 287 |
-
conditions_nonfalse: List[Expr] = [], info: Any = None) -> None:
|
| 288 |
-
expr = _sympify(expr)
|
| 289 |
-
replacement = _sympify(replacement)
|
| 290 |
-
constraints = self._common_constraint[:]
|
| 291 |
-
constraint_conditions_true = [
|
| 292 |
-
self._get_custom_constraint_true(cond) for cond in conditions_true]
|
| 293 |
-
constraint_conditions_nonfalse = [
|
| 294 |
-
self._get_custom_constraint_nonfalse(cond) for cond in conditions_nonfalse]
|
| 295 |
-
constraints.extend(constraint_conditions_true)
|
| 296 |
-
constraints.extend(constraint_conditions_nonfalse)
|
| 297 |
-
pattern = matchpy.Pattern(expr, *constraints)
|
| 298 |
-
if self._lambdify:
|
| 299 |
-
lambda_str = f"lambda {', '.join((x.name for x in expr.atoms(_WildAbstract)))}: {_get_srepr(replacement)}"
|
| 300 |
-
lambda_expr = self._get_lambda(lambda_str)
|
| 301 |
-
replacement = lambda_expr
|
| 302 |
-
else:
|
| 303 |
-
self._wildcards.update({str(i): i for i in expr.atoms(Wildcard)})
|
| 304 |
-
if self._info:
|
| 305 |
-
replacement = ReplacementInfo(replacement, info)
|
| 306 |
-
self._matcher.add(pattern, replacement)
|
| 307 |
-
|
| 308 |
-
def replace(self, expression, max_count: int = -1):
|
| 309 |
-
# This method partly rewrites the .replace method of ManyToOneReplacer
|
| 310 |
-
# in MatchPy.
|
| 311 |
-
# License: https://github.com/HPAC/matchpy/blob/master/LICENSE
|
| 312 |
-
infos = []
|
| 313 |
-
replaced = True
|
| 314 |
-
replace_count = 0
|
| 315 |
-
while replaced and (max_count < 0 or replace_count < max_count):
|
| 316 |
-
replaced = False
|
| 317 |
-
for subexpr, pos in matchpy.preorder_iter_with_position(expression):
|
| 318 |
-
try:
|
| 319 |
-
replacement_data, subst = next(iter(self._matcher.match(subexpr)))
|
| 320 |
-
if self._info:
|
| 321 |
-
replacement = replacement_data.replacement
|
| 322 |
-
infos.append(replacement_data.info)
|
| 323 |
-
else:
|
| 324 |
-
replacement = replacement_data
|
| 325 |
-
|
| 326 |
-
if self._lambdify:
|
| 327 |
-
result = replacement(**subst)
|
| 328 |
-
else:
|
| 329 |
-
result = replacement.xreplace({self._wildcards[k]: v for k, v in subst.items()})
|
| 330 |
-
|
| 331 |
-
expression = matchpy.functions.replace(expression, pos, result)
|
| 332 |
-
replaced = True
|
| 333 |
-
break
|
| 334 |
-
except StopIteration:
|
| 335 |
-
pass
|
| 336 |
-
replace_count += 1
|
| 337 |
-
if self._info:
|
| 338 |
-
return expression, infos
|
| 339 |
-
else:
|
| 340 |
-
return expression
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/__init__.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
"""Module with some functions for MathML, like transforming MathML
|
| 2 |
-
content in MathML presentation.
|
| 3 |
-
|
| 4 |
-
To use this module, you will need lxml.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
-
from sympy.utilities.decorator import doctest_depends_on
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
__doctest_requires__ = {('apply_xsl', 'c2p'): ['lxml']}
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def add_mathml_headers(s):
|
| 16 |
-
return """<math xmlns:mml="http://www.w3.org/1998/Math/MathML"
|
| 17 |
-
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
| 18 |
-
xsi:schemaLocation="http://www.w3.org/1998/Math/MathML
|
| 19 |
-
http://www.w3.org/Math/XMLSchema/mathml2/mathml2.xsd">""" + s + "</math>"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _read_binary(pkgname, filename):
|
| 23 |
-
import sys
|
| 24 |
-
|
| 25 |
-
if sys.version_info >= (3, 10):
|
| 26 |
-
# files was added in Python 3.9 but only seems to work here in 3.10+
|
| 27 |
-
from importlib.resources import files
|
| 28 |
-
return files(pkgname).joinpath(filename).read_bytes()
|
| 29 |
-
else:
|
| 30 |
-
# read_binary was deprecated in Python 3.11
|
| 31 |
-
from importlib.resources import read_binary
|
| 32 |
-
return read_binary(pkgname, filename)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def _read_xsl(xsl):
|
| 36 |
-
# Previously these values were allowed:
|
| 37 |
-
if xsl == 'mathml/data/simple_mmlctop.xsl':
|
| 38 |
-
xsl = 'simple_mmlctop.xsl'
|
| 39 |
-
elif xsl == 'mathml/data/mmlctop.xsl':
|
| 40 |
-
xsl = 'mmlctop.xsl'
|
| 41 |
-
elif xsl == 'mathml/data/mmltex.xsl':
|
| 42 |
-
xsl = 'mmltex.xsl'
|
| 43 |
-
|
| 44 |
-
if xsl in ['simple_mmlctop.xsl', 'mmlctop.xsl', 'mmltex.xsl']:
|
| 45 |
-
xslbytes = _read_binary('sympy.utilities.mathml.data', xsl)
|
| 46 |
-
else:
|
| 47 |
-
xslbytes = Path(xsl).read_bytes()
|
| 48 |
-
|
| 49 |
-
return xslbytes
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@doctest_depends_on(modules=('lxml',))
|
| 53 |
-
def apply_xsl(mml, xsl):
|
| 54 |
-
"""Apply a xsl to a MathML string.
|
| 55 |
-
|
| 56 |
-
Parameters
|
| 57 |
-
==========
|
| 58 |
-
|
| 59 |
-
mml
|
| 60 |
-
A string with MathML code.
|
| 61 |
-
xsl
|
| 62 |
-
A string giving the name of an xsl (xml stylesheet) file which can be
|
| 63 |
-
found in sympy/utilities/mathml/data. The following files are supplied
|
| 64 |
-
with SymPy:
|
| 65 |
-
|
| 66 |
-
- mmlctop.xsl
|
| 67 |
-
- mmltex.xsl
|
| 68 |
-
- simple_mmlctop.xsl
|
| 69 |
-
|
| 70 |
-
Alternatively, a full path to an xsl file can be given.
|
| 71 |
-
|
| 72 |
-
Examples
|
| 73 |
-
========
|
| 74 |
-
|
| 75 |
-
>>> from sympy.utilities.mathml import apply_xsl
|
| 76 |
-
>>> xsl = 'simple_mmlctop.xsl'
|
| 77 |
-
>>> mml = '<apply> <plus/> <ci>a</ci> <ci>b</ci> </apply>'
|
| 78 |
-
>>> res = apply_xsl(mml,xsl)
|
| 79 |
-
>>> print(res)
|
| 80 |
-
<?xml version="1.0"?>
|
| 81 |
-
<mrow xmlns="http://www.w3.org/1998/Math/MathML">
|
| 82 |
-
<mi>a</mi>
|
| 83 |
-
<mo> + </mo>
|
| 84 |
-
<mi>b</mi>
|
| 85 |
-
</mrow>
|
| 86 |
-
"""
|
| 87 |
-
from lxml import etree
|
| 88 |
-
|
| 89 |
-
parser = etree.XMLParser(resolve_entities=False)
|
| 90 |
-
ac = etree.XSLTAccessControl.DENY_ALL
|
| 91 |
-
|
| 92 |
-
s = etree.XML(_read_xsl(xsl), parser=parser)
|
| 93 |
-
transform = etree.XSLT(s, access_control=ac)
|
| 94 |
-
doc = etree.XML(mml, parser=parser)
|
| 95 |
-
result = transform(doc)
|
| 96 |
-
s = str(result)
|
| 97 |
-
return s
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
@doctest_depends_on(modules=('lxml',))
|
| 101 |
-
def c2p(mml, simple=False):
|
| 102 |
-
"""Transforms a document in MathML content (like the one that sympy produces)
|
| 103 |
-
in one document in MathML presentation, more suitable for printing, and more
|
| 104 |
-
widely accepted
|
| 105 |
-
|
| 106 |
-
Examples
|
| 107 |
-
========
|
| 108 |
-
|
| 109 |
-
>>> from sympy.utilities.mathml import c2p
|
| 110 |
-
>>> mml = '<apply> <exp/> <cn>2</cn> </apply>'
|
| 111 |
-
>>> c2p(mml,simple=True) != c2p(mml,simple=False)
|
| 112 |
-
True
|
| 113 |
-
|
| 114 |
-
"""
|
| 115 |
-
|
| 116 |
-
if not mml.startswith('<math'):
|
| 117 |
-
mml = add_mathml_headers(mml)
|
| 118 |
-
|
| 119 |
-
if simple:
|
| 120 |
-
return apply_xsl(mml, 'mathml/data/simple_mmlctop.xsl')
|
| 121 |
-
|
| 122 |
-
return apply_xsl(mml, 'mathml/data/mmlctop.xsl')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/__init__.py
DELETED
|
File without changes
|
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmlctop.xsl
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmltex.xsl
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/memoization.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
from functools import wraps
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def recurrence_memo(initial):
|
| 5 |
-
"""
|
| 6 |
-
Memo decorator for sequences defined by recurrence
|
| 7 |
-
|
| 8 |
-
Examples
|
| 9 |
-
========
|
| 10 |
-
|
| 11 |
-
>>> from sympy.utilities.memoization import recurrence_memo
|
| 12 |
-
>>> @recurrence_memo([1]) # 0! = 1
|
| 13 |
-
... def factorial(n, prev):
|
| 14 |
-
... return n * prev[-1]
|
| 15 |
-
>>> factorial(4)
|
| 16 |
-
24
|
| 17 |
-
>>> factorial(3) # use cache values
|
| 18 |
-
6
|
| 19 |
-
>>> factorial.cache_length() # cache length can be obtained
|
| 20 |
-
5
|
| 21 |
-
>>> factorial.fetch_item(slice(2, 4))
|
| 22 |
-
[2, 6]
|
| 23 |
-
|
| 24 |
-
"""
|
| 25 |
-
cache = initial
|
| 26 |
-
|
| 27 |
-
def decorator(f):
|
| 28 |
-
@wraps(f)
|
| 29 |
-
def g(n):
|
| 30 |
-
L = len(cache)
|
| 31 |
-
if n < L:
|
| 32 |
-
return cache[n]
|
| 33 |
-
for i in range(L, n + 1):
|
| 34 |
-
cache.append(f(i, cache))
|
| 35 |
-
return cache[-1]
|
| 36 |
-
g.cache_length = lambda: len(cache)
|
| 37 |
-
g.fetch_item = lambda x: cache[x]
|
| 38 |
-
return g
|
| 39 |
-
return decorator
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def assoc_recurrence_memo(base_seq):
|
| 43 |
-
"""
|
| 44 |
-
Memo decorator for associated sequences defined by recurrence starting from base
|
| 45 |
-
|
| 46 |
-
base_seq(n) -- callable to get base sequence elements
|
| 47 |
-
|
| 48 |
-
XXX works only for Pn0 = base_seq(0) cases
|
| 49 |
-
XXX works only for m <= n cases
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
cache = []
|
| 53 |
-
|
| 54 |
-
def decorator(f):
|
| 55 |
-
@wraps(f)
|
| 56 |
-
def g(n, m):
|
| 57 |
-
L = len(cache)
|
| 58 |
-
if n < L:
|
| 59 |
-
return cache[n][m]
|
| 60 |
-
|
| 61 |
-
for i in range(L, n + 1):
|
| 62 |
-
# get base sequence
|
| 63 |
-
F_i0 = base_seq(i)
|
| 64 |
-
F_i_cache = [F_i0]
|
| 65 |
-
cache.append(F_i_cache)
|
| 66 |
-
|
| 67 |
-
# XXX only works for m <= n cases
|
| 68 |
-
# generate assoc sequence
|
| 69 |
-
for j in range(1, i + 1):
|
| 70 |
-
F_ij = f(i, j, cache)
|
| 71 |
-
F_i_cache.append(F_ij)
|
| 72 |
-
|
| 73 |
-
return cache[n][m]
|
| 74 |
-
|
| 75 |
-
return g
|
| 76 |
-
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/misc.py
DELETED
|
@@ -1,564 +0,0 @@
|
|
| 1 |
-
"""Miscellaneous stuff that does not really fit anywhere else."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import operator
|
| 6 |
-
import sys
|
| 7 |
-
import os
|
| 8 |
-
import re as _re
|
| 9 |
-
import struct
|
| 10 |
-
from textwrap import fill, dedent
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class Undecidable(ValueError):
|
| 14 |
-
# an error to be raised when a decision cannot be made definitively
|
| 15 |
-
# where a definitive answer is needed
|
| 16 |
-
pass
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def filldedent(s, w=70, **kwargs):
|
| 20 |
-
"""
|
| 21 |
-
Strips leading and trailing empty lines from a copy of ``s``, then dedents,
|
| 22 |
-
fills and returns it.
|
| 23 |
-
|
| 24 |
-
Empty line stripping serves to deal with docstrings like this one that
|
| 25 |
-
start with a newline after the initial triple quote, inserting an empty
|
| 26 |
-
line at the beginning of the string.
|
| 27 |
-
|
| 28 |
-
Additional keyword arguments will be passed to ``textwrap.fill()``.
|
| 29 |
-
|
| 30 |
-
See Also
|
| 31 |
-
========
|
| 32 |
-
strlines, rawlines
|
| 33 |
-
|
| 34 |
-
"""
|
| 35 |
-
return '\n' + fill(dedent(str(s)).strip('\n'), width=w, **kwargs)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def strlines(s, c=64, short=False):
|
| 39 |
-
"""Return a cut-and-pastable string that, when printed, is
|
| 40 |
-
equivalent to the input. The lines will be surrounded by
|
| 41 |
-
parentheses and no line will be longer than c (default 64)
|
| 42 |
-
characters. If the line contains newlines characters, the
|
| 43 |
-
`rawlines` result will be returned. If ``short`` is True
|
| 44 |
-
(default is False) then if there is one line it will be
|
| 45 |
-
returned without bounding parentheses.
|
| 46 |
-
|
| 47 |
-
Examples
|
| 48 |
-
========
|
| 49 |
-
|
| 50 |
-
>>> from sympy.utilities.misc import strlines
|
| 51 |
-
>>> q = 'this is a long string that should be broken into shorter lines'
|
| 52 |
-
>>> print(strlines(q, 40))
|
| 53 |
-
(
|
| 54 |
-
'this is a long string that should be b'
|
| 55 |
-
'roken into shorter lines'
|
| 56 |
-
)
|
| 57 |
-
>>> q == (
|
| 58 |
-
... 'this is a long string that should be b'
|
| 59 |
-
... 'roken into shorter lines'
|
| 60 |
-
... )
|
| 61 |
-
True
|
| 62 |
-
|
| 63 |
-
See Also
|
| 64 |
-
========
|
| 65 |
-
filldedent, rawlines
|
| 66 |
-
"""
|
| 67 |
-
if not isinstance(s, str):
|
| 68 |
-
raise ValueError('expecting string input')
|
| 69 |
-
if '\n' in s:
|
| 70 |
-
return rawlines(s)
|
| 71 |
-
q = '"' if repr(s).startswith('"') else "'"
|
| 72 |
-
q = (q,)*2
|
| 73 |
-
if '\\' in s: # use r-string
|
| 74 |
-
m = '(\nr%s%%s%s\n)' % q
|
| 75 |
-
j = '%s\nr%s' % q
|
| 76 |
-
c -= 3
|
| 77 |
-
else:
|
| 78 |
-
m = '(\n%s%%s%s\n)' % q
|
| 79 |
-
j = '%s\n%s' % q
|
| 80 |
-
c -= 2
|
| 81 |
-
out = []
|
| 82 |
-
while s:
|
| 83 |
-
out.append(s[:c])
|
| 84 |
-
s=s[c:]
|
| 85 |
-
if short and len(out) == 1:
|
| 86 |
-
return (m % out[0]).splitlines()[1] # strip bounding (\n...\n)
|
| 87 |
-
return m % j.join(out)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def rawlines(s):
|
| 91 |
-
"""Return a cut-and-pastable string that, when printed, is equivalent
|
| 92 |
-
to the input. Use this when there is more than one line in the
|
| 93 |
-
string. The string returned is formatted so it can be indented
|
| 94 |
-
nicely within tests; in some cases it is wrapped in the dedent
|
| 95 |
-
function which has to be imported from textwrap.
|
| 96 |
-
|
| 97 |
-
Examples
|
| 98 |
-
========
|
| 99 |
-
|
| 100 |
-
Note: because there are characters in the examples below that need
|
| 101 |
-
to be escaped because they are themselves within a triple quoted
|
| 102 |
-
docstring, expressions below look more complicated than they would
|
| 103 |
-
be if they were printed in an interpreter window.
|
| 104 |
-
|
| 105 |
-
>>> from sympy.utilities.misc import rawlines
|
| 106 |
-
>>> from sympy import TableForm
|
| 107 |
-
>>> s = str(TableForm([[1, 10]], headings=(None, ['a', 'bee'])))
|
| 108 |
-
>>> print(rawlines(s))
|
| 109 |
-
(
|
| 110 |
-
'a bee\\n'
|
| 111 |
-
'-----\\n'
|
| 112 |
-
'1 10 '
|
| 113 |
-
)
|
| 114 |
-
>>> print(rawlines('''this
|
| 115 |
-
... that'''))
|
| 116 |
-
dedent('''\\
|
| 117 |
-
this
|
| 118 |
-
that''')
|
| 119 |
-
|
| 120 |
-
>>> print(rawlines('''this
|
| 121 |
-
... that
|
| 122 |
-
... '''))
|
| 123 |
-
dedent('''\\
|
| 124 |
-
this
|
| 125 |
-
that
|
| 126 |
-
''')
|
| 127 |
-
|
| 128 |
-
>>> s = \"\"\"this
|
| 129 |
-
... is a triple '''
|
| 130 |
-
... \"\"\"
|
| 131 |
-
>>> print(rawlines(s))
|
| 132 |
-
dedent(\"\"\"\\
|
| 133 |
-
this
|
| 134 |
-
is a triple '''
|
| 135 |
-
\"\"\")
|
| 136 |
-
|
| 137 |
-
>>> print(rawlines('''this
|
| 138 |
-
... that
|
| 139 |
-
... '''))
|
| 140 |
-
(
|
| 141 |
-
'this\\n'
|
| 142 |
-
'that\\n'
|
| 143 |
-
' '
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
See Also
|
| 147 |
-
========
|
| 148 |
-
filldedent, strlines
|
| 149 |
-
"""
|
| 150 |
-
lines = s.split('\n')
|
| 151 |
-
if len(lines) == 1:
|
| 152 |
-
return repr(lines[0])
|
| 153 |
-
triple = ["'''" in s, '"""' in s]
|
| 154 |
-
if any(li.endswith(' ') for li in lines) or '\\' in s or all(triple):
|
| 155 |
-
rv = []
|
| 156 |
-
# add on the newlines
|
| 157 |
-
trailing = s.endswith('\n')
|
| 158 |
-
last = len(lines) - 1
|
| 159 |
-
for i, li in enumerate(lines):
|
| 160 |
-
if i != last or trailing:
|
| 161 |
-
rv.append(repr(li + '\n'))
|
| 162 |
-
else:
|
| 163 |
-
rv.append(repr(li))
|
| 164 |
-
return '(\n %s\n)' % '\n '.join(rv)
|
| 165 |
-
else:
|
| 166 |
-
rv = '\n '.join(lines)
|
| 167 |
-
if triple[0]:
|
| 168 |
-
return 'dedent("""\\\n %s""")' % rv
|
| 169 |
-
else:
|
| 170 |
-
return "dedent('''\\\n %s''')" % rv
|
| 171 |
-
|
| 172 |
-
ARCH = str(struct.calcsize('P') * 8) + "-bit"
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
# XXX: PyPy does not support hash randomization
|
| 176 |
-
HASH_RANDOMIZATION = getattr(sys.flags, 'hash_randomization', False)
|
| 177 |
-
|
| 178 |
-
_debug_tmp: list[str] = []
|
| 179 |
-
_debug_iter = 0
|
| 180 |
-
|
| 181 |
-
def debug_decorator(func):
|
| 182 |
-
"""If SYMPY_DEBUG is True, it will print a nice execution tree with
|
| 183 |
-
arguments and results of all decorated functions, else do nothing.
|
| 184 |
-
"""
|
| 185 |
-
from sympy import SYMPY_DEBUG
|
| 186 |
-
|
| 187 |
-
if not SYMPY_DEBUG:
|
| 188 |
-
return func
|
| 189 |
-
|
| 190 |
-
def maketree(f, *args, **kw):
|
| 191 |
-
global _debug_tmp, _debug_iter
|
| 192 |
-
oldtmp = _debug_tmp
|
| 193 |
-
_debug_tmp = []
|
| 194 |
-
_debug_iter += 1
|
| 195 |
-
|
| 196 |
-
def tree(subtrees):
|
| 197 |
-
def indent(s, variant=1):
|
| 198 |
-
x = s.split("\n")
|
| 199 |
-
r = "+-%s\n" % x[0]
|
| 200 |
-
for a in x[1:]:
|
| 201 |
-
if a == "":
|
| 202 |
-
continue
|
| 203 |
-
if variant == 1:
|
| 204 |
-
r += "| %s\n" % a
|
| 205 |
-
else:
|
| 206 |
-
r += " %s\n" % a
|
| 207 |
-
return r
|
| 208 |
-
if len(subtrees) == 0:
|
| 209 |
-
return ""
|
| 210 |
-
f = []
|
| 211 |
-
for a in subtrees[:-1]:
|
| 212 |
-
f.append(indent(a))
|
| 213 |
-
f.append(indent(subtrees[-1], 2))
|
| 214 |
-
return ''.join(f)
|
| 215 |
-
|
| 216 |
-
# If there is a bug and the algorithm enters an infinite loop, enable the
|
| 217 |
-
# following lines. It will print the names and parameters of all major functions
|
| 218 |
-
# that are called, *before* they are called
|
| 219 |
-
#from functools import reduce
|
| 220 |
-
#print("%s%s %s%s" % (_debug_iter, reduce(lambda x, y: x + y, \
|
| 221 |
-
# map(lambda x: '-', range(1, 2 + _debug_iter))), f.__name__, args))
|
| 222 |
-
|
| 223 |
-
r = f(*args, **kw)
|
| 224 |
-
|
| 225 |
-
_debug_iter -= 1
|
| 226 |
-
s = "%s%s = %s\n" % (f.__name__, args, r)
|
| 227 |
-
if _debug_tmp != []:
|
| 228 |
-
s += tree(_debug_tmp)
|
| 229 |
-
_debug_tmp = oldtmp
|
| 230 |
-
_debug_tmp.append(s)
|
| 231 |
-
if _debug_iter == 0:
|
| 232 |
-
print(_debug_tmp[0])
|
| 233 |
-
_debug_tmp = []
|
| 234 |
-
return r
|
| 235 |
-
|
| 236 |
-
def decorated(*args, **kwargs):
|
| 237 |
-
return maketree(func, *args, **kwargs)
|
| 238 |
-
|
| 239 |
-
return decorated
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def debug(*args):
|
| 243 |
-
"""
|
| 244 |
-
Print ``*args`` if SYMPY_DEBUG is True, else do nothing.
|
| 245 |
-
"""
|
| 246 |
-
from sympy import SYMPY_DEBUG
|
| 247 |
-
if SYMPY_DEBUG:
|
| 248 |
-
print(*args, file=sys.stderr)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
def debugf(string, args):
|
| 252 |
-
"""
|
| 253 |
-
Print ``string%args`` if SYMPY_DEBUG is True, else do nothing. This is
|
| 254 |
-
intended for debug messages using formatted strings.
|
| 255 |
-
"""
|
| 256 |
-
from sympy import SYMPY_DEBUG
|
| 257 |
-
if SYMPY_DEBUG:
|
| 258 |
-
print(string%args, file=sys.stderr)
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
def find_executable(executable, path=None):
|
| 262 |
-
"""Try to find 'executable' in the directories listed in 'path' (a
|
| 263 |
-
string listing directories separated by 'os.pathsep'; defaults to
|
| 264 |
-
os.environ['PATH']). Returns the complete filename or None if not
|
| 265 |
-
found
|
| 266 |
-
"""
|
| 267 |
-
from .exceptions import sympy_deprecation_warning
|
| 268 |
-
sympy_deprecation_warning(
|
| 269 |
-
"""
|
| 270 |
-
sympy.utilities.misc.find_executable() is deprecated. Use the standard
|
| 271 |
-
library shutil.which() function instead.
|
| 272 |
-
""",
|
| 273 |
-
deprecated_since_version="1.7",
|
| 274 |
-
active_deprecations_target="deprecated-find-executable",
|
| 275 |
-
)
|
| 276 |
-
if path is None:
|
| 277 |
-
path = os.environ['PATH']
|
| 278 |
-
paths = path.split(os.pathsep)
|
| 279 |
-
extlist = ['']
|
| 280 |
-
if os.name == 'os2':
|
| 281 |
-
(base, ext) = os.path.splitext(executable)
|
| 282 |
-
# executable files on OS/2 can have an arbitrary extension, but
|
| 283 |
-
# .exe is automatically appended if no dot is present in the name
|
| 284 |
-
if not ext:
|
| 285 |
-
executable = executable + ".exe"
|
| 286 |
-
elif sys.platform == 'win32':
|
| 287 |
-
pathext = os.environ['PATHEXT'].lower().split(os.pathsep)
|
| 288 |
-
(base, ext) = os.path.splitext(executable)
|
| 289 |
-
if ext.lower() not in pathext:
|
| 290 |
-
extlist = pathext
|
| 291 |
-
for ext in extlist:
|
| 292 |
-
execname = executable + ext
|
| 293 |
-
if os.path.isfile(execname):
|
| 294 |
-
return execname
|
| 295 |
-
else:
|
| 296 |
-
for p in paths:
|
| 297 |
-
f = os.path.join(p, execname)
|
| 298 |
-
if os.path.isfile(f):
|
| 299 |
-
return f
|
| 300 |
-
|
| 301 |
-
return None
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
def func_name(x, short=False):
|
| 305 |
-
"""Return function name of `x` (if defined) else the `type(x)`.
|
| 306 |
-
If short is True and there is a shorter alias for the result,
|
| 307 |
-
return the alias.
|
| 308 |
-
|
| 309 |
-
Examples
|
| 310 |
-
========
|
| 311 |
-
|
| 312 |
-
>>> from sympy.utilities.misc import func_name
|
| 313 |
-
>>> from sympy import Matrix
|
| 314 |
-
>>> from sympy.abc import x
|
| 315 |
-
>>> func_name(Matrix.eye(3))
|
| 316 |
-
'MutableDenseMatrix'
|
| 317 |
-
>>> func_name(x < 1)
|
| 318 |
-
'StrictLessThan'
|
| 319 |
-
>>> func_name(x < 1, short=True)
|
| 320 |
-
'Lt'
|
| 321 |
-
"""
|
| 322 |
-
alias = {
|
| 323 |
-
'GreaterThan': 'Ge',
|
| 324 |
-
'StrictGreaterThan': 'Gt',
|
| 325 |
-
'LessThan': 'Le',
|
| 326 |
-
'StrictLessThan': 'Lt',
|
| 327 |
-
'Equality': 'Eq',
|
| 328 |
-
'Unequality': 'Ne',
|
| 329 |
-
}
|
| 330 |
-
typ = type(x)
|
| 331 |
-
if str(typ).startswith("<type '"):
|
| 332 |
-
typ = str(typ).split("'")[1].split("'")[0]
|
| 333 |
-
elif str(typ).startswith("<class '"):
|
| 334 |
-
typ = str(typ).split("'")[1].split("'")[0]
|
| 335 |
-
rv = getattr(getattr(x, 'func', x), '__name__', typ)
|
| 336 |
-
if '.' in rv:
|
| 337 |
-
rv = rv.split('.')[-1]
|
| 338 |
-
if short:
|
| 339 |
-
rv = alias.get(rv, rv)
|
| 340 |
-
return rv
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
def _replace(reps):
|
| 344 |
-
"""Return a function that can make the replacements, given in
|
| 345 |
-
``reps``, on a string. The replacements should be given as mapping.
|
| 346 |
-
|
| 347 |
-
Examples
|
| 348 |
-
========
|
| 349 |
-
|
| 350 |
-
>>> from sympy.utilities.misc import _replace
|
| 351 |
-
>>> f = _replace(dict(foo='bar', d='t'))
|
| 352 |
-
>>> f('food')
|
| 353 |
-
'bart'
|
| 354 |
-
>>> f = _replace({})
|
| 355 |
-
>>> f('food')
|
| 356 |
-
'food'
|
| 357 |
-
"""
|
| 358 |
-
if not reps:
|
| 359 |
-
return lambda x: x
|
| 360 |
-
D = lambda match: reps[match.group(0)]
|
| 361 |
-
pattern = _re.compile("|".join(
|
| 362 |
-
[_re.escape(k) for k, v in reps.items()]), _re.MULTILINE)
|
| 363 |
-
return lambda string: pattern.sub(D, string)
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
def replace(string, *reps):
|
| 367 |
-
"""Return ``string`` with all keys in ``reps`` replaced with
|
| 368 |
-
their corresponding values, longer strings first, irrespective
|
| 369 |
-
of the order they are given. ``reps`` may be passed as tuples
|
| 370 |
-
or a single mapping.
|
| 371 |
-
|
| 372 |
-
Examples
|
| 373 |
-
========
|
| 374 |
-
|
| 375 |
-
>>> from sympy.utilities.misc import replace
|
| 376 |
-
>>> replace('foo', {'oo': 'ar', 'f': 'b'})
|
| 377 |
-
'bar'
|
| 378 |
-
>>> replace("spamham sha", ("spam", "eggs"), ("sha","md5"))
|
| 379 |
-
'eggsham md5'
|
| 380 |
-
|
| 381 |
-
There is no guarantee that a unique answer will be
|
| 382 |
-
obtained if keys in a mapping overlap (i.e. are the same
|
| 383 |
-
length and have some identical sequence at the
|
| 384 |
-
beginning/end):
|
| 385 |
-
|
| 386 |
-
>>> reps = [
|
| 387 |
-
... ('ab', 'x'),
|
| 388 |
-
... ('bc', 'y')]
|
| 389 |
-
>>> replace('abc', *reps) in ('xc', 'ay')
|
| 390 |
-
True
|
| 391 |
-
|
| 392 |
-
References
|
| 393 |
-
==========
|
| 394 |
-
|
| 395 |
-
.. [1] https://stackoverflow.com/questions/6116978/how-to-replace-multiple-substrings-of-a-string
|
| 396 |
-
"""
|
| 397 |
-
if len(reps) == 1:
|
| 398 |
-
kv = reps[0]
|
| 399 |
-
if isinstance(kv, dict):
|
| 400 |
-
reps = kv
|
| 401 |
-
else:
|
| 402 |
-
return string.replace(*kv)
|
| 403 |
-
else:
|
| 404 |
-
reps = dict(reps)
|
| 405 |
-
return _replace(reps)(string)
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
def translate(s, a, b=None, c=None):
|
| 409 |
-
"""Return ``s`` where characters have been replaced or deleted.
|
| 410 |
-
|
| 411 |
-
SYNTAX
|
| 412 |
-
======
|
| 413 |
-
|
| 414 |
-
translate(s, None, deletechars):
|
| 415 |
-
all characters in ``deletechars`` are deleted
|
| 416 |
-
translate(s, map [,deletechars]):
|
| 417 |
-
all characters in ``deletechars`` (if provided) are deleted
|
| 418 |
-
then the replacements defined by map are made; if the keys
|
| 419 |
-
of map are strings then the longer ones are handled first.
|
| 420 |
-
Multicharacter deletions should have a value of ''.
|
| 421 |
-
translate(s, oldchars, newchars, deletechars)
|
| 422 |
-
all characters in ``deletechars`` are deleted
|
| 423 |
-
then each character in ``oldchars`` is replaced with the
|
| 424 |
-
corresponding character in ``newchars``
|
| 425 |
-
|
| 426 |
-
Examples
|
| 427 |
-
========
|
| 428 |
-
|
| 429 |
-
>>> from sympy.utilities.misc import translate
|
| 430 |
-
>>> abc = 'abc'
|
| 431 |
-
>>> translate(abc, None, 'a')
|
| 432 |
-
'bc'
|
| 433 |
-
>>> translate(abc, {'a': 'x'}, 'c')
|
| 434 |
-
'xb'
|
| 435 |
-
>>> translate(abc, {'abc': 'x', 'a': 'y'})
|
| 436 |
-
'x'
|
| 437 |
-
|
| 438 |
-
>>> translate('abcd', 'ac', 'AC', 'd')
|
| 439 |
-
'AbC'
|
| 440 |
-
|
| 441 |
-
There is no guarantee that a unique answer will be
|
| 442 |
-
obtained if keys in a mapping overlap are the same
|
| 443 |
-
length and have some identical sequences at the
|
| 444 |
-
beginning/end:
|
| 445 |
-
|
| 446 |
-
>>> translate(abc, {'ab': 'x', 'bc': 'y'}) in ('xc', 'ay')
|
| 447 |
-
True
|
| 448 |
-
"""
|
| 449 |
-
|
| 450 |
-
mr = {}
|
| 451 |
-
if a is None:
|
| 452 |
-
if c is not None:
|
| 453 |
-
raise ValueError('c should be None when a=None is passed, instead got %s' % c)
|
| 454 |
-
if b is None:
|
| 455 |
-
return s
|
| 456 |
-
c = b
|
| 457 |
-
a = b = ''
|
| 458 |
-
else:
|
| 459 |
-
if isinstance(a, dict):
|
| 460 |
-
short = {}
|
| 461 |
-
for k in list(a.keys()):
|
| 462 |
-
if len(k) == 1 and len(a[k]) == 1:
|
| 463 |
-
short[k] = a.pop(k)
|
| 464 |
-
mr = a
|
| 465 |
-
c = b
|
| 466 |
-
if short:
|
| 467 |
-
a, b = [''.join(i) for i in list(zip(*short.items()))]
|
| 468 |
-
else:
|
| 469 |
-
a = b = ''
|
| 470 |
-
elif len(a) != len(b):
|
| 471 |
-
raise ValueError('oldchars and newchars have different lengths')
|
| 472 |
-
|
| 473 |
-
if c:
|
| 474 |
-
val = str.maketrans('', '', c)
|
| 475 |
-
s = s.translate(val)
|
| 476 |
-
s = replace(s, mr)
|
| 477 |
-
n = str.maketrans(a, b)
|
| 478 |
-
return s.translate(n)
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
def ordinal(num):
|
| 482 |
-
"""Return ordinal number string of num, e.g. 1 becomes 1st.
|
| 483 |
-
"""
|
| 484 |
-
# modified from https://codereview.stackexchange.com/questions/41298/producing-ordinal-numbers
|
| 485 |
-
n = as_int(num)
|
| 486 |
-
k = abs(n) % 100
|
| 487 |
-
if 11 <= k <= 13:
|
| 488 |
-
suffix = 'th'
|
| 489 |
-
elif k % 10 == 1:
|
| 490 |
-
suffix = 'st'
|
| 491 |
-
elif k % 10 == 2:
|
| 492 |
-
suffix = 'nd'
|
| 493 |
-
elif k % 10 == 3:
|
| 494 |
-
suffix = 'rd'
|
| 495 |
-
else:
|
| 496 |
-
suffix = 'th'
|
| 497 |
-
return str(n) + suffix
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
def as_int(n, strict=True):
|
| 501 |
-
"""
|
| 502 |
-
Convert the argument to a builtin integer.
|
| 503 |
-
|
| 504 |
-
The return value is guaranteed to be equal to the input. ValueError is
|
| 505 |
-
raised if the input has a non-integral value. When ``strict`` is True, this
|
| 506 |
-
uses `__index__ <https://docs.python.org/3/reference/datamodel.html#object.__index__>`_
|
| 507 |
-
and when it is False it uses ``int``.
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
Examples
|
| 511 |
-
========
|
| 512 |
-
|
| 513 |
-
>>> from sympy.utilities.misc import as_int
|
| 514 |
-
>>> from sympy import sqrt, S
|
| 515 |
-
|
| 516 |
-
The function is primarily concerned with sanitizing input for
|
| 517 |
-
functions that need to work with builtin integers, so anything that
|
| 518 |
-
is unambiguously an integer should be returned as an int:
|
| 519 |
-
|
| 520 |
-
>>> as_int(S(3))
|
| 521 |
-
3
|
| 522 |
-
|
| 523 |
-
Floats, being of limited precision, are not assumed to be exact and
|
| 524 |
-
will raise an error unless the ``strict`` flag is False. This
|
| 525 |
-
precision issue becomes apparent for large floating point numbers:
|
| 526 |
-
|
| 527 |
-
>>> big = 1e23
|
| 528 |
-
>>> type(big) is float
|
| 529 |
-
True
|
| 530 |
-
>>> big == int(big)
|
| 531 |
-
True
|
| 532 |
-
>>> as_int(big)
|
| 533 |
-
Traceback (most recent call last):
|
| 534 |
-
...
|
| 535 |
-
ValueError: ... is not an integer
|
| 536 |
-
>>> as_int(big, strict=False)
|
| 537 |
-
99999999999999991611392
|
| 538 |
-
|
| 539 |
-
Input that might be a complex representation of an integer value is
|
| 540 |
-
also rejected by default:
|
| 541 |
-
|
| 542 |
-
>>> one = sqrt(3 + 2*sqrt(2)) - sqrt(2)
|
| 543 |
-
>>> int(one) == 1
|
| 544 |
-
True
|
| 545 |
-
>>> as_int(one)
|
| 546 |
-
Traceback (most recent call last):
|
| 547 |
-
...
|
| 548 |
-
ValueError: ... is not an integer
|
| 549 |
-
"""
|
| 550 |
-
if strict:
|
| 551 |
-
try:
|
| 552 |
-
if isinstance(n, bool):
|
| 553 |
-
raise TypeError
|
| 554 |
-
return operator.index(n)
|
| 555 |
-
except TypeError:
|
| 556 |
-
raise ValueError('%s is not an integer' % (n,))
|
| 557 |
-
else:
|
| 558 |
-
try:
|
| 559 |
-
result = int(n)
|
| 560 |
-
except TypeError:
|
| 561 |
-
raise ValueError('%s is not an integer' % (n,))
|
| 562 |
-
if n - result:
|
| 563 |
-
raise ValueError('%s is not an integer' % (n,))
|
| 564 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/pkgdata.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
# This module is deprecated and will be removed.
|
| 2 |
-
|
| 3 |
-
import sys
|
| 4 |
-
import os
|
| 5 |
-
from io import StringIO
|
| 6 |
-
|
| 7 |
-
from sympy.utilities.decorator import deprecated
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@deprecated(
|
| 11 |
-
"""
|
| 12 |
-
The sympy.utilities.pkgdata module and its get_resource function are
|
| 13 |
-
deprecated. Use the stdlib importlib.resources module instead.
|
| 14 |
-
""",
|
| 15 |
-
deprecated_since_version="1.12",
|
| 16 |
-
active_deprecations_target="pkgdata",
|
| 17 |
-
)
|
| 18 |
-
def get_resource(identifier, pkgname=__name__):
|
| 19 |
-
|
| 20 |
-
mod = sys.modules[pkgname]
|
| 21 |
-
fn = getattr(mod, '__file__', None)
|
| 22 |
-
if fn is None:
|
| 23 |
-
raise OSError("%r has no __file__!")
|
| 24 |
-
path = os.path.join(os.path.dirname(fn), identifier)
|
| 25 |
-
loader = getattr(mod, '__loader__', None)
|
| 26 |
-
if loader is not None:
|
| 27 |
-
try:
|
| 28 |
-
data = loader.get_data(path)
|
| 29 |
-
except (OSError, AttributeError):
|
| 30 |
-
pass
|
| 31 |
-
else:
|
| 32 |
-
return StringIO(data.decode('utf-8'))
|
| 33 |
-
return open(os.path.normpath(path), 'rb')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/pytest.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
.. deprecated:: 1.6
|
| 3 |
-
|
| 4 |
-
sympy.utilities.pytest has been renamed to sympy.testing.pytest.
|
| 5 |
-
"""
|
| 6 |
-
from sympy.utilities.exceptions import sympy_deprecation_warning
|
| 7 |
-
|
| 8 |
-
sympy_deprecation_warning("The sympy.utilities.pytest submodule is deprecated. Use sympy.testing.pytest instead.",
|
| 9 |
-
deprecated_since_version="1.6",
|
| 10 |
-
active_deprecations_target="deprecated-sympy-utilities-submodules")
|
| 11 |
-
|
| 12 |
-
from sympy.testing.pytest import * # noqa:F401,F403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/randtest.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
.. deprecated:: 1.6
|
| 3 |
-
|
| 4 |
-
sympy.utilities.randtest has been renamed to sympy.core.random.
|
| 5 |
-
"""
|
| 6 |
-
from sympy.utilities.exceptions import sympy_deprecation_warning
|
| 7 |
-
|
| 8 |
-
sympy_deprecation_warning("The sympy.utilities.randtest submodule is deprecated. Use sympy.core.random instead.",
|
| 9 |
-
deprecated_since_version="1.6",
|
| 10 |
-
active_deprecations_target="deprecated-sympy-utilities-submodules")
|
| 11 |
-
|
| 12 |
-
from sympy.core.random import * # noqa:F401,F403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/runtests.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
.. deprecated:: 1.6
|
| 3 |
-
|
| 4 |
-
sympy.utilities.runtests has been renamed to sympy.testing.runtests.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from sympy.utilities.exceptions import sympy_deprecation_warning
|
| 8 |
-
|
| 9 |
-
sympy_deprecation_warning("The sympy.utilities.runtests submodule is deprecated. Use sympy.testing.runtests instead.",
|
| 10 |
-
deprecated_since_version="1.6",
|
| 11 |
-
active_deprecations_target="deprecated-sympy-utilities-submodules")
|
| 12 |
-
|
| 13 |
-
from sympy.testing.runtests import * # noqa: F401,F403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/source.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This module adds several functions for interactive source code inspection.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def get_class(lookup_view):
|
| 7 |
-
"""
|
| 8 |
-
Convert a string version of a class name to the object.
|
| 9 |
-
|
| 10 |
-
For example, get_class('sympy.core.Basic') will return
|
| 11 |
-
class Basic located in module sympy.core
|
| 12 |
-
"""
|
| 13 |
-
if isinstance(lookup_view, str):
|
| 14 |
-
mod_name, func_name = get_mod_func(lookup_view)
|
| 15 |
-
if func_name != '':
|
| 16 |
-
lookup_view = getattr(
|
| 17 |
-
__import__(mod_name, {}, {}, ['*']), func_name)
|
| 18 |
-
if not callable(lookup_view):
|
| 19 |
-
raise AttributeError(
|
| 20 |
-
"'%s.%s' is not a callable." % (mod_name, func_name))
|
| 21 |
-
return lookup_view
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def get_mod_func(callback):
|
| 25 |
-
"""
|
| 26 |
-
splits the string path to a class into a string path to the module
|
| 27 |
-
and the name of the class.
|
| 28 |
-
|
| 29 |
-
Examples
|
| 30 |
-
========
|
| 31 |
-
|
| 32 |
-
>>> from sympy.utilities.source import get_mod_func
|
| 33 |
-
>>> get_mod_func('sympy.core.basic.Basic')
|
| 34 |
-
('sympy.core.basic', 'Basic')
|
| 35 |
-
|
| 36 |
-
"""
|
| 37 |
-
dot = callback.rfind('.')
|
| 38 |
-
if dot == -1:
|
| 39 |
-
return callback, ''
|
| 40 |
-
return callback[:dot], callback[dot + 1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/__init__.py
DELETED
|
File without changes
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_autowrap.py
DELETED
|
@@ -1,467 +0,0 @@
|
|
| 1 |
-
# Tests that require installed backends go into
|
| 2 |
-
# sympy/test_external/test_autowrap
|
| 3 |
-
|
| 4 |
-
import os
|
| 5 |
-
import tempfile
|
| 6 |
-
import shutil
|
| 7 |
-
from io import StringIO
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
from sympy.core import symbols, Eq
|
| 11 |
-
from sympy.utilities.autowrap import (autowrap, binary_function,
|
| 12 |
-
CythonCodeWrapper, UfuncifyCodeWrapper, CodeWrapper)
|
| 13 |
-
from sympy.utilities.codegen import (
|
| 14 |
-
CCodeGen, C99CodeGen, CodeGenArgumentListError, make_routine
|
| 15 |
-
)
|
| 16 |
-
from sympy.testing.pytest import raises
|
| 17 |
-
from sympy.testing.tmpfiles import TmpFileManager
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_string(dump_fn, routines, prefix="file", **kwargs):
|
| 21 |
-
"""Wrapper for dump_fn. dump_fn writes its results to a stream object and
|
| 22 |
-
this wrapper returns the contents of that stream as a string. This
|
| 23 |
-
auxiliary function is used by many tests below.
|
| 24 |
-
|
| 25 |
-
The header and the empty lines are not generator to facilitate the
|
| 26 |
-
testing of the output.
|
| 27 |
-
"""
|
| 28 |
-
output = StringIO()
|
| 29 |
-
dump_fn(routines, output, prefix, **kwargs)
|
| 30 |
-
source = output.getvalue()
|
| 31 |
-
output.close()
|
| 32 |
-
return source
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def test_cython_wrapper_scalar_function():
|
| 36 |
-
x, y, z = symbols('x,y,z')
|
| 37 |
-
expr = (x + y)*z
|
| 38 |
-
routine = make_routine("test", expr)
|
| 39 |
-
code_gen = CythonCodeWrapper(CCodeGen())
|
| 40 |
-
source = get_string(code_gen.dump_pyx, [routine])
|
| 41 |
-
|
| 42 |
-
expected = (
|
| 43 |
-
"cdef extern from 'file.h':\n"
|
| 44 |
-
" double test(double x, double y, double z)\n"
|
| 45 |
-
"\n"
|
| 46 |
-
"def test_c(double x, double y, double z):\n"
|
| 47 |
-
"\n"
|
| 48 |
-
" return test(x, y, z)")
|
| 49 |
-
assert source == expected
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def test_cython_wrapper_outarg():
|
| 53 |
-
from sympy.core.relational import Equality
|
| 54 |
-
x, y, z = symbols('x,y,z')
|
| 55 |
-
code_gen = CythonCodeWrapper(C99CodeGen())
|
| 56 |
-
|
| 57 |
-
routine = make_routine("test", Equality(z, x + y))
|
| 58 |
-
source = get_string(code_gen.dump_pyx, [routine])
|
| 59 |
-
expected = (
|
| 60 |
-
"cdef extern from 'file.h':\n"
|
| 61 |
-
" void test(double x, double y, double *z)\n"
|
| 62 |
-
"\n"
|
| 63 |
-
"def test_c(double x, double y):\n"
|
| 64 |
-
"\n"
|
| 65 |
-
" cdef double z = 0\n"
|
| 66 |
-
" test(x, y, &z)\n"
|
| 67 |
-
" return z")
|
| 68 |
-
assert source == expected
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def test_cython_wrapper_inoutarg():
|
| 72 |
-
from sympy.core.relational import Equality
|
| 73 |
-
x, y, z = symbols('x,y,z')
|
| 74 |
-
code_gen = CythonCodeWrapper(C99CodeGen())
|
| 75 |
-
routine = make_routine("test", Equality(z, x + y + z))
|
| 76 |
-
source = get_string(code_gen.dump_pyx, [routine])
|
| 77 |
-
expected = (
|
| 78 |
-
"cdef extern from 'file.h':\n"
|
| 79 |
-
" void test(double x, double y, double *z)\n"
|
| 80 |
-
"\n"
|
| 81 |
-
"def test_c(double x, double y, double z):\n"
|
| 82 |
-
"\n"
|
| 83 |
-
" test(x, y, &z)\n"
|
| 84 |
-
" return z")
|
| 85 |
-
assert source == expected
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def test_cython_wrapper_compile_flags():
|
| 89 |
-
from sympy.core.relational import Equality
|
| 90 |
-
x, y, z = symbols('x,y,z')
|
| 91 |
-
routine = make_routine("test", Equality(z, x + y))
|
| 92 |
-
|
| 93 |
-
code_gen = CythonCodeWrapper(CCodeGen())
|
| 94 |
-
|
| 95 |
-
expected = """\
|
| 96 |
-
from setuptools import setup
|
| 97 |
-
from setuptools import Extension
|
| 98 |
-
from Cython.Build import cythonize
|
| 99 |
-
cy_opts = {'compiler_directives': {'language_level': '3'}}
|
| 100 |
-
|
| 101 |
-
ext_mods = [Extension(
|
| 102 |
-
'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
|
| 103 |
-
include_dirs=[],
|
| 104 |
-
library_dirs=[],
|
| 105 |
-
libraries=[],
|
| 106 |
-
extra_compile_args=['-std=c99'],
|
| 107 |
-
extra_link_args=[]
|
| 108 |
-
)]
|
| 109 |
-
setup(ext_modules=cythonize(ext_mods, **cy_opts))
|
| 110 |
-
""" % {'num': CodeWrapper._module_counter}
|
| 111 |
-
|
| 112 |
-
temp_dir = tempfile.mkdtemp()
|
| 113 |
-
TmpFileManager.tmp_folder(temp_dir)
|
| 114 |
-
setup_file_path = os.path.join(temp_dir, 'setup.py')
|
| 115 |
-
|
| 116 |
-
code_gen._prepare_files(routine, build_dir=temp_dir)
|
| 117 |
-
setup_text = Path(setup_file_path).read_text()
|
| 118 |
-
assert setup_text == expected
|
| 119 |
-
|
| 120 |
-
code_gen = CythonCodeWrapper(CCodeGen(),
|
| 121 |
-
include_dirs=['/usr/local/include', '/opt/booger/include'],
|
| 122 |
-
library_dirs=['/user/local/lib'],
|
| 123 |
-
libraries=['thelib', 'nilib'],
|
| 124 |
-
extra_compile_args=['-slow-math'],
|
| 125 |
-
extra_link_args=['-lswamp', '-ltrident'],
|
| 126 |
-
cythonize_options={'compiler_directives': {'boundscheck': False}}
|
| 127 |
-
)
|
| 128 |
-
expected = """\
|
| 129 |
-
from setuptools import setup
|
| 130 |
-
from setuptools import Extension
|
| 131 |
-
from Cython.Build import cythonize
|
| 132 |
-
cy_opts = {'compiler_directives': {'boundscheck': False}}
|
| 133 |
-
|
| 134 |
-
ext_mods = [Extension(
|
| 135 |
-
'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
|
| 136 |
-
include_dirs=['/usr/local/include', '/opt/booger/include'],
|
| 137 |
-
library_dirs=['/user/local/lib'],
|
| 138 |
-
libraries=['thelib', 'nilib'],
|
| 139 |
-
extra_compile_args=['-slow-math', '-std=c99'],
|
| 140 |
-
extra_link_args=['-lswamp', '-ltrident']
|
| 141 |
-
)]
|
| 142 |
-
setup(ext_modules=cythonize(ext_mods, **cy_opts))
|
| 143 |
-
""" % {'num': CodeWrapper._module_counter}
|
| 144 |
-
|
| 145 |
-
code_gen._prepare_files(routine, build_dir=temp_dir)
|
| 146 |
-
setup_text = Path(setup_file_path).read_text()
|
| 147 |
-
assert setup_text == expected
|
| 148 |
-
|
| 149 |
-
expected = """\
|
| 150 |
-
from setuptools import setup
|
| 151 |
-
from setuptools import Extension
|
| 152 |
-
from Cython.Build import cythonize
|
| 153 |
-
cy_opts = {'compiler_directives': {'boundscheck': False}}
|
| 154 |
-
import numpy as np
|
| 155 |
-
|
| 156 |
-
ext_mods = [Extension(
|
| 157 |
-
'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
|
| 158 |
-
include_dirs=['/usr/local/include', '/opt/booger/include', np.get_include()],
|
| 159 |
-
library_dirs=['/user/local/lib'],
|
| 160 |
-
libraries=['thelib', 'nilib'],
|
| 161 |
-
extra_compile_args=['-slow-math', '-std=c99'],
|
| 162 |
-
extra_link_args=['-lswamp', '-ltrident']
|
| 163 |
-
)]
|
| 164 |
-
setup(ext_modules=cythonize(ext_mods, **cy_opts))
|
| 165 |
-
""" % {'num': CodeWrapper._module_counter}
|
| 166 |
-
|
| 167 |
-
code_gen._need_numpy = True
|
| 168 |
-
code_gen._prepare_files(routine, build_dir=temp_dir)
|
| 169 |
-
setup_text = Path(setup_file_path).read_text()
|
| 170 |
-
assert setup_text == expected
|
| 171 |
-
|
| 172 |
-
TmpFileManager.cleanup()
|
| 173 |
-
|
| 174 |
-
def test_cython_wrapper_unique_dummyvars():
|
| 175 |
-
from sympy.core.relational import Equality
|
| 176 |
-
from sympy.core.symbol import Dummy
|
| 177 |
-
x, y, z = Dummy('x'), Dummy('y'), Dummy('z')
|
| 178 |
-
x_id, y_id, z_id = [str(d.dummy_index) for d in [x, y, z]]
|
| 179 |
-
expr = Equality(z, x + y)
|
| 180 |
-
routine = make_routine("test", expr)
|
| 181 |
-
code_gen = CythonCodeWrapper(CCodeGen())
|
| 182 |
-
source = get_string(code_gen.dump_pyx, [routine])
|
| 183 |
-
expected_template = (
|
| 184 |
-
"cdef extern from 'file.h':\n"
|
| 185 |
-
" void test(double x_{x_id}, double y_{y_id}, double *z_{z_id})\n"
|
| 186 |
-
"\n"
|
| 187 |
-
"def test_c(double x_{x_id}, double y_{y_id}):\n"
|
| 188 |
-
"\n"
|
| 189 |
-
" cdef double z_{z_id} = 0\n"
|
| 190 |
-
" test(x_{x_id}, y_{y_id}, &z_{z_id})\n"
|
| 191 |
-
" return z_{z_id}")
|
| 192 |
-
expected = expected_template.format(x_id=x_id, y_id=y_id, z_id=z_id)
|
| 193 |
-
assert source == expected
|
| 194 |
-
|
| 195 |
-
def test_autowrap_dummy():
|
| 196 |
-
x, y, z = symbols('x y z')
|
| 197 |
-
|
| 198 |
-
# Uses DummyWrapper to test that codegen works as expected
|
| 199 |
-
|
| 200 |
-
f = autowrap(x + y, backend='dummy')
|
| 201 |
-
assert f() == str(x + y)
|
| 202 |
-
assert f.args == "x, y"
|
| 203 |
-
assert f.returns == "nameless"
|
| 204 |
-
f = autowrap(Eq(z, x + y), backend='dummy')
|
| 205 |
-
assert f() == str(x + y)
|
| 206 |
-
assert f.args == "x, y"
|
| 207 |
-
assert f.returns == "z"
|
| 208 |
-
f = autowrap(Eq(z, x + y + z), backend='dummy')
|
| 209 |
-
assert f() == str(x + y + z)
|
| 210 |
-
assert f.args == "x, y, z"
|
| 211 |
-
assert f.returns == "z"
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
def test_autowrap_args():
|
| 215 |
-
x, y, z = symbols('x y z')
|
| 216 |
-
|
| 217 |
-
raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y),
|
| 218 |
-
backend='dummy', args=[x]))
|
| 219 |
-
f = autowrap(Eq(z, x + y), backend='dummy', args=[y, x])
|
| 220 |
-
assert f() == str(x + y)
|
| 221 |
-
assert f.args == "y, x"
|
| 222 |
-
assert f.returns == "z"
|
| 223 |
-
|
| 224 |
-
raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y + z),
|
| 225 |
-
backend='dummy', args=[x, y]))
|
| 226 |
-
f = autowrap(Eq(z, x + y + z), backend='dummy', args=[y, x, z])
|
| 227 |
-
assert f() == str(x + y + z)
|
| 228 |
-
assert f.args == "y, x, z"
|
| 229 |
-
assert f.returns == "z"
|
| 230 |
-
|
| 231 |
-
f = autowrap(Eq(z, x + y + z), backend='dummy', args=(y, x, z))
|
| 232 |
-
assert f() == str(x + y + z)
|
| 233 |
-
assert f.args == "y, x, z"
|
| 234 |
-
assert f.returns == "z"
|
| 235 |
-
|
| 236 |
-
def test_autowrap_store_files():
|
| 237 |
-
x, y = symbols('x y')
|
| 238 |
-
tmp = tempfile.mkdtemp()
|
| 239 |
-
TmpFileManager.tmp_folder(tmp)
|
| 240 |
-
|
| 241 |
-
f = autowrap(x + y, backend='dummy', tempdir=tmp)
|
| 242 |
-
assert f() == str(x + y)
|
| 243 |
-
assert os.access(tmp, os.F_OK)
|
| 244 |
-
|
| 245 |
-
TmpFileManager.cleanup()
|
| 246 |
-
|
| 247 |
-
def test_autowrap_store_files_issue_gh12939():
|
| 248 |
-
x, y = symbols('x y')
|
| 249 |
-
tmp = './tmp'
|
| 250 |
-
saved_cwd = os.getcwd()
|
| 251 |
-
temp_cwd = tempfile.mkdtemp()
|
| 252 |
-
try:
|
| 253 |
-
os.chdir(temp_cwd)
|
| 254 |
-
f = autowrap(x + y, backend='dummy', tempdir=tmp)
|
| 255 |
-
assert f() == str(x + y)
|
| 256 |
-
assert os.access(tmp, os.F_OK)
|
| 257 |
-
finally:
|
| 258 |
-
os.chdir(saved_cwd)
|
| 259 |
-
shutil.rmtree(temp_cwd)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
def test_binary_function():
|
| 263 |
-
x, y = symbols('x y')
|
| 264 |
-
f = binary_function('f', x + y, backend='dummy')
|
| 265 |
-
assert f._imp_() == str(x + y)
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
def test_ufuncify_source():
|
| 269 |
-
x, y, z = symbols('x,y,z')
|
| 270 |
-
code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
|
| 271 |
-
routine = make_routine("test", x + y + z)
|
| 272 |
-
source = get_string(code_wrapper.dump_c, [routine])
|
| 273 |
-
expected = """\
|
| 274 |
-
#include "Python.h"
|
| 275 |
-
#include "math.h"
|
| 276 |
-
#include "numpy/ndarraytypes.h"
|
| 277 |
-
#include "numpy/ufuncobject.h"
|
| 278 |
-
#include "numpy/halffloat.h"
|
| 279 |
-
#include "file.h"
|
| 280 |
-
|
| 281 |
-
static PyMethodDef wrapper_module_%(num)sMethods[] = {
|
| 282 |
-
{NULL, NULL, 0, NULL}
|
| 283 |
-
};
|
| 284 |
-
|
| 285 |
-
#ifdef NPY_1_19_API_VERSION
|
| 286 |
-
static void test_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
|
| 287 |
-
#else
|
| 288 |
-
static void test_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
|
| 289 |
-
#endif
|
| 290 |
-
{
|
| 291 |
-
npy_intp i;
|
| 292 |
-
npy_intp n = dimensions[0];
|
| 293 |
-
char *in0 = args[0];
|
| 294 |
-
char *in1 = args[1];
|
| 295 |
-
char *in2 = args[2];
|
| 296 |
-
char *out0 = args[3];
|
| 297 |
-
npy_intp in0_step = steps[0];
|
| 298 |
-
npy_intp in1_step = steps[1];
|
| 299 |
-
npy_intp in2_step = steps[2];
|
| 300 |
-
npy_intp out0_step = steps[3];
|
| 301 |
-
for (i = 0; i < n; i++) {
|
| 302 |
-
*((double *)out0) = test(*(double *)in0, *(double *)in1, *(double *)in2);
|
| 303 |
-
in0 += in0_step;
|
| 304 |
-
in1 += in1_step;
|
| 305 |
-
in2 += in2_step;
|
| 306 |
-
out0 += out0_step;
|
| 307 |
-
}
|
| 308 |
-
}
|
| 309 |
-
PyUFuncGenericFunction test_funcs[1] = {&test_ufunc};
|
| 310 |
-
static char test_types[4] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
|
| 311 |
-
static void *test_data[1] = {NULL};
|
| 312 |
-
|
| 313 |
-
#if PY_VERSION_HEX >= 0x03000000
|
| 314 |
-
static struct PyModuleDef moduledef = {
|
| 315 |
-
PyModuleDef_HEAD_INIT,
|
| 316 |
-
"wrapper_module_%(num)s",
|
| 317 |
-
NULL,
|
| 318 |
-
-1,
|
| 319 |
-
wrapper_module_%(num)sMethods,
|
| 320 |
-
NULL,
|
| 321 |
-
NULL,
|
| 322 |
-
NULL,
|
| 323 |
-
NULL
|
| 324 |
-
};
|
| 325 |
-
|
| 326 |
-
PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
|
| 327 |
-
{
|
| 328 |
-
PyObject *m, *d;
|
| 329 |
-
PyObject *ufunc0;
|
| 330 |
-
m = PyModule_Create(&moduledef);
|
| 331 |
-
if (!m) {
|
| 332 |
-
return NULL;
|
| 333 |
-
}
|
| 334 |
-
import_array();
|
| 335 |
-
import_umath();
|
| 336 |
-
d = PyModule_GetDict(m);
|
| 337 |
-
ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
|
| 338 |
-
PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
|
| 339 |
-
PyDict_SetItemString(d, "test", ufunc0);
|
| 340 |
-
Py_DECREF(ufunc0);
|
| 341 |
-
return m;
|
| 342 |
-
}
|
| 343 |
-
#else
|
| 344 |
-
PyMODINIT_FUNC initwrapper_module_%(num)s(void)
|
| 345 |
-
{
|
| 346 |
-
PyObject *m, *d;
|
| 347 |
-
PyObject *ufunc0;
|
| 348 |
-
m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
|
| 349 |
-
if (m == NULL) {
|
| 350 |
-
return;
|
| 351 |
-
}
|
| 352 |
-
import_array();
|
| 353 |
-
import_umath();
|
| 354 |
-
d = PyModule_GetDict(m);
|
| 355 |
-
ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
|
| 356 |
-
PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
|
| 357 |
-
PyDict_SetItemString(d, "test", ufunc0);
|
| 358 |
-
Py_DECREF(ufunc0);
|
| 359 |
-
}
|
| 360 |
-
#endif""" % {'num': CodeWrapper._module_counter}
|
| 361 |
-
assert source == expected
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
def test_ufuncify_source_multioutput():
|
| 365 |
-
x, y, z = symbols('x,y,z')
|
| 366 |
-
var_symbols = (x, y, z)
|
| 367 |
-
expr = x + y**3 + 10*z**2
|
| 368 |
-
code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
|
| 369 |
-
routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))]
|
| 370 |
-
source = get_string(code_wrapper.dump_c, routines, funcname='multitest')
|
| 371 |
-
expected = """\
|
| 372 |
-
#include "Python.h"
|
| 373 |
-
#include "math.h"
|
| 374 |
-
#include "numpy/ndarraytypes.h"
|
| 375 |
-
#include "numpy/ufuncobject.h"
|
| 376 |
-
#include "numpy/halffloat.h"
|
| 377 |
-
#include "file.h"
|
| 378 |
-
|
| 379 |
-
static PyMethodDef wrapper_module_%(num)sMethods[] = {
|
| 380 |
-
{NULL, NULL, 0, NULL}
|
| 381 |
-
};
|
| 382 |
-
|
| 383 |
-
#ifdef NPY_1_19_API_VERSION
|
| 384 |
-
static void multitest_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
|
| 385 |
-
#else
|
| 386 |
-
static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
|
| 387 |
-
#endif
|
| 388 |
-
{
|
| 389 |
-
npy_intp i;
|
| 390 |
-
npy_intp n = dimensions[0];
|
| 391 |
-
char *in0 = args[0];
|
| 392 |
-
char *in1 = args[1];
|
| 393 |
-
char *in2 = args[2];
|
| 394 |
-
char *out0 = args[3];
|
| 395 |
-
char *out1 = args[4];
|
| 396 |
-
char *out2 = args[5];
|
| 397 |
-
npy_intp in0_step = steps[0];
|
| 398 |
-
npy_intp in1_step = steps[1];
|
| 399 |
-
npy_intp in2_step = steps[2];
|
| 400 |
-
npy_intp out0_step = steps[3];
|
| 401 |
-
npy_intp out1_step = steps[4];
|
| 402 |
-
npy_intp out2_step = steps[5];
|
| 403 |
-
for (i = 0; i < n; i++) {
|
| 404 |
-
*((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2);
|
| 405 |
-
*((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2);
|
| 406 |
-
*((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2);
|
| 407 |
-
in0 += in0_step;
|
| 408 |
-
in1 += in1_step;
|
| 409 |
-
in2 += in2_step;
|
| 410 |
-
out0 += out0_step;
|
| 411 |
-
out1 += out1_step;
|
| 412 |
-
out2 += out2_step;
|
| 413 |
-
}
|
| 414 |
-
}
|
| 415 |
-
PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc};
|
| 416 |
-
static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
|
| 417 |
-
static void *multitest_data[1] = {NULL};
|
| 418 |
-
|
| 419 |
-
#if PY_VERSION_HEX >= 0x03000000
|
| 420 |
-
static struct PyModuleDef moduledef = {
|
| 421 |
-
PyModuleDef_HEAD_INIT,
|
| 422 |
-
"wrapper_module_%(num)s",
|
| 423 |
-
NULL,
|
| 424 |
-
-1,
|
| 425 |
-
wrapper_module_%(num)sMethods,
|
| 426 |
-
NULL,
|
| 427 |
-
NULL,
|
| 428 |
-
NULL,
|
| 429 |
-
NULL
|
| 430 |
-
};
|
| 431 |
-
|
| 432 |
-
PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
|
| 433 |
-
{
|
| 434 |
-
PyObject *m, *d;
|
| 435 |
-
PyObject *ufunc0;
|
| 436 |
-
m = PyModule_Create(&moduledef);
|
| 437 |
-
if (!m) {
|
| 438 |
-
return NULL;
|
| 439 |
-
}
|
| 440 |
-
import_array();
|
| 441 |
-
import_umath();
|
| 442 |
-
d = PyModule_GetDict(m);
|
| 443 |
-
ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
|
| 444 |
-
PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
|
| 445 |
-
PyDict_SetItemString(d, "multitest", ufunc0);
|
| 446 |
-
Py_DECREF(ufunc0);
|
| 447 |
-
return m;
|
| 448 |
-
}
|
| 449 |
-
#else
|
| 450 |
-
PyMODINIT_FUNC initwrapper_module_%(num)s(void)
|
| 451 |
-
{
|
| 452 |
-
PyObject *m, *d;
|
| 453 |
-
PyObject *ufunc0;
|
| 454 |
-
m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
|
| 455 |
-
if (m == NULL) {
|
| 456 |
-
return;
|
| 457 |
-
}
|
| 458 |
-
import_array();
|
| 459 |
-
import_umath();
|
| 460 |
-
d = PyModule_GetDict(m);
|
| 461 |
-
ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
|
| 462 |
-
PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
|
| 463 |
-
PyDict_SetItemString(d, "multitest", ufunc0);
|
| 464 |
-
Py_DECREF(ufunc0);
|
| 465 |
-
}
|
| 466 |
-
#endif""" % {'num': CodeWrapper._module_counter}
|
| 467 |
-
assert source == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen.py
DELETED
|
@@ -1,1632 +0,0 @@
|
|
| 1 |
-
from io import StringIO
|
| 2 |
-
|
| 3 |
-
from sympy.core import symbols, Eq, pi, Catalan, Lambda, Dummy
|
| 4 |
-
from sympy.core.relational import Equality
|
| 5 |
-
from sympy.core.symbol import Symbol
|
| 6 |
-
from sympy.functions.special.error_functions import erf
|
| 7 |
-
from sympy.integrals.integrals import Integral
|
| 8 |
-
from sympy.matrices import Matrix, MatrixSymbol
|
| 9 |
-
from sympy.utilities.codegen import (
|
| 10 |
-
codegen, make_routine, CCodeGen, C89CodeGen, C99CodeGen, InputArgument,
|
| 11 |
-
CodeGenError, FCodeGen, CodeGenArgumentListError, OutputArgument,
|
| 12 |
-
InOutArgument)
|
| 13 |
-
from sympy.testing.pytest import raises
|
| 14 |
-
from sympy.utilities.lambdify import implemented_function
|
| 15 |
-
|
| 16 |
-
#FIXME: Fails due to circular import in with core
|
| 17 |
-
# from sympy import codegen
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_string(dump_fn, routines, prefix="file", header=False, empty=False):
|
| 21 |
-
"""Wrapper for dump_fn. dump_fn writes its results to a stream object and
|
| 22 |
-
this wrapper returns the contents of that stream as a string. This
|
| 23 |
-
auxiliary function is used by many tests below.
|
| 24 |
-
|
| 25 |
-
The header and the empty lines are not generated to facilitate the
|
| 26 |
-
testing of the output.
|
| 27 |
-
"""
|
| 28 |
-
output = StringIO()
|
| 29 |
-
dump_fn(routines, output, prefix, header, empty)
|
| 30 |
-
source = output.getvalue()
|
| 31 |
-
output.close()
|
| 32 |
-
return source
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def test_Routine_argument_order():
|
| 36 |
-
a, x, y, z = symbols('a x y z')
|
| 37 |
-
expr = (x + y)*z
|
| 38 |
-
raises(CodeGenArgumentListError, lambda: make_routine("test", expr,
|
| 39 |
-
argument_sequence=[z, x]))
|
| 40 |
-
raises(CodeGenArgumentListError, lambda: make_routine("test", Eq(a,
|
| 41 |
-
expr), argument_sequence=[z, x, y]))
|
| 42 |
-
r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
|
| 43 |
-
assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
|
| 44 |
-
assert [ type(arg) for arg in r.arguments ] == [
|
| 45 |
-
InputArgument, InputArgument, OutputArgument, InputArgument ]
|
| 46 |
-
r = make_routine('test', Eq(z, expr), argument_sequence=[z, x, y])
|
| 47 |
-
assert [ type(arg) for arg in r.arguments ] == [
|
| 48 |
-
InOutArgument, InputArgument, InputArgument ]
|
| 49 |
-
|
| 50 |
-
from sympy.tensor import IndexedBase, Idx
|
| 51 |
-
A, B = map(IndexedBase, ['A', 'B'])
|
| 52 |
-
m = symbols('m', integer=True)
|
| 53 |
-
i = Idx('i', m)
|
| 54 |
-
r = make_routine('test', Eq(A[i], B[i]), argument_sequence=[B, A, m])
|
| 55 |
-
assert [ arg.name for arg in r.arguments ] == [B.label, A.label, m]
|
| 56 |
-
|
| 57 |
-
expr = Integral(x*y*z, (x, 1, 2), (y, 1, 3))
|
| 58 |
-
r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
|
| 59 |
-
assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def test_empty_c_code():
|
| 63 |
-
code_gen = C89CodeGen()
|
| 64 |
-
source = get_string(code_gen.dump_c, [])
|
| 65 |
-
assert source == "#include \"file.h\"\n#include <math.h>\n"
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def test_empty_c_code_with_comment():
|
| 69 |
-
code_gen = C89CodeGen()
|
| 70 |
-
source = get_string(code_gen.dump_c, [], header=True)
|
| 71 |
-
assert source[:82] == (
|
| 72 |
-
"/******************************************************************************\n *"
|
| 73 |
-
)
|
| 74 |
-
# " Code generated with SymPy 0.7.2-git "
|
| 75 |
-
assert source[158:] == ( "*\n"
|
| 76 |
-
" * *\n"
|
| 77 |
-
" * See http://www.sympy.org/ for more information. *\n"
|
| 78 |
-
" * *\n"
|
| 79 |
-
" * This file is part of 'project' *\n"
|
| 80 |
-
" ******************************************************************************/\n"
|
| 81 |
-
"#include \"file.h\"\n"
|
| 82 |
-
"#include <math.h>\n"
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def test_empty_c_header():
|
| 87 |
-
code_gen = C99CodeGen()
|
| 88 |
-
source = get_string(code_gen.dump_h, [])
|
| 89 |
-
assert source == "#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n#endif\n"
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def test_simple_c_code():
|
| 93 |
-
x, y, z = symbols('x,y,z')
|
| 94 |
-
expr = (x + y)*z
|
| 95 |
-
routine = make_routine("test", expr)
|
| 96 |
-
code_gen = C89CodeGen()
|
| 97 |
-
source = get_string(code_gen.dump_c, [routine])
|
| 98 |
-
expected = (
|
| 99 |
-
"#include \"file.h\"\n"
|
| 100 |
-
"#include <math.h>\n"
|
| 101 |
-
"double test(double x, double y, double z) {\n"
|
| 102 |
-
" double test_result;\n"
|
| 103 |
-
" test_result = z*(x + y);\n"
|
| 104 |
-
" return test_result;\n"
|
| 105 |
-
"}\n"
|
| 106 |
-
)
|
| 107 |
-
assert source == expected
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def test_c_code_reserved_words():
|
| 111 |
-
x, y, z = symbols('if, typedef, while')
|
| 112 |
-
expr = (x + y) * z
|
| 113 |
-
routine = make_routine("test", expr)
|
| 114 |
-
code_gen = C99CodeGen()
|
| 115 |
-
source = get_string(code_gen.dump_c, [routine])
|
| 116 |
-
expected = (
|
| 117 |
-
"#include \"file.h\"\n"
|
| 118 |
-
"#include <math.h>\n"
|
| 119 |
-
"double test(double if_, double typedef_, double while_) {\n"
|
| 120 |
-
" double test_result;\n"
|
| 121 |
-
" test_result = while_*(if_ + typedef_);\n"
|
| 122 |
-
" return test_result;\n"
|
| 123 |
-
"}\n"
|
| 124 |
-
)
|
| 125 |
-
assert source == expected
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def test_numbersymbol_c_code():
|
| 129 |
-
routine = make_routine("test", pi**Catalan)
|
| 130 |
-
code_gen = C89CodeGen()
|
| 131 |
-
source = get_string(code_gen.dump_c, [routine])
|
| 132 |
-
expected = (
|
| 133 |
-
"#include \"file.h\"\n"
|
| 134 |
-
"#include <math.h>\n"
|
| 135 |
-
"double test() {\n"
|
| 136 |
-
" double test_result;\n"
|
| 137 |
-
" double const Catalan = %s;\n"
|
| 138 |
-
" test_result = pow(M_PI, Catalan);\n"
|
| 139 |
-
" return test_result;\n"
|
| 140 |
-
"}\n"
|
| 141 |
-
) % Catalan.evalf(17)
|
| 142 |
-
assert source == expected
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def test_c_code_argument_order():
|
| 146 |
-
x, y, z = symbols('x,y,z')
|
| 147 |
-
expr = x + y
|
| 148 |
-
routine = make_routine("test", expr, argument_sequence=[z, x, y])
|
| 149 |
-
code_gen = C89CodeGen()
|
| 150 |
-
source = get_string(code_gen.dump_c, [routine])
|
| 151 |
-
expected = (
|
| 152 |
-
"#include \"file.h\"\n"
|
| 153 |
-
"#include <math.h>\n"
|
| 154 |
-
"double test(double z, double x, double y) {\n"
|
| 155 |
-
" double test_result;\n"
|
| 156 |
-
" test_result = x + y;\n"
|
| 157 |
-
" return test_result;\n"
|
| 158 |
-
"}\n"
|
| 159 |
-
)
|
| 160 |
-
assert source == expected
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def test_simple_c_header():
|
| 164 |
-
x, y, z = symbols('x,y,z')
|
| 165 |
-
expr = (x + y)*z
|
| 166 |
-
routine = make_routine("test", expr)
|
| 167 |
-
code_gen = C89CodeGen()
|
| 168 |
-
source = get_string(code_gen.dump_h, [routine])
|
| 169 |
-
expected = (
|
| 170 |
-
"#ifndef PROJECT__FILE__H\n"
|
| 171 |
-
"#define PROJECT__FILE__H\n"
|
| 172 |
-
"double test(double x, double y, double z);\n"
|
| 173 |
-
"#endif\n"
|
| 174 |
-
)
|
| 175 |
-
assert source == expected
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def test_simple_c_codegen():
|
| 179 |
-
x, y, z = symbols('x,y,z')
|
| 180 |
-
expr = (x + y)*z
|
| 181 |
-
expected = [
|
| 182 |
-
("file.c",
|
| 183 |
-
"#include \"file.h\"\n"
|
| 184 |
-
"#include <math.h>\n"
|
| 185 |
-
"double test(double x, double y, double z) {\n"
|
| 186 |
-
" double test_result;\n"
|
| 187 |
-
" test_result = z*(x + y);\n"
|
| 188 |
-
" return test_result;\n"
|
| 189 |
-
"}\n"),
|
| 190 |
-
("file.h",
|
| 191 |
-
"#ifndef PROJECT__FILE__H\n"
|
| 192 |
-
"#define PROJECT__FILE__H\n"
|
| 193 |
-
"double test(double x, double y, double z);\n"
|
| 194 |
-
"#endif\n")
|
| 195 |
-
]
|
| 196 |
-
result = codegen(("test", expr), "C", "file", header=False, empty=False)
|
| 197 |
-
assert result == expected
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def test_multiple_results_c():
|
| 201 |
-
x, y, z = symbols('x,y,z')
|
| 202 |
-
expr1 = (x + y)*z
|
| 203 |
-
expr2 = (x - y)*z
|
| 204 |
-
routine = make_routine(
|
| 205 |
-
"test",
|
| 206 |
-
[expr1, expr2]
|
| 207 |
-
)
|
| 208 |
-
code_gen = C99CodeGen()
|
| 209 |
-
raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def test_no_results_c():
|
| 213 |
-
raises(ValueError, lambda: make_routine("test", []))
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
def test_ansi_math1_codegen():
|
| 217 |
-
# not included: log10
|
| 218 |
-
from sympy.functions.elementary.complexes import Abs
|
| 219 |
-
from sympy.functions.elementary.exponential import log
|
| 220 |
-
from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
|
| 221 |
-
from sympy.functions.elementary.integers import (ceiling, floor)
|
| 222 |
-
from sympy.functions.elementary.miscellaneous import sqrt
|
| 223 |
-
from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
|
| 224 |
-
x = symbols('x')
|
| 225 |
-
name_expr = [
|
| 226 |
-
("test_fabs", Abs(x)),
|
| 227 |
-
("test_acos", acos(x)),
|
| 228 |
-
("test_asin", asin(x)),
|
| 229 |
-
("test_atan", atan(x)),
|
| 230 |
-
("test_ceil", ceiling(x)),
|
| 231 |
-
("test_cos", cos(x)),
|
| 232 |
-
("test_cosh", cosh(x)),
|
| 233 |
-
("test_floor", floor(x)),
|
| 234 |
-
("test_log", log(x)),
|
| 235 |
-
("test_ln", log(x)),
|
| 236 |
-
("test_sin", sin(x)),
|
| 237 |
-
("test_sinh", sinh(x)),
|
| 238 |
-
("test_sqrt", sqrt(x)),
|
| 239 |
-
("test_tan", tan(x)),
|
| 240 |
-
("test_tanh", tanh(x)),
|
| 241 |
-
]
|
| 242 |
-
result = codegen(name_expr, "C89", "file", header=False, empty=False)
|
| 243 |
-
assert result[0][0] == "file.c"
|
| 244 |
-
assert result[0][1] == (
|
| 245 |
-
'#include "file.h"\n#include <math.h>\n'
|
| 246 |
-
'double test_fabs(double x) {\n double test_fabs_result;\n test_fabs_result = fabs(x);\n return test_fabs_result;\n}\n'
|
| 247 |
-
'double test_acos(double x) {\n double test_acos_result;\n test_acos_result = acos(x);\n return test_acos_result;\n}\n'
|
| 248 |
-
'double test_asin(double x) {\n double test_asin_result;\n test_asin_result = asin(x);\n return test_asin_result;\n}\n'
|
| 249 |
-
'double test_atan(double x) {\n double test_atan_result;\n test_atan_result = atan(x);\n return test_atan_result;\n}\n'
|
| 250 |
-
'double test_ceil(double x) {\n double test_ceil_result;\n test_ceil_result = ceil(x);\n return test_ceil_result;\n}\n'
|
| 251 |
-
'double test_cos(double x) {\n double test_cos_result;\n test_cos_result = cos(x);\n return test_cos_result;\n}\n'
|
| 252 |
-
'double test_cosh(double x) {\n double test_cosh_result;\n test_cosh_result = cosh(x);\n return test_cosh_result;\n}\n'
|
| 253 |
-
'double test_floor(double x) {\n double test_floor_result;\n test_floor_result = floor(x);\n return test_floor_result;\n}\n'
|
| 254 |
-
'double test_log(double x) {\n double test_log_result;\n test_log_result = log(x);\n return test_log_result;\n}\n'
|
| 255 |
-
'double test_ln(double x) {\n double test_ln_result;\n test_ln_result = log(x);\n return test_ln_result;\n}\n'
|
| 256 |
-
'double test_sin(double x) {\n double test_sin_result;\n test_sin_result = sin(x);\n return test_sin_result;\n}\n'
|
| 257 |
-
'double test_sinh(double x) {\n double test_sinh_result;\n test_sinh_result = sinh(x);\n return test_sinh_result;\n}\n'
|
| 258 |
-
'double test_sqrt(double x) {\n double test_sqrt_result;\n test_sqrt_result = sqrt(x);\n return test_sqrt_result;\n}\n'
|
| 259 |
-
'double test_tan(double x) {\n double test_tan_result;\n test_tan_result = tan(x);\n return test_tan_result;\n}\n'
|
| 260 |
-
'double test_tanh(double x) {\n double test_tanh_result;\n test_tanh_result = tanh(x);\n return test_tanh_result;\n}\n'
|
| 261 |
-
)
|
| 262 |
-
assert result[1][0] == "file.h"
|
| 263 |
-
assert result[1][1] == (
|
| 264 |
-
'#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
|
| 265 |
-
'double test_fabs(double x);\ndouble test_acos(double x);\n'
|
| 266 |
-
'double test_asin(double x);\ndouble test_atan(double x);\n'
|
| 267 |
-
'double test_ceil(double x);\ndouble test_cos(double x);\n'
|
| 268 |
-
'double test_cosh(double x);\ndouble test_floor(double x);\n'
|
| 269 |
-
'double test_log(double x);\ndouble test_ln(double x);\n'
|
| 270 |
-
'double test_sin(double x);\ndouble test_sinh(double x);\n'
|
| 271 |
-
'double test_sqrt(double x);\ndouble test_tan(double x);\n'
|
| 272 |
-
'double test_tanh(double x);\n#endif\n'
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
def test_ansi_math2_codegen():
|
| 277 |
-
# not included: frexp, ldexp, modf, fmod
|
| 278 |
-
from sympy.functions.elementary.trigonometric import atan2
|
| 279 |
-
x, y = symbols('x,y')
|
| 280 |
-
name_expr = [
|
| 281 |
-
("test_atan2", atan2(x, y)),
|
| 282 |
-
("test_pow", x**y),
|
| 283 |
-
]
|
| 284 |
-
result = codegen(name_expr, "C89", "file", header=False, empty=False)
|
| 285 |
-
assert result[0][0] == "file.c"
|
| 286 |
-
assert result[0][1] == (
|
| 287 |
-
'#include "file.h"\n#include <math.h>\n'
|
| 288 |
-
'double test_atan2(double x, double y) {\n double test_atan2_result;\n test_atan2_result = atan2(x, y);\n return test_atan2_result;\n}\n'
|
| 289 |
-
'double test_pow(double x, double y) {\n double test_pow_result;\n test_pow_result = pow(x, y);\n return test_pow_result;\n}\n'
|
| 290 |
-
)
|
| 291 |
-
assert result[1][0] == "file.h"
|
| 292 |
-
assert result[1][1] == (
|
| 293 |
-
'#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
|
| 294 |
-
'double test_atan2(double x, double y);\n'
|
| 295 |
-
'double test_pow(double x, double y);\n'
|
| 296 |
-
'#endif\n'
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def test_complicated_codegen():
|
| 301 |
-
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
| 302 |
-
x, y, z = symbols('x,y,z')
|
| 303 |
-
name_expr = [
|
| 304 |
-
("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
|
| 305 |
-
("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
|
| 306 |
-
]
|
| 307 |
-
result = codegen(name_expr, "C89", "file", header=False, empty=False)
|
| 308 |
-
assert result[0][0] == "file.c"
|
| 309 |
-
assert result[0][1] == (
|
| 310 |
-
'#include "file.h"\n#include <math.h>\n'
|
| 311 |
-
'double test1(double x, double y, double z) {\n'
|
| 312 |
-
' double test1_result;\n'
|
| 313 |
-
' test1_result = '
|
| 314 |
-
'pow(sin(x), 7) + '
|
| 315 |
-
'7*pow(sin(x), 6)*cos(y) + '
|
| 316 |
-
'7*pow(sin(x), 6)*tan(z) + '
|
| 317 |
-
'21*pow(sin(x), 5)*pow(cos(y), 2) + '
|
| 318 |
-
'42*pow(sin(x), 5)*cos(y)*tan(z) + '
|
| 319 |
-
'21*pow(sin(x), 5)*pow(tan(z), 2) + '
|
| 320 |
-
'35*pow(sin(x), 4)*pow(cos(y), 3) + '
|
| 321 |
-
'105*pow(sin(x), 4)*pow(cos(y), 2)*tan(z) + '
|
| 322 |
-
'105*pow(sin(x), 4)*cos(y)*pow(tan(z), 2) + '
|
| 323 |
-
'35*pow(sin(x), 4)*pow(tan(z), 3) + '
|
| 324 |
-
'35*pow(sin(x), 3)*pow(cos(y), 4) + '
|
| 325 |
-
'140*pow(sin(x), 3)*pow(cos(y), 3)*tan(z) + '
|
| 326 |
-
'210*pow(sin(x), 3)*pow(cos(y), 2)*pow(tan(z), 2) + '
|
| 327 |
-
'140*pow(sin(x), 3)*cos(y)*pow(tan(z), 3) + '
|
| 328 |
-
'35*pow(sin(x), 3)*pow(tan(z), 4) + '
|
| 329 |
-
'21*pow(sin(x), 2)*pow(cos(y), 5) + '
|
| 330 |
-
'105*pow(sin(x), 2)*pow(cos(y), 4)*tan(z) + '
|
| 331 |
-
'210*pow(sin(x), 2)*pow(cos(y), 3)*pow(tan(z), 2) + '
|
| 332 |
-
'210*pow(sin(x), 2)*pow(cos(y), 2)*pow(tan(z), 3) + '
|
| 333 |
-
'105*pow(sin(x), 2)*cos(y)*pow(tan(z), 4) + '
|
| 334 |
-
'21*pow(sin(x), 2)*pow(tan(z), 5) + '
|
| 335 |
-
'7*sin(x)*pow(cos(y), 6) + '
|
| 336 |
-
'42*sin(x)*pow(cos(y), 5)*tan(z) + '
|
| 337 |
-
'105*sin(x)*pow(cos(y), 4)*pow(tan(z), 2) + '
|
| 338 |
-
'140*sin(x)*pow(cos(y), 3)*pow(tan(z), 3) + '
|
| 339 |
-
'105*sin(x)*pow(cos(y), 2)*pow(tan(z), 4) + '
|
| 340 |
-
'42*sin(x)*cos(y)*pow(tan(z), 5) + '
|
| 341 |
-
'7*sin(x)*pow(tan(z), 6) + '
|
| 342 |
-
'pow(cos(y), 7) + '
|
| 343 |
-
'7*pow(cos(y), 6)*tan(z) + '
|
| 344 |
-
'21*pow(cos(y), 5)*pow(tan(z), 2) + '
|
| 345 |
-
'35*pow(cos(y), 4)*pow(tan(z), 3) + '
|
| 346 |
-
'35*pow(cos(y), 3)*pow(tan(z), 4) + '
|
| 347 |
-
'21*pow(cos(y), 2)*pow(tan(z), 5) + '
|
| 348 |
-
'7*cos(y)*pow(tan(z), 6) + '
|
| 349 |
-
'pow(tan(z), 7);\n'
|
| 350 |
-
' return test1_result;\n'
|
| 351 |
-
'}\n'
|
| 352 |
-
'double test2(double x, double y, double z) {\n'
|
| 353 |
-
' double test2_result;\n'
|
| 354 |
-
' test2_result = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n'
|
| 355 |
-
' return test2_result;\n'
|
| 356 |
-
'}\n'
|
| 357 |
-
)
|
| 358 |
-
assert result[1][0] == "file.h"
|
| 359 |
-
assert result[1][1] == (
|
| 360 |
-
'#ifndef PROJECT__FILE__H\n'
|
| 361 |
-
'#define PROJECT__FILE__H\n'
|
| 362 |
-
'double test1(double x, double y, double z);\n'
|
| 363 |
-
'double test2(double x, double y, double z);\n'
|
| 364 |
-
'#endif\n'
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
def test_loops_c():
|
| 369 |
-
from sympy.tensor import IndexedBase, Idx
|
| 370 |
-
from sympy.core.symbol import symbols
|
| 371 |
-
n, m = symbols('n m', integer=True)
|
| 372 |
-
A = IndexedBase('A')
|
| 373 |
-
x = IndexedBase('x')
|
| 374 |
-
y = IndexedBase('y')
|
| 375 |
-
i = Idx('i', m)
|
| 376 |
-
j = Idx('j', n)
|
| 377 |
-
|
| 378 |
-
(f1, code), (f2, interface) = codegen(
|
| 379 |
-
('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
|
| 380 |
-
|
| 381 |
-
assert f1 == 'file.c'
|
| 382 |
-
expected = (
|
| 383 |
-
'#include "file.h"\n'
|
| 384 |
-
'#include <math.h>\n'
|
| 385 |
-
'void matrix_vector(double *A, int m, int n, double *x, double *y) {\n'
|
| 386 |
-
' for (int i=0; i<m; i++){\n'
|
| 387 |
-
' y[i] = 0;\n'
|
| 388 |
-
' }\n'
|
| 389 |
-
' for (int i=0; i<m; i++){\n'
|
| 390 |
-
' for (int j=0; j<n; j++){\n'
|
| 391 |
-
' y[i] = %(rhs)s + y[i];\n'
|
| 392 |
-
' }\n'
|
| 393 |
-
' }\n'
|
| 394 |
-
'}\n'
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*n + j)} or
|
| 398 |
-
code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*n)} or
|
| 399 |
-
code == expected % {'rhs': 'x[j]*A[%s]' % (i*n + j)} or
|
| 400 |
-
code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*n)})
|
| 401 |
-
assert f2 == 'file.h'
|
| 402 |
-
assert interface == (
|
| 403 |
-
'#ifndef PROJECT__FILE__H\n'
|
| 404 |
-
'#define PROJECT__FILE__H\n'
|
| 405 |
-
'void matrix_vector(double *A, int m, int n, double *x, double *y);\n'
|
| 406 |
-
'#endif\n'
|
| 407 |
-
)
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
def test_dummy_loops_c():
|
| 411 |
-
from sympy.tensor import IndexedBase, Idx
|
| 412 |
-
i, m = symbols('i m', integer=True, cls=Dummy)
|
| 413 |
-
x = IndexedBase('x')
|
| 414 |
-
y = IndexedBase('y')
|
| 415 |
-
i = Idx(i, m)
|
| 416 |
-
expected = (
|
| 417 |
-
'#include "file.h"\n'
|
| 418 |
-
'#include <math.h>\n'
|
| 419 |
-
'void test_dummies(int m_%(mno)i, double *x, double *y) {\n'
|
| 420 |
-
' for (int i_%(ino)i=0; i_%(ino)i<m_%(mno)i; i_%(ino)i++){\n'
|
| 421 |
-
' y[i_%(ino)i] = x[i_%(ino)i];\n'
|
| 422 |
-
' }\n'
|
| 423 |
-
'}\n'
|
| 424 |
-
) % {'ino': i.label.dummy_index, 'mno': m.dummy_index}
|
| 425 |
-
r = make_routine('test_dummies', Eq(y[i], x[i]))
|
| 426 |
-
c89 = C89CodeGen()
|
| 427 |
-
c99 = C99CodeGen()
|
| 428 |
-
code = get_string(c99.dump_c, [r])
|
| 429 |
-
assert code == expected
|
| 430 |
-
with raises(NotImplementedError):
|
| 431 |
-
get_string(c89.dump_c, [r])
|
| 432 |
-
|
| 433 |
-
def test_partial_loops_c():
|
| 434 |
-
# check that loop boundaries are determined by Idx, and array strides
|
| 435 |
-
# determined by shape of IndexedBase object.
|
| 436 |
-
from sympy.tensor import IndexedBase, Idx
|
| 437 |
-
from sympy.core.symbol import symbols
|
| 438 |
-
n, m, o, p = symbols('n m o p', integer=True)
|
| 439 |
-
A = IndexedBase('A', shape=(m, p))
|
| 440 |
-
x = IndexedBase('x')
|
| 441 |
-
y = IndexedBase('y')
|
| 442 |
-
i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
|
| 443 |
-
j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
|
| 444 |
-
|
| 445 |
-
(f1, code), (f2, interface) = codegen(
|
| 446 |
-
('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
|
| 447 |
-
|
| 448 |
-
assert f1 == 'file.c'
|
| 449 |
-
expected = (
|
| 450 |
-
'#include "file.h"\n'
|
| 451 |
-
'#include <math.h>\n'
|
| 452 |
-
'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y) {\n'
|
| 453 |
-
' for (int i=o; i<%(upperi)s; i++){\n'
|
| 454 |
-
' y[i] = 0;\n'
|
| 455 |
-
' }\n'
|
| 456 |
-
' for (int i=o; i<%(upperi)s; i++){\n'
|
| 457 |
-
' for (int j=0; j<n; j++){\n'
|
| 458 |
-
' y[i] = %(rhs)s + y[i];\n'
|
| 459 |
-
' }\n'
|
| 460 |
-
' }\n'
|
| 461 |
-
'}\n'
|
| 462 |
-
) % {'upperi': m - 4, 'rhs': '%(rhs)s'}
|
| 463 |
-
|
| 464 |
-
assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*p + j)} or
|
| 465 |
-
code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*p)} or
|
| 466 |
-
code == expected % {'rhs': 'x[j]*A[%s]' % (i*p + j)} or
|
| 467 |
-
code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*p)})
|
| 468 |
-
assert f2 == 'file.h'
|
| 469 |
-
assert interface == (
|
| 470 |
-
'#ifndef PROJECT__FILE__H\n'
|
| 471 |
-
'#define PROJECT__FILE__H\n'
|
| 472 |
-
'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y);\n'
|
| 473 |
-
'#endif\n'
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
def test_output_arg_c():
|
| 478 |
-
from sympy.core.relational import Equality
|
| 479 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 480 |
-
x, y, z = symbols("x,y,z")
|
| 481 |
-
r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
|
| 482 |
-
c = C89CodeGen()
|
| 483 |
-
result = c.write([r], "test", header=False, empty=False)
|
| 484 |
-
assert result[0][0] == "test.c"
|
| 485 |
-
expected = (
|
| 486 |
-
'#include "test.h"\n'
|
| 487 |
-
'#include <math.h>\n'
|
| 488 |
-
'double foo(double x, double *y) {\n'
|
| 489 |
-
' (*y) = sin(x);\n'
|
| 490 |
-
' double foo_result;\n'
|
| 491 |
-
' foo_result = cos(x);\n'
|
| 492 |
-
' return foo_result;\n'
|
| 493 |
-
'}\n'
|
| 494 |
-
)
|
| 495 |
-
assert result[0][1] == expected
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
def test_output_arg_c_reserved_words():
|
| 499 |
-
from sympy.core.relational import Equality
|
| 500 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 501 |
-
x, y, z = symbols("if, while, z")
|
| 502 |
-
r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
|
| 503 |
-
c = C89CodeGen()
|
| 504 |
-
result = c.write([r], "test", header=False, empty=False)
|
| 505 |
-
assert result[0][0] == "test.c"
|
| 506 |
-
expected = (
|
| 507 |
-
'#include "test.h"\n'
|
| 508 |
-
'#include <math.h>\n'
|
| 509 |
-
'double foo(double if_, double *while_) {\n'
|
| 510 |
-
' (*while_) = sin(if_);\n'
|
| 511 |
-
' double foo_result;\n'
|
| 512 |
-
' foo_result = cos(if_);\n'
|
| 513 |
-
' return foo_result;\n'
|
| 514 |
-
'}\n'
|
| 515 |
-
)
|
| 516 |
-
assert result[0][1] == expected
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
def test_multidim_c_argument_cse():
|
| 520 |
-
A_sym = MatrixSymbol('A', 3, 3)
|
| 521 |
-
b_sym = MatrixSymbol('b', 3, 1)
|
| 522 |
-
A = Matrix(A_sym)
|
| 523 |
-
b = Matrix(b_sym)
|
| 524 |
-
c = A*b
|
| 525 |
-
cgen = CCodeGen(project="test", cse=True)
|
| 526 |
-
r = cgen.routine("c", c)
|
| 527 |
-
r.arguments[-1].result_var = "out"
|
| 528 |
-
r.arguments[-1]._name = "out"
|
| 529 |
-
code = get_string(cgen.dump_c, [r], prefix="test")
|
| 530 |
-
expected = (
|
| 531 |
-
'#include "test.h"\n'
|
| 532 |
-
"#include <math.h>\n"
|
| 533 |
-
"void c(double *A, double *b, double *out) {\n"
|
| 534 |
-
" out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n"
|
| 535 |
-
" out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n"
|
| 536 |
-
" out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n"
|
| 537 |
-
"}\n"
|
| 538 |
-
)
|
| 539 |
-
assert code == expected
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
def test_ccode_results_named_ordered():
|
| 543 |
-
x, y, z = symbols('x,y,z')
|
| 544 |
-
B, C = symbols('B,C')
|
| 545 |
-
A = MatrixSymbol('A', 1, 3)
|
| 546 |
-
expr1 = Equality(A, Matrix([[1, 2, x]]))
|
| 547 |
-
expr2 = Equality(C, (x + y)*z)
|
| 548 |
-
expr3 = Equality(B, 2*x)
|
| 549 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 550 |
-
expected = (
|
| 551 |
-
'#include "test.h"\n'
|
| 552 |
-
'#include <math.h>\n'
|
| 553 |
-
'void test(double x, double *C, double z, double y, double *A, double *B) {\n'
|
| 554 |
-
' (*C) = z*(x + y);\n'
|
| 555 |
-
' A[0] = 1;\n'
|
| 556 |
-
' A[1] = 2;\n'
|
| 557 |
-
' A[2] = x;\n'
|
| 558 |
-
' (*B) = 2*x;\n'
|
| 559 |
-
'}\n'
|
| 560 |
-
)
|
| 561 |
-
|
| 562 |
-
result = codegen(name_expr, "c", "test", header=False, empty=False,
|
| 563 |
-
argument_sequence=(x, C, z, y, A, B))
|
| 564 |
-
source = result[0][1]
|
| 565 |
-
assert source == expected
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
def test_ccode_matrixsymbol_slice():
|
| 569 |
-
A = MatrixSymbol('A', 5, 3)
|
| 570 |
-
B = MatrixSymbol('B', 1, 3)
|
| 571 |
-
C = MatrixSymbol('C', 1, 3)
|
| 572 |
-
D = MatrixSymbol('D', 5, 1)
|
| 573 |
-
name_expr = ("test", [Equality(B, A[0, :]),
|
| 574 |
-
Equality(C, A[1, :]),
|
| 575 |
-
Equality(D, A[:, 2])])
|
| 576 |
-
result = codegen(name_expr, "c99", "test", header=False, empty=False)
|
| 577 |
-
source = result[0][1]
|
| 578 |
-
expected = (
|
| 579 |
-
'#include "test.h"\n'
|
| 580 |
-
'#include <math.h>\n'
|
| 581 |
-
'void test(double *A, double *B, double *C, double *D) {\n'
|
| 582 |
-
' B[0] = A[0];\n'
|
| 583 |
-
' B[1] = A[1];\n'
|
| 584 |
-
' B[2] = A[2];\n'
|
| 585 |
-
' C[0] = A[3];\n'
|
| 586 |
-
' C[1] = A[4];\n'
|
| 587 |
-
' C[2] = A[5];\n'
|
| 588 |
-
' D[0] = A[2];\n'
|
| 589 |
-
' D[1] = A[5];\n'
|
| 590 |
-
' D[2] = A[8];\n'
|
| 591 |
-
' D[3] = A[11];\n'
|
| 592 |
-
' D[4] = A[14];\n'
|
| 593 |
-
'}\n'
|
| 594 |
-
)
|
| 595 |
-
assert source == expected
|
| 596 |
-
|
| 597 |
-
def test_ccode_cse():
|
| 598 |
-
a, b, c, d = symbols('a b c d')
|
| 599 |
-
e = MatrixSymbol('e', 3, 1)
|
| 600 |
-
name_expr = ("test", [Equality(e, Matrix([[a*b], [a*b + c*d], [a*b*c*d]]))])
|
| 601 |
-
generator = CCodeGen(cse=True)
|
| 602 |
-
result = codegen(name_expr, code_gen=generator, header=False, empty=False)
|
| 603 |
-
source = result[0][1]
|
| 604 |
-
expected = (
|
| 605 |
-
'#include "test.h"\n'
|
| 606 |
-
'#include <math.h>\n'
|
| 607 |
-
'void test(double a, double b, double c, double d, double *e) {\n'
|
| 608 |
-
' const double x0 = a*b;\n'
|
| 609 |
-
' const double x1 = c*d;\n'
|
| 610 |
-
' e[0] = x0;\n'
|
| 611 |
-
' e[1] = x0 + x1;\n'
|
| 612 |
-
' e[2] = x0*x1;\n'
|
| 613 |
-
'}\n'
|
| 614 |
-
)
|
| 615 |
-
assert source == expected
|
| 616 |
-
|
| 617 |
-
def test_ccode_unused_array_arg():
|
| 618 |
-
x = MatrixSymbol('x', 2, 1)
|
| 619 |
-
# x does not appear in output
|
| 620 |
-
name_expr = ("test", 1.0)
|
| 621 |
-
generator = CCodeGen()
|
| 622 |
-
result = codegen(name_expr, code_gen=generator, header=False, empty=False, argument_sequence=(x,))
|
| 623 |
-
source = result[0][1]
|
| 624 |
-
# note: x should appear as (double *)
|
| 625 |
-
expected = (
|
| 626 |
-
'#include "test.h"\n'
|
| 627 |
-
'#include <math.h>\n'
|
| 628 |
-
'double test(double *x) {\n'
|
| 629 |
-
' double test_result;\n'
|
| 630 |
-
' test_result = 1.0;\n'
|
| 631 |
-
' return test_result;\n'
|
| 632 |
-
'}\n'
|
| 633 |
-
)
|
| 634 |
-
assert source == expected
|
| 635 |
-
|
| 636 |
-
def test_ccode_unused_array_arg_func():
|
| 637 |
-
# issue 16689
|
| 638 |
-
X = MatrixSymbol('X',3,1)
|
| 639 |
-
Y = MatrixSymbol('Y',3,1)
|
| 640 |
-
z = symbols('z',integer = True)
|
| 641 |
-
name_expr = ('testBug', X[0] + X[1])
|
| 642 |
-
result = codegen(name_expr, language='C', header=False, empty=False, argument_sequence=(X, Y, z))
|
| 643 |
-
source = result[0][1]
|
| 644 |
-
expected = (
|
| 645 |
-
'#include "testBug.h"\n'
|
| 646 |
-
'#include <math.h>\n'
|
| 647 |
-
'double testBug(double *X, double *Y, int z) {\n'
|
| 648 |
-
' double testBug_result;\n'
|
| 649 |
-
' testBug_result = X[0] + X[1];\n'
|
| 650 |
-
' return testBug_result;\n'
|
| 651 |
-
'}\n'
|
| 652 |
-
)
|
| 653 |
-
assert source == expected
|
| 654 |
-
|
| 655 |
-
def test_empty_f_code():
|
| 656 |
-
code_gen = FCodeGen()
|
| 657 |
-
source = get_string(code_gen.dump_f95, [])
|
| 658 |
-
assert source == ""
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
def test_empty_f_code_with_header():
|
| 662 |
-
code_gen = FCodeGen()
|
| 663 |
-
source = get_string(code_gen.dump_f95, [], header=True)
|
| 664 |
-
assert source[:82] == (
|
| 665 |
-
"!******************************************************************************\n!*"
|
| 666 |
-
)
|
| 667 |
-
# " Code generated with SymPy 0.7.2-git "
|
| 668 |
-
assert source[158:] == ( "*\n"
|
| 669 |
-
"!* *\n"
|
| 670 |
-
"!* See http://www.sympy.org/ for more information. *\n"
|
| 671 |
-
"!* *\n"
|
| 672 |
-
"!* This file is part of 'project' *\n"
|
| 673 |
-
"!******************************************************************************\n"
|
| 674 |
-
)
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
def test_empty_f_header():
|
| 678 |
-
code_gen = FCodeGen()
|
| 679 |
-
source = get_string(code_gen.dump_h, [])
|
| 680 |
-
assert source == ""
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
def test_simple_f_code():
|
| 684 |
-
x, y, z = symbols('x,y,z')
|
| 685 |
-
expr = (x + y)*z
|
| 686 |
-
routine = make_routine("test", expr)
|
| 687 |
-
code_gen = FCodeGen()
|
| 688 |
-
source = get_string(code_gen.dump_f95, [routine])
|
| 689 |
-
expected = (
|
| 690 |
-
"REAL*8 function test(x, y, z)\n"
|
| 691 |
-
"implicit none\n"
|
| 692 |
-
"REAL*8, intent(in) :: x\n"
|
| 693 |
-
"REAL*8, intent(in) :: y\n"
|
| 694 |
-
"REAL*8, intent(in) :: z\n"
|
| 695 |
-
"test = z*(x + y)\n"
|
| 696 |
-
"end function\n"
|
| 697 |
-
)
|
| 698 |
-
assert source == expected
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
def test_numbersymbol_f_code():
|
| 702 |
-
routine = make_routine("test", pi**Catalan)
|
| 703 |
-
code_gen = FCodeGen()
|
| 704 |
-
source = get_string(code_gen.dump_f95, [routine])
|
| 705 |
-
expected = (
|
| 706 |
-
"REAL*8 function test()\n"
|
| 707 |
-
"implicit none\n"
|
| 708 |
-
"REAL*8, parameter :: Catalan = %sd0\n"
|
| 709 |
-
"REAL*8, parameter :: pi = %sd0\n"
|
| 710 |
-
"test = pi**Catalan\n"
|
| 711 |
-
"end function\n"
|
| 712 |
-
) % (Catalan.evalf(17), pi.evalf(17))
|
| 713 |
-
assert source == expected
|
| 714 |
-
|
| 715 |
-
def test_erf_f_code():
|
| 716 |
-
x = symbols('x')
|
| 717 |
-
routine = make_routine("test", erf(x) - erf(-2 * x))
|
| 718 |
-
code_gen = FCodeGen()
|
| 719 |
-
source = get_string(code_gen.dump_f95, [routine])
|
| 720 |
-
expected = (
|
| 721 |
-
"REAL*8 function test(x)\n"
|
| 722 |
-
"implicit none\n"
|
| 723 |
-
"REAL*8, intent(in) :: x\n"
|
| 724 |
-
"test = erf(x) + erf(2.0d0*x)\n"
|
| 725 |
-
"end function\n"
|
| 726 |
-
)
|
| 727 |
-
assert source == expected, source
|
| 728 |
-
|
| 729 |
-
def test_f_code_argument_order():
|
| 730 |
-
x, y, z = symbols('x,y,z')
|
| 731 |
-
expr = x + y
|
| 732 |
-
routine = make_routine("test", expr, argument_sequence=[z, x, y])
|
| 733 |
-
code_gen = FCodeGen()
|
| 734 |
-
source = get_string(code_gen.dump_f95, [routine])
|
| 735 |
-
expected = (
|
| 736 |
-
"REAL*8 function test(z, x, y)\n"
|
| 737 |
-
"implicit none\n"
|
| 738 |
-
"REAL*8, intent(in) :: z\n"
|
| 739 |
-
"REAL*8, intent(in) :: x\n"
|
| 740 |
-
"REAL*8, intent(in) :: y\n"
|
| 741 |
-
"test = x + y\n"
|
| 742 |
-
"end function\n"
|
| 743 |
-
)
|
| 744 |
-
assert source == expected
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
def test_simple_f_header():
|
| 748 |
-
x, y, z = symbols('x,y,z')
|
| 749 |
-
expr = (x + y)*z
|
| 750 |
-
routine = make_routine("test", expr)
|
| 751 |
-
code_gen = FCodeGen()
|
| 752 |
-
source = get_string(code_gen.dump_h, [routine])
|
| 753 |
-
expected = (
|
| 754 |
-
"interface\n"
|
| 755 |
-
"REAL*8 function test(x, y, z)\n"
|
| 756 |
-
"implicit none\n"
|
| 757 |
-
"REAL*8, intent(in) :: x\n"
|
| 758 |
-
"REAL*8, intent(in) :: y\n"
|
| 759 |
-
"REAL*8, intent(in) :: z\n"
|
| 760 |
-
"end function\n"
|
| 761 |
-
"end interface\n"
|
| 762 |
-
)
|
| 763 |
-
assert source == expected
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
def test_simple_f_codegen():
|
| 767 |
-
x, y, z = symbols('x,y,z')
|
| 768 |
-
expr = (x + y)*z
|
| 769 |
-
result = codegen(
|
| 770 |
-
("test", expr), "F95", "file", header=False, empty=False)
|
| 771 |
-
expected = [
|
| 772 |
-
("file.f90",
|
| 773 |
-
"REAL*8 function test(x, y, z)\n"
|
| 774 |
-
"implicit none\n"
|
| 775 |
-
"REAL*8, intent(in) :: x\n"
|
| 776 |
-
"REAL*8, intent(in) :: y\n"
|
| 777 |
-
"REAL*8, intent(in) :: z\n"
|
| 778 |
-
"test = z*(x + y)\n"
|
| 779 |
-
"end function\n"),
|
| 780 |
-
("file.h",
|
| 781 |
-
"interface\n"
|
| 782 |
-
"REAL*8 function test(x, y, z)\n"
|
| 783 |
-
"implicit none\n"
|
| 784 |
-
"REAL*8, intent(in) :: x\n"
|
| 785 |
-
"REAL*8, intent(in) :: y\n"
|
| 786 |
-
"REAL*8, intent(in) :: z\n"
|
| 787 |
-
"end function\n"
|
| 788 |
-
"end interface\n")
|
| 789 |
-
]
|
| 790 |
-
assert result == expected
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
def test_multiple_results_f():
|
| 794 |
-
x, y, z = symbols('x,y,z')
|
| 795 |
-
expr1 = (x + y)*z
|
| 796 |
-
expr2 = (x - y)*z
|
| 797 |
-
routine = make_routine(
|
| 798 |
-
"test",
|
| 799 |
-
[expr1, expr2]
|
| 800 |
-
)
|
| 801 |
-
code_gen = FCodeGen()
|
| 802 |
-
raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
def test_no_results_f():
|
| 806 |
-
raises(ValueError, lambda: make_routine("test", []))
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
def test_intrinsic_math_codegen():
|
| 810 |
-
# not included: log10
|
| 811 |
-
from sympy.functions.elementary.complexes import Abs
|
| 812 |
-
from sympy.functions.elementary.exponential import log
|
| 813 |
-
from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
|
| 814 |
-
from sympy.functions.elementary.miscellaneous import sqrt
|
| 815 |
-
from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
|
| 816 |
-
x = symbols('x')
|
| 817 |
-
name_expr = [
|
| 818 |
-
("test_abs", Abs(x)),
|
| 819 |
-
("test_acos", acos(x)),
|
| 820 |
-
("test_asin", asin(x)),
|
| 821 |
-
("test_atan", atan(x)),
|
| 822 |
-
("test_cos", cos(x)),
|
| 823 |
-
("test_cosh", cosh(x)),
|
| 824 |
-
("test_log", log(x)),
|
| 825 |
-
("test_ln", log(x)),
|
| 826 |
-
("test_sin", sin(x)),
|
| 827 |
-
("test_sinh", sinh(x)),
|
| 828 |
-
("test_sqrt", sqrt(x)),
|
| 829 |
-
("test_tan", tan(x)),
|
| 830 |
-
("test_tanh", tanh(x)),
|
| 831 |
-
]
|
| 832 |
-
result = codegen(name_expr, "F95", "file", header=False, empty=False)
|
| 833 |
-
assert result[0][0] == "file.f90"
|
| 834 |
-
expected = (
|
| 835 |
-
'REAL*8 function test_abs(x)\n'
|
| 836 |
-
'implicit none\n'
|
| 837 |
-
'REAL*8, intent(in) :: x\n'
|
| 838 |
-
'test_abs = abs(x)\n'
|
| 839 |
-
'end function\n'
|
| 840 |
-
'REAL*8 function test_acos(x)\n'
|
| 841 |
-
'implicit none\n'
|
| 842 |
-
'REAL*8, intent(in) :: x\n'
|
| 843 |
-
'test_acos = acos(x)\n'
|
| 844 |
-
'end function\n'
|
| 845 |
-
'REAL*8 function test_asin(x)\n'
|
| 846 |
-
'implicit none\n'
|
| 847 |
-
'REAL*8, intent(in) :: x\n'
|
| 848 |
-
'test_asin = asin(x)\n'
|
| 849 |
-
'end function\n'
|
| 850 |
-
'REAL*8 function test_atan(x)\n'
|
| 851 |
-
'implicit none\n'
|
| 852 |
-
'REAL*8, intent(in) :: x\n'
|
| 853 |
-
'test_atan = atan(x)\n'
|
| 854 |
-
'end function\n'
|
| 855 |
-
'REAL*8 function test_cos(x)\n'
|
| 856 |
-
'implicit none\n'
|
| 857 |
-
'REAL*8, intent(in) :: x\n'
|
| 858 |
-
'test_cos = cos(x)\n'
|
| 859 |
-
'end function\n'
|
| 860 |
-
'REAL*8 function test_cosh(x)\n'
|
| 861 |
-
'implicit none\n'
|
| 862 |
-
'REAL*8, intent(in) :: x\n'
|
| 863 |
-
'test_cosh = cosh(x)\n'
|
| 864 |
-
'end function\n'
|
| 865 |
-
'REAL*8 function test_log(x)\n'
|
| 866 |
-
'implicit none\n'
|
| 867 |
-
'REAL*8, intent(in) :: x\n'
|
| 868 |
-
'test_log = log(x)\n'
|
| 869 |
-
'end function\n'
|
| 870 |
-
'REAL*8 function test_ln(x)\n'
|
| 871 |
-
'implicit none\n'
|
| 872 |
-
'REAL*8, intent(in) :: x\n'
|
| 873 |
-
'test_ln = log(x)\n'
|
| 874 |
-
'end function\n'
|
| 875 |
-
'REAL*8 function test_sin(x)\n'
|
| 876 |
-
'implicit none\n'
|
| 877 |
-
'REAL*8, intent(in) :: x\n'
|
| 878 |
-
'test_sin = sin(x)\n'
|
| 879 |
-
'end function\n'
|
| 880 |
-
'REAL*8 function test_sinh(x)\n'
|
| 881 |
-
'implicit none\n'
|
| 882 |
-
'REAL*8, intent(in) :: x\n'
|
| 883 |
-
'test_sinh = sinh(x)\n'
|
| 884 |
-
'end function\n'
|
| 885 |
-
'REAL*8 function test_sqrt(x)\n'
|
| 886 |
-
'implicit none\n'
|
| 887 |
-
'REAL*8, intent(in) :: x\n'
|
| 888 |
-
'test_sqrt = sqrt(x)\n'
|
| 889 |
-
'end function\n'
|
| 890 |
-
'REAL*8 function test_tan(x)\n'
|
| 891 |
-
'implicit none\n'
|
| 892 |
-
'REAL*8, intent(in) :: x\n'
|
| 893 |
-
'test_tan = tan(x)\n'
|
| 894 |
-
'end function\n'
|
| 895 |
-
'REAL*8 function test_tanh(x)\n'
|
| 896 |
-
'implicit none\n'
|
| 897 |
-
'REAL*8, intent(in) :: x\n'
|
| 898 |
-
'test_tanh = tanh(x)\n'
|
| 899 |
-
'end function\n'
|
| 900 |
-
)
|
| 901 |
-
assert result[0][1] == expected
|
| 902 |
-
|
| 903 |
-
assert result[1][0] == "file.h"
|
| 904 |
-
expected = (
|
| 905 |
-
'interface\n'
|
| 906 |
-
'REAL*8 function test_abs(x)\n'
|
| 907 |
-
'implicit none\n'
|
| 908 |
-
'REAL*8, intent(in) :: x\n'
|
| 909 |
-
'end function\n'
|
| 910 |
-
'end interface\n'
|
| 911 |
-
'interface\n'
|
| 912 |
-
'REAL*8 function test_acos(x)\n'
|
| 913 |
-
'implicit none\n'
|
| 914 |
-
'REAL*8, intent(in) :: x\n'
|
| 915 |
-
'end function\n'
|
| 916 |
-
'end interface\n'
|
| 917 |
-
'interface\n'
|
| 918 |
-
'REAL*8 function test_asin(x)\n'
|
| 919 |
-
'implicit none\n'
|
| 920 |
-
'REAL*8, intent(in) :: x\n'
|
| 921 |
-
'end function\n'
|
| 922 |
-
'end interface\n'
|
| 923 |
-
'interface\n'
|
| 924 |
-
'REAL*8 function test_atan(x)\n'
|
| 925 |
-
'implicit none\n'
|
| 926 |
-
'REAL*8, intent(in) :: x\n'
|
| 927 |
-
'end function\n'
|
| 928 |
-
'end interface\n'
|
| 929 |
-
'interface\n'
|
| 930 |
-
'REAL*8 function test_cos(x)\n'
|
| 931 |
-
'implicit none\n'
|
| 932 |
-
'REAL*8, intent(in) :: x\n'
|
| 933 |
-
'end function\n'
|
| 934 |
-
'end interface\n'
|
| 935 |
-
'interface\n'
|
| 936 |
-
'REAL*8 function test_cosh(x)\n'
|
| 937 |
-
'implicit none\n'
|
| 938 |
-
'REAL*8, intent(in) :: x\n'
|
| 939 |
-
'end function\n'
|
| 940 |
-
'end interface\n'
|
| 941 |
-
'interface\n'
|
| 942 |
-
'REAL*8 function test_log(x)\n'
|
| 943 |
-
'implicit none\n'
|
| 944 |
-
'REAL*8, intent(in) :: x\n'
|
| 945 |
-
'end function\n'
|
| 946 |
-
'end interface\n'
|
| 947 |
-
'interface\n'
|
| 948 |
-
'REAL*8 function test_ln(x)\n'
|
| 949 |
-
'implicit none\n'
|
| 950 |
-
'REAL*8, intent(in) :: x\n'
|
| 951 |
-
'end function\n'
|
| 952 |
-
'end interface\n'
|
| 953 |
-
'interface\n'
|
| 954 |
-
'REAL*8 function test_sin(x)\n'
|
| 955 |
-
'implicit none\n'
|
| 956 |
-
'REAL*8, intent(in) :: x\n'
|
| 957 |
-
'end function\n'
|
| 958 |
-
'end interface\n'
|
| 959 |
-
'interface\n'
|
| 960 |
-
'REAL*8 function test_sinh(x)\n'
|
| 961 |
-
'implicit none\n'
|
| 962 |
-
'REAL*8, intent(in) :: x\n'
|
| 963 |
-
'end function\n'
|
| 964 |
-
'end interface\n'
|
| 965 |
-
'interface\n'
|
| 966 |
-
'REAL*8 function test_sqrt(x)\n'
|
| 967 |
-
'implicit none\n'
|
| 968 |
-
'REAL*8, intent(in) :: x\n'
|
| 969 |
-
'end function\n'
|
| 970 |
-
'end interface\n'
|
| 971 |
-
'interface\n'
|
| 972 |
-
'REAL*8 function test_tan(x)\n'
|
| 973 |
-
'implicit none\n'
|
| 974 |
-
'REAL*8, intent(in) :: x\n'
|
| 975 |
-
'end function\n'
|
| 976 |
-
'end interface\n'
|
| 977 |
-
'interface\n'
|
| 978 |
-
'REAL*8 function test_tanh(x)\n'
|
| 979 |
-
'implicit none\n'
|
| 980 |
-
'REAL*8, intent(in) :: x\n'
|
| 981 |
-
'end function\n'
|
| 982 |
-
'end interface\n'
|
| 983 |
-
)
|
| 984 |
-
assert result[1][1] == expected
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
def test_intrinsic_math2_codegen():
|
| 988 |
-
# not included: frexp, ldexp, modf, fmod
|
| 989 |
-
from sympy.functions.elementary.trigonometric import atan2
|
| 990 |
-
x, y = symbols('x,y')
|
| 991 |
-
name_expr = [
|
| 992 |
-
("test_atan2", atan2(x, y)),
|
| 993 |
-
("test_pow", x**y),
|
| 994 |
-
]
|
| 995 |
-
result = codegen(name_expr, "F95", "file", header=False, empty=False)
|
| 996 |
-
assert result[0][0] == "file.f90"
|
| 997 |
-
expected = (
|
| 998 |
-
'REAL*8 function test_atan2(x, y)\n'
|
| 999 |
-
'implicit none\n'
|
| 1000 |
-
'REAL*8, intent(in) :: x\n'
|
| 1001 |
-
'REAL*8, intent(in) :: y\n'
|
| 1002 |
-
'test_atan2 = atan2(x, y)\n'
|
| 1003 |
-
'end function\n'
|
| 1004 |
-
'REAL*8 function test_pow(x, y)\n'
|
| 1005 |
-
'implicit none\n'
|
| 1006 |
-
'REAL*8, intent(in) :: x\n'
|
| 1007 |
-
'REAL*8, intent(in) :: y\n'
|
| 1008 |
-
'test_pow = x**y\n'
|
| 1009 |
-
'end function\n'
|
| 1010 |
-
)
|
| 1011 |
-
assert result[0][1] == expected
|
| 1012 |
-
|
| 1013 |
-
assert result[1][0] == "file.h"
|
| 1014 |
-
expected = (
|
| 1015 |
-
'interface\n'
|
| 1016 |
-
'REAL*8 function test_atan2(x, y)\n'
|
| 1017 |
-
'implicit none\n'
|
| 1018 |
-
'REAL*8, intent(in) :: x\n'
|
| 1019 |
-
'REAL*8, intent(in) :: y\n'
|
| 1020 |
-
'end function\n'
|
| 1021 |
-
'end interface\n'
|
| 1022 |
-
'interface\n'
|
| 1023 |
-
'REAL*8 function test_pow(x, y)\n'
|
| 1024 |
-
'implicit none\n'
|
| 1025 |
-
'REAL*8, intent(in) :: x\n'
|
| 1026 |
-
'REAL*8, intent(in) :: y\n'
|
| 1027 |
-
'end function\n'
|
| 1028 |
-
'end interface\n'
|
| 1029 |
-
)
|
| 1030 |
-
assert result[1][1] == expected
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
def test_complicated_codegen_f95():
|
| 1034 |
-
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
| 1035 |
-
x, y, z = symbols('x,y,z')
|
| 1036 |
-
name_expr = [
|
| 1037 |
-
("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
|
| 1038 |
-
("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
|
| 1039 |
-
]
|
| 1040 |
-
result = codegen(name_expr, "F95", "file", header=False, empty=False)
|
| 1041 |
-
assert result[0][0] == "file.f90"
|
| 1042 |
-
expected = (
|
| 1043 |
-
'REAL*8 function test1(x, y, z)\n'
|
| 1044 |
-
'implicit none\n'
|
| 1045 |
-
'REAL*8, intent(in) :: x\n'
|
| 1046 |
-
'REAL*8, intent(in) :: y\n'
|
| 1047 |
-
'REAL*8, intent(in) :: z\n'
|
| 1048 |
-
'test1 = sin(x)**7 + 7*sin(x)**6*cos(y) + 7*sin(x)**6*tan(z) + 21*sin(x) &\n'
|
| 1049 |
-
' **5*cos(y)**2 + 42*sin(x)**5*cos(y)*tan(z) + 21*sin(x)**5*tan(z) &\n'
|
| 1050 |
-
' **2 + 35*sin(x)**4*cos(y)**3 + 105*sin(x)**4*cos(y)**2*tan(z) + &\n'
|
| 1051 |
-
' 105*sin(x)**4*cos(y)*tan(z)**2 + 35*sin(x)**4*tan(z)**3 + 35*sin( &\n'
|
| 1052 |
-
' x)**3*cos(y)**4 + 140*sin(x)**3*cos(y)**3*tan(z) + 210*sin(x)**3* &\n'
|
| 1053 |
-
' cos(y)**2*tan(z)**2 + 140*sin(x)**3*cos(y)*tan(z)**3 + 35*sin(x) &\n'
|
| 1054 |
-
' **3*tan(z)**4 + 21*sin(x)**2*cos(y)**5 + 105*sin(x)**2*cos(y)**4* &\n'
|
| 1055 |
-
' tan(z) + 210*sin(x)**2*cos(y)**3*tan(z)**2 + 210*sin(x)**2*cos(y) &\n'
|
| 1056 |
-
' **2*tan(z)**3 + 105*sin(x)**2*cos(y)*tan(z)**4 + 21*sin(x)**2*tan &\n'
|
| 1057 |
-
' (z)**5 + 7*sin(x)*cos(y)**6 + 42*sin(x)*cos(y)**5*tan(z) + 105* &\n'
|
| 1058 |
-
' sin(x)*cos(y)**4*tan(z)**2 + 140*sin(x)*cos(y)**3*tan(z)**3 + 105 &\n'
|
| 1059 |
-
' *sin(x)*cos(y)**2*tan(z)**4 + 42*sin(x)*cos(y)*tan(z)**5 + 7*sin( &\n'
|
| 1060 |
-
' x)*tan(z)**6 + cos(y)**7 + 7*cos(y)**6*tan(z) + 21*cos(y)**5*tan( &\n'
|
| 1061 |
-
' z)**2 + 35*cos(y)**4*tan(z)**3 + 35*cos(y)**3*tan(z)**4 + 21*cos( &\n'
|
| 1062 |
-
' y)**2*tan(z)**5 + 7*cos(y)*tan(z)**6 + tan(z)**7\n'
|
| 1063 |
-
'end function\n'
|
| 1064 |
-
'REAL*8 function test2(x, y, z)\n'
|
| 1065 |
-
'implicit none\n'
|
| 1066 |
-
'REAL*8, intent(in) :: x\n'
|
| 1067 |
-
'REAL*8, intent(in) :: y\n'
|
| 1068 |
-
'REAL*8, intent(in) :: z\n'
|
| 1069 |
-
'test2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n'
|
| 1070 |
-
'end function\n'
|
| 1071 |
-
)
|
| 1072 |
-
assert result[0][1] == expected
|
| 1073 |
-
assert result[1][0] == "file.h"
|
| 1074 |
-
expected = (
|
| 1075 |
-
'interface\n'
|
| 1076 |
-
'REAL*8 function test1(x, y, z)\n'
|
| 1077 |
-
'implicit none\n'
|
| 1078 |
-
'REAL*8, intent(in) :: x\n'
|
| 1079 |
-
'REAL*8, intent(in) :: y\n'
|
| 1080 |
-
'REAL*8, intent(in) :: z\n'
|
| 1081 |
-
'end function\n'
|
| 1082 |
-
'end interface\n'
|
| 1083 |
-
'interface\n'
|
| 1084 |
-
'REAL*8 function test2(x, y, z)\n'
|
| 1085 |
-
'implicit none\n'
|
| 1086 |
-
'REAL*8, intent(in) :: x\n'
|
| 1087 |
-
'REAL*8, intent(in) :: y\n'
|
| 1088 |
-
'REAL*8, intent(in) :: z\n'
|
| 1089 |
-
'end function\n'
|
| 1090 |
-
'end interface\n'
|
| 1091 |
-
)
|
| 1092 |
-
assert result[1][1] == expected
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
def test_loops():
|
| 1096 |
-
from sympy.tensor import IndexedBase, Idx
|
| 1097 |
-
from sympy.core.symbol import symbols
|
| 1098 |
-
|
| 1099 |
-
n, m = symbols('n,m', integer=True)
|
| 1100 |
-
A, x, y = map(IndexedBase, 'Axy')
|
| 1101 |
-
i = Idx('i', m)
|
| 1102 |
-
j = Idx('j', n)
|
| 1103 |
-
|
| 1104 |
-
(f1, code), (f2, interface) = codegen(
|
| 1105 |
-
('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
|
| 1106 |
-
|
| 1107 |
-
assert f1 == 'file.f90'
|
| 1108 |
-
expected = (
|
| 1109 |
-
'subroutine matrix_vector(A, m, n, x, y)\n'
|
| 1110 |
-
'implicit none\n'
|
| 1111 |
-
'INTEGER*4, intent(in) :: m\n'
|
| 1112 |
-
'INTEGER*4, intent(in) :: n\n'
|
| 1113 |
-
'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
|
| 1114 |
-
'REAL*8, intent(in), dimension(1:n) :: x\n'
|
| 1115 |
-
'REAL*8, intent(out), dimension(1:m) :: y\n'
|
| 1116 |
-
'INTEGER*4 :: i\n'
|
| 1117 |
-
'INTEGER*4 :: j\n'
|
| 1118 |
-
'do i = 1, m\n'
|
| 1119 |
-
' y(i) = 0\n'
|
| 1120 |
-
'end do\n'
|
| 1121 |
-
'do i = 1, m\n'
|
| 1122 |
-
' do j = 1, n\n'
|
| 1123 |
-
' y(i) = %(rhs)s + y(i)\n'
|
| 1124 |
-
' end do\n'
|
| 1125 |
-
'end do\n'
|
| 1126 |
-
'end subroutine\n'
|
| 1127 |
-
)
|
| 1128 |
-
|
| 1129 |
-
assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
|
| 1130 |
-
code == expected % {'rhs': 'x(j)*A(i, j)'}
|
| 1131 |
-
assert f2 == 'file.h'
|
| 1132 |
-
assert interface == (
|
| 1133 |
-
'interface\n'
|
| 1134 |
-
'subroutine matrix_vector(A, m, n, x, y)\n'
|
| 1135 |
-
'implicit none\n'
|
| 1136 |
-
'INTEGER*4, intent(in) :: m\n'
|
| 1137 |
-
'INTEGER*4, intent(in) :: n\n'
|
| 1138 |
-
'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
|
| 1139 |
-
'REAL*8, intent(in), dimension(1:n) :: x\n'
|
| 1140 |
-
'REAL*8, intent(out), dimension(1:m) :: y\n'
|
| 1141 |
-
'end subroutine\n'
|
| 1142 |
-
'end interface\n'
|
| 1143 |
-
)
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
def test_dummy_loops_f95():
|
| 1147 |
-
from sympy.tensor import IndexedBase, Idx
|
| 1148 |
-
i, m = symbols('i m', integer=True, cls=Dummy)
|
| 1149 |
-
x = IndexedBase('x')
|
| 1150 |
-
y = IndexedBase('y')
|
| 1151 |
-
i = Idx(i, m)
|
| 1152 |
-
expected = (
|
| 1153 |
-
'subroutine test_dummies(m_%(mcount)i, x, y)\n'
|
| 1154 |
-
'implicit none\n'
|
| 1155 |
-
'INTEGER*4, intent(in) :: m_%(mcount)i\n'
|
| 1156 |
-
'REAL*8, intent(in), dimension(1:m_%(mcount)i) :: x\n'
|
| 1157 |
-
'REAL*8, intent(out), dimension(1:m_%(mcount)i) :: y\n'
|
| 1158 |
-
'INTEGER*4 :: i_%(icount)i\n'
|
| 1159 |
-
'do i_%(icount)i = 1, m_%(mcount)i\n'
|
| 1160 |
-
' y(i_%(icount)i) = x(i_%(icount)i)\n'
|
| 1161 |
-
'end do\n'
|
| 1162 |
-
'end subroutine\n'
|
| 1163 |
-
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
|
| 1164 |
-
r = make_routine('test_dummies', Eq(y[i], x[i]))
|
| 1165 |
-
c = FCodeGen()
|
| 1166 |
-
code = get_string(c.dump_f95, [r])
|
| 1167 |
-
assert code == expected
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
def test_loops_InOut():
|
| 1171 |
-
from sympy.tensor import IndexedBase, Idx
|
| 1172 |
-
from sympy.core.symbol import symbols
|
| 1173 |
-
|
| 1174 |
-
i, j, n, m = symbols('i,j,n,m', integer=True)
|
| 1175 |
-
A, x, y = symbols('A,x,y')
|
| 1176 |
-
A = IndexedBase(A)[Idx(i, m), Idx(j, n)]
|
| 1177 |
-
x = IndexedBase(x)[Idx(j, n)]
|
| 1178 |
-
y = IndexedBase(y)[Idx(i, m)]
|
| 1179 |
-
|
| 1180 |
-
(f1, code), (f2, interface) = codegen(
|
| 1181 |
-
('matrix_vector', Eq(y, y + A*x)), "F95", "file", header=False, empty=False)
|
| 1182 |
-
|
| 1183 |
-
assert f1 == 'file.f90'
|
| 1184 |
-
expected = (
|
| 1185 |
-
'subroutine matrix_vector(A, m, n, x, y)\n'
|
| 1186 |
-
'implicit none\n'
|
| 1187 |
-
'INTEGER*4, intent(in) :: m\n'
|
| 1188 |
-
'INTEGER*4, intent(in) :: n\n'
|
| 1189 |
-
'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
|
| 1190 |
-
'REAL*8, intent(in), dimension(1:n) :: x\n'
|
| 1191 |
-
'REAL*8, intent(inout), dimension(1:m) :: y\n'
|
| 1192 |
-
'INTEGER*4 :: i\n'
|
| 1193 |
-
'INTEGER*4 :: j\n'
|
| 1194 |
-
'do i = 1, m\n'
|
| 1195 |
-
' do j = 1, n\n'
|
| 1196 |
-
' y(i) = %(rhs)s + y(i)\n'
|
| 1197 |
-
' end do\n'
|
| 1198 |
-
'end do\n'
|
| 1199 |
-
'end subroutine\n'
|
| 1200 |
-
)
|
| 1201 |
-
|
| 1202 |
-
assert (code == expected % {'rhs': 'A(i, j)*x(j)'} or
|
| 1203 |
-
code == expected % {'rhs': 'x(j)*A(i, j)'})
|
| 1204 |
-
assert f2 == 'file.h'
|
| 1205 |
-
assert interface == (
|
| 1206 |
-
'interface\n'
|
| 1207 |
-
'subroutine matrix_vector(A, m, n, x, y)\n'
|
| 1208 |
-
'implicit none\n'
|
| 1209 |
-
'INTEGER*4, intent(in) :: m\n'
|
| 1210 |
-
'INTEGER*4, intent(in) :: n\n'
|
| 1211 |
-
'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
|
| 1212 |
-
'REAL*8, intent(in), dimension(1:n) :: x\n'
|
| 1213 |
-
'REAL*8, intent(inout), dimension(1:m) :: y\n'
|
| 1214 |
-
'end subroutine\n'
|
| 1215 |
-
'end interface\n'
|
| 1216 |
-
)
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
def test_partial_loops_f():
|
| 1220 |
-
# check that loop boundaries are determined by Idx, and array strides
|
| 1221 |
-
# determined by shape of IndexedBase object.
|
| 1222 |
-
from sympy.tensor import IndexedBase, Idx
|
| 1223 |
-
from sympy.core.symbol import symbols
|
| 1224 |
-
n, m, o, p = symbols('n m o p', integer=True)
|
| 1225 |
-
A = IndexedBase('A', shape=(m, p))
|
| 1226 |
-
x = IndexedBase('x')
|
| 1227 |
-
y = IndexedBase('y')
|
| 1228 |
-
i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
|
| 1229 |
-
j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
|
| 1230 |
-
|
| 1231 |
-
(f1, code), (f2, interface) = codegen(
|
| 1232 |
-
('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
|
| 1233 |
-
|
| 1234 |
-
expected = (
|
| 1235 |
-
'subroutine matrix_vector(A, m, n, o, p, x, y)\n'
|
| 1236 |
-
'implicit none\n'
|
| 1237 |
-
'INTEGER*4, intent(in) :: m\n'
|
| 1238 |
-
'INTEGER*4, intent(in) :: n\n'
|
| 1239 |
-
'INTEGER*4, intent(in) :: o\n'
|
| 1240 |
-
'INTEGER*4, intent(in) :: p\n'
|
| 1241 |
-
'REAL*8, intent(in), dimension(1:m, 1:p) :: A\n'
|
| 1242 |
-
'REAL*8, intent(in), dimension(1:n) :: x\n'
|
| 1243 |
-
'REAL*8, intent(out), dimension(1:%(iup-ilow)s) :: y\n'
|
| 1244 |
-
'INTEGER*4 :: i\n'
|
| 1245 |
-
'INTEGER*4 :: j\n'
|
| 1246 |
-
'do i = %(ilow)s, %(iup)s\n'
|
| 1247 |
-
' y(i) = 0\n'
|
| 1248 |
-
'end do\n'
|
| 1249 |
-
'do i = %(ilow)s, %(iup)s\n'
|
| 1250 |
-
' do j = 1, n\n'
|
| 1251 |
-
' y(i) = %(rhs)s + y(i)\n'
|
| 1252 |
-
' end do\n'
|
| 1253 |
-
'end do\n'
|
| 1254 |
-
'end subroutine\n'
|
| 1255 |
-
) % {
|
| 1256 |
-
'rhs': '%(rhs)s',
|
| 1257 |
-
'iup': str(m - 4),
|
| 1258 |
-
'ilow': str(1 + o),
|
| 1259 |
-
'iup-ilow': str(m - 4 - o)
|
| 1260 |
-
}
|
| 1261 |
-
|
| 1262 |
-
assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
|
| 1263 |
-
code == expected % {'rhs': 'x(j)*A(i, j)'}
|
| 1264 |
-
|
| 1265 |
-
|
| 1266 |
-
def test_output_arg_f():
|
| 1267 |
-
from sympy.core.relational import Equality
|
| 1268 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 1269 |
-
x, y, z = symbols("x,y,z")
|
| 1270 |
-
r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
|
| 1271 |
-
c = FCodeGen()
|
| 1272 |
-
result = c.write([r], "test", header=False, empty=False)
|
| 1273 |
-
assert result[0][0] == "test.f90"
|
| 1274 |
-
assert result[0][1] == (
|
| 1275 |
-
'REAL*8 function foo(x, y)\n'
|
| 1276 |
-
'implicit none\n'
|
| 1277 |
-
'REAL*8, intent(in) :: x\n'
|
| 1278 |
-
'REAL*8, intent(out) :: y\n'
|
| 1279 |
-
'y = sin(x)\n'
|
| 1280 |
-
'foo = cos(x)\n'
|
| 1281 |
-
'end function\n'
|
| 1282 |
-
)
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
def test_inline_function():
|
| 1286 |
-
from sympy.tensor import IndexedBase, Idx
|
| 1287 |
-
from sympy.core.symbol import symbols
|
| 1288 |
-
n, m = symbols('n m', integer=True)
|
| 1289 |
-
A, x, y = map(IndexedBase, 'Axy')
|
| 1290 |
-
i = Idx('i', m)
|
| 1291 |
-
p = FCodeGen()
|
| 1292 |
-
func = implemented_function('func', Lambda(n, n*(n + 1)))
|
| 1293 |
-
routine = make_routine('test_inline', Eq(y[i], func(x[i])))
|
| 1294 |
-
code = get_string(p.dump_f95, [routine])
|
| 1295 |
-
expected = (
|
| 1296 |
-
'subroutine test_inline(m, x, y)\n'
|
| 1297 |
-
'implicit none\n'
|
| 1298 |
-
'INTEGER*4, intent(in) :: m\n'
|
| 1299 |
-
'REAL*8, intent(in), dimension(1:m) :: x\n'
|
| 1300 |
-
'REAL*8, intent(out), dimension(1:m) :: y\n'
|
| 1301 |
-
'INTEGER*4 :: i\n'
|
| 1302 |
-
'do i = 1, m\n'
|
| 1303 |
-
' y(i) = %s*%s\n'
|
| 1304 |
-
'end do\n'
|
| 1305 |
-
'end subroutine\n'
|
| 1306 |
-
)
|
| 1307 |
-
args = ('x(i)', '(x(i) + 1)')
|
| 1308 |
-
assert code == expected % args or\
|
| 1309 |
-
code == expected % args[::-1]
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
-
def test_f_code_call_signature_wrap():
|
| 1313 |
-
# Issue #7934
|
| 1314 |
-
x = symbols('x:20')
|
| 1315 |
-
expr = 0
|
| 1316 |
-
for sym in x:
|
| 1317 |
-
expr += sym
|
| 1318 |
-
routine = make_routine("test", expr)
|
| 1319 |
-
code_gen = FCodeGen()
|
| 1320 |
-
source = get_string(code_gen.dump_f95, [routine])
|
| 1321 |
-
expected = """\
|
| 1322 |
-
REAL*8 function test(x0, x1, x10, x11, x12, x13, x14, x15, x16, x17, x18, &
|
| 1323 |
-
x19, x2, x3, x4, x5, x6, x7, x8, x9)
|
| 1324 |
-
implicit none
|
| 1325 |
-
REAL*8, intent(in) :: x0
|
| 1326 |
-
REAL*8, intent(in) :: x1
|
| 1327 |
-
REAL*8, intent(in) :: x10
|
| 1328 |
-
REAL*8, intent(in) :: x11
|
| 1329 |
-
REAL*8, intent(in) :: x12
|
| 1330 |
-
REAL*8, intent(in) :: x13
|
| 1331 |
-
REAL*8, intent(in) :: x14
|
| 1332 |
-
REAL*8, intent(in) :: x15
|
| 1333 |
-
REAL*8, intent(in) :: x16
|
| 1334 |
-
REAL*8, intent(in) :: x17
|
| 1335 |
-
REAL*8, intent(in) :: x18
|
| 1336 |
-
REAL*8, intent(in) :: x19
|
| 1337 |
-
REAL*8, intent(in) :: x2
|
| 1338 |
-
REAL*8, intent(in) :: x3
|
| 1339 |
-
REAL*8, intent(in) :: x4
|
| 1340 |
-
REAL*8, intent(in) :: x5
|
| 1341 |
-
REAL*8, intent(in) :: x6
|
| 1342 |
-
REAL*8, intent(in) :: x7
|
| 1343 |
-
REAL*8, intent(in) :: x8
|
| 1344 |
-
REAL*8, intent(in) :: x9
|
| 1345 |
-
test = x0 + x1 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + &
|
| 1346 |
-
x19 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
|
| 1347 |
-
end function
|
| 1348 |
-
"""
|
| 1349 |
-
assert source == expected
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
-
def test_check_case():
|
| 1353 |
-
x, X = symbols('x,X')
|
| 1354 |
-
raises(CodeGenError, lambda: codegen(('test', x*X), 'f95', 'prefix'))
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
def test_check_case_false_positive():
|
| 1358 |
-
# The upper case/lower case exception should not be triggered by SymPy
|
| 1359 |
-
# objects that differ only because of assumptions. (It may be useful to
|
| 1360 |
-
# have a check for that as well, but here we only want to test against
|
| 1361 |
-
# false positives with respect to case checking.)
|
| 1362 |
-
x1 = symbols('x')
|
| 1363 |
-
x2 = symbols('x', my_assumption=True)
|
| 1364 |
-
try:
|
| 1365 |
-
codegen(('test', x1*x2), 'f95', 'prefix')
|
| 1366 |
-
except CodeGenError as e:
|
| 1367 |
-
if e.args[0].startswith("Fortran ignores case."):
|
| 1368 |
-
raise AssertionError("This exception should not be raised!")
|
| 1369 |
-
|
| 1370 |
-
|
| 1371 |
-
def test_c_fortran_omit_routine_name():
|
| 1372 |
-
x, y = symbols("x,y")
|
| 1373 |
-
name_expr = [("foo", 2*x)]
|
| 1374 |
-
result = codegen(name_expr, "F95", header=False, empty=False)
|
| 1375 |
-
expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
|
| 1376 |
-
assert result[0][1] == expresult[0][1]
|
| 1377 |
-
|
| 1378 |
-
name_expr = ("foo", x*y)
|
| 1379 |
-
result = codegen(name_expr, "F95", header=False, empty=False)
|
| 1380 |
-
expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
|
| 1381 |
-
assert result[0][1] == expresult[0][1]
|
| 1382 |
-
|
| 1383 |
-
name_expr = ("foo", Matrix([[x, y], [x+y, x-y]]))
|
| 1384 |
-
result = codegen(name_expr, "C89", header=False, empty=False)
|
| 1385 |
-
expresult = codegen(name_expr, "C89", "foo", header=False, empty=False)
|
| 1386 |
-
assert result[0][1] == expresult[0][1]
|
| 1387 |
-
|
| 1388 |
-
|
| 1389 |
-
def test_fcode_matrix_output():
|
| 1390 |
-
x, y, z = symbols('x,y,z')
|
| 1391 |
-
e1 = x + y
|
| 1392 |
-
e2 = Matrix([[x, y], [z, 16]])
|
| 1393 |
-
name_expr = ("test", (e1, e2))
|
| 1394 |
-
result = codegen(name_expr, "f95", "test", header=False, empty=False)
|
| 1395 |
-
source = result[0][1]
|
| 1396 |
-
expected = (
|
| 1397 |
-
"REAL*8 function test(x, y, z, out_%(hash)s)\n"
|
| 1398 |
-
"implicit none\n"
|
| 1399 |
-
"REAL*8, intent(in) :: x\n"
|
| 1400 |
-
"REAL*8, intent(in) :: y\n"
|
| 1401 |
-
"REAL*8, intent(in) :: z\n"
|
| 1402 |
-
"REAL*8, intent(out), dimension(1:2, 1:2) :: out_%(hash)s\n"
|
| 1403 |
-
"out_%(hash)s(1, 1) = x\n"
|
| 1404 |
-
"out_%(hash)s(2, 1) = z\n"
|
| 1405 |
-
"out_%(hash)s(1, 2) = y\n"
|
| 1406 |
-
"out_%(hash)s(2, 2) = 16\n"
|
| 1407 |
-
"test = x + y\n"
|
| 1408 |
-
"end function\n"
|
| 1409 |
-
)
|
| 1410 |
-
# look for the magic number
|
| 1411 |
-
a = source.splitlines()[5]
|
| 1412 |
-
b = a.split('_')
|
| 1413 |
-
out = b[1]
|
| 1414 |
-
expected = expected % {'hash': out}
|
| 1415 |
-
assert source == expected
|
| 1416 |
-
|
| 1417 |
-
|
| 1418 |
-
def test_fcode_results_named_ordered():
|
| 1419 |
-
x, y, z = symbols('x,y,z')
|
| 1420 |
-
B, C = symbols('B,C')
|
| 1421 |
-
A = MatrixSymbol('A', 1, 3)
|
| 1422 |
-
expr1 = Equality(A, Matrix([[1, 2, x]]))
|
| 1423 |
-
expr2 = Equality(C, (x + y)*z)
|
| 1424 |
-
expr3 = Equality(B, 2*x)
|
| 1425 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 1426 |
-
result = codegen(name_expr, "f95", "test", header=False, empty=False,
|
| 1427 |
-
argument_sequence=(x, z, y, C, A, B))
|
| 1428 |
-
source = result[0][1]
|
| 1429 |
-
expected = (
|
| 1430 |
-
"subroutine test(x, z, y, C, A, B)\n"
|
| 1431 |
-
"implicit none\n"
|
| 1432 |
-
"REAL*8, intent(in) :: x\n"
|
| 1433 |
-
"REAL*8, intent(in) :: z\n"
|
| 1434 |
-
"REAL*8, intent(in) :: y\n"
|
| 1435 |
-
"REAL*8, intent(out) :: C\n"
|
| 1436 |
-
"REAL*8, intent(out) :: B\n"
|
| 1437 |
-
"REAL*8, intent(out), dimension(1:1, 1:3) :: A\n"
|
| 1438 |
-
"C = z*(x + y)\n"
|
| 1439 |
-
"A(1, 1) = 1\n"
|
| 1440 |
-
"A(1, 2) = 2\n"
|
| 1441 |
-
"A(1, 3) = x\n"
|
| 1442 |
-
"B = 2*x\n"
|
| 1443 |
-
"end subroutine\n"
|
| 1444 |
-
)
|
| 1445 |
-
assert source == expected
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
def test_fcode_matrixsymbol_slice():
|
| 1449 |
-
A = MatrixSymbol('A', 2, 3)
|
| 1450 |
-
B = MatrixSymbol('B', 1, 3)
|
| 1451 |
-
C = MatrixSymbol('C', 1, 3)
|
| 1452 |
-
D = MatrixSymbol('D', 2, 1)
|
| 1453 |
-
name_expr = ("test", [Equality(B, A[0, :]),
|
| 1454 |
-
Equality(C, A[1, :]),
|
| 1455 |
-
Equality(D, A[:, 2])])
|
| 1456 |
-
result = codegen(name_expr, "f95", "test", header=False, empty=False)
|
| 1457 |
-
source = result[0][1]
|
| 1458 |
-
expected = (
|
| 1459 |
-
"subroutine test(A, B, C, D)\n"
|
| 1460 |
-
"implicit none\n"
|
| 1461 |
-
"REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
|
| 1462 |
-
"REAL*8, intent(out), dimension(1:1, 1:3) :: B\n"
|
| 1463 |
-
"REAL*8, intent(out), dimension(1:1, 1:3) :: C\n"
|
| 1464 |
-
"REAL*8, intent(out), dimension(1:2, 1:1) :: D\n"
|
| 1465 |
-
"B(1, 1) = A(1, 1)\n"
|
| 1466 |
-
"B(1, 2) = A(1, 2)\n"
|
| 1467 |
-
"B(1, 3) = A(1, 3)\n"
|
| 1468 |
-
"C(1, 1) = A(2, 1)\n"
|
| 1469 |
-
"C(1, 2) = A(2, 2)\n"
|
| 1470 |
-
"C(1, 3) = A(2, 3)\n"
|
| 1471 |
-
"D(1, 1) = A(1, 3)\n"
|
| 1472 |
-
"D(2, 1) = A(2, 3)\n"
|
| 1473 |
-
"end subroutine\n"
|
| 1474 |
-
)
|
| 1475 |
-
assert source == expected
|
| 1476 |
-
|
| 1477 |
-
|
| 1478 |
-
def test_fcode_matrixsymbol_slice_autoname():
|
| 1479 |
-
# see issue #8093
|
| 1480 |
-
A = MatrixSymbol('A', 2, 3)
|
| 1481 |
-
name_expr = ("test", A[:, 1])
|
| 1482 |
-
result = codegen(name_expr, "f95", "test", header=False, empty=False)
|
| 1483 |
-
source = result[0][1]
|
| 1484 |
-
expected = (
|
| 1485 |
-
"subroutine test(A, out_%(hash)s)\n"
|
| 1486 |
-
"implicit none\n"
|
| 1487 |
-
"REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
|
| 1488 |
-
"REAL*8, intent(out), dimension(1:2, 1:1) :: out_%(hash)s\n"
|
| 1489 |
-
"out_%(hash)s(1, 1) = A(1, 2)\n"
|
| 1490 |
-
"out_%(hash)s(2, 1) = A(2, 2)\n"
|
| 1491 |
-
"end subroutine\n"
|
| 1492 |
-
)
|
| 1493 |
-
# look for the magic number
|
| 1494 |
-
a = source.splitlines()[3]
|
| 1495 |
-
b = a.split('_')
|
| 1496 |
-
out = b[1]
|
| 1497 |
-
expected = expected % {'hash': out}
|
| 1498 |
-
assert source == expected
|
| 1499 |
-
|
| 1500 |
-
|
| 1501 |
-
def test_global_vars():
|
| 1502 |
-
x, y, z, t = symbols("x y z t")
|
| 1503 |
-
result = codegen(('f', x*y), "F95", header=False, empty=False,
|
| 1504 |
-
global_vars=(y,))
|
| 1505 |
-
source = result[0][1]
|
| 1506 |
-
expected = (
|
| 1507 |
-
"REAL*8 function f(x)\n"
|
| 1508 |
-
"implicit none\n"
|
| 1509 |
-
"REAL*8, intent(in) :: x\n"
|
| 1510 |
-
"f = x*y\n"
|
| 1511 |
-
"end function\n"
|
| 1512 |
-
)
|
| 1513 |
-
assert source == expected
|
| 1514 |
-
|
| 1515 |
-
expected = (
|
| 1516 |
-
'#include "f.h"\n'
|
| 1517 |
-
'#include <math.h>\n'
|
| 1518 |
-
'double f(double x, double y) {\n'
|
| 1519 |
-
' double f_result;\n'
|
| 1520 |
-
' f_result = x*y + z;\n'
|
| 1521 |
-
' return f_result;\n'
|
| 1522 |
-
'}\n'
|
| 1523 |
-
)
|
| 1524 |
-
result = codegen(('f', x*y+z), "C", header=False, empty=False,
|
| 1525 |
-
global_vars=(z, t))
|
| 1526 |
-
source = result[0][1]
|
| 1527 |
-
assert source == expected
|
| 1528 |
-
|
| 1529 |
-
def test_custom_codegen():
|
| 1530 |
-
from sympy.printing.c import C99CodePrinter
|
| 1531 |
-
from sympy.functions.elementary.exponential import exp
|
| 1532 |
-
|
| 1533 |
-
printer = C99CodePrinter(settings={'user_functions': {'exp': 'fastexp'}})
|
| 1534 |
-
|
| 1535 |
-
x, y = symbols('x y')
|
| 1536 |
-
expr = exp(x + y)
|
| 1537 |
-
|
| 1538 |
-
# replace math.h with a different header
|
| 1539 |
-
gen = C99CodeGen(printer=printer,
|
| 1540 |
-
preprocessor_statements=['#include "fastexp.h"'])
|
| 1541 |
-
|
| 1542 |
-
expected = (
|
| 1543 |
-
'#include "expr.h"\n'
|
| 1544 |
-
'#include "fastexp.h"\n'
|
| 1545 |
-
'double expr(double x, double y) {\n'
|
| 1546 |
-
' double expr_result;\n'
|
| 1547 |
-
' expr_result = fastexp(x + y);\n'
|
| 1548 |
-
' return expr_result;\n'
|
| 1549 |
-
'}\n'
|
| 1550 |
-
)
|
| 1551 |
-
|
| 1552 |
-
result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
|
| 1553 |
-
source = result[0][1]
|
| 1554 |
-
assert source == expected
|
| 1555 |
-
|
| 1556 |
-
# use both math.h and an external header
|
| 1557 |
-
gen = C99CodeGen(printer=printer)
|
| 1558 |
-
gen.preprocessor_statements.append('#include "fastexp.h"')
|
| 1559 |
-
|
| 1560 |
-
expected = (
|
| 1561 |
-
'#include "expr.h"\n'
|
| 1562 |
-
'#include <math.h>\n'
|
| 1563 |
-
'#include "fastexp.h"\n'
|
| 1564 |
-
'double expr(double x, double y) {\n'
|
| 1565 |
-
' double expr_result;\n'
|
| 1566 |
-
' expr_result = fastexp(x + y);\n'
|
| 1567 |
-
' return expr_result;\n'
|
| 1568 |
-
'}\n'
|
| 1569 |
-
)
|
| 1570 |
-
|
| 1571 |
-
result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
|
| 1572 |
-
source = result[0][1]
|
| 1573 |
-
assert source == expected
|
| 1574 |
-
|
| 1575 |
-
def test_c_with_printer():
|
| 1576 |
-
# issue 13586
|
| 1577 |
-
from sympy.printing.c import C99CodePrinter
|
| 1578 |
-
class CustomPrinter(C99CodePrinter):
|
| 1579 |
-
def _print_Pow(self, expr):
|
| 1580 |
-
return "fastpow({}, {})".format(self._print(expr.base),
|
| 1581 |
-
self._print(expr.exp))
|
| 1582 |
-
|
| 1583 |
-
x = symbols('x')
|
| 1584 |
-
expr = x**3
|
| 1585 |
-
expected =[
|
| 1586 |
-
("file.c",
|
| 1587 |
-
"#include \"file.h\"\n"
|
| 1588 |
-
"#include <math.h>\n"
|
| 1589 |
-
"double test(double x) {\n"
|
| 1590 |
-
" double test_result;\n"
|
| 1591 |
-
" test_result = fastpow(x, 3);\n"
|
| 1592 |
-
" return test_result;\n"
|
| 1593 |
-
"}\n"),
|
| 1594 |
-
("file.h",
|
| 1595 |
-
"#ifndef PROJECT__FILE__H\n"
|
| 1596 |
-
"#define PROJECT__FILE__H\n"
|
| 1597 |
-
"double test(double x);\n"
|
| 1598 |
-
"#endif\n")
|
| 1599 |
-
]
|
| 1600 |
-
result = codegen(("test", expr), "C","file", header=False, empty=False, printer = CustomPrinter())
|
| 1601 |
-
assert result == expected
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
-
def test_fcode_complex():
|
| 1605 |
-
import sympy.utilities.codegen
|
| 1606 |
-
sympy.utilities.codegen.COMPLEX_ALLOWED = True
|
| 1607 |
-
x = Symbol('x', real=True)
|
| 1608 |
-
y = Symbol('y',real=True)
|
| 1609 |
-
result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
|
| 1610 |
-
source = (result[0][1])
|
| 1611 |
-
expected = (
|
| 1612 |
-
"REAL*8 function test(x, y)\n"
|
| 1613 |
-
"implicit none\n"
|
| 1614 |
-
"REAL*8, intent(in) :: x\n"
|
| 1615 |
-
"REAL*8, intent(in) :: y\n"
|
| 1616 |
-
"test = x + y\n"
|
| 1617 |
-
"end function\n")
|
| 1618 |
-
assert source == expected
|
| 1619 |
-
x = Symbol('x')
|
| 1620 |
-
y = Symbol('y',real=True)
|
| 1621 |
-
result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
|
| 1622 |
-
source = (result[0][1])
|
| 1623 |
-
expected = (
|
| 1624 |
-
"COMPLEX*16 function test(x, y)\n"
|
| 1625 |
-
"implicit none\n"
|
| 1626 |
-
"COMPLEX*16, intent(in) :: x\n"
|
| 1627 |
-
"REAL*8, intent(in) :: y\n"
|
| 1628 |
-
"test = x + y\n"
|
| 1629 |
-
"end function\n"
|
| 1630 |
-
)
|
| 1631 |
-
assert source==expected
|
| 1632 |
-
sympy.utilities.codegen.COMPLEX_ALLOWED = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_julia.py
DELETED
|
@@ -1,620 +0,0 @@
|
|
| 1 |
-
from io import StringIO
|
| 2 |
-
|
| 3 |
-
from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
|
| 4 |
-
from sympy.core.relational import Equality
|
| 5 |
-
from sympy.functions.elementary.piecewise import Piecewise
|
| 6 |
-
from sympy.matrices import Matrix, MatrixSymbol
|
| 7 |
-
from sympy.utilities.codegen import JuliaCodeGen, codegen, make_routine
|
| 8 |
-
from sympy.testing.pytest import XFAIL
|
| 9 |
-
import sympy
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
x, y, z = symbols('x,y,z')
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def test_empty_jl_code():
|
| 16 |
-
code_gen = JuliaCodeGen()
|
| 17 |
-
output = StringIO()
|
| 18 |
-
code_gen.dump_jl([], output, "file", header=False, empty=False)
|
| 19 |
-
source = output.getvalue()
|
| 20 |
-
assert source == ""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def test_jl_simple_code():
|
| 24 |
-
name_expr = ("test", (x + y)*z)
|
| 25 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 26 |
-
assert result[0] == "test.jl"
|
| 27 |
-
source = result[1]
|
| 28 |
-
expected = (
|
| 29 |
-
"function test(x, y, z)\n"
|
| 30 |
-
" out1 = z .* (x + y)\n"
|
| 31 |
-
" return out1\n"
|
| 32 |
-
"end\n"
|
| 33 |
-
)
|
| 34 |
-
assert source == expected
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def test_jl_simple_code_with_header():
|
| 38 |
-
name_expr = ("test", (x + y)*z)
|
| 39 |
-
result, = codegen(name_expr, "Julia", header=True, empty=False)
|
| 40 |
-
assert result[0] == "test.jl"
|
| 41 |
-
source = result[1]
|
| 42 |
-
expected = (
|
| 43 |
-
"# Code generated with SymPy " + sympy.__version__ + "\n"
|
| 44 |
-
"#\n"
|
| 45 |
-
"# See http://www.sympy.org/ for more information.\n"
|
| 46 |
-
"#\n"
|
| 47 |
-
"# This file is part of 'project'\n"
|
| 48 |
-
"function test(x, y, z)\n"
|
| 49 |
-
" out1 = z .* (x + y)\n"
|
| 50 |
-
" return out1\n"
|
| 51 |
-
"end\n"
|
| 52 |
-
)
|
| 53 |
-
assert source == expected
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def test_jl_simple_code_nameout():
|
| 57 |
-
expr = Equality(z, (x + y))
|
| 58 |
-
name_expr = ("test", expr)
|
| 59 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 60 |
-
source = result[1]
|
| 61 |
-
expected = (
|
| 62 |
-
"function test(x, y)\n"
|
| 63 |
-
" z = x + y\n"
|
| 64 |
-
" return z\n"
|
| 65 |
-
"end\n"
|
| 66 |
-
)
|
| 67 |
-
assert source == expected
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def test_jl_numbersymbol():
|
| 71 |
-
name_expr = ("test", pi**Catalan)
|
| 72 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 73 |
-
source = result[1]
|
| 74 |
-
expected = (
|
| 75 |
-
"function test()\n"
|
| 76 |
-
" out1 = pi ^ catalan\n"
|
| 77 |
-
" return out1\n"
|
| 78 |
-
"end\n"
|
| 79 |
-
)
|
| 80 |
-
assert source == expected
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
@XFAIL
|
| 84 |
-
def test_jl_numbersymbol_no_inline():
|
| 85 |
-
# FIXME: how to pass inline=False to the JuliaCodePrinter?
|
| 86 |
-
name_expr = ("test", [pi**Catalan, EulerGamma])
|
| 87 |
-
result, = codegen(name_expr, "Julia", header=False,
|
| 88 |
-
empty=False, inline=False)
|
| 89 |
-
source = result[1]
|
| 90 |
-
expected = (
|
| 91 |
-
"function test()\n"
|
| 92 |
-
" Catalan = 0.915965594177219\n"
|
| 93 |
-
" EulerGamma = 0.5772156649015329\n"
|
| 94 |
-
" out1 = pi ^ Catalan\n"
|
| 95 |
-
" out2 = EulerGamma\n"
|
| 96 |
-
" return out1, out2\n"
|
| 97 |
-
"end\n"
|
| 98 |
-
)
|
| 99 |
-
assert source == expected
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def test_jl_code_argument_order():
|
| 103 |
-
expr = x + y
|
| 104 |
-
routine = make_routine("test", expr, argument_sequence=[z, x, y], language="julia")
|
| 105 |
-
code_gen = JuliaCodeGen()
|
| 106 |
-
output = StringIO()
|
| 107 |
-
code_gen.dump_jl([routine], output, "test", header=False, empty=False)
|
| 108 |
-
source = output.getvalue()
|
| 109 |
-
expected = (
|
| 110 |
-
"function test(z, x, y)\n"
|
| 111 |
-
" out1 = x + y\n"
|
| 112 |
-
" return out1\n"
|
| 113 |
-
"end\n"
|
| 114 |
-
)
|
| 115 |
-
assert source == expected
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def test_multiple_results_m():
|
| 119 |
-
# Here the output order is the input order
|
| 120 |
-
expr1 = (x + y)*z
|
| 121 |
-
expr2 = (x - y)*z
|
| 122 |
-
name_expr = ("test", [expr1, expr2])
|
| 123 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 124 |
-
source = result[1]
|
| 125 |
-
expected = (
|
| 126 |
-
"function test(x, y, z)\n"
|
| 127 |
-
" out1 = z .* (x + y)\n"
|
| 128 |
-
" out2 = z .* (x - y)\n"
|
| 129 |
-
" return out1, out2\n"
|
| 130 |
-
"end\n"
|
| 131 |
-
)
|
| 132 |
-
assert source == expected
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def test_results_named_unordered():
|
| 136 |
-
# Here output order is based on name_expr
|
| 137 |
-
A, B, C = symbols('A,B,C')
|
| 138 |
-
expr1 = Equality(C, (x + y)*z)
|
| 139 |
-
expr2 = Equality(A, (x - y)*z)
|
| 140 |
-
expr3 = Equality(B, 2*x)
|
| 141 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 142 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 143 |
-
source = result[1]
|
| 144 |
-
expected = (
|
| 145 |
-
"function test(x, y, z)\n"
|
| 146 |
-
" C = z .* (x + y)\n"
|
| 147 |
-
" A = z .* (x - y)\n"
|
| 148 |
-
" B = 2 * x\n"
|
| 149 |
-
" return C, A, B\n"
|
| 150 |
-
"end\n"
|
| 151 |
-
)
|
| 152 |
-
assert source == expected
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def test_results_named_ordered():
|
| 156 |
-
A, B, C = symbols('A,B,C')
|
| 157 |
-
expr1 = Equality(C, (x + y)*z)
|
| 158 |
-
expr2 = Equality(A, (x - y)*z)
|
| 159 |
-
expr3 = Equality(B, 2*x)
|
| 160 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 161 |
-
result = codegen(name_expr, "Julia", header=False, empty=False,
|
| 162 |
-
argument_sequence=(x, z, y))
|
| 163 |
-
assert result[0][0] == "test.jl"
|
| 164 |
-
source = result[0][1]
|
| 165 |
-
expected = (
|
| 166 |
-
"function test(x, z, y)\n"
|
| 167 |
-
" C = z .* (x + y)\n"
|
| 168 |
-
" A = z .* (x - y)\n"
|
| 169 |
-
" B = 2 * x\n"
|
| 170 |
-
" return C, A, B\n"
|
| 171 |
-
"end\n"
|
| 172 |
-
)
|
| 173 |
-
assert source == expected
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def test_complicated_jl_codegen():
|
| 177 |
-
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
| 178 |
-
name_expr = ("testlong",
|
| 179 |
-
[ ((sin(x) + cos(y) + tan(z))**3).expand(),
|
| 180 |
-
cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
|
| 181 |
-
])
|
| 182 |
-
result = codegen(name_expr, "Julia", header=False, empty=False)
|
| 183 |
-
assert result[0][0] == "testlong.jl"
|
| 184 |
-
source = result[0][1]
|
| 185 |
-
expected = (
|
| 186 |
-
"function testlong(x, y, z)\n"
|
| 187 |
-
" out1 = sin(x) .^ 3 + 3 * sin(x) .^ 2 .* cos(y) + 3 * sin(x) .^ 2 .* tan(z)"
|
| 188 |
-
" + 3 * sin(x) .* cos(y) .^ 2 + 6 * sin(x) .* cos(y) .* tan(z) + 3 * sin(x) .* tan(z) .^ 2"
|
| 189 |
-
" + cos(y) .^ 3 + 3 * cos(y) .^ 2 .* tan(z) + 3 * cos(y) .* tan(z) .^ 2 + tan(z) .^ 3\n"
|
| 190 |
-
" out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n"
|
| 191 |
-
" return out1, out2\n"
|
| 192 |
-
"end\n"
|
| 193 |
-
)
|
| 194 |
-
assert source == expected
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def test_jl_output_arg_mixed_unordered():
|
| 198 |
-
# named outputs are alphabetical, unnamed output appear in the given order
|
| 199 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 200 |
-
a = symbols("a")
|
| 201 |
-
name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
|
| 202 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 203 |
-
assert result[0] == "foo.jl"
|
| 204 |
-
source = result[1]
|
| 205 |
-
expected = (
|
| 206 |
-
'function foo(x)\n'
|
| 207 |
-
' out1 = cos(2 * x)\n'
|
| 208 |
-
' y = sin(x)\n'
|
| 209 |
-
' out3 = cos(x)\n'
|
| 210 |
-
' a = sin(2 * x)\n'
|
| 211 |
-
' return out1, y, out3, a\n'
|
| 212 |
-
'end\n'
|
| 213 |
-
)
|
| 214 |
-
assert source == expected
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def test_jl_piecewise_():
|
| 218 |
-
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
|
| 219 |
-
name_expr = ("pwtest", pw)
|
| 220 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 221 |
-
source = result[1]
|
| 222 |
-
expected = (
|
| 223 |
-
"function pwtest(x)\n"
|
| 224 |
-
" out1 = ((x < -1) ? (0) :\n"
|
| 225 |
-
" (x <= 1) ? (x .^ 2) :\n"
|
| 226 |
-
" (x > 1) ? (2 - x) : (1))\n"
|
| 227 |
-
" return out1\n"
|
| 228 |
-
"end\n"
|
| 229 |
-
)
|
| 230 |
-
assert source == expected
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
@XFAIL
|
| 234 |
-
def test_jl_piecewise_no_inline():
|
| 235 |
-
# FIXME: how to pass inline=False to the JuliaCodePrinter?
|
| 236 |
-
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
|
| 237 |
-
name_expr = ("pwtest", pw)
|
| 238 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False,
|
| 239 |
-
inline=False)
|
| 240 |
-
source = result[1]
|
| 241 |
-
expected = (
|
| 242 |
-
"function pwtest(x)\n"
|
| 243 |
-
" if (x < -1)\n"
|
| 244 |
-
" out1 = 0\n"
|
| 245 |
-
" elseif (x <= 1)\n"
|
| 246 |
-
" out1 = x .^ 2\n"
|
| 247 |
-
" elseif (x > 1)\n"
|
| 248 |
-
" out1 = -x + 2\n"
|
| 249 |
-
" else\n"
|
| 250 |
-
" out1 = 1\n"
|
| 251 |
-
" end\n"
|
| 252 |
-
" return out1\n"
|
| 253 |
-
"end\n"
|
| 254 |
-
)
|
| 255 |
-
assert source == expected
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
def test_jl_multifcns_per_file():
|
| 259 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 260 |
-
result = codegen(name_expr, "Julia", header=False, empty=False)
|
| 261 |
-
assert result[0][0] == "foo.jl"
|
| 262 |
-
source = result[0][1]
|
| 263 |
-
expected = (
|
| 264 |
-
"function foo(x, y)\n"
|
| 265 |
-
" out1 = 2 * x\n"
|
| 266 |
-
" out2 = 3 * y\n"
|
| 267 |
-
" return out1, out2\n"
|
| 268 |
-
"end\n"
|
| 269 |
-
"function bar(y)\n"
|
| 270 |
-
" out1 = y .^ 2\n"
|
| 271 |
-
" out2 = 4 * y\n"
|
| 272 |
-
" return out1, out2\n"
|
| 273 |
-
"end\n"
|
| 274 |
-
)
|
| 275 |
-
assert source == expected
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
def test_jl_multifcns_per_file_w_header():
|
| 279 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 280 |
-
result = codegen(name_expr, "Julia", header=True, empty=False)
|
| 281 |
-
assert result[0][0] == "foo.jl"
|
| 282 |
-
source = result[0][1]
|
| 283 |
-
expected = (
|
| 284 |
-
"# Code generated with SymPy " + sympy.__version__ + "\n"
|
| 285 |
-
"#\n"
|
| 286 |
-
"# See http://www.sympy.org/ for more information.\n"
|
| 287 |
-
"#\n"
|
| 288 |
-
"# This file is part of 'project'\n"
|
| 289 |
-
"function foo(x, y)\n"
|
| 290 |
-
" out1 = 2 * x\n"
|
| 291 |
-
" out2 = 3 * y\n"
|
| 292 |
-
" return out1, out2\n"
|
| 293 |
-
"end\n"
|
| 294 |
-
"function bar(y)\n"
|
| 295 |
-
" out1 = y .^ 2\n"
|
| 296 |
-
" out2 = 4 * y\n"
|
| 297 |
-
" return out1, out2\n"
|
| 298 |
-
"end\n"
|
| 299 |
-
)
|
| 300 |
-
assert source == expected
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
def test_jl_filename_match_prefix():
|
| 304 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 305 |
-
result, = codegen(name_expr, "Julia", prefix="baz", header=False,
|
| 306 |
-
empty=False)
|
| 307 |
-
assert result[0] == "baz.jl"
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
def test_jl_matrix_named():
|
| 311 |
-
e2 = Matrix([[x, 2*y, pi*z]])
|
| 312 |
-
name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
|
| 313 |
-
result = codegen(name_expr, "Julia", header=False, empty=False)
|
| 314 |
-
assert result[0][0] == "test.jl"
|
| 315 |
-
source = result[0][1]
|
| 316 |
-
expected = (
|
| 317 |
-
"function test(x, y, z)\n"
|
| 318 |
-
" myout1 = [x 2 * y pi * z]\n"
|
| 319 |
-
" return myout1\n"
|
| 320 |
-
"end\n"
|
| 321 |
-
)
|
| 322 |
-
assert source == expected
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def test_jl_matrix_named_matsym():
|
| 326 |
-
myout1 = MatrixSymbol('myout1', 1, 3)
|
| 327 |
-
e2 = Matrix([[x, 2*y, pi*z]])
|
| 328 |
-
name_expr = ("test", Equality(myout1, e2, evaluate=False))
|
| 329 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 330 |
-
source = result[1]
|
| 331 |
-
expected = (
|
| 332 |
-
"function test(x, y, z)\n"
|
| 333 |
-
" myout1 = [x 2 * y pi * z]\n"
|
| 334 |
-
" return myout1\n"
|
| 335 |
-
"end\n"
|
| 336 |
-
)
|
| 337 |
-
assert source == expected
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
def test_jl_matrix_output_autoname():
|
| 341 |
-
expr = Matrix([[x, x+y, 3]])
|
| 342 |
-
name_expr = ("test", expr)
|
| 343 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 344 |
-
source = result[1]
|
| 345 |
-
expected = (
|
| 346 |
-
"function test(x, y)\n"
|
| 347 |
-
" out1 = [x x + y 3]\n"
|
| 348 |
-
" return out1\n"
|
| 349 |
-
"end\n"
|
| 350 |
-
)
|
| 351 |
-
assert source == expected
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def test_jl_matrix_output_autoname_2():
|
| 355 |
-
e1 = (x + y)
|
| 356 |
-
e2 = Matrix([[2*x, 2*y, 2*z]])
|
| 357 |
-
e3 = Matrix([[x], [y], [z]])
|
| 358 |
-
e4 = Matrix([[x, y], [z, 16]])
|
| 359 |
-
name_expr = ("test", (e1, e2, e3, e4))
|
| 360 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 361 |
-
source = result[1]
|
| 362 |
-
expected = (
|
| 363 |
-
"function test(x, y, z)\n"
|
| 364 |
-
" out1 = x + y\n"
|
| 365 |
-
" out2 = [2 * x 2 * y 2 * z]\n"
|
| 366 |
-
" out3 = [x, y, z]\n"
|
| 367 |
-
" out4 = [x y;\n"
|
| 368 |
-
" z 16]\n"
|
| 369 |
-
" return out1, out2, out3, out4\n"
|
| 370 |
-
"end\n"
|
| 371 |
-
)
|
| 372 |
-
assert source == expected
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
def test_jl_results_matrix_named_ordered():
|
| 376 |
-
B, C = symbols('B,C')
|
| 377 |
-
A = MatrixSymbol('A', 1, 3)
|
| 378 |
-
expr1 = Equality(C, (x + y)*z)
|
| 379 |
-
expr2 = Equality(A, Matrix([[1, 2, x]]))
|
| 380 |
-
expr3 = Equality(B, 2*x)
|
| 381 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 382 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False,
|
| 383 |
-
argument_sequence=(x, z, y))
|
| 384 |
-
source = result[1]
|
| 385 |
-
expected = (
|
| 386 |
-
"function test(x, z, y)\n"
|
| 387 |
-
" C = z .* (x + y)\n"
|
| 388 |
-
" A = [1 2 x]\n"
|
| 389 |
-
" B = 2 * x\n"
|
| 390 |
-
" return C, A, B\n"
|
| 391 |
-
"end\n"
|
| 392 |
-
)
|
| 393 |
-
assert source == expected
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
def test_jl_matrixsymbol_slice():
|
| 397 |
-
A = MatrixSymbol('A', 2, 3)
|
| 398 |
-
B = MatrixSymbol('B', 1, 3)
|
| 399 |
-
C = MatrixSymbol('C', 1, 3)
|
| 400 |
-
D = MatrixSymbol('D', 2, 1)
|
| 401 |
-
name_expr = ("test", [Equality(B, A[0, :]),
|
| 402 |
-
Equality(C, A[1, :]),
|
| 403 |
-
Equality(D, A[:, 2])])
|
| 404 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 405 |
-
source = result[1]
|
| 406 |
-
expected = (
|
| 407 |
-
"function test(A)\n"
|
| 408 |
-
" B = A[1,:]\n"
|
| 409 |
-
" C = A[2,:]\n"
|
| 410 |
-
" D = A[:,3]\n"
|
| 411 |
-
" return B, C, D\n"
|
| 412 |
-
"end\n"
|
| 413 |
-
)
|
| 414 |
-
assert source == expected
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
def test_jl_matrixsymbol_slice2():
|
| 418 |
-
A = MatrixSymbol('A', 3, 4)
|
| 419 |
-
B = MatrixSymbol('B', 2, 2)
|
| 420 |
-
C = MatrixSymbol('C', 2, 2)
|
| 421 |
-
name_expr = ("test", [Equality(B, A[0:2, 0:2]),
|
| 422 |
-
Equality(C, A[0:2, 1:3])])
|
| 423 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 424 |
-
source = result[1]
|
| 425 |
-
expected = (
|
| 426 |
-
"function test(A)\n"
|
| 427 |
-
" B = A[1:2,1:2]\n"
|
| 428 |
-
" C = A[1:2,2:3]\n"
|
| 429 |
-
" return B, C\n"
|
| 430 |
-
"end\n"
|
| 431 |
-
)
|
| 432 |
-
assert source == expected
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
def test_jl_matrixsymbol_slice3():
|
| 436 |
-
A = MatrixSymbol('A', 8, 7)
|
| 437 |
-
B = MatrixSymbol('B', 2, 2)
|
| 438 |
-
C = MatrixSymbol('C', 4, 2)
|
| 439 |
-
name_expr = ("test", [Equality(B, A[6:, 1::3]),
|
| 440 |
-
Equality(C, A[::2, ::3])])
|
| 441 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 442 |
-
source = result[1]
|
| 443 |
-
expected = (
|
| 444 |
-
"function test(A)\n"
|
| 445 |
-
" B = A[7:end,2:3:end]\n"
|
| 446 |
-
" C = A[1:2:end,1:3:end]\n"
|
| 447 |
-
" return B, C\n"
|
| 448 |
-
"end\n"
|
| 449 |
-
)
|
| 450 |
-
assert source == expected
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
def test_jl_matrixsymbol_slice_autoname():
|
| 454 |
-
A = MatrixSymbol('A', 2, 3)
|
| 455 |
-
B = MatrixSymbol('B', 1, 3)
|
| 456 |
-
name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
|
| 457 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 458 |
-
source = result[1]
|
| 459 |
-
expected = (
|
| 460 |
-
"function test(A)\n"
|
| 461 |
-
" B = A[1,:]\n"
|
| 462 |
-
" out2 = A[2,:]\n"
|
| 463 |
-
" out3 = A[:,1]\n"
|
| 464 |
-
" out4 = A[:,2]\n"
|
| 465 |
-
" return B, out2, out3, out4\n"
|
| 466 |
-
"end\n"
|
| 467 |
-
)
|
| 468 |
-
assert source == expected
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
def test_jl_loops():
|
| 472 |
-
# Note: an Julia programmer would probably vectorize this across one or
|
| 473 |
-
# more dimensions. Also, size(A) would be used rather than passing in m
|
| 474 |
-
# and n. Perhaps users would expect us to vectorize automatically here?
|
| 475 |
-
# Or is it possible to represent such things using IndexedBase?
|
| 476 |
-
from sympy.tensor import IndexedBase, Idx
|
| 477 |
-
from sympy.core.symbol import symbols
|
| 478 |
-
n, m = symbols('n m', integer=True)
|
| 479 |
-
A = IndexedBase('A')
|
| 480 |
-
x = IndexedBase('x')
|
| 481 |
-
y = IndexedBase('y')
|
| 482 |
-
i = Idx('i', m)
|
| 483 |
-
j = Idx('j', n)
|
| 484 |
-
result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Julia",
|
| 485 |
-
header=False, empty=False)
|
| 486 |
-
source = result[1]
|
| 487 |
-
expected = (
|
| 488 |
-
'function mat_vec_mult(y, A, m, n, x)\n'
|
| 489 |
-
' for i = 1:m\n'
|
| 490 |
-
' y[i] = 0\n'
|
| 491 |
-
' end\n'
|
| 492 |
-
' for i = 1:m\n'
|
| 493 |
-
' for j = 1:n\n'
|
| 494 |
-
' y[i] = %(rhs)s + y[i]\n'
|
| 495 |
-
' end\n'
|
| 496 |
-
' end\n'
|
| 497 |
-
' return y\n'
|
| 498 |
-
'end\n'
|
| 499 |
-
)
|
| 500 |
-
assert (source == expected % {'rhs': 'A[%s,%s] .* x[j]' % (i, j)} or
|
| 501 |
-
source == expected % {'rhs': 'x[j] .* A[%s,%s]' % (i, j)})
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
def test_jl_tensor_loops_multiple_contractions():
|
| 505 |
-
# see comments in previous test about vectorizing
|
| 506 |
-
from sympy.tensor import IndexedBase, Idx
|
| 507 |
-
from sympy.core.symbol import symbols
|
| 508 |
-
n, m, o, p = symbols('n m o p', integer=True)
|
| 509 |
-
A = IndexedBase('A')
|
| 510 |
-
B = IndexedBase('B')
|
| 511 |
-
y = IndexedBase('y')
|
| 512 |
-
i = Idx('i', m)
|
| 513 |
-
j = Idx('j', n)
|
| 514 |
-
k = Idx('k', o)
|
| 515 |
-
l = Idx('l', p)
|
| 516 |
-
result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
|
| 517 |
-
"Julia", header=False, empty=False)
|
| 518 |
-
source = result[1]
|
| 519 |
-
expected = (
|
| 520 |
-
'function tensorthing(y, A, B, m, n, o, p)\n'
|
| 521 |
-
' for i = 1:m\n'
|
| 522 |
-
' y[i] = 0\n'
|
| 523 |
-
' end\n'
|
| 524 |
-
' for i = 1:m\n'
|
| 525 |
-
' for j = 1:n\n'
|
| 526 |
-
' for k = 1:o\n'
|
| 527 |
-
' for l = 1:p\n'
|
| 528 |
-
' y[i] = A[i,j,k,l] .* B[j,k,l] + y[i]\n'
|
| 529 |
-
' end\n'
|
| 530 |
-
' end\n'
|
| 531 |
-
' end\n'
|
| 532 |
-
' end\n'
|
| 533 |
-
' return y\n'
|
| 534 |
-
'end\n'
|
| 535 |
-
)
|
| 536 |
-
assert source == expected
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
def test_jl_InOutArgument():
|
| 540 |
-
expr = Equality(x, x**2)
|
| 541 |
-
name_expr = ("mysqr", expr)
|
| 542 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 543 |
-
source = result[1]
|
| 544 |
-
expected = (
|
| 545 |
-
"function mysqr(x)\n"
|
| 546 |
-
" x = x .^ 2\n"
|
| 547 |
-
" return x\n"
|
| 548 |
-
"end\n"
|
| 549 |
-
)
|
| 550 |
-
assert source == expected
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
def test_jl_InOutArgument_order():
|
| 554 |
-
# can specify the order as (x, y)
|
| 555 |
-
expr = Equality(x, x**2 + y)
|
| 556 |
-
name_expr = ("test", expr)
|
| 557 |
-
result, = codegen(name_expr, "Julia", header=False,
|
| 558 |
-
empty=False, argument_sequence=(x,y))
|
| 559 |
-
source = result[1]
|
| 560 |
-
expected = (
|
| 561 |
-
"function test(x, y)\n"
|
| 562 |
-
" x = x .^ 2 + y\n"
|
| 563 |
-
" return x\n"
|
| 564 |
-
"end\n"
|
| 565 |
-
)
|
| 566 |
-
assert source == expected
|
| 567 |
-
# make sure it gives (x, y) not (y, x)
|
| 568 |
-
expr = Equality(x, x**2 + y)
|
| 569 |
-
name_expr = ("test", expr)
|
| 570 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 571 |
-
source = result[1]
|
| 572 |
-
expected = (
|
| 573 |
-
"function test(x, y)\n"
|
| 574 |
-
" x = x .^ 2 + y\n"
|
| 575 |
-
" return x\n"
|
| 576 |
-
"end\n"
|
| 577 |
-
)
|
| 578 |
-
assert source == expected
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
def test_jl_not_supported():
|
| 582 |
-
f = Function('f')
|
| 583 |
-
name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
|
| 584 |
-
result, = codegen(name_expr, "Julia", header=False, empty=False)
|
| 585 |
-
source = result[1]
|
| 586 |
-
expected = (
|
| 587 |
-
"function test(x)\n"
|
| 588 |
-
" # unsupported: Derivative(f(x), x)\n"
|
| 589 |
-
" # unsupported: zoo\n"
|
| 590 |
-
" out1 = Derivative(f(x), x)\n"
|
| 591 |
-
" out2 = zoo\n"
|
| 592 |
-
" return out1, out2\n"
|
| 593 |
-
"end\n"
|
| 594 |
-
)
|
| 595 |
-
assert source == expected
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
def test_global_vars_octave():
|
| 599 |
-
x, y, z, t = symbols("x y z t")
|
| 600 |
-
result = codegen(('f', x*y), "Julia", header=False, empty=False,
|
| 601 |
-
global_vars=(y,))
|
| 602 |
-
source = result[0][1]
|
| 603 |
-
expected = (
|
| 604 |
-
"function f(x)\n"
|
| 605 |
-
" out1 = x .* y\n"
|
| 606 |
-
" return out1\n"
|
| 607 |
-
"end\n"
|
| 608 |
-
)
|
| 609 |
-
assert source == expected
|
| 610 |
-
|
| 611 |
-
result = codegen(('f', x*y+z), "Julia", header=False, empty=False,
|
| 612 |
-
argument_sequence=(x, y), global_vars=(z, t))
|
| 613 |
-
source = result[0][1]
|
| 614 |
-
expected = (
|
| 615 |
-
"function f(x, y)\n"
|
| 616 |
-
" out1 = x .* y + z\n"
|
| 617 |
-
" return out1\n"
|
| 618 |
-
"end\n"
|
| 619 |
-
)
|
| 620 |
-
assert source == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_octave.py
DELETED
|
@@ -1,589 +0,0 @@
|
|
| 1 |
-
from io import StringIO
|
| 2 |
-
|
| 3 |
-
from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
|
| 4 |
-
from sympy.core.relational import Equality
|
| 5 |
-
from sympy.functions.elementary.piecewise import Piecewise
|
| 6 |
-
from sympy.matrices import Matrix, MatrixSymbol
|
| 7 |
-
from sympy.utilities.codegen import OctaveCodeGen, codegen, make_routine
|
| 8 |
-
from sympy.testing.pytest import raises
|
| 9 |
-
from sympy.testing.pytest import XFAIL
|
| 10 |
-
import sympy
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
x, y, z = symbols('x,y,z')
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def test_empty_m_code():
|
| 17 |
-
code_gen = OctaveCodeGen()
|
| 18 |
-
output = StringIO()
|
| 19 |
-
code_gen.dump_m([], output, "file", header=False, empty=False)
|
| 20 |
-
source = output.getvalue()
|
| 21 |
-
assert source == ""
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def test_m_simple_code():
|
| 25 |
-
name_expr = ("test", (x + y)*z)
|
| 26 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 27 |
-
assert result[0] == "test.m"
|
| 28 |
-
source = result[1]
|
| 29 |
-
expected = (
|
| 30 |
-
"function out1 = test(x, y, z)\n"
|
| 31 |
-
" out1 = z.*(x + y);\n"
|
| 32 |
-
"end\n"
|
| 33 |
-
)
|
| 34 |
-
assert source == expected
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def test_m_simple_code_with_header():
|
| 38 |
-
name_expr = ("test", (x + y)*z)
|
| 39 |
-
result, = codegen(name_expr, "Octave", header=True, empty=False)
|
| 40 |
-
assert result[0] == "test.m"
|
| 41 |
-
source = result[1]
|
| 42 |
-
expected = (
|
| 43 |
-
"function out1 = test(x, y, z)\n"
|
| 44 |
-
" %TEST Autogenerated by SymPy\n"
|
| 45 |
-
" % Code generated with SymPy " + sympy.__version__ + "\n"
|
| 46 |
-
" %\n"
|
| 47 |
-
" % See http://www.sympy.org/ for more information.\n"
|
| 48 |
-
" %\n"
|
| 49 |
-
" % This file is part of 'project'\n"
|
| 50 |
-
" out1 = z.*(x + y);\n"
|
| 51 |
-
"end\n"
|
| 52 |
-
)
|
| 53 |
-
assert source == expected
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def test_m_simple_code_nameout():
|
| 57 |
-
expr = Equality(z, (x + y))
|
| 58 |
-
name_expr = ("test", expr)
|
| 59 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 60 |
-
source = result[1]
|
| 61 |
-
expected = (
|
| 62 |
-
"function z = test(x, y)\n"
|
| 63 |
-
" z = x + y;\n"
|
| 64 |
-
"end\n"
|
| 65 |
-
)
|
| 66 |
-
assert source == expected
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def test_m_numbersymbol():
|
| 70 |
-
name_expr = ("test", pi**Catalan)
|
| 71 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 72 |
-
source = result[1]
|
| 73 |
-
expected = (
|
| 74 |
-
"function out1 = test()\n"
|
| 75 |
-
" out1 = pi^%s;\n"
|
| 76 |
-
"end\n"
|
| 77 |
-
) % Catalan.evalf(17)
|
| 78 |
-
assert source == expected
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
@XFAIL
|
| 82 |
-
def test_m_numbersymbol_no_inline():
|
| 83 |
-
# FIXME: how to pass inline=False to the OctaveCodePrinter?
|
| 84 |
-
name_expr = ("test", [pi**Catalan, EulerGamma])
|
| 85 |
-
result, = codegen(name_expr, "Octave", header=False,
|
| 86 |
-
empty=False, inline=False)
|
| 87 |
-
source = result[1]
|
| 88 |
-
expected = (
|
| 89 |
-
"function [out1, out2] = test()\n"
|
| 90 |
-
" Catalan = 0.915965594177219; % constant\n"
|
| 91 |
-
" EulerGamma = 0.5772156649015329; % constant\n"
|
| 92 |
-
" out1 = pi^Catalan;\n"
|
| 93 |
-
" out2 = EulerGamma;\n"
|
| 94 |
-
"end\n"
|
| 95 |
-
)
|
| 96 |
-
assert source == expected
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def test_m_code_argument_order():
|
| 100 |
-
expr = x + y
|
| 101 |
-
routine = make_routine("test", expr, argument_sequence=[z, x, y], language="octave")
|
| 102 |
-
code_gen = OctaveCodeGen()
|
| 103 |
-
output = StringIO()
|
| 104 |
-
code_gen.dump_m([routine], output, "test", header=False, empty=False)
|
| 105 |
-
source = output.getvalue()
|
| 106 |
-
expected = (
|
| 107 |
-
"function out1 = test(z, x, y)\n"
|
| 108 |
-
" out1 = x + y;\n"
|
| 109 |
-
"end\n"
|
| 110 |
-
)
|
| 111 |
-
assert source == expected
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def test_multiple_results_m():
|
| 115 |
-
# Here the output order is the input order
|
| 116 |
-
expr1 = (x + y)*z
|
| 117 |
-
expr2 = (x - y)*z
|
| 118 |
-
name_expr = ("test", [expr1, expr2])
|
| 119 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 120 |
-
source = result[1]
|
| 121 |
-
expected = (
|
| 122 |
-
"function [out1, out2] = test(x, y, z)\n"
|
| 123 |
-
" out1 = z.*(x + y);\n"
|
| 124 |
-
" out2 = z.*(x - y);\n"
|
| 125 |
-
"end\n"
|
| 126 |
-
)
|
| 127 |
-
assert source == expected
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def test_results_named_unordered():
|
| 131 |
-
# Here output order is based on name_expr
|
| 132 |
-
A, B, C = symbols('A,B,C')
|
| 133 |
-
expr1 = Equality(C, (x + y)*z)
|
| 134 |
-
expr2 = Equality(A, (x - y)*z)
|
| 135 |
-
expr3 = Equality(B, 2*x)
|
| 136 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 137 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 138 |
-
source = result[1]
|
| 139 |
-
expected = (
|
| 140 |
-
"function [C, A, B] = test(x, y, z)\n"
|
| 141 |
-
" C = z.*(x + y);\n"
|
| 142 |
-
" A = z.*(x - y);\n"
|
| 143 |
-
" B = 2*x;\n"
|
| 144 |
-
"end\n"
|
| 145 |
-
)
|
| 146 |
-
assert source == expected
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def test_results_named_ordered():
|
| 150 |
-
A, B, C = symbols('A,B,C')
|
| 151 |
-
expr1 = Equality(C, (x + y)*z)
|
| 152 |
-
expr2 = Equality(A, (x - y)*z)
|
| 153 |
-
expr3 = Equality(B, 2*x)
|
| 154 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 155 |
-
result = codegen(name_expr, "Octave", header=False, empty=False,
|
| 156 |
-
argument_sequence=(x, z, y))
|
| 157 |
-
assert result[0][0] == "test.m"
|
| 158 |
-
source = result[0][1]
|
| 159 |
-
expected = (
|
| 160 |
-
"function [C, A, B] = test(x, z, y)\n"
|
| 161 |
-
" C = z.*(x + y);\n"
|
| 162 |
-
" A = z.*(x - y);\n"
|
| 163 |
-
" B = 2*x;\n"
|
| 164 |
-
"end\n"
|
| 165 |
-
)
|
| 166 |
-
assert source == expected
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def test_complicated_m_codegen():
|
| 170 |
-
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
| 171 |
-
name_expr = ("testlong",
|
| 172 |
-
[ ((sin(x) + cos(y) + tan(z))**3).expand(),
|
| 173 |
-
cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
|
| 174 |
-
])
|
| 175 |
-
result = codegen(name_expr, "Octave", header=False, empty=False)
|
| 176 |
-
assert result[0][0] == "testlong.m"
|
| 177 |
-
source = result[0][1]
|
| 178 |
-
expected = (
|
| 179 |
-
"function [out1, out2] = testlong(x, y, z)\n"
|
| 180 |
-
" out1 = sin(x).^3 + 3*sin(x).^2.*cos(y) + 3*sin(x).^2.*tan(z)"
|
| 181 |
-
" + 3*sin(x).*cos(y).^2 + 6*sin(x).*cos(y).*tan(z) + 3*sin(x).*tan(z).^2"
|
| 182 |
-
" + cos(y).^3 + 3*cos(y).^2.*tan(z) + 3*cos(y).*tan(z).^2 + tan(z).^3;\n"
|
| 183 |
-
" out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n"
|
| 184 |
-
"end\n"
|
| 185 |
-
)
|
| 186 |
-
assert source == expected
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def test_m_output_arg_mixed_unordered():
|
| 190 |
-
# named outputs are alphabetical, unnamed output appear in the given order
|
| 191 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 192 |
-
a = symbols("a")
|
| 193 |
-
name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
|
| 194 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 195 |
-
assert result[0] == "foo.m"
|
| 196 |
-
source = result[1]
|
| 197 |
-
expected = (
|
| 198 |
-
'function [out1, y, out3, a] = foo(x)\n'
|
| 199 |
-
' out1 = cos(2*x);\n'
|
| 200 |
-
' y = sin(x);\n'
|
| 201 |
-
' out3 = cos(x);\n'
|
| 202 |
-
' a = sin(2*x);\n'
|
| 203 |
-
'end\n'
|
| 204 |
-
)
|
| 205 |
-
assert source == expected
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def test_m_piecewise_():
|
| 209 |
-
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
|
| 210 |
-
name_expr = ("pwtest", pw)
|
| 211 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 212 |
-
source = result[1]
|
| 213 |
-
expected = (
|
| 214 |
-
"function out1 = pwtest(x)\n"
|
| 215 |
-
" out1 = ((x < -1).*(0) + (~(x < -1)).*( ...\n"
|
| 216 |
-
" (x <= 1).*(x.^2) + (~(x <= 1)).*( ...\n"
|
| 217 |
-
" (x > 1).*(2 - x) + (~(x > 1)).*(1))));\n"
|
| 218 |
-
"end\n"
|
| 219 |
-
)
|
| 220 |
-
assert source == expected
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
@XFAIL
|
| 224 |
-
def test_m_piecewise_no_inline():
|
| 225 |
-
# FIXME: how to pass inline=False to the OctaveCodePrinter?
|
| 226 |
-
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
|
| 227 |
-
name_expr = ("pwtest", pw)
|
| 228 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False,
|
| 229 |
-
inline=False)
|
| 230 |
-
source = result[1]
|
| 231 |
-
expected = (
|
| 232 |
-
"function out1 = pwtest(x)\n"
|
| 233 |
-
" if (x < -1)\n"
|
| 234 |
-
" out1 = 0;\n"
|
| 235 |
-
" elseif (x <= 1)\n"
|
| 236 |
-
" out1 = x.^2;\n"
|
| 237 |
-
" elseif (x > 1)\n"
|
| 238 |
-
" out1 = -x + 2;\n"
|
| 239 |
-
" else\n"
|
| 240 |
-
" out1 = 1;\n"
|
| 241 |
-
" end\n"
|
| 242 |
-
"end\n"
|
| 243 |
-
)
|
| 244 |
-
assert source == expected
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def test_m_multifcns_per_file():
|
| 248 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 249 |
-
result = codegen(name_expr, "Octave", header=False, empty=False)
|
| 250 |
-
assert result[0][0] == "foo.m"
|
| 251 |
-
source = result[0][1]
|
| 252 |
-
expected = (
|
| 253 |
-
"function [out1, out2] = foo(x, y)\n"
|
| 254 |
-
" out1 = 2*x;\n"
|
| 255 |
-
" out2 = 3*y;\n"
|
| 256 |
-
"end\n"
|
| 257 |
-
"function [out1, out2] = bar(y)\n"
|
| 258 |
-
" out1 = y.^2;\n"
|
| 259 |
-
" out2 = 4*y;\n"
|
| 260 |
-
"end\n"
|
| 261 |
-
)
|
| 262 |
-
assert source == expected
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
def test_m_multifcns_per_file_w_header():
|
| 266 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 267 |
-
result = codegen(name_expr, "Octave", header=True, empty=False)
|
| 268 |
-
assert result[0][0] == "foo.m"
|
| 269 |
-
source = result[0][1]
|
| 270 |
-
expected = (
|
| 271 |
-
"function [out1, out2] = foo(x, y)\n"
|
| 272 |
-
" %FOO Autogenerated by SymPy\n"
|
| 273 |
-
" % Code generated with SymPy " + sympy.__version__ + "\n"
|
| 274 |
-
" %\n"
|
| 275 |
-
" % See http://www.sympy.org/ for more information.\n"
|
| 276 |
-
" %\n"
|
| 277 |
-
" % This file is part of 'project'\n"
|
| 278 |
-
" out1 = 2*x;\n"
|
| 279 |
-
" out2 = 3*y;\n"
|
| 280 |
-
"end\n"
|
| 281 |
-
"function [out1, out2] = bar(y)\n"
|
| 282 |
-
" out1 = y.^2;\n"
|
| 283 |
-
" out2 = 4*y;\n"
|
| 284 |
-
"end\n"
|
| 285 |
-
)
|
| 286 |
-
assert source == expected
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
def test_m_filename_match_first_fcn():
|
| 290 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 291 |
-
raises(ValueError, lambda: codegen(name_expr,
|
| 292 |
-
"Octave", prefix="bar", header=False, empty=False))
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def test_m_matrix_named():
|
| 296 |
-
e2 = Matrix([[x, 2*y, pi*z]])
|
| 297 |
-
name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
|
| 298 |
-
result = codegen(name_expr, "Octave", header=False, empty=False)
|
| 299 |
-
assert result[0][0] == "test.m"
|
| 300 |
-
source = result[0][1]
|
| 301 |
-
expected = (
|
| 302 |
-
"function myout1 = test(x, y, z)\n"
|
| 303 |
-
" myout1 = [x 2*y pi*z];\n"
|
| 304 |
-
"end\n"
|
| 305 |
-
)
|
| 306 |
-
assert source == expected
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def test_m_matrix_named_matsym():
|
| 310 |
-
myout1 = MatrixSymbol('myout1', 1, 3)
|
| 311 |
-
e2 = Matrix([[x, 2*y, pi*z]])
|
| 312 |
-
name_expr = ("test", Equality(myout1, e2, evaluate=False))
|
| 313 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 314 |
-
source = result[1]
|
| 315 |
-
expected = (
|
| 316 |
-
"function myout1 = test(x, y, z)\n"
|
| 317 |
-
" myout1 = [x 2*y pi*z];\n"
|
| 318 |
-
"end\n"
|
| 319 |
-
)
|
| 320 |
-
assert source == expected
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
def test_m_matrix_output_autoname():
|
| 324 |
-
expr = Matrix([[x, x+y, 3]])
|
| 325 |
-
name_expr = ("test", expr)
|
| 326 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 327 |
-
source = result[1]
|
| 328 |
-
expected = (
|
| 329 |
-
"function out1 = test(x, y)\n"
|
| 330 |
-
" out1 = [x x + y 3];\n"
|
| 331 |
-
"end\n"
|
| 332 |
-
)
|
| 333 |
-
assert source == expected
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
def test_m_matrix_output_autoname_2():
|
| 337 |
-
e1 = (x + y)
|
| 338 |
-
e2 = Matrix([[2*x, 2*y, 2*z]])
|
| 339 |
-
e3 = Matrix([[x], [y], [z]])
|
| 340 |
-
e4 = Matrix([[x, y], [z, 16]])
|
| 341 |
-
name_expr = ("test", (e1, e2, e3, e4))
|
| 342 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 343 |
-
source = result[1]
|
| 344 |
-
expected = (
|
| 345 |
-
"function [out1, out2, out3, out4] = test(x, y, z)\n"
|
| 346 |
-
" out1 = x + y;\n"
|
| 347 |
-
" out2 = [2*x 2*y 2*z];\n"
|
| 348 |
-
" out3 = [x; y; z];\n"
|
| 349 |
-
" out4 = [x y; z 16];\n"
|
| 350 |
-
"end\n"
|
| 351 |
-
)
|
| 352 |
-
assert source == expected
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
def test_m_results_matrix_named_ordered():
|
| 356 |
-
B, C = symbols('B,C')
|
| 357 |
-
A = MatrixSymbol('A', 1, 3)
|
| 358 |
-
expr1 = Equality(C, (x + y)*z)
|
| 359 |
-
expr2 = Equality(A, Matrix([[1, 2, x]]))
|
| 360 |
-
expr3 = Equality(B, 2*x)
|
| 361 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 362 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False,
|
| 363 |
-
argument_sequence=(x, z, y))
|
| 364 |
-
source = result[1]
|
| 365 |
-
expected = (
|
| 366 |
-
"function [C, A, B] = test(x, z, y)\n"
|
| 367 |
-
" C = z.*(x + y);\n"
|
| 368 |
-
" A = [1 2 x];\n"
|
| 369 |
-
" B = 2*x;\n"
|
| 370 |
-
"end\n"
|
| 371 |
-
)
|
| 372 |
-
assert source == expected
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
def test_m_matrixsymbol_slice():
|
| 376 |
-
A = MatrixSymbol('A', 2, 3)
|
| 377 |
-
B = MatrixSymbol('B', 1, 3)
|
| 378 |
-
C = MatrixSymbol('C', 1, 3)
|
| 379 |
-
D = MatrixSymbol('D', 2, 1)
|
| 380 |
-
name_expr = ("test", [Equality(B, A[0, :]),
|
| 381 |
-
Equality(C, A[1, :]),
|
| 382 |
-
Equality(D, A[:, 2])])
|
| 383 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 384 |
-
source = result[1]
|
| 385 |
-
expected = (
|
| 386 |
-
"function [B, C, D] = test(A)\n"
|
| 387 |
-
" B = A(1, :);\n"
|
| 388 |
-
" C = A(2, :);\n"
|
| 389 |
-
" D = A(:, 3);\n"
|
| 390 |
-
"end\n"
|
| 391 |
-
)
|
| 392 |
-
assert source == expected
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
def test_m_matrixsymbol_slice2():
|
| 396 |
-
A = MatrixSymbol('A', 3, 4)
|
| 397 |
-
B = MatrixSymbol('B', 2, 2)
|
| 398 |
-
C = MatrixSymbol('C', 2, 2)
|
| 399 |
-
name_expr = ("test", [Equality(B, A[0:2, 0:2]),
|
| 400 |
-
Equality(C, A[0:2, 1:3])])
|
| 401 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 402 |
-
source = result[1]
|
| 403 |
-
expected = (
|
| 404 |
-
"function [B, C] = test(A)\n"
|
| 405 |
-
" B = A(1:2, 1:2);\n"
|
| 406 |
-
" C = A(1:2, 2:3);\n"
|
| 407 |
-
"end\n"
|
| 408 |
-
)
|
| 409 |
-
assert source == expected
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
def test_m_matrixsymbol_slice3():
|
| 413 |
-
A = MatrixSymbol('A', 8, 7)
|
| 414 |
-
B = MatrixSymbol('B', 2, 2)
|
| 415 |
-
C = MatrixSymbol('C', 4, 2)
|
| 416 |
-
name_expr = ("test", [Equality(B, A[6:, 1::3]),
|
| 417 |
-
Equality(C, A[::2, ::3])])
|
| 418 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 419 |
-
source = result[1]
|
| 420 |
-
expected = (
|
| 421 |
-
"function [B, C] = test(A)\n"
|
| 422 |
-
" B = A(7:end, 2:3:end);\n"
|
| 423 |
-
" C = A(1:2:end, 1:3:end);\n"
|
| 424 |
-
"end\n"
|
| 425 |
-
)
|
| 426 |
-
assert source == expected
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
def test_m_matrixsymbol_slice_autoname():
|
| 430 |
-
A = MatrixSymbol('A', 2, 3)
|
| 431 |
-
B = MatrixSymbol('B', 1, 3)
|
| 432 |
-
name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
|
| 433 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 434 |
-
source = result[1]
|
| 435 |
-
expected = (
|
| 436 |
-
"function [B, out2, out3, out4] = test(A)\n"
|
| 437 |
-
" B = A(1, :);\n"
|
| 438 |
-
" out2 = A(2, :);\n"
|
| 439 |
-
" out3 = A(:, 1);\n"
|
| 440 |
-
" out4 = A(:, 2);\n"
|
| 441 |
-
"end\n"
|
| 442 |
-
)
|
| 443 |
-
assert source == expected
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
def test_m_loops():
|
| 447 |
-
# Note: an Octave programmer would probably vectorize this across one or
|
| 448 |
-
# more dimensions. Also, size(A) would be used rather than passing in m
|
| 449 |
-
# and n. Perhaps users would expect us to vectorize automatically here?
|
| 450 |
-
# Or is it possible to represent such things using IndexedBase?
|
| 451 |
-
from sympy.tensor import IndexedBase, Idx
|
| 452 |
-
from sympy.core.symbol import symbols
|
| 453 |
-
n, m = symbols('n m', integer=True)
|
| 454 |
-
A = IndexedBase('A')
|
| 455 |
-
x = IndexedBase('x')
|
| 456 |
-
y = IndexedBase('y')
|
| 457 |
-
i = Idx('i', m)
|
| 458 |
-
j = Idx('j', n)
|
| 459 |
-
result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Octave",
|
| 460 |
-
header=False, empty=False)
|
| 461 |
-
source = result[1]
|
| 462 |
-
expected = (
|
| 463 |
-
'function y = mat_vec_mult(A, m, n, x)\n'
|
| 464 |
-
' for i = 1:m\n'
|
| 465 |
-
' y(i) = 0;\n'
|
| 466 |
-
' end\n'
|
| 467 |
-
' for i = 1:m\n'
|
| 468 |
-
' for j = 1:n\n'
|
| 469 |
-
' y(i) = %(rhs)s + y(i);\n'
|
| 470 |
-
' end\n'
|
| 471 |
-
' end\n'
|
| 472 |
-
'end\n'
|
| 473 |
-
)
|
| 474 |
-
assert (source == expected % {'rhs': 'A(%s, %s).*x(j)' % (i, j)} or
|
| 475 |
-
source == expected % {'rhs': 'x(j).*A(%s, %s)' % (i, j)})
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
def test_m_tensor_loops_multiple_contractions():
|
| 479 |
-
# see comments in previous test about vectorizing
|
| 480 |
-
from sympy.tensor import IndexedBase, Idx
|
| 481 |
-
from sympy.core.symbol import symbols
|
| 482 |
-
n, m, o, p = symbols('n m o p', integer=True)
|
| 483 |
-
A = IndexedBase('A')
|
| 484 |
-
B = IndexedBase('B')
|
| 485 |
-
y = IndexedBase('y')
|
| 486 |
-
i = Idx('i', m)
|
| 487 |
-
j = Idx('j', n)
|
| 488 |
-
k = Idx('k', o)
|
| 489 |
-
l = Idx('l', p)
|
| 490 |
-
result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
|
| 491 |
-
"Octave", header=False, empty=False)
|
| 492 |
-
source = result[1]
|
| 493 |
-
expected = (
|
| 494 |
-
'function y = tensorthing(A, B, m, n, o, p)\n'
|
| 495 |
-
' for i = 1:m\n'
|
| 496 |
-
' y(i) = 0;\n'
|
| 497 |
-
' end\n'
|
| 498 |
-
' for i = 1:m\n'
|
| 499 |
-
' for j = 1:n\n'
|
| 500 |
-
' for k = 1:o\n'
|
| 501 |
-
' for l = 1:p\n'
|
| 502 |
-
' y(i) = A(i, j, k, l).*B(j, k, l) + y(i);\n'
|
| 503 |
-
' end\n'
|
| 504 |
-
' end\n'
|
| 505 |
-
' end\n'
|
| 506 |
-
' end\n'
|
| 507 |
-
'end\n'
|
| 508 |
-
)
|
| 509 |
-
assert source == expected
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
def test_m_InOutArgument():
|
| 513 |
-
expr = Equality(x, x**2)
|
| 514 |
-
name_expr = ("mysqr", expr)
|
| 515 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 516 |
-
source = result[1]
|
| 517 |
-
expected = (
|
| 518 |
-
"function x = mysqr(x)\n"
|
| 519 |
-
" x = x.^2;\n"
|
| 520 |
-
"end\n"
|
| 521 |
-
)
|
| 522 |
-
assert source == expected
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
def test_m_InOutArgument_order():
|
| 526 |
-
# can specify the order as (x, y)
|
| 527 |
-
expr = Equality(x, x**2 + y)
|
| 528 |
-
name_expr = ("test", expr)
|
| 529 |
-
result, = codegen(name_expr, "Octave", header=False,
|
| 530 |
-
empty=False, argument_sequence=(x,y))
|
| 531 |
-
source = result[1]
|
| 532 |
-
expected = (
|
| 533 |
-
"function x = test(x, y)\n"
|
| 534 |
-
" x = x.^2 + y;\n"
|
| 535 |
-
"end\n"
|
| 536 |
-
)
|
| 537 |
-
assert source == expected
|
| 538 |
-
# make sure it gives (x, y) not (y, x)
|
| 539 |
-
expr = Equality(x, x**2 + y)
|
| 540 |
-
name_expr = ("test", expr)
|
| 541 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 542 |
-
source = result[1]
|
| 543 |
-
expected = (
|
| 544 |
-
"function x = test(x, y)\n"
|
| 545 |
-
" x = x.^2 + y;\n"
|
| 546 |
-
"end\n"
|
| 547 |
-
)
|
| 548 |
-
assert source == expected
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
def test_m_not_supported():
|
| 552 |
-
f = Function('f')
|
| 553 |
-
name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
|
| 554 |
-
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
| 555 |
-
source = result[1]
|
| 556 |
-
expected = (
|
| 557 |
-
"function [out1, out2] = test(x)\n"
|
| 558 |
-
" % unsupported: Derivative(f(x), x)\n"
|
| 559 |
-
" % unsupported: zoo\n"
|
| 560 |
-
" out1 = Derivative(f(x), x);\n"
|
| 561 |
-
" out2 = zoo;\n"
|
| 562 |
-
"end\n"
|
| 563 |
-
)
|
| 564 |
-
assert source == expected
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
def test_global_vars_octave():
|
| 568 |
-
x, y, z, t = symbols("x y z t")
|
| 569 |
-
result = codegen(('f', x*y), "Octave", header=False, empty=False,
|
| 570 |
-
global_vars=(y,))
|
| 571 |
-
source = result[0][1]
|
| 572 |
-
expected = (
|
| 573 |
-
"function out1 = f(x)\n"
|
| 574 |
-
" global y\n"
|
| 575 |
-
" out1 = x.*y;\n"
|
| 576 |
-
"end\n"
|
| 577 |
-
)
|
| 578 |
-
assert source == expected
|
| 579 |
-
|
| 580 |
-
result = codegen(('f', x*y+z), "Octave", header=False, empty=False,
|
| 581 |
-
argument_sequence=(x, y), global_vars=(z, t))
|
| 582 |
-
source = result[0][1]
|
| 583 |
-
expected = (
|
| 584 |
-
"function out1 = f(x, y)\n"
|
| 585 |
-
" global t z\n"
|
| 586 |
-
" out1 = x.*y + z;\n"
|
| 587 |
-
"end\n"
|
| 588 |
-
)
|
| 589 |
-
assert source == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_rust.py
DELETED
|
@@ -1,401 +0,0 @@
|
|
| 1 |
-
from io import StringIO
|
| 2 |
-
|
| 3 |
-
from sympy.core import S, symbols, pi, Catalan, EulerGamma, Function
|
| 4 |
-
from sympy.core.relational import Equality
|
| 5 |
-
from sympy.functions.elementary.piecewise import Piecewise
|
| 6 |
-
from sympy.utilities.codegen import RustCodeGen, codegen, make_routine
|
| 7 |
-
from sympy.testing.pytest import XFAIL
|
| 8 |
-
import sympy
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
x, y, z = symbols('x,y,z')
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def test_empty_rust_code():
|
| 15 |
-
code_gen = RustCodeGen()
|
| 16 |
-
output = StringIO()
|
| 17 |
-
code_gen.dump_rs([], output, "file", header=False, empty=False)
|
| 18 |
-
source = output.getvalue()
|
| 19 |
-
assert source == ""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def test_simple_rust_code():
|
| 23 |
-
name_expr = ("test", (x + y)*z)
|
| 24 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 25 |
-
assert result[0] == "test.rs"
|
| 26 |
-
source = result[1]
|
| 27 |
-
expected = (
|
| 28 |
-
"fn test(x: f64, y: f64, z: f64) -> f64 {\n"
|
| 29 |
-
" let out1 = z*(x + y);\n"
|
| 30 |
-
" out1\n"
|
| 31 |
-
"}\n"
|
| 32 |
-
)
|
| 33 |
-
assert source == expected
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def test_simple_code_with_header():
|
| 37 |
-
name_expr = ("test", (x + y)*z)
|
| 38 |
-
result, = codegen(name_expr, "Rust", header=True, empty=False)
|
| 39 |
-
assert result[0] == "test.rs"
|
| 40 |
-
source = result[1]
|
| 41 |
-
version_str = "Code generated with SymPy %s" % sympy.__version__
|
| 42 |
-
version_line = version_str.center(76).rstrip()
|
| 43 |
-
expected = (
|
| 44 |
-
"/*\n"
|
| 45 |
-
" *%(version_line)s\n"
|
| 46 |
-
" *\n"
|
| 47 |
-
" * See http://www.sympy.org/ for more information.\n"
|
| 48 |
-
" *\n"
|
| 49 |
-
" * This file is part of 'project'\n"
|
| 50 |
-
" */\n"
|
| 51 |
-
"fn test(x: f64, y: f64, z: f64) -> f64 {\n"
|
| 52 |
-
" let out1 = z*(x + y);\n"
|
| 53 |
-
" out1\n"
|
| 54 |
-
"}\n"
|
| 55 |
-
) % {'version_line': version_line}
|
| 56 |
-
assert source == expected
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def test_simple_code_nameout():
|
| 60 |
-
expr = Equality(z, (x + y))
|
| 61 |
-
name_expr = ("test", expr)
|
| 62 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 63 |
-
source = result[1]
|
| 64 |
-
expected = (
|
| 65 |
-
"fn test(x: f64, y: f64) -> f64 {\n"
|
| 66 |
-
" let z = x + y;\n"
|
| 67 |
-
" z\n"
|
| 68 |
-
"}\n"
|
| 69 |
-
)
|
| 70 |
-
assert source == expected
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def test_numbersymbol():
|
| 74 |
-
name_expr = ("test", pi**Catalan)
|
| 75 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 76 |
-
source = result[1]
|
| 77 |
-
expected = (
|
| 78 |
-
"fn test() -> f64 {\n"
|
| 79 |
-
" const Catalan: f64 = %s;\n"
|
| 80 |
-
" let out1 = PI.powf(Catalan);\n"
|
| 81 |
-
" out1\n"
|
| 82 |
-
"}\n"
|
| 83 |
-
) % Catalan.evalf(17)
|
| 84 |
-
assert source == expected
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
@XFAIL
|
| 88 |
-
def test_numbersymbol_inline():
|
| 89 |
-
# FIXME: how to pass inline to the RustCodePrinter?
|
| 90 |
-
name_expr = ("test", [pi**Catalan, EulerGamma])
|
| 91 |
-
result, = codegen(name_expr, "Rust", header=False,
|
| 92 |
-
empty=False, inline=True)
|
| 93 |
-
source = result[1]
|
| 94 |
-
expected = (
|
| 95 |
-
"fn test() -> (f64, f64) {\n"
|
| 96 |
-
" const Catalan: f64 = %s;\n"
|
| 97 |
-
" const EulerGamma: f64 = %s;\n"
|
| 98 |
-
" let out1 = PI.powf(Catalan);\n"
|
| 99 |
-
" let out2 = EulerGamma);\n"
|
| 100 |
-
" (out1, out2)\n"
|
| 101 |
-
"}\n"
|
| 102 |
-
) % (Catalan.evalf(17), EulerGamma.evalf(17))
|
| 103 |
-
assert source == expected
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def test_argument_order():
|
| 107 |
-
expr = x + y
|
| 108 |
-
routine = make_routine("test", expr, argument_sequence=[z, x, y], language="rust")
|
| 109 |
-
code_gen = RustCodeGen()
|
| 110 |
-
output = StringIO()
|
| 111 |
-
code_gen.dump_rs([routine], output, "test", header=False, empty=False)
|
| 112 |
-
source = output.getvalue()
|
| 113 |
-
expected = (
|
| 114 |
-
"fn test(z: f64, x: f64, y: f64) -> f64 {\n"
|
| 115 |
-
" let out1 = x + y;\n"
|
| 116 |
-
" out1\n"
|
| 117 |
-
"}\n"
|
| 118 |
-
)
|
| 119 |
-
assert source == expected
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def test_multiple_results_rust():
|
| 123 |
-
# Here the output order is the input order
|
| 124 |
-
expr1 = (x + y)*z
|
| 125 |
-
expr2 = (x - y)*z
|
| 126 |
-
name_expr = ("test", [expr1, expr2])
|
| 127 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 128 |
-
source = result[1]
|
| 129 |
-
expected = (
|
| 130 |
-
"fn test(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
|
| 131 |
-
" let out1 = z*(x + y);\n"
|
| 132 |
-
" let out2 = z*(x - y);\n"
|
| 133 |
-
" (out1, out2)\n"
|
| 134 |
-
"}\n"
|
| 135 |
-
)
|
| 136 |
-
assert source == expected
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def test_results_named_unordered():
|
| 140 |
-
# Here output order is based on name_expr
|
| 141 |
-
A, B, C = symbols('A,B,C')
|
| 142 |
-
expr1 = Equality(C, (x + y)*z)
|
| 143 |
-
expr2 = Equality(A, (x - y)*z)
|
| 144 |
-
expr3 = Equality(B, 2*x)
|
| 145 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 146 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 147 |
-
source = result[1]
|
| 148 |
-
expected = (
|
| 149 |
-
"fn test(x: f64, y: f64, z: f64) -> (f64, f64, f64) {\n"
|
| 150 |
-
" let C = z*(x + y);\n"
|
| 151 |
-
" let A = z*(x - y);\n"
|
| 152 |
-
" let B = 2*x;\n"
|
| 153 |
-
" (C, A, B)\n"
|
| 154 |
-
"}\n"
|
| 155 |
-
)
|
| 156 |
-
assert source == expected
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def test_results_named_ordered():
|
| 160 |
-
A, B, C = symbols('A,B,C')
|
| 161 |
-
expr1 = Equality(C, (x + y)*z)
|
| 162 |
-
expr2 = Equality(A, (x - y)*z)
|
| 163 |
-
expr3 = Equality(B, 2*x)
|
| 164 |
-
name_expr = ("test", [expr1, expr2, expr3])
|
| 165 |
-
result = codegen(name_expr, "Rust", header=False, empty=False,
|
| 166 |
-
argument_sequence=(x, z, y))
|
| 167 |
-
assert result[0][0] == "test.rs"
|
| 168 |
-
source = result[0][1]
|
| 169 |
-
expected = (
|
| 170 |
-
"fn test(x: f64, z: f64, y: f64) -> (f64, f64, f64) {\n"
|
| 171 |
-
" let C = z*(x + y);\n"
|
| 172 |
-
" let A = z*(x - y);\n"
|
| 173 |
-
" let B = 2*x;\n"
|
| 174 |
-
" (C, A, B)\n"
|
| 175 |
-
"}\n"
|
| 176 |
-
)
|
| 177 |
-
assert source == expected
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def test_complicated_rs_codegen():
|
| 181 |
-
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
| 182 |
-
name_expr = ("testlong",
|
| 183 |
-
[ ((sin(x) + cos(y) + tan(z))**3).expand(),
|
| 184 |
-
cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
|
| 185 |
-
])
|
| 186 |
-
result = codegen(name_expr, "Rust", header=False, empty=False)
|
| 187 |
-
assert result[0][0] == "testlong.rs"
|
| 188 |
-
source = result[0][1]
|
| 189 |
-
expected = (
|
| 190 |
-
"fn testlong(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
|
| 191 |
-
" let out1 = x.sin().powi(3) + 3*x.sin().powi(2)*y.cos()"
|
| 192 |
-
" + 3*x.sin().powi(2)*z.tan() + 3*x.sin()*y.cos().powi(2)"
|
| 193 |
-
" + 6*x.sin()*y.cos()*z.tan() + 3*x.sin()*z.tan().powi(2)"
|
| 194 |
-
" + y.cos().powi(3) + 3*y.cos().powi(2)*z.tan()"
|
| 195 |
-
" + 3*y.cos()*z.tan().powi(2) + z.tan().powi(3);\n"
|
| 196 |
-
" let out2 = (x + y + z).cos().cos().cos().cos()"
|
| 197 |
-
".cos().cos().cos().cos();\n"
|
| 198 |
-
" (out1, out2)\n"
|
| 199 |
-
"}\n"
|
| 200 |
-
)
|
| 201 |
-
assert source == expected
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def test_output_arg_mixed_unordered():
|
| 205 |
-
# named outputs are alphabetical, unnamed output appear in the given order
|
| 206 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 207 |
-
a = symbols("a")
|
| 208 |
-
name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
|
| 209 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 210 |
-
assert result[0] == "foo.rs"
|
| 211 |
-
source = result[1]
|
| 212 |
-
expected = (
|
| 213 |
-
"fn foo(x: f64) -> (f64, f64, f64, f64) {\n"
|
| 214 |
-
" let out1 = (2*x).cos();\n"
|
| 215 |
-
" let y = x.sin();\n"
|
| 216 |
-
" let out3 = x.cos();\n"
|
| 217 |
-
" let a = (2*x).sin();\n"
|
| 218 |
-
" (out1, y, out3, a)\n"
|
| 219 |
-
"}\n"
|
| 220 |
-
)
|
| 221 |
-
assert source == expected
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
def test_piecewise_():
|
| 225 |
-
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
|
| 226 |
-
name_expr = ("pwtest", pw)
|
| 227 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 228 |
-
source = result[1]
|
| 229 |
-
expected = (
|
| 230 |
-
"fn pwtest(x: f64) -> f64 {\n"
|
| 231 |
-
" let out1 = if (x < -1.0) {\n"
|
| 232 |
-
" 0\n"
|
| 233 |
-
" } else if (x <= 1.0) {\n"
|
| 234 |
-
" x.powi(2)\n"
|
| 235 |
-
" } else if (x > 1.0) {\n"
|
| 236 |
-
" 2 - x\n"
|
| 237 |
-
" } else {\n"
|
| 238 |
-
" 1\n"
|
| 239 |
-
" };\n"
|
| 240 |
-
" out1\n"
|
| 241 |
-
"}\n"
|
| 242 |
-
)
|
| 243 |
-
assert source == expected
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
@XFAIL
|
| 247 |
-
def test_piecewise_inline():
|
| 248 |
-
# FIXME: how to pass inline to the RustCodePrinter?
|
| 249 |
-
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
|
| 250 |
-
name_expr = ("pwtest", pw)
|
| 251 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False,
|
| 252 |
-
inline=True)
|
| 253 |
-
source = result[1]
|
| 254 |
-
expected = (
|
| 255 |
-
"fn pwtest(x: f64) -> f64 {\n"
|
| 256 |
-
" let out1 = if (x < -1) { 0 } else if (x <= 1) { x.powi(2) }"
|
| 257 |
-
" else if (x > 1) { -x + 2 } else { 1 };\n"
|
| 258 |
-
" out1\n"
|
| 259 |
-
"}\n"
|
| 260 |
-
)
|
| 261 |
-
assert source == expected
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
def test_multifcns_per_file():
|
| 265 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 266 |
-
result = codegen(name_expr, "Rust", header=False, empty=False)
|
| 267 |
-
assert result[0][0] == "foo.rs"
|
| 268 |
-
source = result[0][1]
|
| 269 |
-
expected = (
|
| 270 |
-
"fn foo(x: f64, y: f64) -> (f64, f64) {\n"
|
| 271 |
-
" let out1 = 2*x;\n"
|
| 272 |
-
" let out2 = 3*y;\n"
|
| 273 |
-
" (out1, out2)\n"
|
| 274 |
-
"}\n"
|
| 275 |
-
"fn bar(y: f64) -> (f64, f64) {\n"
|
| 276 |
-
" let out1 = y.powi(2);\n"
|
| 277 |
-
" let out2 = 4*y;\n"
|
| 278 |
-
" (out1, out2)\n"
|
| 279 |
-
"}\n"
|
| 280 |
-
)
|
| 281 |
-
assert source == expected
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
def test_multifcns_per_file_w_header():
|
| 285 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 286 |
-
result = codegen(name_expr, "Rust", header=True, empty=False)
|
| 287 |
-
assert result[0][0] == "foo.rs"
|
| 288 |
-
source = result[0][1]
|
| 289 |
-
version_str = "Code generated with SymPy %s" % sympy.__version__
|
| 290 |
-
version_line = version_str.center(76).rstrip()
|
| 291 |
-
expected = (
|
| 292 |
-
"/*\n"
|
| 293 |
-
" *%(version_line)s\n"
|
| 294 |
-
" *\n"
|
| 295 |
-
" * See http://www.sympy.org/ for more information.\n"
|
| 296 |
-
" *\n"
|
| 297 |
-
" * This file is part of 'project'\n"
|
| 298 |
-
" */\n"
|
| 299 |
-
"fn foo(x: f64, y: f64) -> (f64, f64) {\n"
|
| 300 |
-
" let out1 = 2*x;\n"
|
| 301 |
-
" let out2 = 3*y;\n"
|
| 302 |
-
" (out1, out2)\n"
|
| 303 |
-
"}\n"
|
| 304 |
-
"fn bar(y: f64) -> (f64, f64) {\n"
|
| 305 |
-
" let out1 = y.powi(2);\n"
|
| 306 |
-
" let out2 = 4*y;\n"
|
| 307 |
-
" (out1, out2)\n"
|
| 308 |
-
"}\n"
|
| 309 |
-
) % {'version_line': version_line}
|
| 310 |
-
assert source == expected
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
def test_filename_match_prefix():
|
| 314 |
-
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
| 315 |
-
result, = codegen(name_expr, "Rust", prefix="baz", header=False,
|
| 316 |
-
empty=False)
|
| 317 |
-
assert result[0] == "baz.rs"
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
def test_InOutArgument():
|
| 321 |
-
expr = Equality(x, x**2)
|
| 322 |
-
name_expr = ("mysqr", expr)
|
| 323 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 324 |
-
source = result[1]
|
| 325 |
-
expected = (
|
| 326 |
-
"fn mysqr(x: f64) -> f64 {\n"
|
| 327 |
-
" let x = x.powi(2);\n"
|
| 328 |
-
" x\n"
|
| 329 |
-
"}\n"
|
| 330 |
-
)
|
| 331 |
-
assert source == expected
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
def test_InOutArgument_order():
|
| 335 |
-
# can specify the order as (x, y)
|
| 336 |
-
expr = Equality(x, x**2 + y)
|
| 337 |
-
name_expr = ("test", expr)
|
| 338 |
-
result, = codegen(name_expr, "Rust", header=False,
|
| 339 |
-
empty=False, argument_sequence=(x,y))
|
| 340 |
-
source = result[1]
|
| 341 |
-
expected = (
|
| 342 |
-
"fn test(x: f64, y: f64) -> f64 {\n"
|
| 343 |
-
" let x = x.powi(2) + y;\n"
|
| 344 |
-
" x\n"
|
| 345 |
-
"}\n"
|
| 346 |
-
)
|
| 347 |
-
assert source == expected
|
| 348 |
-
# make sure it gives (x, y) not (y, x)
|
| 349 |
-
expr = Equality(x, x**2 + y)
|
| 350 |
-
name_expr = ("test", expr)
|
| 351 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 352 |
-
source = result[1]
|
| 353 |
-
expected = (
|
| 354 |
-
"fn test(x: f64, y: f64) -> f64 {\n"
|
| 355 |
-
" let x = x.powi(2) + y;\n"
|
| 356 |
-
" x\n"
|
| 357 |
-
"}\n"
|
| 358 |
-
)
|
| 359 |
-
assert source == expected
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
def test_not_supported():
|
| 363 |
-
f = Function('f')
|
| 364 |
-
name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
|
| 365 |
-
result, = codegen(name_expr, "Rust", header=False, empty=False)
|
| 366 |
-
source = result[1]
|
| 367 |
-
expected = (
|
| 368 |
-
"fn test(x: f64) -> (f64, f64) {\n"
|
| 369 |
-
" // unsupported: Derivative(f(x), x)\n"
|
| 370 |
-
" // unsupported: zoo\n"
|
| 371 |
-
" let out1 = Derivative(f(x), x);\n"
|
| 372 |
-
" let out2 = zoo;\n"
|
| 373 |
-
" (out1, out2)\n"
|
| 374 |
-
"}\n"
|
| 375 |
-
)
|
| 376 |
-
assert source == expected
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
def test_global_vars_rust():
|
| 380 |
-
x, y, z, t = symbols("x y z t")
|
| 381 |
-
result = codegen(('f', x*y), "Rust", header=False, empty=False,
|
| 382 |
-
global_vars=(y,))
|
| 383 |
-
source = result[0][1]
|
| 384 |
-
expected = (
|
| 385 |
-
"fn f(x: f64) -> f64 {\n"
|
| 386 |
-
" let out1 = x*y;\n"
|
| 387 |
-
" out1\n"
|
| 388 |
-
"}\n"
|
| 389 |
-
)
|
| 390 |
-
assert source == expected
|
| 391 |
-
|
| 392 |
-
result = codegen(('f', x*y+z), "Rust", header=False, empty=False,
|
| 393 |
-
argument_sequence=(x, y), global_vars=(z, t))
|
| 394 |
-
source = result[0][1]
|
| 395 |
-
expected = (
|
| 396 |
-
"fn f(x: f64, y: f64) -> f64 {\n"
|
| 397 |
-
" let out1 = x*y + z;\n"
|
| 398 |
-
" out1\n"
|
| 399 |
-
"}\n"
|
| 400 |
-
)
|
| 401 |
-
assert source == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_decorator.py
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
from functools import wraps
|
| 2 |
-
|
| 3 |
-
from sympy.utilities.decorator import threaded, xthreaded, memoize_property, deprecated
|
| 4 |
-
from sympy.testing.pytest import warns_deprecated_sympy
|
| 5 |
-
|
| 6 |
-
from sympy.core.basic import Basic
|
| 7 |
-
from sympy.core.relational import Eq
|
| 8 |
-
from sympy.matrices.dense import Matrix
|
| 9 |
-
|
| 10 |
-
from sympy.abc import x, y
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def test_threaded():
|
| 14 |
-
@threaded
|
| 15 |
-
def function(expr, *args):
|
| 16 |
-
return 2*expr + sum(args)
|
| 17 |
-
|
| 18 |
-
assert function(Matrix([[x, y], [1, x]]), 1, 2) == \
|
| 19 |
-
Matrix([[2*x + 3, 2*y + 3], [5, 2*x + 3]])
|
| 20 |
-
|
| 21 |
-
assert function(Eq(x, y), 1, 2) == Eq(2*x + 3, 2*y + 3)
|
| 22 |
-
|
| 23 |
-
assert function([x, y], 1, 2) == [2*x + 3, 2*y + 3]
|
| 24 |
-
assert function((x, y), 1, 2) == (2*x + 3, 2*y + 3)
|
| 25 |
-
|
| 26 |
-
assert function({x, y}, 1, 2) == {2*x + 3, 2*y + 3}
|
| 27 |
-
|
| 28 |
-
@threaded
|
| 29 |
-
def function(expr, n):
|
| 30 |
-
return expr**n
|
| 31 |
-
|
| 32 |
-
assert function(x + y, 2) == x**2 + y**2
|
| 33 |
-
assert function(x, 2) == x**2
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def test_xthreaded():
|
| 37 |
-
@xthreaded
|
| 38 |
-
def function(expr, n):
|
| 39 |
-
return expr**n
|
| 40 |
-
|
| 41 |
-
assert function(x + y, 2) == (x + y)**2
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def test_wraps():
|
| 45 |
-
def my_func(x):
|
| 46 |
-
"""My function. """
|
| 47 |
-
|
| 48 |
-
my_func.is_my_func = True
|
| 49 |
-
|
| 50 |
-
new_my_func = threaded(my_func)
|
| 51 |
-
new_my_func = wraps(my_func)(new_my_func)
|
| 52 |
-
|
| 53 |
-
assert new_my_func.__name__ == 'my_func'
|
| 54 |
-
assert new_my_func.__doc__ == 'My function. '
|
| 55 |
-
assert hasattr(new_my_func, 'is_my_func')
|
| 56 |
-
assert new_my_func.is_my_func is True
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def test_memoize_property():
|
| 60 |
-
class TestMemoize(Basic):
|
| 61 |
-
@memoize_property
|
| 62 |
-
def prop(self):
|
| 63 |
-
return Basic()
|
| 64 |
-
|
| 65 |
-
member = TestMemoize()
|
| 66 |
-
obj1 = member.prop
|
| 67 |
-
obj2 = member.prop
|
| 68 |
-
assert obj1 is obj2
|
| 69 |
-
|
| 70 |
-
def test_deprecated():
|
| 71 |
-
@deprecated('deprecated_function is deprecated',
|
| 72 |
-
deprecated_since_version='1.10',
|
| 73 |
-
# This is the target at the top of the file, which will never
|
| 74 |
-
# go away.
|
| 75 |
-
active_deprecations_target='active-deprecations')
|
| 76 |
-
def deprecated_function(x):
|
| 77 |
-
return x
|
| 78 |
-
|
| 79 |
-
with warns_deprecated_sympy():
|
| 80 |
-
assert deprecated_function(1) == 1
|
| 81 |
-
|
| 82 |
-
@deprecated('deprecated_class is deprecated',
|
| 83 |
-
deprecated_since_version='1.10',
|
| 84 |
-
active_deprecations_target='active-deprecations')
|
| 85 |
-
class deprecated_class:
|
| 86 |
-
pass
|
| 87 |
-
|
| 88 |
-
with warns_deprecated_sympy():
|
| 89 |
-
assert isinstance(deprecated_class(), deprecated_class)
|
| 90 |
-
|
| 91 |
-
# Ensure the class decorator works even when the class never returns
|
| 92 |
-
# itself
|
| 93 |
-
@deprecated('deprecated_class_new is deprecated',
|
| 94 |
-
deprecated_since_version='1.10',
|
| 95 |
-
active_deprecations_target='active-deprecations')
|
| 96 |
-
class deprecated_class_new:
|
| 97 |
-
def __new__(cls, arg):
|
| 98 |
-
return arg
|
| 99 |
-
|
| 100 |
-
with warns_deprecated_sympy():
|
| 101 |
-
assert deprecated_class_new(1) == 1
|
| 102 |
-
|
| 103 |
-
@deprecated('deprecated_class_init is deprecated',
|
| 104 |
-
deprecated_since_version='1.10',
|
| 105 |
-
active_deprecations_target='active-deprecations')
|
| 106 |
-
class deprecated_class_init:
|
| 107 |
-
def __init__(self, arg):
|
| 108 |
-
self.arg = 1
|
| 109 |
-
|
| 110 |
-
with warns_deprecated_sympy():
|
| 111 |
-
assert deprecated_class_init(1).arg == 1
|
| 112 |
-
|
| 113 |
-
@deprecated('deprecated_class_new_init is deprecated',
|
| 114 |
-
deprecated_since_version='1.10',
|
| 115 |
-
active_deprecations_target='active-deprecations')
|
| 116 |
-
class deprecated_class_new_init:
|
| 117 |
-
def __new__(cls, arg):
|
| 118 |
-
if arg == 0:
|
| 119 |
-
return arg
|
| 120 |
-
return object.__new__(cls)
|
| 121 |
-
|
| 122 |
-
def __init__(self, arg):
|
| 123 |
-
self.arg = 1
|
| 124 |
-
|
| 125 |
-
with warns_deprecated_sympy():
|
| 126 |
-
assert deprecated_class_new_init(0) == 0
|
| 127 |
-
|
| 128 |
-
with warns_deprecated_sympy():
|
| 129 |
-
assert deprecated_class_new_init(1).arg == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_deprecated.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from sympy.testing.pytest import warns_deprecated_sympy
|
| 2 |
-
|
| 3 |
-
# See https://github.com/sympy/sympy/pull/18095
|
| 4 |
-
|
| 5 |
-
def test_deprecated_utilities():
|
| 6 |
-
with warns_deprecated_sympy():
|
| 7 |
-
import sympy.utilities.pytest # noqa:F401
|
| 8 |
-
with warns_deprecated_sympy():
|
| 9 |
-
import sympy.utilities.runtests # noqa:F401
|
| 10 |
-
with warns_deprecated_sympy():
|
| 11 |
-
import sympy.utilities.randtest # noqa:F401
|
| 12 |
-
with warns_deprecated_sympy():
|
| 13 |
-
import sympy.utilities.tmpfiles # noqa:F401
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_enumerative.py
DELETED
|
@@ -1,179 +0,0 @@
|
|
| 1 |
-
import string
|
| 2 |
-
from itertools import zip_longest
|
| 3 |
-
|
| 4 |
-
from sympy.utilities.enumerative import (
|
| 5 |
-
list_visitor,
|
| 6 |
-
MultisetPartitionTraverser,
|
| 7 |
-
multiset_partitions_taocp
|
| 8 |
-
)
|
| 9 |
-
from sympy.utilities.iterables import _set_partitions
|
| 10 |
-
|
| 11 |
-
# first some functions only useful as test scaffolding - these provide
|
| 12 |
-
# straightforward, but slow reference implementations against which to
|
| 13 |
-
# compare the real versions, and also a comparison to verify that
|
| 14 |
-
# different versions are giving identical results.
|
| 15 |
-
|
| 16 |
-
def part_range_filter(partition_iterator, lb, ub):
|
| 17 |
-
"""
|
| 18 |
-
Filters (on the number of parts) a multiset partition enumeration
|
| 19 |
-
|
| 20 |
-
Arguments
|
| 21 |
-
=========
|
| 22 |
-
|
| 23 |
-
lb, and ub are a range (in the Python slice sense) on the lpart
|
| 24 |
-
variable returned from a multiset partition enumeration. Recall
|
| 25 |
-
that lpart is 0-based (it points to the topmost part on the part
|
| 26 |
-
stack), so if you want to return parts of sizes 2,3,4,5 you would
|
| 27 |
-
use lb=1 and ub=5.
|
| 28 |
-
"""
|
| 29 |
-
for state in partition_iterator:
|
| 30 |
-
f, lpart, pstack = state
|
| 31 |
-
if lpart >= lb and lpart < ub:
|
| 32 |
-
yield state
|
| 33 |
-
|
| 34 |
-
def multiset_partitions_baseline(multiplicities, components):
|
| 35 |
-
"""Enumerates partitions of a multiset
|
| 36 |
-
|
| 37 |
-
Parameters
|
| 38 |
-
==========
|
| 39 |
-
|
| 40 |
-
multiplicities
|
| 41 |
-
list of integer multiplicities of the components of the multiset.
|
| 42 |
-
|
| 43 |
-
components
|
| 44 |
-
the components (elements) themselves
|
| 45 |
-
|
| 46 |
-
Returns
|
| 47 |
-
=======
|
| 48 |
-
|
| 49 |
-
Set of partitions. Each partition is tuple of parts, and each
|
| 50 |
-
part is a tuple of components (with repeats to indicate
|
| 51 |
-
multiplicity)
|
| 52 |
-
|
| 53 |
-
Notes
|
| 54 |
-
=====
|
| 55 |
-
|
| 56 |
-
Multiset partitions can be created as equivalence classes of set
|
| 57 |
-
partitions, and this function does just that. This approach is
|
| 58 |
-
slow and memory intensive compared to the more advanced algorithms
|
| 59 |
-
available, but the code is simple and easy to understand. Hence
|
| 60 |
-
this routine is strictly for testing -- to provide a
|
| 61 |
-
straightforward baseline against which to regress the production
|
| 62 |
-
versions. (This code is a simplified version of an earlier
|
| 63 |
-
production implementation.)
|
| 64 |
-
"""
|
| 65 |
-
|
| 66 |
-
canon = [] # list of components with repeats
|
| 67 |
-
for ct, elem in zip(multiplicities, components):
|
| 68 |
-
canon.extend([elem]*ct)
|
| 69 |
-
|
| 70 |
-
# accumulate the multiset partitions in a set to eliminate dups
|
| 71 |
-
cache = set()
|
| 72 |
-
n = len(canon)
|
| 73 |
-
for nc, q in _set_partitions(n):
|
| 74 |
-
rv = [[] for i in range(nc)]
|
| 75 |
-
for i in range(n):
|
| 76 |
-
rv[q[i]].append(canon[i])
|
| 77 |
-
canonical = tuple(
|
| 78 |
-
sorted([tuple(p) for p in rv]))
|
| 79 |
-
cache.add(canonical)
|
| 80 |
-
return cache
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def compare_multiset_w_baseline(multiplicities):
|
| 84 |
-
"""
|
| 85 |
-
Enumerates the partitions of multiset with AOCP algorithm and
|
| 86 |
-
baseline implementation, and compare the results.
|
| 87 |
-
|
| 88 |
-
"""
|
| 89 |
-
letters = string.ascii_lowercase
|
| 90 |
-
bl_partitions = multiset_partitions_baseline(multiplicities, letters)
|
| 91 |
-
|
| 92 |
-
# The partitions returned by the different algorithms may have
|
| 93 |
-
# their parts in different orders. Also, they generate partitions
|
| 94 |
-
# in different orders. Hence the sorting, and set comparison.
|
| 95 |
-
|
| 96 |
-
aocp_partitions = set()
|
| 97 |
-
for state in multiset_partitions_taocp(multiplicities):
|
| 98 |
-
p1 = tuple(sorted(
|
| 99 |
-
[tuple(p) for p in list_visitor(state, letters)]))
|
| 100 |
-
aocp_partitions.add(p1)
|
| 101 |
-
|
| 102 |
-
assert bl_partitions == aocp_partitions
|
| 103 |
-
|
| 104 |
-
def compare_multiset_states(s1, s2):
|
| 105 |
-
"""compare for equality two instances of multiset partition states
|
| 106 |
-
|
| 107 |
-
This is useful for comparing different versions of the algorithm
|
| 108 |
-
to verify correctness."""
|
| 109 |
-
# Comparison is physical, the only use of semantics is to ignore
|
| 110 |
-
# trash off the top of the stack.
|
| 111 |
-
f1, lpart1, pstack1 = s1
|
| 112 |
-
f2, lpart2, pstack2 = s2
|
| 113 |
-
|
| 114 |
-
if (lpart1 == lpart2) and (f1[0:lpart1+1] == f2[0:lpart2+1]):
|
| 115 |
-
if pstack1[0:f1[lpart1+1]] == pstack2[0:f2[lpart2+1]]:
|
| 116 |
-
return True
|
| 117 |
-
return False
|
| 118 |
-
|
| 119 |
-
def test_multiset_partitions_taocp():
|
| 120 |
-
"""Compares the output of multiset_partitions_taocp with a baseline
|
| 121 |
-
(set partition based) implementation."""
|
| 122 |
-
|
| 123 |
-
# Test cases should not be too large, since the baseline
|
| 124 |
-
# implementation is fairly slow.
|
| 125 |
-
multiplicities = [2,2]
|
| 126 |
-
compare_multiset_w_baseline(multiplicities)
|
| 127 |
-
|
| 128 |
-
multiplicities = [4,3,1]
|
| 129 |
-
compare_multiset_w_baseline(multiplicities)
|
| 130 |
-
|
| 131 |
-
def test_multiset_partitions_versions():
|
| 132 |
-
"""Compares Knuth-based versions of multiset_partitions"""
|
| 133 |
-
multiplicities = [5,2,2,1]
|
| 134 |
-
m = MultisetPartitionTraverser()
|
| 135 |
-
for s1, s2 in zip_longest(m.enum_all(multiplicities),
|
| 136 |
-
multiset_partitions_taocp(multiplicities)):
|
| 137 |
-
assert compare_multiset_states(s1, s2)
|
| 138 |
-
|
| 139 |
-
def subrange_exercise(mult, lb, ub):
|
| 140 |
-
"""Compare filter-based and more optimized subrange implementations
|
| 141 |
-
|
| 142 |
-
Helper for tests, called with both small and larger multisets.
|
| 143 |
-
"""
|
| 144 |
-
m = MultisetPartitionTraverser()
|
| 145 |
-
assert m.count_partitions(mult) == \
|
| 146 |
-
m.count_partitions_slow(mult)
|
| 147 |
-
|
| 148 |
-
# Note - multiple traversals from the same
|
| 149 |
-
# MultisetPartitionTraverser object cannot execute at the same
|
| 150 |
-
# time, hence make several instances here.
|
| 151 |
-
ma = MultisetPartitionTraverser()
|
| 152 |
-
mc = MultisetPartitionTraverser()
|
| 153 |
-
md = MultisetPartitionTraverser()
|
| 154 |
-
|
| 155 |
-
# Several paths to compute just the size two partitions
|
| 156 |
-
a_it = ma.enum_range(mult, lb, ub)
|
| 157 |
-
b_it = part_range_filter(multiset_partitions_taocp(mult), lb, ub)
|
| 158 |
-
c_it = part_range_filter(mc.enum_small(mult, ub), lb, sum(mult))
|
| 159 |
-
d_it = part_range_filter(md.enum_large(mult, lb), 0, ub)
|
| 160 |
-
|
| 161 |
-
for sa, sb, sc, sd in zip_longest(a_it, b_it, c_it, d_it):
|
| 162 |
-
assert compare_multiset_states(sa, sb)
|
| 163 |
-
assert compare_multiset_states(sa, sc)
|
| 164 |
-
assert compare_multiset_states(sa, sd)
|
| 165 |
-
|
| 166 |
-
def test_subrange():
|
| 167 |
-
# Quick, but doesn't hit some of the corner cases
|
| 168 |
-
mult = [4,4,2,1] # mississippi
|
| 169 |
-
lb = 1
|
| 170 |
-
ub = 2
|
| 171 |
-
subrange_exercise(mult, lb, ub)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def test_subrange_large():
|
| 175 |
-
# takes a second or so, depending on cpu, Python version, etc.
|
| 176 |
-
mult = [6,3,2,1]
|
| 177 |
-
lb = 4
|
| 178 |
-
ub = 7
|
| 179 |
-
subrange_exercise(mult, lb, ub)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_exceptions.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
from sympy.testing.pytest import raises
|
| 2 |
-
from sympy.utilities.exceptions import sympy_deprecation_warning
|
| 3 |
-
|
| 4 |
-
# Only test exceptions here because the other cases are tested in the
|
| 5 |
-
# warns_deprecated_sympy tests
|
| 6 |
-
def test_sympy_deprecation_warning():
|
| 7 |
-
raises(TypeError, lambda: sympy_deprecation_warning('test',
|
| 8 |
-
deprecated_since_version=1.10,
|
| 9 |
-
active_deprecations_target='active-deprecations'))
|
| 10 |
-
|
| 11 |
-
raises(ValueError, lambda: sympy_deprecation_warning('test',
|
| 12 |
-
deprecated_since_version="1.10", active_deprecations_target='(active-deprecations)='))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_iterables.py
DELETED
|
@@ -1,945 +0,0 @@
|
|
| 1 |
-
from textwrap import dedent
|
| 2 |
-
from itertools import islice, product
|
| 3 |
-
|
| 4 |
-
from sympy.core.basic import Basic
|
| 5 |
-
from sympy.core.numbers import Integer
|
| 6 |
-
from sympy.core.sorting import ordered
|
| 7 |
-
from sympy.core.symbol import (Dummy, symbols)
|
| 8 |
-
from sympy.functions.combinatorial.factorials import factorial
|
| 9 |
-
from sympy.matrices.dense import Matrix
|
| 10 |
-
from sympy.combinatorics import RGS_enum, RGS_unrank, Permutation
|
| 11 |
-
from sympy.utilities.iterables import (
|
| 12 |
-
_partition, _set_partitions, binary_partitions, bracelets, capture,
|
| 13 |
-
cartes, common_prefix, common_suffix, connected_components, dict_merge,
|
| 14 |
-
filter_symbols, flatten, generate_bell, generate_derangements,
|
| 15 |
-
generate_involutions, generate_oriented_forest, group, has_dups, ibin,
|
| 16 |
-
iproduct, kbins, minlex, multiset, multiset_combinations,
|
| 17 |
-
multiset_partitions, multiset_permutations, necklaces, numbered_symbols,
|
| 18 |
-
partitions, permutations, postfixes,
|
| 19 |
-
prefixes, reshape, rotate_left, rotate_right, runs, sift,
|
| 20 |
-
strongly_connected_components, subsets, take, topological_sort, unflatten,
|
| 21 |
-
uniq, variations, ordered_partitions, rotations, is_palindromic, iterable,
|
| 22 |
-
NotIterable, multiset_derangements, signed_permutations,
|
| 23 |
-
sequence_partitions, sequence_partitions_empty)
|
| 24 |
-
from sympy.utilities.enumerative import (
|
| 25 |
-
factoring_visitor, multiset_partitions_taocp )
|
| 26 |
-
|
| 27 |
-
from sympy.core.singleton import S
|
| 28 |
-
from sympy.testing.pytest import raises, warns_deprecated_sympy
|
| 29 |
-
|
| 30 |
-
w, x, y, z = symbols('w,x,y,z')
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def test_deprecated_iterables():
|
| 34 |
-
from sympy.utilities.iterables import default_sort_key, ordered
|
| 35 |
-
with warns_deprecated_sympy():
|
| 36 |
-
assert list(ordered([y, x])) == [x, y]
|
| 37 |
-
with warns_deprecated_sympy():
|
| 38 |
-
assert sorted([y, x], key=default_sort_key) == [x, y]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def test_is_palindromic():
|
| 42 |
-
assert is_palindromic('')
|
| 43 |
-
assert is_palindromic('x')
|
| 44 |
-
assert is_palindromic('xx')
|
| 45 |
-
assert is_palindromic('xyx')
|
| 46 |
-
assert not is_palindromic('xy')
|
| 47 |
-
assert not is_palindromic('xyzx')
|
| 48 |
-
assert is_palindromic('xxyzzyx', 1)
|
| 49 |
-
assert not is_palindromic('xxyzzyx', 2)
|
| 50 |
-
assert is_palindromic('xxyzzyx', 2, -1)
|
| 51 |
-
assert is_palindromic('xxyzzyx', 2, 6)
|
| 52 |
-
assert is_palindromic('xxyzyx', 1)
|
| 53 |
-
assert not is_palindromic('xxyzyx', 2)
|
| 54 |
-
assert is_palindromic('xxyzyx', 2, 2 + 3)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def test_flatten():
|
| 58 |
-
assert flatten((1, (1,))) == [1, 1]
|
| 59 |
-
assert flatten((x, (x,))) == [x, x]
|
| 60 |
-
|
| 61 |
-
ls = [[(-2, -1), (1, 2)], [(0, 0)]]
|
| 62 |
-
|
| 63 |
-
assert flatten(ls, levels=0) == ls
|
| 64 |
-
assert flatten(ls, levels=1) == [(-2, -1), (1, 2), (0, 0)]
|
| 65 |
-
assert flatten(ls, levels=2) == [-2, -1, 1, 2, 0, 0]
|
| 66 |
-
assert flatten(ls, levels=3) == [-2, -1, 1, 2, 0, 0]
|
| 67 |
-
|
| 68 |
-
raises(ValueError, lambda: flatten(ls, levels=-1))
|
| 69 |
-
|
| 70 |
-
class MyOp(Basic):
|
| 71 |
-
pass
|
| 72 |
-
|
| 73 |
-
assert flatten([MyOp(x, y), z]) == [MyOp(x, y), z]
|
| 74 |
-
assert flatten([MyOp(x, y), z], cls=MyOp) == [x, y, z]
|
| 75 |
-
|
| 76 |
-
assert flatten({1, 11, 2}) == list({1, 11, 2})
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def test_iproduct():
|
| 80 |
-
assert list(iproduct()) == [()]
|
| 81 |
-
assert list(iproduct([])) == []
|
| 82 |
-
assert list(iproduct([1,2,3])) == [(1,),(2,),(3,)]
|
| 83 |
-
assert sorted(iproduct([1, 2], [3, 4, 5])) == [
|
| 84 |
-
(1,3),(1,4),(1,5),(2,3),(2,4),(2,5)]
|
| 85 |
-
assert sorted(iproduct([0,1],[0,1],[0,1])) == [
|
| 86 |
-
(0,0,0),(0,0,1),(0,1,0),(0,1,1),(1,0,0),(1,0,1),(1,1,0),(1,1,1)]
|
| 87 |
-
assert iterable(iproduct(S.Integers)) is True
|
| 88 |
-
assert iterable(iproduct(S.Integers, S.Integers)) is True
|
| 89 |
-
assert (3,) in iproduct(S.Integers)
|
| 90 |
-
assert (4, 5) in iproduct(S.Integers, S.Integers)
|
| 91 |
-
assert (1, 2, 3) in iproduct(S.Integers, S.Integers, S.Integers)
|
| 92 |
-
triples = set(islice(iproduct(S.Integers, S.Integers, S.Integers), 1000))
|
| 93 |
-
for n1, n2, n3 in triples:
|
| 94 |
-
assert isinstance(n1, Integer)
|
| 95 |
-
assert isinstance(n2, Integer)
|
| 96 |
-
assert isinstance(n3, Integer)
|
| 97 |
-
for t in set(product(*([range(-2, 3)]*3))):
|
| 98 |
-
assert t in iproduct(S.Integers, S.Integers, S.Integers)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def test_group():
|
| 102 |
-
assert group([]) == []
|
| 103 |
-
assert group([], multiple=False) == []
|
| 104 |
-
|
| 105 |
-
assert group([1]) == [[1]]
|
| 106 |
-
assert group([1], multiple=False) == [(1, 1)]
|
| 107 |
-
|
| 108 |
-
assert group([1, 1]) == [[1, 1]]
|
| 109 |
-
assert group([1, 1], multiple=False) == [(1, 2)]
|
| 110 |
-
|
| 111 |
-
assert group([1, 1, 1]) == [[1, 1, 1]]
|
| 112 |
-
assert group([1, 1, 1], multiple=False) == [(1, 3)]
|
| 113 |
-
|
| 114 |
-
assert group([1, 2, 1]) == [[1], [2], [1]]
|
| 115 |
-
assert group([1, 2, 1], multiple=False) == [(1, 1), (2, 1), (1, 1)]
|
| 116 |
-
|
| 117 |
-
assert group([1, 1, 2, 2, 2, 1, 3, 3]) == [[1, 1], [2, 2, 2], [1], [3, 3]]
|
| 118 |
-
assert group([1, 1, 2, 2, 2, 1, 3, 3], multiple=False) == [(1, 2),
|
| 119 |
-
(2, 3), (1, 1), (3, 2)]
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def test_subsets():
|
| 123 |
-
# combinations
|
| 124 |
-
assert list(subsets([1, 2, 3], 0)) == [()]
|
| 125 |
-
assert list(subsets([1, 2, 3], 1)) == [(1,), (2,), (3,)]
|
| 126 |
-
assert list(subsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)]
|
| 127 |
-
assert list(subsets([1, 2, 3], 3)) == [(1, 2, 3)]
|
| 128 |
-
l = list(range(4))
|
| 129 |
-
assert list(subsets(l, 0, repetition=True)) == [()]
|
| 130 |
-
assert list(subsets(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
|
| 131 |
-
assert list(subsets(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
|
| 132 |
-
(0, 3), (1, 1), (1, 2),
|
| 133 |
-
(1, 3), (2, 2), (2, 3),
|
| 134 |
-
(3, 3)]
|
| 135 |
-
assert list(subsets(l, 3, repetition=True)) == [(0, 0, 0), (0, 0, 1),
|
| 136 |
-
(0, 0, 2), (0, 0, 3),
|
| 137 |
-
(0, 1, 1), (0, 1, 2),
|
| 138 |
-
(0, 1, 3), (0, 2, 2),
|
| 139 |
-
(0, 2, 3), (0, 3, 3),
|
| 140 |
-
(1, 1, 1), (1, 1, 2),
|
| 141 |
-
(1, 1, 3), (1, 2, 2),
|
| 142 |
-
(1, 2, 3), (1, 3, 3),
|
| 143 |
-
(2, 2, 2), (2, 2, 3),
|
| 144 |
-
(2, 3, 3), (3, 3, 3)]
|
| 145 |
-
assert len(list(subsets(l, 4, repetition=True))) == 35
|
| 146 |
-
|
| 147 |
-
assert list(subsets(l[:2], 3, repetition=False)) == []
|
| 148 |
-
assert list(subsets(l[:2], 3, repetition=True)) == [(0, 0, 0),
|
| 149 |
-
(0, 0, 1),
|
| 150 |
-
(0, 1, 1),
|
| 151 |
-
(1, 1, 1)]
|
| 152 |
-
assert list(subsets([1, 2], repetition=True)) == \
|
| 153 |
-
[(), (1,), (2,), (1, 1), (1, 2), (2, 2)]
|
| 154 |
-
assert list(subsets([1, 2], repetition=False)) == \
|
| 155 |
-
[(), (1,), (2,), (1, 2)]
|
| 156 |
-
assert list(subsets([1, 2, 3], 2)) == \
|
| 157 |
-
[(1, 2), (1, 3), (2, 3)]
|
| 158 |
-
assert list(subsets([1, 2, 3], 2, repetition=True)) == \
|
| 159 |
-
[(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def test_variations():
|
| 163 |
-
# permutations
|
| 164 |
-
l = list(range(4))
|
| 165 |
-
assert list(variations(l, 0, repetition=False)) == [()]
|
| 166 |
-
assert list(variations(l, 1, repetition=False)) == [(0,), (1,), (2,), (3,)]
|
| 167 |
-
assert list(variations(l, 2, repetition=False)) == [(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)]
|
| 168 |
-
assert list(variations(l, 3, repetition=False)) == [(0, 1, 2), (0, 1, 3), (0, 2, 1), (0, 2, 3), (0, 3, 1), (0, 3, 2), (1, 0, 2), (1, 0, 3), (1, 2, 0), (1, 2, 3), (1, 3, 0), (1, 3, 2), (2, 0, 1), (2, 0, 3), (2, 1, 0), (2, 1, 3), (2, 3, 0), (2, 3, 1), (3, 0, 1), (3, 0, 2), (3, 1, 0), (3, 1, 2), (3, 2, 0), (3, 2, 1)]
|
| 169 |
-
assert list(variations(l, 0, repetition=True)) == [()]
|
| 170 |
-
assert list(variations(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
|
| 171 |
-
assert list(variations(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
|
| 172 |
-
(0, 3), (1, 0), (1, 1),
|
| 173 |
-
(1, 2), (1, 3), (2, 0),
|
| 174 |
-
(2, 1), (2, 2), (2, 3),
|
| 175 |
-
(3, 0), (3, 1), (3, 2),
|
| 176 |
-
(3, 3)]
|
| 177 |
-
assert len(list(variations(l, 3, repetition=True))) == 64
|
| 178 |
-
assert len(list(variations(l, 4, repetition=True))) == 256
|
| 179 |
-
assert list(variations(l[:2], 3, repetition=False)) == []
|
| 180 |
-
assert list(variations(l[:2], 3, repetition=True)) == [
|
| 181 |
-
(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1),
|
| 182 |
-
(1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)
|
| 183 |
-
]
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def test_cartes():
|
| 187 |
-
assert list(cartes([1, 2], [3, 4, 5])) == \
|
| 188 |
-
[(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5)]
|
| 189 |
-
assert list(cartes()) == [()]
|
| 190 |
-
assert list(cartes('a')) == [('a',)]
|
| 191 |
-
assert list(cartes('a', repeat=2)) == [('a', 'a')]
|
| 192 |
-
assert list(cartes(list(range(2)))) == [(0,), (1,)]
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def test_filter_symbols():
|
| 196 |
-
s = numbered_symbols()
|
| 197 |
-
filtered = filter_symbols(s, symbols("x0 x2 x3"))
|
| 198 |
-
assert take(filtered, 3) == list(symbols("x1 x4 x5"))
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def test_numbered_symbols():
|
| 202 |
-
s = numbered_symbols(cls=Dummy)
|
| 203 |
-
assert isinstance(next(s), Dummy)
|
| 204 |
-
assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
|
| 205 |
-
symbols('C2')
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def test_sift():
|
| 209 |
-
assert sift(list(range(5)), lambda _: _ % 2) == {1: [1, 3], 0: [0, 2, 4]}
|
| 210 |
-
assert sift([x, y], lambda _: _.has(x)) == {False: [y], True: [x]}
|
| 211 |
-
assert sift([S.One], lambda _: _.has(x)) == {False: [1]}
|
| 212 |
-
assert sift([0, 1, 2, 3], lambda x: x % 2, binary=True) == (
|
| 213 |
-
[1, 3], [0, 2])
|
| 214 |
-
assert sift([0, 1, 2, 3], lambda x: x % 3 == 1, binary=True) == (
|
| 215 |
-
[1], [0, 2, 3])
|
| 216 |
-
raises(ValueError, lambda:
|
| 217 |
-
sift([0, 1, 2, 3], lambda x: x % 3, binary=True))
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def test_take():
|
| 221 |
-
X = numbered_symbols()
|
| 222 |
-
|
| 223 |
-
assert take(X, 5) == list(symbols('x0:5'))
|
| 224 |
-
assert take(X, 5) == list(symbols('x5:10'))
|
| 225 |
-
|
| 226 |
-
assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
def test_dict_merge():
|
| 230 |
-
assert dict_merge({}, {1: x, y: z}) == {1: x, y: z}
|
| 231 |
-
assert dict_merge({1: x, y: z}, {}) == {1: x, y: z}
|
| 232 |
-
|
| 233 |
-
assert dict_merge({2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
|
| 234 |
-
assert dict_merge({1: x, y: z}, {2: z}) == {1: x, 2: z, y: z}
|
| 235 |
-
|
| 236 |
-
assert dict_merge({1: y, 2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
|
| 237 |
-
assert dict_merge({1: x, y: z}, {1: y, 2: z}) == {1: y, 2: z, y: z}
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def test_prefixes():
|
| 241 |
-
assert list(prefixes([])) == []
|
| 242 |
-
assert list(prefixes([1])) == [[1]]
|
| 243 |
-
assert list(prefixes([1, 2])) == [[1], [1, 2]]
|
| 244 |
-
|
| 245 |
-
assert list(prefixes([1, 2, 3, 4, 5])) == \
|
| 246 |
-
[[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
def test_postfixes():
|
| 250 |
-
assert list(postfixes([])) == []
|
| 251 |
-
assert list(postfixes([1])) == [[1]]
|
| 252 |
-
assert list(postfixes([1, 2])) == [[2], [1, 2]]
|
| 253 |
-
|
| 254 |
-
assert list(postfixes([1, 2, 3, 4, 5])) == \
|
| 255 |
-
[[5], [4, 5], [3, 4, 5], [2, 3, 4, 5], [1, 2, 3, 4, 5]]
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
def test_topological_sort():
|
| 259 |
-
V = [2, 3, 5, 7, 8, 9, 10, 11]
|
| 260 |
-
E = [(7, 11), (7, 8), (5, 11),
|
| 261 |
-
(3, 8), (3, 10), (11, 2),
|
| 262 |
-
(11, 9), (11, 10), (8, 9)]
|
| 263 |
-
|
| 264 |
-
assert topological_sort((V, E)) == [3, 5, 7, 8, 11, 2, 9, 10]
|
| 265 |
-
assert topological_sort((V, E), key=lambda v: -v) == \
|
| 266 |
-
[7, 5, 11, 3, 10, 8, 9, 2]
|
| 267 |
-
|
| 268 |
-
raises(ValueError, lambda: topological_sort((V, E + [(10, 7)])))
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
def test_strongly_connected_components():
|
| 272 |
-
assert strongly_connected_components(([], [])) == []
|
| 273 |
-
assert strongly_connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
|
| 274 |
-
|
| 275 |
-
V = [1, 2, 3]
|
| 276 |
-
E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
|
| 277 |
-
assert strongly_connected_components((V, E)) == [[1, 2, 3]]
|
| 278 |
-
|
| 279 |
-
V = [1, 2, 3, 4]
|
| 280 |
-
E = [(1, 2), (2, 3), (3, 2), (3, 4)]
|
| 281 |
-
assert strongly_connected_components((V, E)) == [[4], [2, 3], [1]]
|
| 282 |
-
|
| 283 |
-
V = [1, 2, 3, 4]
|
| 284 |
-
E = [(1, 2), (2, 1), (3, 4), (4, 3)]
|
| 285 |
-
assert strongly_connected_components((V, E)) == [[1, 2], [3, 4]]
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def test_connected_components():
|
| 289 |
-
assert connected_components(([], [])) == []
|
| 290 |
-
assert connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
|
| 291 |
-
|
| 292 |
-
V = [1, 2, 3]
|
| 293 |
-
E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
|
| 294 |
-
assert connected_components((V, E)) == [[1, 2, 3]]
|
| 295 |
-
|
| 296 |
-
V = [1, 2, 3, 4]
|
| 297 |
-
E = [(1, 2), (2, 3), (3, 2), (3, 4)]
|
| 298 |
-
assert connected_components((V, E)) == [[1, 2, 3, 4]]
|
| 299 |
-
|
| 300 |
-
V = [1, 2, 3, 4]
|
| 301 |
-
E = [(1, 2), (3, 4)]
|
| 302 |
-
assert connected_components((V, E)) == [[1, 2], [3, 4]]
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
def test_rotate():
|
| 306 |
-
A = [0, 1, 2, 3, 4]
|
| 307 |
-
|
| 308 |
-
assert rotate_left(A, 2) == [2, 3, 4, 0, 1]
|
| 309 |
-
assert rotate_right(A, 1) == [4, 0, 1, 2, 3]
|
| 310 |
-
A = []
|
| 311 |
-
B = rotate_right(A, 1)
|
| 312 |
-
assert B == []
|
| 313 |
-
B.append(1)
|
| 314 |
-
assert A == []
|
| 315 |
-
B = rotate_left(A, 1)
|
| 316 |
-
assert B == []
|
| 317 |
-
B.append(1)
|
| 318 |
-
assert A == []
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
def test_multiset_partitions():
|
| 322 |
-
A = [0, 1, 2, 3, 4]
|
| 323 |
-
|
| 324 |
-
assert list(multiset_partitions(A, 5)) == [[[0], [1], [2], [3], [4]]]
|
| 325 |
-
assert len(list(multiset_partitions(A, 4))) == 10
|
| 326 |
-
assert len(list(multiset_partitions(A, 3))) == 25
|
| 327 |
-
|
| 328 |
-
assert list(multiset_partitions([1, 1, 1, 2, 2], 2)) == [
|
| 329 |
-
[[1, 1, 1, 2], [2]], [[1, 1, 1], [2, 2]], [[1, 1, 2, 2], [1]],
|
| 330 |
-
[[1, 1, 2], [1, 2]], [[1, 1], [1, 2, 2]]]
|
| 331 |
-
|
| 332 |
-
assert list(multiset_partitions([1, 1, 2, 2], 2)) == [
|
| 333 |
-
[[1, 1, 2], [2]], [[1, 1], [2, 2]], [[1, 2, 2], [1]],
|
| 334 |
-
[[1, 2], [1, 2]]]
|
| 335 |
-
|
| 336 |
-
assert list(multiset_partitions([1, 2, 3, 4], 2)) == [
|
| 337 |
-
[[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]],
|
| 338 |
-
[[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]],
|
| 339 |
-
[[1], [2, 3, 4]]]
|
| 340 |
-
|
| 341 |
-
assert list(multiset_partitions([1, 2, 2], 2)) == [
|
| 342 |
-
[[1, 2], [2]], [[1], [2, 2]]]
|
| 343 |
-
|
| 344 |
-
assert list(multiset_partitions(3)) == [
|
| 345 |
-
[[0, 1, 2]], [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]],
|
| 346 |
-
[[0], [1], [2]]]
|
| 347 |
-
assert list(multiset_partitions(3, 2)) == [
|
| 348 |
-
[[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]]]
|
| 349 |
-
assert list(multiset_partitions([1] * 3, 2)) == [[[1], [1, 1]]]
|
| 350 |
-
assert list(multiset_partitions([1] * 3)) == [
|
| 351 |
-
[[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]]
|
| 352 |
-
a = [3, 2, 1]
|
| 353 |
-
assert list(multiset_partitions(a)) == \
|
| 354 |
-
list(multiset_partitions(sorted(a)))
|
| 355 |
-
assert list(multiset_partitions(a, 5)) == []
|
| 356 |
-
assert list(multiset_partitions(a, 1)) == [[[1, 2, 3]]]
|
| 357 |
-
assert list(multiset_partitions(a + [4], 5)) == []
|
| 358 |
-
assert list(multiset_partitions(a + [4], 1)) == [[[1, 2, 3, 4]]]
|
| 359 |
-
assert list(multiset_partitions(2, 5)) == []
|
| 360 |
-
assert list(multiset_partitions(2, 1)) == [[[0, 1]]]
|
| 361 |
-
assert list(multiset_partitions('a')) == [[['a']]]
|
| 362 |
-
assert list(multiset_partitions('a', 2)) == []
|
| 363 |
-
assert list(multiset_partitions('ab')) == [[['a', 'b']], [['a'], ['b']]]
|
| 364 |
-
assert list(multiset_partitions('ab', 1)) == [[['a', 'b']]]
|
| 365 |
-
assert list(multiset_partitions('aaa', 1)) == [['aaa']]
|
| 366 |
-
assert list(multiset_partitions([1, 1], 1)) == [[[1, 1]]]
|
| 367 |
-
ans = [('mpsyy',), ('mpsy', 'y'), ('mps', 'yy'), ('mps', 'y', 'y'),
|
| 368 |
-
('mpyy', 's'), ('mpy', 'sy'), ('mpy', 's', 'y'), ('mp', 'syy'),
|
| 369 |
-
('mp', 'sy', 'y'), ('mp', 's', 'yy'), ('mp', 's', 'y', 'y'),
|
| 370 |
-
('msyy', 'p'), ('msy', 'py'), ('msy', 'p', 'y'), ('ms', 'pyy'),
|
| 371 |
-
('ms', 'py', 'y'), ('ms', 'p', 'yy'), ('ms', 'p', 'y', 'y'),
|
| 372 |
-
('myy', 'ps'), ('myy', 'p', 's'), ('my', 'psy'), ('my', 'ps', 'y'),
|
| 373 |
-
('my', 'py', 's'), ('my', 'p', 'sy'), ('my', 'p', 's', 'y'),
|
| 374 |
-
('m', 'psyy'), ('m', 'psy', 'y'), ('m', 'ps', 'yy'),
|
| 375 |
-
('m', 'ps', 'y', 'y'), ('m', 'pyy', 's'), ('m', 'py', 'sy'),
|
| 376 |
-
('m', 'py', 's', 'y'), ('m', 'p', 'syy'),
|
| 377 |
-
('m', 'p', 'sy', 'y'), ('m', 'p', 's', 'yy'),
|
| 378 |
-
('m', 'p', 's', 'y', 'y')]
|
| 379 |
-
assert [tuple("".join(part) for part in p)
|
| 380 |
-
for p in multiset_partitions('sympy')] == ans
|
| 381 |
-
factorings = [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3],
|
| 382 |
-
[6, 2, 2], [2, 2, 2, 3]]
|
| 383 |
-
assert [factoring_visitor(p, [2,3]) for
|
| 384 |
-
p in multiset_partitions_taocp([3, 1])] == factorings
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
def test_multiset_combinations():
|
| 388 |
-
ans = ['iii', 'iim', 'iip', 'iis', 'imp', 'ims', 'ipp', 'ips',
|
| 389 |
-
'iss', 'mpp', 'mps', 'mss', 'pps', 'pss', 'sss']
|
| 390 |
-
assert [''.join(i) for i in
|
| 391 |
-
list(multiset_combinations('mississippi', 3))] == ans
|
| 392 |
-
M = multiset('mississippi')
|
| 393 |
-
assert [''.join(i) for i in
|
| 394 |
-
list(multiset_combinations(M, 3))] == ans
|
| 395 |
-
assert [''.join(i) for i in multiset_combinations(M, 30)] == []
|
| 396 |
-
assert list(multiset_combinations([[1], [2, 3]], 2)) == [[[1], [2, 3]]]
|
| 397 |
-
assert len(list(multiset_combinations('a', 3))) == 0
|
| 398 |
-
assert len(list(multiset_combinations('a', 0))) == 1
|
| 399 |
-
assert list(multiset_combinations('abc', 1)) == [['a'], ['b'], ['c']]
|
| 400 |
-
raises(ValueError, lambda: list(multiset_combinations({0: 3, 1: -1}, 2)))
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
def test_multiset_permutations():
|
| 404 |
-
ans = ['abby', 'abyb', 'aybb', 'baby', 'bayb', 'bbay', 'bbya', 'byab',
|
| 405 |
-
'byba', 'yabb', 'ybab', 'ybba']
|
| 406 |
-
assert [''.join(i) for i in multiset_permutations('baby')] == ans
|
| 407 |
-
assert [''.join(i) for i in multiset_permutations(multiset('baby'))] == ans
|
| 408 |
-
assert list(multiset_permutations([0, 0, 0], 2)) == [[0, 0]]
|
| 409 |
-
assert list(multiset_permutations([0, 2, 1], 2)) == [
|
| 410 |
-
[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]
|
| 411 |
-
assert len(list(multiset_permutations('a', 0))) == 1
|
| 412 |
-
assert len(list(multiset_permutations('a', 3))) == 0
|
| 413 |
-
for nul in ([], {}, ''):
|
| 414 |
-
assert list(multiset_permutations(nul)) == [[]]
|
| 415 |
-
assert list(multiset_permutations(nul, 0)) == [[]]
|
| 416 |
-
# impossible requests give no result
|
| 417 |
-
assert list(multiset_permutations(nul, 1)) == []
|
| 418 |
-
assert list(multiset_permutations(nul, -1)) == []
|
| 419 |
-
|
| 420 |
-
def test():
|
| 421 |
-
for i in range(1, 7):
|
| 422 |
-
print(i)
|
| 423 |
-
for p in multiset_permutations([0, 0, 1, 0, 1], i):
|
| 424 |
-
print(p)
|
| 425 |
-
assert capture(lambda: test()) == dedent('''\
|
| 426 |
-
1
|
| 427 |
-
[0]
|
| 428 |
-
[1]
|
| 429 |
-
2
|
| 430 |
-
[0, 0]
|
| 431 |
-
[0, 1]
|
| 432 |
-
[1, 0]
|
| 433 |
-
[1, 1]
|
| 434 |
-
3
|
| 435 |
-
[0, 0, 0]
|
| 436 |
-
[0, 0, 1]
|
| 437 |
-
[0, 1, 0]
|
| 438 |
-
[0, 1, 1]
|
| 439 |
-
[1, 0, 0]
|
| 440 |
-
[1, 0, 1]
|
| 441 |
-
[1, 1, 0]
|
| 442 |
-
4
|
| 443 |
-
[0, 0, 0, 1]
|
| 444 |
-
[0, 0, 1, 0]
|
| 445 |
-
[0, 0, 1, 1]
|
| 446 |
-
[0, 1, 0, 0]
|
| 447 |
-
[0, 1, 0, 1]
|
| 448 |
-
[0, 1, 1, 0]
|
| 449 |
-
[1, 0, 0, 0]
|
| 450 |
-
[1, 0, 0, 1]
|
| 451 |
-
[1, 0, 1, 0]
|
| 452 |
-
[1, 1, 0, 0]
|
| 453 |
-
5
|
| 454 |
-
[0, 0, 0, 1, 1]
|
| 455 |
-
[0, 0, 1, 0, 1]
|
| 456 |
-
[0, 0, 1, 1, 0]
|
| 457 |
-
[0, 1, 0, 0, 1]
|
| 458 |
-
[0, 1, 0, 1, 0]
|
| 459 |
-
[0, 1, 1, 0, 0]
|
| 460 |
-
[1, 0, 0, 0, 1]
|
| 461 |
-
[1, 0, 0, 1, 0]
|
| 462 |
-
[1, 0, 1, 0, 0]
|
| 463 |
-
[1, 1, 0, 0, 0]
|
| 464 |
-
6\n''')
|
| 465 |
-
raises(ValueError, lambda: list(multiset_permutations({0: 3, 1: -1})))
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
def test_partitions():
|
| 469 |
-
ans = [[{}], [(0, {})]]
|
| 470 |
-
for i in range(2):
|
| 471 |
-
assert list(partitions(0, size=i)) == ans[i]
|
| 472 |
-
assert list(partitions(1, 0, size=i)) == ans[i]
|
| 473 |
-
assert list(partitions(6, 2, 2, size=i)) == ans[i]
|
| 474 |
-
assert list(partitions(6, 2, None, size=i)) != ans[i]
|
| 475 |
-
assert list(partitions(6, None, 2, size=i)) != ans[i]
|
| 476 |
-
assert list(partitions(6, 2, 0, size=i)) == ans[i]
|
| 477 |
-
|
| 478 |
-
assert list(partitions(6, k=2)) == [
|
| 479 |
-
{2: 3}, {1: 2, 2: 2}, {1: 4, 2: 1}, {1: 6}]
|
| 480 |
-
|
| 481 |
-
assert list(partitions(6, k=3)) == [
|
| 482 |
-
{3: 2}, {1: 1, 2: 1, 3: 1}, {1: 3, 3: 1}, {2: 3}, {1: 2, 2: 2},
|
| 483 |
-
{1: 4, 2: 1}, {1: 6}]
|
| 484 |
-
|
| 485 |
-
assert list(partitions(8, k=4, m=3)) == [
|
| 486 |
-
{4: 2}, {1: 1, 3: 1, 4: 1}, {2: 2, 4: 1}, {2: 1, 3: 2}] == [
|
| 487 |
-
i for i in partitions(8, k=4, m=3) if all(k <= 4 for k in i)
|
| 488 |
-
and sum(i.values()) <=3]
|
| 489 |
-
|
| 490 |
-
assert list(partitions(S(3), m=2)) == [
|
| 491 |
-
{3: 1}, {1: 1, 2: 1}]
|
| 492 |
-
|
| 493 |
-
assert list(partitions(4, k=3)) == [
|
| 494 |
-
{1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] == [
|
| 495 |
-
i for i in partitions(4) if all(k <= 3 for k in i)]
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
# Consistency check on output of _partitions and RGS_unrank.
|
| 499 |
-
# This provides a sanity test on both routines. Also verifies that
|
| 500 |
-
# the total number of partitions is the same in each case.
|
| 501 |
-
# (from pkrathmann2)
|
| 502 |
-
|
| 503 |
-
for n in range(2, 6):
|
| 504 |
-
i = 0
|
| 505 |
-
for m, q in _set_partitions(n):
|
| 506 |
-
assert q == RGS_unrank(i, n)
|
| 507 |
-
i += 1
|
| 508 |
-
assert i == RGS_enum(n)
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
def test_binary_partitions():
|
| 512 |
-
assert [i[:] for i in binary_partitions(10)] == [[8, 2], [8, 1, 1],
|
| 513 |
-
[4, 4, 2], [4, 4, 1, 1], [4, 2, 2, 2], [4, 2, 2, 1, 1],
|
| 514 |
-
[4, 2, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2],
|
| 515 |
-
[2, 2, 2, 2, 1, 1], [2, 2, 2, 1, 1, 1, 1], [2, 2, 1, 1, 1, 1, 1, 1],
|
| 516 |
-
[2, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
| 517 |
-
|
| 518 |
-
assert len([j[:] for j in binary_partitions(16)]) == 36
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
def test_bell_perm():
|
| 522 |
-
assert [len(set(generate_bell(i))) for i in range(1, 7)] == [
|
| 523 |
-
factorial(i) for i in range(1, 7)]
|
| 524 |
-
assert list(generate_bell(3)) == [
|
| 525 |
-
(0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)]
|
| 526 |
-
# generate_bell and trotterjohnson are advertised to return the same
|
| 527 |
-
# permutations; this is not technically necessary so this test could
|
| 528 |
-
# be removed
|
| 529 |
-
for n in range(1, 5):
|
| 530 |
-
p = Permutation(range(n))
|
| 531 |
-
b = generate_bell(n)
|
| 532 |
-
for bi in b:
|
| 533 |
-
assert bi == tuple(p.array_form)
|
| 534 |
-
p = p.next_trotterjohnson()
|
| 535 |
-
raises(ValueError, lambda: list(generate_bell(0))) # XXX is this consistent with other permutation algorithms?
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
def test_involutions():
|
| 539 |
-
lengths = [1, 2, 4, 10, 26, 76]
|
| 540 |
-
for n, N in enumerate(lengths):
|
| 541 |
-
i = list(generate_involutions(n + 1))
|
| 542 |
-
assert len(i) == N
|
| 543 |
-
assert len({Permutation(j)**2 for j in i}) == 1
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
def test_derangements():
|
| 547 |
-
assert len(list(generate_derangements(list(range(6))))) == 265
|
| 548 |
-
assert ''.join(''.join(i) for i in generate_derangements('abcde')) == (
|
| 549 |
-
'badecbaecdbcaedbcdeabceadbdaecbdeacbdecabeacdbedacbedcacabedcadebcaebd'
|
| 550 |
-
'cdaebcdbeacdeabcdebaceabdcebadcedabcedbadabecdaebcdaecbdcaebdcbeadceab'
|
| 551 |
-
'dcebadeabcdeacbdebacdebcaeabcdeadbceadcbecabdecbadecdabecdbaedabcedacb'
|
| 552 |
-
'edbacedbca')
|
| 553 |
-
assert list(generate_derangements([0, 1, 2, 3])) == [
|
| 554 |
-
[1, 0, 3, 2], [1, 2, 3, 0], [1, 3, 0, 2], [2, 0, 3, 1],
|
| 555 |
-
[2, 3, 0, 1], [2, 3, 1, 0], [3, 0, 1, 2], [3, 2, 0, 1], [3, 2, 1, 0]]
|
| 556 |
-
assert list(generate_derangements([0, 1, 2, 2])) == [
|
| 557 |
-
[2, 2, 0, 1], [2, 2, 1, 0]]
|
| 558 |
-
assert list(generate_derangements('ba')) == [list('ab')]
|
| 559 |
-
# multiset_derangements
|
| 560 |
-
D = multiset_derangements
|
| 561 |
-
assert list(D('abb')) == []
|
| 562 |
-
assert [''.join(i) for i in D('ab')] == ['ba']
|
| 563 |
-
assert [''.join(i) for i in D('abc')] == ['bca', 'cab']
|
| 564 |
-
assert [''.join(i) for i in D('aabb')] == ['bbaa']
|
| 565 |
-
assert [''.join(i) for i in D('aabbcccc')] == [
|
| 566 |
-
'ccccaabb', 'ccccabab', 'ccccabba', 'ccccbaab', 'ccccbaba',
|
| 567 |
-
'ccccbbaa']
|
| 568 |
-
assert [''.join(i) for i in D('aabbccc')] == [
|
| 569 |
-
'cccabba', 'cccabab', 'cccaabb', 'ccacbba', 'ccacbab',
|
| 570 |
-
'ccacabb', 'cbccbaa', 'cbccaba', 'cbccaab', 'bcccbaa',
|
| 571 |
-
'bcccaba', 'bcccaab']
|
| 572 |
-
assert [''.join(i) for i in D('books')] == ['kbsoo', 'ksboo',
|
| 573 |
-
'sbkoo', 'skboo', 'oksbo', 'oskbo', 'okbso', 'obkso', 'oskob',
|
| 574 |
-
'oksob', 'osbok', 'obsok']
|
| 575 |
-
assert list(generate_derangements([[3], [2], [2], [1]])) == [
|
| 576 |
-
[[2], [1], [3], [2]], [[2], [3], [1], [2]]]
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
def test_necklaces():
|
| 580 |
-
def count(n, k, f):
|
| 581 |
-
return len(list(necklaces(n, k, f)))
|
| 582 |
-
m = []
|
| 583 |
-
for i in range(1, 8):
|
| 584 |
-
m.append((
|
| 585 |
-
i, count(i, 2, 0), count(i, 2, 1), count(i, 3, 1)))
|
| 586 |
-
assert Matrix(m) == Matrix([
|
| 587 |
-
[1, 2, 2, 3],
|
| 588 |
-
[2, 3, 3, 6],
|
| 589 |
-
[3, 4, 4, 10],
|
| 590 |
-
[4, 6, 6, 21],
|
| 591 |
-
[5, 8, 8, 39],
|
| 592 |
-
[6, 14, 13, 92],
|
| 593 |
-
[7, 20, 18, 198]])
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
def test_bracelets():
|
| 597 |
-
bc = list(bracelets(2, 4))
|
| 598 |
-
assert Matrix(bc) == Matrix([
|
| 599 |
-
[0, 0],
|
| 600 |
-
[0, 1],
|
| 601 |
-
[0, 2],
|
| 602 |
-
[0, 3],
|
| 603 |
-
[1, 1],
|
| 604 |
-
[1, 2],
|
| 605 |
-
[1, 3],
|
| 606 |
-
[2, 2],
|
| 607 |
-
[2, 3],
|
| 608 |
-
[3, 3]
|
| 609 |
-
])
|
| 610 |
-
bc = list(bracelets(4, 2))
|
| 611 |
-
assert Matrix(bc) == Matrix([
|
| 612 |
-
[0, 0, 0, 0],
|
| 613 |
-
[0, 0, 0, 1],
|
| 614 |
-
[0, 0, 1, 1],
|
| 615 |
-
[0, 1, 0, 1],
|
| 616 |
-
[0, 1, 1, 1],
|
| 617 |
-
[1, 1, 1, 1]
|
| 618 |
-
])
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
def test_generate_oriented_forest():
|
| 622 |
-
assert list(generate_oriented_forest(5)) == [[0, 1, 2, 3, 4],
|
| 623 |
-
[0, 1, 2, 3, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 0],
|
| 624 |
-
[0, 1, 2, 2, 2], [0, 1, 2, 2, 1], [0, 1, 2, 2, 0], [0, 1, 2, 1, 2],
|
| 625 |
-
[0, 1, 2, 1, 1], [0, 1, 2, 1, 0], [0, 1, 2, 0, 1], [0, 1, 2, 0, 0],
|
| 626 |
-
[0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1], [0, 1, 1, 0, 0],
|
| 627 |
-
[0, 1, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]]
|
| 628 |
-
assert len(list(generate_oriented_forest(10))) == 1842
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
def test_unflatten():
|
| 632 |
-
r = list(range(10))
|
| 633 |
-
assert unflatten(r) == list(zip(r[::2], r[1::2]))
|
| 634 |
-
assert unflatten(r, 5) == [tuple(r[:5]), tuple(r[5:])]
|
| 635 |
-
raises(ValueError, lambda: unflatten(list(range(10)), 3))
|
| 636 |
-
raises(ValueError, lambda: unflatten(list(range(10)), -2))
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
def test_common_prefix_suffix():
|
| 640 |
-
assert common_prefix([], [1]) == []
|
| 641 |
-
assert common_prefix(list(range(3))) == [0, 1, 2]
|
| 642 |
-
assert common_prefix(list(range(3)), list(range(4))) == [0, 1, 2]
|
| 643 |
-
assert common_prefix([1, 2, 3], [1, 2, 5]) == [1, 2]
|
| 644 |
-
assert common_prefix([1, 2, 3], [1, 3, 5]) == [1]
|
| 645 |
-
|
| 646 |
-
assert common_suffix([], [1]) == []
|
| 647 |
-
assert common_suffix(list(range(3))) == [0, 1, 2]
|
| 648 |
-
assert common_suffix(list(range(3)), list(range(3))) == [0, 1, 2]
|
| 649 |
-
assert common_suffix(list(range(3)), list(range(4))) == []
|
| 650 |
-
assert common_suffix([1, 2, 3], [9, 2, 3]) == [2, 3]
|
| 651 |
-
assert common_suffix([1, 2, 3], [9, 7, 3]) == [3]
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
def test_minlex():
|
| 655 |
-
assert minlex([1, 2, 0]) == (0, 1, 2)
|
| 656 |
-
assert minlex((1, 2, 0)) == (0, 1, 2)
|
| 657 |
-
assert minlex((1, 0, 2)) == (0, 2, 1)
|
| 658 |
-
assert minlex((1, 0, 2), directed=False) == (0, 1, 2)
|
| 659 |
-
assert minlex('aba') == 'aab'
|
| 660 |
-
assert minlex(('bb', 'aaa', 'c', 'a'), key=len) == ('c', 'a', 'bb', 'aaa')
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
def test_ordered():
|
| 664 |
-
assert list(ordered((x, y), hash, default=False)) in [[x, y], [y, x]]
|
| 665 |
-
assert list(ordered((x, y), hash, default=False)) == \
|
| 666 |
-
list(ordered((y, x), hash, default=False))
|
| 667 |
-
assert list(ordered((x, y))) == [x, y]
|
| 668 |
-
|
| 669 |
-
seq, keys = [[[1, 2, 1], [0, 3, 1], [1, 1, 3], [2], [1]],
|
| 670 |
-
(lambda x: len(x), lambda x: sum(x))]
|
| 671 |
-
assert list(ordered(seq, keys, default=False, warn=False)) == \
|
| 672 |
-
[[1], [2], [1, 2, 1], [0, 3, 1], [1, 1, 3]]
|
| 673 |
-
raises(ValueError, lambda:
|
| 674 |
-
list(ordered(seq, keys, default=False, warn=True)))
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
def test_runs():
|
| 678 |
-
assert runs([]) == []
|
| 679 |
-
assert runs([1]) == [[1]]
|
| 680 |
-
assert runs([1, 1]) == [[1], [1]]
|
| 681 |
-
assert runs([1, 1, 2]) == [[1], [1, 2]]
|
| 682 |
-
assert runs([1, 2, 1]) == [[1, 2], [1]]
|
| 683 |
-
assert runs([2, 1, 1]) == [[2], [1], [1]]
|
| 684 |
-
from operator import lt
|
| 685 |
-
assert runs([2, 1, 1], lt) == [[2, 1], [1]]
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
def test_reshape():
|
| 689 |
-
seq = list(range(1, 9))
|
| 690 |
-
assert reshape(seq, [4]) == \
|
| 691 |
-
[[1, 2, 3, 4], [5, 6, 7, 8]]
|
| 692 |
-
assert reshape(seq, (4,)) == \
|
| 693 |
-
[(1, 2, 3, 4), (5, 6, 7, 8)]
|
| 694 |
-
assert reshape(seq, (2, 2)) == \
|
| 695 |
-
[(1, 2, 3, 4), (5, 6, 7, 8)]
|
| 696 |
-
assert reshape(seq, (2, [2])) == \
|
| 697 |
-
[(1, 2, [3, 4]), (5, 6, [7, 8])]
|
| 698 |
-
assert reshape(seq, ((2,), [2])) == \
|
| 699 |
-
[((1, 2), [3, 4]), ((5, 6), [7, 8])]
|
| 700 |
-
assert reshape(seq, (1, [2], 1)) == \
|
| 701 |
-
[(1, [2, 3], 4), (5, [6, 7], 8)]
|
| 702 |
-
assert reshape(tuple(seq), ([[1], 1, (2,)],)) == \
|
| 703 |
-
(([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],))
|
| 704 |
-
assert reshape(tuple(seq), ([1], 1, (2,))) == \
|
| 705 |
-
(([1], 2, (3, 4)), ([5], 6, (7, 8)))
|
| 706 |
-
assert reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) == \
|
| 707 |
-
[[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]]
|
| 708 |
-
raises(ValueError, lambda: reshape([0, 1], [-1]))
|
| 709 |
-
raises(ValueError, lambda: reshape([0, 1], [3]))
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
def test_uniq():
|
| 713 |
-
assert list(uniq(p for p in partitions(4))) == \
|
| 714 |
-
[{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}]
|
| 715 |
-
assert list(uniq(x % 2 for x in range(5))) == [0, 1]
|
| 716 |
-
assert list(uniq('a')) == ['a']
|
| 717 |
-
assert list(uniq('ababc')) == list('abc')
|
| 718 |
-
assert list(uniq([[1], [2, 1], [1]])) == [[1], [2, 1]]
|
| 719 |
-
assert list(uniq(permutations(i for i in [[1], 2, 2]))) == \
|
| 720 |
-
[([1], 2, 2), (2, [1], 2), (2, 2, [1])]
|
| 721 |
-
assert list(uniq([2, 3, 2, 4, [2], [1], [2], [3], [1]])) == \
|
| 722 |
-
[2, 3, 4, [2], [1], [3]]
|
| 723 |
-
f = [1]
|
| 724 |
-
raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
|
| 725 |
-
f = [[1]]
|
| 726 |
-
raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
def test_kbins():
|
| 730 |
-
assert len(list(kbins('1123', 2, ordered=1))) == 24
|
| 731 |
-
assert len(list(kbins('1123', 2, ordered=11))) == 36
|
| 732 |
-
assert len(list(kbins('1123', 2, ordered=10))) == 10
|
| 733 |
-
assert len(list(kbins('1123', 2, ordered=0))) == 5
|
| 734 |
-
assert len(list(kbins('1123', 2, ordered=None))) == 3
|
| 735 |
-
|
| 736 |
-
def test1():
|
| 737 |
-
for orderedval in [None, 0, 1, 10, 11]:
|
| 738 |
-
print('ordered =', orderedval)
|
| 739 |
-
for p in kbins([0, 0, 1], 2, ordered=orderedval):
|
| 740 |
-
print(' ', p)
|
| 741 |
-
assert capture(lambda : test1()) == dedent('''\
|
| 742 |
-
ordered = None
|
| 743 |
-
[[0], [0, 1]]
|
| 744 |
-
[[0, 0], [1]]
|
| 745 |
-
ordered = 0
|
| 746 |
-
[[0, 0], [1]]
|
| 747 |
-
[[0, 1], [0]]
|
| 748 |
-
ordered = 1
|
| 749 |
-
[[0], [0, 1]]
|
| 750 |
-
[[0], [1, 0]]
|
| 751 |
-
[[1], [0, 0]]
|
| 752 |
-
ordered = 10
|
| 753 |
-
[[0, 0], [1]]
|
| 754 |
-
[[1], [0, 0]]
|
| 755 |
-
[[0, 1], [0]]
|
| 756 |
-
[[0], [0, 1]]
|
| 757 |
-
ordered = 11
|
| 758 |
-
[[0], [0, 1]]
|
| 759 |
-
[[0, 0], [1]]
|
| 760 |
-
[[0], [1, 0]]
|
| 761 |
-
[[0, 1], [0]]
|
| 762 |
-
[[1], [0, 0]]
|
| 763 |
-
[[1, 0], [0]]\n''')
|
| 764 |
-
|
| 765 |
-
def test2():
|
| 766 |
-
for orderedval in [None, 0, 1, 10, 11]:
|
| 767 |
-
print('ordered =', orderedval)
|
| 768 |
-
for p in kbins(list(range(3)), 2, ordered=orderedval):
|
| 769 |
-
print(' ', p)
|
| 770 |
-
assert capture(lambda : test2()) == dedent('''\
|
| 771 |
-
ordered = None
|
| 772 |
-
[[0], [1, 2]]
|
| 773 |
-
[[0, 1], [2]]
|
| 774 |
-
ordered = 0
|
| 775 |
-
[[0, 1], [2]]
|
| 776 |
-
[[0, 2], [1]]
|
| 777 |
-
[[0], [1, 2]]
|
| 778 |
-
ordered = 1
|
| 779 |
-
[[0], [1, 2]]
|
| 780 |
-
[[0], [2, 1]]
|
| 781 |
-
[[1], [0, 2]]
|
| 782 |
-
[[1], [2, 0]]
|
| 783 |
-
[[2], [0, 1]]
|
| 784 |
-
[[2], [1, 0]]
|
| 785 |
-
ordered = 10
|
| 786 |
-
[[0, 1], [2]]
|
| 787 |
-
[[2], [0, 1]]
|
| 788 |
-
[[0, 2], [1]]
|
| 789 |
-
[[1], [0, 2]]
|
| 790 |
-
[[0], [1, 2]]
|
| 791 |
-
[[1, 2], [0]]
|
| 792 |
-
ordered = 11
|
| 793 |
-
[[0], [1, 2]]
|
| 794 |
-
[[0, 1], [2]]
|
| 795 |
-
[[0], [2, 1]]
|
| 796 |
-
[[0, 2], [1]]
|
| 797 |
-
[[1], [0, 2]]
|
| 798 |
-
[[1, 0], [2]]
|
| 799 |
-
[[1], [2, 0]]
|
| 800 |
-
[[1, 2], [0]]
|
| 801 |
-
[[2], [0, 1]]
|
| 802 |
-
[[2, 0], [1]]
|
| 803 |
-
[[2], [1, 0]]
|
| 804 |
-
[[2, 1], [0]]\n''')
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
def test_has_dups():
|
| 808 |
-
assert has_dups(set()) is False
|
| 809 |
-
assert has_dups(list(range(3))) is False
|
| 810 |
-
assert has_dups([1, 2, 1]) is True
|
| 811 |
-
assert has_dups([[1], [1]]) is True
|
| 812 |
-
assert has_dups([[1], [2]]) is False
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
def test__partition():
|
| 816 |
-
assert _partition('abcde', [1, 0, 1, 2, 0]) == [
|
| 817 |
-
['b', 'e'], ['a', 'c'], ['d']]
|
| 818 |
-
assert _partition('abcde', [1, 0, 1, 2, 0], 3) == [
|
| 819 |
-
['b', 'e'], ['a', 'c'], ['d']]
|
| 820 |
-
output = (3, [1, 0, 1, 2, 0])
|
| 821 |
-
assert _partition('abcde', *output) == [['b', 'e'], ['a', 'c'], ['d']]
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
def test_ordered_partitions():
|
| 825 |
-
from sympy.functions.combinatorial.numbers import nT
|
| 826 |
-
f = ordered_partitions
|
| 827 |
-
assert list(f(0, 1)) == [[]]
|
| 828 |
-
assert list(f(1, 0)) == [[]]
|
| 829 |
-
for i in range(1, 7):
|
| 830 |
-
for j in [None] + list(range(1, i)):
|
| 831 |
-
assert (
|
| 832 |
-
sum(1 for p in f(i, j, 1)) ==
|
| 833 |
-
sum(1 for p in f(i, j, 0)) ==
|
| 834 |
-
nT(i, j))
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
def test_rotations():
|
| 838 |
-
assert list(rotations('ab')) == [['a', 'b'], ['b', 'a']]
|
| 839 |
-
assert list(rotations(range(3))) == [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
|
| 840 |
-
assert list(rotations(range(3), dir=-1)) == [[0, 1, 2], [2, 0, 1], [1, 2, 0]]
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
def test_ibin():
|
| 844 |
-
assert ibin(3) == [1, 1]
|
| 845 |
-
assert ibin(3, 3) == [0, 1, 1]
|
| 846 |
-
assert ibin(3, str=True) == '11'
|
| 847 |
-
assert ibin(3, 3, str=True) == '011'
|
| 848 |
-
assert list(ibin(2, 'all')) == [(0, 0), (0, 1), (1, 0), (1, 1)]
|
| 849 |
-
assert list(ibin(2, '', str=True)) == ['00', '01', '10', '11']
|
| 850 |
-
raises(ValueError, lambda: ibin(-.5))
|
| 851 |
-
raises(ValueError, lambda: ibin(2, 1))
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
def test_iterable():
|
| 855 |
-
assert iterable(0) is False
|
| 856 |
-
assert iterable(1) is False
|
| 857 |
-
assert iterable(None) is False
|
| 858 |
-
|
| 859 |
-
class Test1(NotIterable):
|
| 860 |
-
pass
|
| 861 |
-
|
| 862 |
-
assert iterable(Test1()) is False
|
| 863 |
-
|
| 864 |
-
class Test2(NotIterable):
|
| 865 |
-
_iterable = True
|
| 866 |
-
|
| 867 |
-
assert iterable(Test2()) is True
|
| 868 |
-
|
| 869 |
-
class Test3:
|
| 870 |
-
pass
|
| 871 |
-
|
| 872 |
-
assert iterable(Test3()) is False
|
| 873 |
-
|
| 874 |
-
class Test4:
|
| 875 |
-
_iterable = True
|
| 876 |
-
|
| 877 |
-
assert iterable(Test4()) is True
|
| 878 |
-
|
| 879 |
-
class Test5:
|
| 880 |
-
def __iter__(self):
|
| 881 |
-
yield 1
|
| 882 |
-
|
| 883 |
-
assert iterable(Test5()) is True
|
| 884 |
-
|
| 885 |
-
class Test6(Test5):
|
| 886 |
-
_iterable = False
|
| 887 |
-
|
| 888 |
-
assert iterable(Test6()) is False
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
def test_sequence_partitions():
|
| 892 |
-
assert list(sequence_partitions([1], 1)) == [[[1]]]
|
| 893 |
-
assert list(sequence_partitions([1, 2], 1)) == [[[1, 2]]]
|
| 894 |
-
assert list(sequence_partitions([1, 2], 2)) == [[[1], [2]]]
|
| 895 |
-
assert list(sequence_partitions([1, 2, 3], 1)) == [[[1, 2, 3]]]
|
| 896 |
-
assert list(sequence_partitions([1, 2, 3], 2)) == \
|
| 897 |
-
[[[1], [2, 3]], [[1, 2], [3]]]
|
| 898 |
-
assert list(sequence_partitions([1, 2, 3], 3)) == [[[1], [2], [3]]]
|
| 899 |
-
|
| 900 |
-
# Exceptional cases
|
| 901 |
-
assert list(sequence_partitions([], 0)) == []
|
| 902 |
-
assert list(sequence_partitions([], 1)) == []
|
| 903 |
-
assert list(sequence_partitions([1, 2], 0)) == []
|
| 904 |
-
assert list(sequence_partitions([1, 2], 3)) == []
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
def test_sequence_partitions_empty():
|
| 908 |
-
assert list(sequence_partitions_empty([], 1)) == [[[]]]
|
| 909 |
-
assert list(sequence_partitions_empty([], 2)) == [[[], []]]
|
| 910 |
-
assert list(sequence_partitions_empty([], 3)) == [[[], [], []]]
|
| 911 |
-
assert list(sequence_partitions_empty([1], 1)) == [[[1]]]
|
| 912 |
-
assert list(sequence_partitions_empty([1], 2)) == [[[], [1]], [[1], []]]
|
| 913 |
-
assert list(sequence_partitions_empty([1], 3)) == \
|
| 914 |
-
[[[], [], [1]], [[], [1], []], [[1], [], []]]
|
| 915 |
-
assert list(sequence_partitions_empty([1, 2], 1)) == [[[1, 2]]]
|
| 916 |
-
assert list(sequence_partitions_empty([1, 2], 2)) == \
|
| 917 |
-
[[[], [1, 2]], [[1], [2]], [[1, 2], []]]
|
| 918 |
-
assert list(sequence_partitions_empty([1, 2], 3)) == [
|
| 919 |
-
[[], [], [1, 2]], [[], [1], [2]], [[], [1, 2], []],
|
| 920 |
-
[[1], [], [2]], [[1], [2], []], [[1, 2], [], []]
|
| 921 |
-
]
|
| 922 |
-
assert list(sequence_partitions_empty([1, 2, 3], 1)) == [[[1, 2, 3]]]
|
| 923 |
-
assert list(sequence_partitions_empty([1, 2, 3], 2)) == \
|
| 924 |
-
[[[], [1, 2, 3]], [[1], [2, 3]], [[1, 2], [3]], [[1, 2, 3], []]]
|
| 925 |
-
assert list(sequence_partitions_empty([1, 2, 3], 3)) == [
|
| 926 |
-
[[], [], [1, 2, 3]], [[], [1], [2, 3]],
|
| 927 |
-
[[], [1, 2], [3]], [[], [1, 2, 3], []],
|
| 928 |
-
[[1], [], [2, 3]], [[1], [2], [3]],
|
| 929 |
-
[[1], [2, 3], []], [[1, 2], [], [3]],
|
| 930 |
-
[[1, 2], [3], []], [[1, 2, 3], [], []]
|
| 931 |
-
]
|
| 932 |
-
|
| 933 |
-
# Exceptional cases
|
| 934 |
-
assert list(sequence_partitions([], 0)) == []
|
| 935 |
-
assert list(sequence_partitions([1], 0)) == []
|
| 936 |
-
assert list(sequence_partitions([1, 2], 0)) == []
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
def test_signed_permutations():
|
| 940 |
-
ans = [(0, 1, 1), (0, -1, 1), (0, 1, -1), (0, -1, -1),
|
| 941 |
-
(1, 0, 1), (-1, 0, 1), (1, 0, -1), (-1, 0, -1),
|
| 942 |
-
(1, 1, 0), (-1, 1, 0), (1, -1, 0), (-1, -1, 0)]
|
| 943 |
-
assert list(signed_permutations((0, 1, 1))) == ans
|
| 944 |
-
assert list(signed_permutations((1, 0, 1))) == ans
|
| 945 |
-
assert list(signed_permutations((1, 1, 0))) == ans
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_lambdify.py
DELETED
|
@@ -1,2263 +0,0 @@
|
|
| 1 |
-
from itertools import product
|
| 2 |
-
import math
|
| 3 |
-
import inspect
|
| 4 |
-
import linecache
|
| 5 |
-
import gc
|
| 6 |
-
|
| 7 |
-
import mpmath
|
| 8 |
-
import cmath
|
| 9 |
-
|
| 10 |
-
from sympy.testing.pytest import raises, warns_deprecated_sympy
|
| 11 |
-
from sympy.concrete.summations import Sum
|
| 12 |
-
from sympy.core.function import (Function, Lambda, diff)
|
| 13 |
-
from sympy.core.numbers import (E, Float, I, Rational, all_close, oo, pi)
|
| 14 |
-
from sympy.core.relational import Eq
|
| 15 |
-
from sympy.core.singleton import S
|
| 16 |
-
from sympy.core.symbol import (Dummy, symbols)
|
| 17 |
-
from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
|
| 18 |
-
from sympy.functions.combinatorial.numbers import bernoulli, harmonic
|
| 19 |
-
from sympy.functions.elementary.complexes import Abs, sign
|
| 20 |
-
from sympy.functions.elementary.exponential import exp, log
|
| 21 |
-
from sympy.functions.elementary.hyperbolic import asinh,acosh,atanh
|
| 22 |
-
from sympy.functions.elementary.integers import floor
|
| 23 |
-
from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt)
|
| 24 |
-
from sympy.functions.elementary.piecewise import Piecewise
|
| 25 |
-
from sympy.functions.elementary.trigonometric import (asin, acos, atan, cos, cot, sin,
|
| 26 |
-
sinc, tan)
|
| 27 |
-
from sympy.functions import sinh,cosh,tanh
|
| 28 |
-
from sympy.functions.special.bessel import (besseli, besselj, besselk, bessely, jn, yn)
|
| 29 |
-
from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized)
|
| 30 |
-
from sympy.functions.special.delta_functions import (Heaviside)
|
| 31 |
-
from sympy.functions.special.error_functions import (Ei, erf, erfc, fresnelc, fresnels, Si, Ci)
|
| 32 |
-
from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, polygamma)
|
| 33 |
-
from sympy.functions.special.zeta_functions import zeta
|
| 34 |
-
from sympy.integrals.integrals import Integral
|
| 35 |
-
from sympy.logic.boolalg import (And, false, ITE, Not, Or, true)
|
| 36 |
-
from sympy.matrices.expressions.dotproduct import DotProduct
|
| 37 |
-
from sympy.simplify.cse_main import cse
|
| 38 |
-
from sympy.tensor.array import derive_by_array, Array
|
| 39 |
-
from sympy.tensor.array.expressions import ArraySymbol
|
| 40 |
-
from sympy.tensor.indexed import IndexedBase, Idx
|
| 41 |
-
from sympy.utilities.lambdify import lambdify
|
| 42 |
-
from sympy.utilities.iterables import numbered_symbols
|
| 43 |
-
from sympy.vector import CoordSys3D
|
| 44 |
-
from sympy.core.expr import UnevaluatedExpr
|
| 45 |
-
from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, log10, hypot, isnan, isinf
|
| 46 |
-
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, amin, amax, minimum, maximum
|
| 47 |
-
from sympy.codegen.scipy_nodes import cosm1, powm1
|
| 48 |
-
from sympy.functions.elementary.complexes import re, im, arg
|
| 49 |
-
from sympy.functions.special.polynomials import \
|
| 50 |
-
chebyshevt, chebyshevu, legendre, hermite, laguerre, gegenbauer, \
|
| 51 |
-
assoc_legendre, assoc_laguerre, jacobi
|
| 52 |
-
from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
|
| 53 |
-
from sympy.printing.codeprinter import PrintMethodNotImplementedError
|
| 54 |
-
from sympy.printing.lambdarepr import LambdaPrinter
|
| 55 |
-
from sympy.printing.numpy import NumPyPrinter
|
| 56 |
-
from sympy.utilities.lambdify import implemented_function, lambdastr
|
| 57 |
-
from sympy.testing.pytest import skip
|
| 58 |
-
from sympy.utilities.decorator import conserve_mpmath_dps
|
| 59 |
-
from sympy.utilities.exceptions import ignore_warnings
|
| 60 |
-
from sympy.external import import_module
|
| 61 |
-
from sympy.functions.special.gamma_functions import uppergamma, lowergamma
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
import sympy
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
MutableDenseMatrix = Matrix
|
| 68 |
-
|
| 69 |
-
numpy = import_module('numpy')
|
| 70 |
-
scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']})
|
| 71 |
-
numexpr = import_module('numexpr')
|
| 72 |
-
tensorflow = import_module('tensorflow')
|
| 73 |
-
cupy = import_module('cupy')
|
| 74 |
-
jax = import_module('jax')
|
| 75 |
-
numba = import_module('numba')
|
| 76 |
-
|
| 77 |
-
if tensorflow:
|
| 78 |
-
# Hide Tensorflow warnings
|
| 79 |
-
import os
|
| 80 |
-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 81 |
-
|
| 82 |
-
w, x, y, z = symbols('w,x,y,z')
|
| 83 |
-
|
| 84 |
-
#================== Test different arguments =======================
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def test_no_args():
|
| 88 |
-
f = lambdify([], 1)
|
| 89 |
-
raises(TypeError, lambda: f(-1))
|
| 90 |
-
assert f() == 1
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def test_single_arg():
|
| 94 |
-
f = lambdify(x, 2*x)
|
| 95 |
-
assert f(1) == 2
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def test_list_args():
|
| 99 |
-
f = lambdify([x, y], x + y)
|
| 100 |
-
assert f(1, 2) == 3
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def test_nested_args():
|
| 104 |
-
f1 = lambdify([[w, x]], [w, x])
|
| 105 |
-
assert f1([91, 2]) == [91, 2]
|
| 106 |
-
raises(TypeError, lambda: f1(1, 2))
|
| 107 |
-
|
| 108 |
-
f2 = lambdify([(w, x), (y, z)], [w, x, y, z])
|
| 109 |
-
assert f2((18, 12), (73, 4)) == [18, 12, 73, 4]
|
| 110 |
-
raises(TypeError, lambda: f2(3, 4))
|
| 111 |
-
|
| 112 |
-
f3 = lambdify([w, [[[x]], y], z], [w, x, y, z])
|
| 113 |
-
assert f3(10, [[[52]], 31], 44) == [10, 52, 31, 44]
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def test_str_args():
|
| 117 |
-
f = lambdify('x,y,z', 'z,y,x')
|
| 118 |
-
assert f(3, 2, 1) == (1, 2, 3)
|
| 119 |
-
assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
|
| 120 |
-
# make sure correct number of args required
|
| 121 |
-
raises(TypeError, lambda: f(0))
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def test_own_namespace_1():
|
| 125 |
-
myfunc = lambda x: 1
|
| 126 |
-
f = lambdify(x, sin(x), {"sin": myfunc})
|
| 127 |
-
assert f(0.1) == 1
|
| 128 |
-
assert f(100) == 1
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def test_own_namespace_2():
|
| 132 |
-
def myfunc(x):
|
| 133 |
-
return 1
|
| 134 |
-
f = lambdify(x, sin(x), {'sin': myfunc})
|
| 135 |
-
assert f(0.1) == 1
|
| 136 |
-
assert f(100) == 1
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def test_own_module():
|
| 140 |
-
f = lambdify(x, sin(x), math)
|
| 141 |
-
assert f(0) == 0.0
|
| 142 |
-
|
| 143 |
-
p, q, r = symbols("p q r", real=True)
|
| 144 |
-
ae = abs(exp(p+UnevaluatedExpr(q+r)))
|
| 145 |
-
f = lambdify([p, q, r], [ae, ae], modules=math)
|
| 146 |
-
results = f(1.0, 1e18, -1e18)
|
| 147 |
-
refvals = [math.exp(1.0)]*2
|
| 148 |
-
for res, ref in zip(results, refvals):
|
| 149 |
-
assert abs((res-ref)/ref) < 1e-15
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def test_bad_args():
|
| 153 |
-
# no vargs given
|
| 154 |
-
raises(TypeError, lambda: lambdify(1))
|
| 155 |
-
# same with vector exprs
|
| 156 |
-
raises(TypeError, lambda: lambdify([1, 2]))
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def test_atoms():
|
| 160 |
-
# Non-Symbol atoms should not be pulled out from the expression namespace
|
| 161 |
-
f = lambdify(x, pi + x, {"pi": 3.14})
|
| 162 |
-
assert f(0) == 3.14
|
| 163 |
-
f = lambdify(x, I + x, {"I": 1j})
|
| 164 |
-
assert f(1) == 1 + 1j
|
| 165 |
-
|
| 166 |
-
#================== Test different modules =========================
|
| 167 |
-
|
| 168 |
-
# high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
@conserve_mpmath_dps
|
| 172 |
-
def test_sympy_lambda():
|
| 173 |
-
mpmath.mp.dps = 50
|
| 174 |
-
sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
|
| 175 |
-
f = lambdify(x, sin(x), "sympy")
|
| 176 |
-
assert f(x) == sin(x)
|
| 177 |
-
prec = 1e-15
|
| 178 |
-
assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec
|
| 179 |
-
# arctan is in numpy module and should not be available
|
| 180 |
-
# The arctan below gives NameError. What is this supposed to test?
|
| 181 |
-
# raises(NameError, lambda: lambdify(x, arctan(x), "sympy"))
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
@conserve_mpmath_dps
|
| 185 |
-
def test_math_lambda():
|
| 186 |
-
mpmath.mp.dps = 50
|
| 187 |
-
sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
|
| 188 |
-
f = lambdify(x, sin(x), "math")
|
| 189 |
-
prec = 1e-15
|
| 190 |
-
assert -prec < f(0.2) - sin02 < prec
|
| 191 |
-
raises(TypeError, lambda: f(x))
|
| 192 |
-
# if this succeeds, it can't be a Python math function
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
@conserve_mpmath_dps
|
| 196 |
-
def test_mpmath_lambda():
|
| 197 |
-
mpmath.mp.dps = 50
|
| 198 |
-
sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
|
| 199 |
-
f = lambdify(x, sin(x), "mpmath")
|
| 200 |
-
prec = 1e-49 # mpmath precision is around 50 decimal places
|
| 201 |
-
assert -prec < f(mpmath.mpf("0.2")) - sin02 < prec
|
| 202 |
-
raises(TypeError, lambda: f(x))
|
| 203 |
-
# if this succeeds, it can't be a mpmath function
|
| 204 |
-
|
| 205 |
-
ref2 = (mpmath.mpf("1e-30")
|
| 206 |
-
- mpmath.mpf("1e-45")/2
|
| 207 |
-
+ 5*mpmath.mpf("1e-60")/6
|
| 208 |
-
- 3*mpmath.mpf("1e-75")/4
|
| 209 |
-
+ 33*mpmath.mpf("1e-90")/40
|
| 210 |
-
)
|
| 211 |
-
f2a = lambdify((x, y), x**y - 1, "mpmath")
|
| 212 |
-
f2b = lambdify((x, y), powm1(x, y), "mpmath")
|
| 213 |
-
f2c = lambdify((x,), expm1(x*log1p(x)), "mpmath")
|
| 214 |
-
ans2a = f2a(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
|
| 215 |
-
ans2b = f2b(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
|
| 216 |
-
ans2c = f2c(mpmath.mpf("1e-15"))
|
| 217 |
-
assert abs(ans2a - ref2) < 1e-51
|
| 218 |
-
assert abs(ans2b - ref2) < 1e-67
|
| 219 |
-
assert abs(ans2c - ref2) < 1e-80
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
@conserve_mpmath_dps
|
| 223 |
-
def test_number_precision():
|
| 224 |
-
mpmath.mp.dps = 50
|
| 225 |
-
sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
|
| 226 |
-
f = lambdify(x, sin02, "mpmath")
|
| 227 |
-
prec = 1e-49 # mpmath precision is around 50 decimal places
|
| 228 |
-
assert -prec < f(0) - sin02 < prec
|
| 229 |
-
|
| 230 |
-
@conserve_mpmath_dps
|
| 231 |
-
def test_mpmath_precision():
|
| 232 |
-
mpmath.mp.dps = 100
|
| 233 |
-
assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100))
|
| 234 |
-
|
| 235 |
-
#================== Test Translations ==============================
|
| 236 |
-
# We can only check if all translated functions are valid. It has to be checked
|
| 237 |
-
# by hand if they are complete.
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def test_math_transl():
|
| 241 |
-
from sympy.utilities.lambdify import MATH_TRANSLATIONS
|
| 242 |
-
for sym, mat in MATH_TRANSLATIONS.items():
|
| 243 |
-
assert sym in sympy.__dict__
|
| 244 |
-
assert mat in math.__dict__
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def test_mpmath_transl():
|
| 248 |
-
from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
|
| 249 |
-
for sym, mat in MPMATH_TRANSLATIONS.items():
|
| 250 |
-
assert sym in sympy.__dict__ or sym == 'Matrix'
|
| 251 |
-
assert mat in mpmath.__dict__
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def test_numpy_transl():
|
| 255 |
-
if not numpy:
|
| 256 |
-
skip("numpy not installed.")
|
| 257 |
-
|
| 258 |
-
from sympy.utilities.lambdify import NUMPY_TRANSLATIONS
|
| 259 |
-
for sym, nump in NUMPY_TRANSLATIONS.items():
|
| 260 |
-
assert sym in sympy.__dict__
|
| 261 |
-
assert nump in numpy.__dict__
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
def test_scipy_transl():
|
| 265 |
-
if not scipy:
|
| 266 |
-
skip("scipy not installed.")
|
| 267 |
-
|
| 268 |
-
from sympy.utilities.lambdify import SCIPY_TRANSLATIONS
|
| 269 |
-
for sym, scip in SCIPY_TRANSLATIONS.items():
|
| 270 |
-
assert sym in sympy.__dict__
|
| 271 |
-
assert scip in scipy.__dict__ or scip in scipy.special.__dict__
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
def test_numpy_translation_abs():
|
| 275 |
-
if not numpy:
|
| 276 |
-
skip("numpy not installed.")
|
| 277 |
-
|
| 278 |
-
f = lambdify(x, Abs(x), "numpy")
|
| 279 |
-
assert f(-1) == 1
|
| 280 |
-
assert f(1) == 1
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
def test_numexpr_printer():
|
| 284 |
-
if not numexpr:
|
| 285 |
-
skip("numexpr not installed.")
|
| 286 |
-
|
| 287 |
-
# if translation/printing is done incorrectly then evaluating
|
| 288 |
-
# a lambdified numexpr expression will throw an exception
|
| 289 |
-
from sympy.printing.lambdarepr import NumExprPrinter
|
| 290 |
-
|
| 291 |
-
blacklist = ('where', 'complex', 'contains')
|
| 292 |
-
arg_tuple = (x, y, z) # some functions take more than one argument
|
| 293 |
-
for sym in NumExprPrinter._numexpr_functions.keys():
|
| 294 |
-
if sym in blacklist:
|
| 295 |
-
continue
|
| 296 |
-
ssym = S(sym)
|
| 297 |
-
if hasattr(ssym, '_nargs'):
|
| 298 |
-
nargs = ssym._nargs[0]
|
| 299 |
-
else:
|
| 300 |
-
nargs = 1
|
| 301 |
-
args = arg_tuple[:nargs]
|
| 302 |
-
f = lambdify(args, ssym(*args), modules='numexpr')
|
| 303 |
-
assert f(*(1, )*nargs) is not None
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def test_cmath_sqrt():
|
| 307 |
-
f = lambdify(x, sqrt(x), "cmath")
|
| 308 |
-
assert f(0) == 0
|
| 309 |
-
assert f(1) == 1
|
| 310 |
-
assert f(4) == 2
|
| 311 |
-
assert abs(f(2) - 1.414) < 0.001
|
| 312 |
-
assert f(-1) == 1j
|
| 313 |
-
assert f(-4) == 2j
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
def test_cmath_log():
|
| 317 |
-
f = lambdify(x, log(x), "cmath")
|
| 318 |
-
assert abs(f(1) - 0) < 1e-15
|
| 319 |
-
assert abs(f(cmath.e) - 1) < 1e-15
|
| 320 |
-
assert abs(f(-1) - cmath.log(-1)) < 1e-15
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
def test_cmath_sinh():
|
| 324 |
-
f = lambdify(x, sinh(x), "cmath")
|
| 325 |
-
assert abs(f(0) - cmath.sinh(0)) < 1e-15
|
| 326 |
-
assert abs(f(pi) - cmath.sinh(pi)) < 1e-15
|
| 327 |
-
assert abs(f(-pi) - cmath.sinh(-pi)) < 1e-15
|
| 328 |
-
assert abs(f(1j) - cmath.sinh(1j)) < 1e-15
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
def test_cmath_cosh():
|
| 332 |
-
f = lambdify(x, cosh(x), "cmath")
|
| 333 |
-
assert abs(f(0) - cmath.cosh(0)) < 1e-15
|
| 334 |
-
assert abs(f(pi) - cmath.cosh(pi)) < 1e-15
|
| 335 |
-
assert abs(f(-pi) - cmath.cosh(-pi)) < 1e-15
|
| 336 |
-
assert abs(f(1j) - cmath.cosh(1j)) < 1e-15
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
def test_cmath_tanh():
|
| 340 |
-
f = lambdify(x, tanh(x), "cmath")
|
| 341 |
-
assert abs(f(0) - cmath.tanh(0)) < 1e-15
|
| 342 |
-
assert abs(f(pi) - cmath.tanh(pi)) < 1e-15
|
| 343 |
-
assert abs(f(-pi) - cmath.tanh(-pi)) < 1e-15
|
| 344 |
-
assert abs(f(1j) - cmath.tanh(1j)) < 1e-15
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
def test_cmath_sin():
|
| 348 |
-
f = lambdify(x, sin(x), "cmath")
|
| 349 |
-
assert abs(f(0) - cmath.sin(0)) < 1e-15
|
| 350 |
-
assert abs(f(pi) - cmath.sin(pi)) < 1e-15
|
| 351 |
-
assert abs(f(-pi) - cmath.sin(-pi)) < 1e-15
|
| 352 |
-
assert abs(f(1j) - cmath.sin(1j)) < 1e-15
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
def test_cmath_cos():
|
| 356 |
-
f = lambdify(x, cos(x), "cmath")
|
| 357 |
-
assert abs(f(0) - cmath.cos(0)) < 1e-15
|
| 358 |
-
assert abs(f(pi) - cmath.cos(pi)) < 1e-15
|
| 359 |
-
assert abs(f(-pi) - cmath.cos(-pi)) < 1e-15
|
| 360 |
-
assert abs(f(1j) - cmath.cos(1j)) < 1e-15
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
def test_cmath_tan():
|
| 364 |
-
f = lambdify(x, tan(x), "cmath")
|
| 365 |
-
assert abs(f(0) - cmath.tan(0)) < 1e-15
|
| 366 |
-
assert abs(f(1j) - cmath.tan(1j)) < 1e-15
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
def test_cmath_asin():
|
| 370 |
-
f = lambdify(x, asin(x), "cmath")
|
| 371 |
-
assert abs(f(0) - cmath.asin(0)) < 1e-15
|
| 372 |
-
assert abs(f(1) - cmath.asin(1)) < 1e-15
|
| 373 |
-
assert abs(f(-1) - cmath.asin(-1)) < 1e-15
|
| 374 |
-
assert abs(f(2) - cmath.asin(2)) < 1e-15
|
| 375 |
-
assert abs(f(1j) - cmath.asin(1j)) < 1e-15
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
def test_cmath_acos():
|
| 379 |
-
f = lambdify(x, acos(x), "cmath")
|
| 380 |
-
assert abs(f(1) - cmath.acos(1)) < 1e-15
|
| 381 |
-
assert abs(f(-1) - cmath.acos(-1)) < 1e-15
|
| 382 |
-
assert abs(f(2) - cmath.acos(2)) < 1e-15
|
| 383 |
-
assert abs(f(1j) - cmath.acos(1j)) < 1e-15
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
def test_cmath_atan():
|
| 387 |
-
f = lambdify(x, atan(x), "cmath")
|
| 388 |
-
assert abs(f(0) - cmath.atan(0)) < 1e-15
|
| 389 |
-
assert abs(f(1) - cmath.atan(1)) < 1e-15
|
| 390 |
-
assert abs(f(-1) - cmath.atan(-1)) < 1e-15
|
| 391 |
-
assert abs(f(2) - cmath.atan(2)) < 1e-15
|
| 392 |
-
assert abs(f(2j) - cmath.atan(2j)) < 1e-15
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
def test_cmath_asinh():
|
| 396 |
-
f = lambdify(x, asinh(x), "cmath")
|
| 397 |
-
assert abs(f(0) - cmath.asinh(0)) < 1e-15
|
| 398 |
-
assert abs(f(1) - cmath.asinh(1)) < 1e-15
|
| 399 |
-
assert abs(f(-1) - cmath.asinh(-1)) < 1e-15
|
| 400 |
-
assert abs(f(2) - cmath.asinh(2)) < 1e-15
|
| 401 |
-
assert abs(f(2j) - cmath.asinh(2j)) < 1e-15
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
def test_cmath_acosh():
|
| 405 |
-
f = lambdify(x, acosh(x), "cmath")
|
| 406 |
-
assert abs(f(1) - cmath.acosh(1)) < 1e-15
|
| 407 |
-
assert abs(f(2) - cmath.acosh(2)) < 1e-15
|
| 408 |
-
assert abs(f(-1) - cmath.acosh(-1)) < 1e-15
|
| 409 |
-
assert abs(f(2j) - cmath.acosh(2j)) < 1e-15
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
def test_cmath_atanh():
|
| 413 |
-
f = lambdify(x, atanh(x), "cmath")
|
| 414 |
-
assert abs(f(0) - cmath.atanh(0)) < 1e-15
|
| 415 |
-
assert abs(f(0.5) - cmath.atanh(0.5)) < 1e-15
|
| 416 |
-
assert abs(f(-0.5) - cmath.atanh(-0.5)) < 1e-15
|
| 417 |
-
assert abs(f(2) - cmath.atanh(2)) < 1e-15
|
| 418 |
-
assert abs(f(-2) - cmath.atanh(-2)) < 1e-15
|
| 419 |
-
assert abs(f(2j) - cmath.atanh(2j)) < 1e-15
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
def test_cmath_complex_identities():
|
| 423 |
-
# Define symbol
|
| 424 |
-
z = symbols('z')
|
| 425 |
-
|
| 426 |
-
# Trigonometric identity using re(z) and im(z)
|
| 427 |
-
expr = cos(z) - cos(re(z)) * cosh(im(z)) + I * sin(re(z)) * sinh(im(z))
|
| 428 |
-
func = lambdify([z], expr, modules=["cmath", "math"])
|
| 429 |
-
hpi = math.pi / 2
|
| 430 |
-
assert abs(func(hpi + 1j * hpi)) < 4e-16
|
| 431 |
-
|
| 432 |
-
# Euler's Formula: e^(i*z) = cos(z) + i*sin(z)
|
| 433 |
-
func = lambdify([z], exp(I * z) - (cos(z) + I * sin(z)), modules=["cmath", "math"])
|
| 434 |
-
assert abs(func(hpi)) < 4e-16
|
| 435 |
-
|
| 436 |
-
# Exponential Identity: e^z = e^(Re(z)) * (cos(Im(z)) + i*sin(Im(z)))
|
| 437 |
-
func_exp = lambdify([z], exp(z) - exp(re(z)) * (cos(im(z)) + I * sin(im(z))),
|
| 438 |
-
modules=["cmath", "math"])
|
| 439 |
-
assert abs(func_exp(hpi + 1j * hpi)) < 4e-16
|
| 440 |
-
|
| 441 |
-
# Complex Cosine Identity: cos(z) = cos(Re(z)) * cosh(Im(z)) - i*sin(Re(z)) * sinh(Im(z))
|
| 442 |
-
func_cos = lambdify([z], cos(z) - (cos(re(z)) * cosh(im(z)) - I * sin(re(z)) * sinh(im(z))),
|
| 443 |
-
modules=["cmath", "math"])
|
| 444 |
-
assert abs(func_cos(hpi + 1j * hpi)) < 4e-16
|
| 445 |
-
|
| 446 |
-
# Complex Sine Identity: sin(z) = sin(Re(z)) * cosh(Im(z)) + i*cos(Re(z)) * sinh(Im(z))
|
| 447 |
-
func_sin = lambdify([z], sin(z) - (sin(re(z)) * cosh(im(z)) + I * cos(re(z)) * sinh(im(z))),
|
| 448 |
-
modules=["cmath", "math"])
|
| 449 |
-
assert abs(func_sin(hpi + 1j * hpi)) < 4e-16
|
| 450 |
-
|
| 451 |
-
# Complex Hyperbolic Cosine Identity: cosh(z) = cosh(Re(z)) * cos(Im(z)) + i*sinh(Re(z)) * sin(Im(z))
|
| 452 |
-
func_cosh_1 = lambdify([z], cosh(z) - (cosh(re(z)) * cos(im(z)) + I * sinh(re(z)) * sin(im(z))),
|
| 453 |
-
modules=["cmath", "math"])
|
| 454 |
-
assert abs(func_cosh_1(hpi + 1j * hpi)) < 4e-16
|
| 455 |
-
|
| 456 |
-
# Complex Hyperbolic Sine Identity: sinh(z) = sinh(Re(z)) * cos(Im(z)) + i*cosh(Re(z)) * sin(Im(z))
|
| 457 |
-
func_sinh = lambdify([z], sinh(z) - (sinh(re(z)) * cos(im(z)) + I * cosh(re(z)) * sin(im(z))),
|
| 458 |
-
modules=["cmath", "math"])
|
| 459 |
-
assert abs(func_sinh(hpi + 1j * hpi)) < 4e-16
|
| 460 |
-
|
| 461 |
-
# cosh(z) = (e^z + e^(-z)) / 2
|
| 462 |
-
func_cosh_2 = lambdify([z], cosh(z) - (exp(z) + exp(-z)) / 2, modules=["cmath", "math"])
|
| 463 |
-
assert abs(func_cosh_2(hpi)) < 4e-16
|
| 464 |
-
|
| 465 |
-
# Additional expressions testing log and exp with real and imaginary parts
|
| 466 |
-
expr1 = log(re(z)) + log(im(z)) - log(re(z) * im(z))
|
| 467 |
-
expr2 = exp(re(z)) * exp(im(z) * I) - exp(z)
|
| 468 |
-
expr3 = log(exp(re(z))) - re(z)
|
| 469 |
-
expr4 = exp(log(re(z))) - re(z)
|
| 470 |
-
expr5 = log(exp(re(z) + im(z))) - (re(z) + im(z))
|
| 471 |
-
expr6 = exp(log(re(z) + im(z))) - (re(z) + im(z))
|
| 472 |
-
func1 = lambdify([z], expr1, modules=["cmath", "math"])
|
| 473 |
-
func2 = lambdify([z], expr2, modules=["cmath", "math"])
|
| 474 |
-
func3 = lambdify([z], expr3, modules=["cmath", "math"])
|
| 475 |
-
func4 = lambdify([z], expr4, modules=["cmath", "math"])
|
| 476 |
-
func5 = lambdify([z], expr5, modules=["cmath", "math"])
|
| 477 |
-
func6 = lambdify([z], expr6, modules=["cmath", "math"])
|
| 478 |
-
test_value = 3 + 4j
|
| 479 |
-
assert abs(func1(test_value)) < 4e-16
|
| 480 |
-
assert abs(func2(test_value)) < 4e-16
|
| 481 |
-
assert abs(func3(test_value)) < 4e-16
|
| 482 |
-
assert abs(func4(test_value)) < 4e-16
|
| 483 |
-
assert abs(func5(test_value)) < 4e-16
|
| 484 |
-
assert abs(func6(test_value)) < 4e-16
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
def test_issue_9334():
|
| 488 |
-
if not numexpr:
|
| 489 |
-
skip("numexpr not installed.")
|
| 490 |
-
if not numpy:
|
| 491 |
-
skip("numpy not installed.")
|
| 492 |
-
expr = S('b*a - sqrt(a**2)')
|
| 493 |
-
a, b = sorted(expr.free_symbols, key=lambda s: s.name)
|
| 494 |
-
func_numexpr = lambdify((a,b), expr, modules=[numexpr], dummify=False)
|
| 495 |
-
foo, bar = numpy.random.random((2, 4))
|
| 496 |
-
func_numexpr(foo, bar)
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
def test_issue_12984():
|
| 500 |
-
if not numexpr:
|
| 501 |
-
skip("numexpr not installed.")
|
| 502 |
-
func_numexpr = lambdify((x,y,z), Piecewise((y, x >= 0), (z, x > -1)), numexpr)
|
| 503 |
-
with ignore_warnings(RuntimeWarning):
|
| 504 |
-
assert func_numexpr(1, 24, 42) == 24
|
| 505 |
-
assert str(func_numexpr(-1, 24, 42)) == 'nan'
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
def test_empty_modules():
|
| 509 |
-
x, y = symbols('x y')
|
| 510 |
-
expr = -(x % y)
|
| 511 |
-
|
| 512 |
-
no_modules = lambdify([x, y], expr)
|
| 513 |
-
empty_modules = lambdify([x, y], expr, modules=[])
|
| 514 |
-
assert no_modules(3, 7) == empty_modules(3, 7)
|
| 515 |
-
assert no_modules(3, 7) == -3
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
def test_exponentiation():
|
| 519 |
-
f = lambdify(x, x**2)
|
| 520 |
-
assert f(-1) == 1
|
| 521 |
-
assert f(0) == 0
|
| 522 |
-
assert f(1) == 1
|
| 523 |
-
assert f(-2) == 4
|
| 524 |
-
assert f(2) == 4
|
| 525 |
-
assert f(2.5) == 6.25
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
def test_sqrt():
|
| 529 |
-
f = lambdify(x, sqrt(x))
|
| 530 |
-
assert f(0) == 0.0
|
| 531 |
-
assert f(1) == 1.0
|
| 532 |
-
assert f(4) == 2.0
|
| 533 |
-
assert abs(f(2) - 1.414) < 0.001
|
| 534 |
-
assert f(6.25) == 2.5
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
def test_trig():
|
| 538 |
-
f = lambdify([x], [cos(x), sin(x)], 'math')
|
| 539 |
-
d = f(pi)
|
| 540 |
-
prec = 1e-11
|
| 541 |
-
assert -prec < d[0] + 1 < prec
|
| 542 |
-
assert -prec < d[1] < prec
|
| 543 |
-
d = f(3.14159)
|
| 544 |
-
prec = 1e-5
|
| 545 |
-
assert -prec < d[0] + 1 < prec
|
| 546 |
-
assert -prec < d[1] < prec
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
def test_integral():
|
| 550 |
-
if numpy and not scipy:
|
| 551 |
-
skip("scipy not installed.")
|
| 552 |
-
f = Lambda(x, exp(-x**2))
|
| 553 |
-
l = lambdify(y, Integral(f(x), (x, y, oo)))
|
| 554 |
-
d = l(-oo)
|
| 555 |
-
assert 1.77245385 < d < 1.772453851
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
def test_double_integral():
|
| 559 |
-
if numpy and not scipy:
|
| 560 |
-
skip("scipy not installed.")
|
| 561 |
-
# example from http://mpmath.org/doc/current/calculus/integration.html
|
| 562 |
-
i = Integral(1/(1 - x**2*y**2), (x, 0, 1), (y, 0, z))
|
| 563 |
-
l = lambdify([z], i)
|
| 564 |
-
d = l(1)
|
| 565 |
-
assert 1.23370055 < d < 1.233700551
|
| 566 |
-
|
| 567 |
-
def test_spherical_bessel():
|
| 568 |
-
if numpy and not scipy:
|
| 569 |
-
skip("scipy not installed.")
|
| 570 |
-
test_point = 4.2 #randomly selected
|
| 571 |
-
x = symbols("x")
|
| 572 |
-
jtest = jn(2, x)
|
| 573 |
-
assert abs(lambdify(x,jtest)(test_point) -
|
| 574 |
-
jtest.subs(x,test_point).evalf()) < 1e-8
|
| 575 |
-
ytest = yn(2, x)
|
| 576 |
-
assert abs(lambdify(x,ytest)(test_point) -
|
| 577 |
-
ytest.subs(x,test_point).evalf()) < 1e-8
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
#================== Test vectors ===================================
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
def test_vector_simple():
|
| 584 |
-
f = lambdify((x, y, z), (z, y, x))
|
| 585 |
-
assert f(3, 2, 1) == (1, 2, 3)
|
| 586 |
-
assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
|
| 587 |
-
# make sure correct number of args required
|
| 588 |
-
raises(TypeError, lambda: f(0))
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
def test_vector_discontinuous():
|
| 592 |
-
f = lambdify(x, (-1/x, 1/x))
|
| 593 |
-
raises(ZeroDivisionError, lambda: f(0))
|
| 594 |
-
assert f(1) == (-1.0, 1.0)
|
| 595 |
-
assert f(2) == (-0.5, 0.5)
|
| 596 |
-
assert f(-2) == (0.5, -0.5)
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
def test_trig_symbolic():
|
| 600 |
-
f = lambdify([x], [cos(x), sin(x)], 'math')
|
| 601 |
-
d = f(pi)
|
| 602 |
-
assert abs(d[0] + 1) < 0.0001
|
| 603 |
-
assert abs(d[1] - 0) < 0.0001
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
def test_trig_float():
|
| 607 |
-
f = lambdify([x], [cos(x), sin(x)])
|
| 608 |
-
d = f(3.14159)
|
| 609 |
-
assert abs(d[0] + 1) < 0.0001
|
| 610 |
-
assert abs(d[1] - 0) < 0.0001
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
def test_docs():
|
| 614 |
-
f = lambdify(x, x**2)
|
| 615 |
-
assert f(2) == 4
|
| 616 |
-
f = lambdify([x, y, z], [z, y, x])
|
| 617 |
-
assert f(1, 2, 3) == [3, 2, 1]
|
| 618 |
-
f = lambdify(x, sqrt(x))
|
| 619 |
-
assert f(4) == 2.0
|
| 620 |
-
f = lambdify((x, y), sin(x*y)**2)
|
| 621 |
-
assert f(0, 5) == 0
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
def test_math():
|
| 625 |
-
f = lambdify((x, y), sin(x), modules="math")
|
| 626 |
-
assert f(0, 5) == 0
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
def test_sin():
|
| 630 |
-
f = lambdify(x, sin(x)**2)
|
| 631 |
-
assert isinstance(f(2), float)
|
| 632 |
-
f = lambdify(x, sin(x)**2, modules="math")
|
| 633 |
-
assert isinstance(f(2), float)
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
def test_matrix():
|
| 637 |
-
A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
|
| 638 |
-
sol = Matrix([[1, 2], [sin(3) + 4, 1]])
|
| 639 |
-
f = lambdify((x, y, z), A, modules="sympy")
|
| 640 |
-
assert f(1, 2, 3) == sol
|
| 641 |
-
f = lambdify((x, y, z), (A, [A]), modules="sympy")
|
| 642 |
-
assert f(1, 2, 3) == (sol, [sol])
|
| 643 |
-
J = Matrix((x, x + y)).jacobian((x, y))
|
| 644 |
-
v = Matrix((x, y))
|
| 645 |
-
sol = Matrix([[1, 0], [1, 1]])
|
| 646 |
-
assert lambdify(v, J, modules='sympy')(1, 2) == sol
|
| 647 |
-
assert lambdify(v.T, J, modules='sympy')(1, 2) == sol
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
def test_numpy_matrix():
|
| 651 |
-
if not numpy:
|
| 652 |
-
skip("numpy not installed.")
|
| 653 |
-
A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
|
| 654 |
-
sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
|
| 655 |
-
#Lambdify array first, to ensure return to array as default
|
| 656 |
-
f = lambdify((x, y, z), A, ['numpy'])
|
| 657 |
-
numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
|
| 658 |
-
#Check that the types are arrays and matrices
|
| 659 |
-
assert isinstance(f(1, 2, 3), numpy.ndarray)
|
| 660 |
-
|
| 661 |
-
# gh-15071
|
| 662 |
-
class dot(Function):
|
| 663 |
-
pass
|
| 664 |
-
x_dot_mtx = dot(x, Matrix([[2], [1], [0]]))
|
| 665 |
-
f_dot1 = lambdify(x, x_dot_mtx)
|
| 666 |
-
inp = numpy.zeros((17, 3))
|
| 667 |
-
assert numpy.all(f_dot1(inp) == 0)
|
| 668 |
-
|
| 669 |
-
strict_kw = {"allow_unknown_functions": False, "inline": True, "fully_qualified_modules": False}
|
| 670 |
-
p2 = NumPyPrinter(dict(user_functions={'dot': 'dot'}, **strict_kw))
|
| 671 |
-
f_dot2 = lambdify(x, x_dot_mtx, printer=p2)
|
| 672 |
-
assert numpy.all(f_dot2(inp) == 0)
|
| 673 |
-
|
| 674 |
-
p3 = NumPyPrinter(strict_kw)
|
| 675 |
-
# The line below should probably fail upon construction (before calling with "(inp)"):
|
| 676 |
-
raises(Exception, lambda: lambdify(x, x_dot_mtx, printer=p3)(inp))
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
def test_numpy_transpose():
|
| 680 |
-
if not numpy:
|
| 681 |
-
skip("numpy not installed.")
|
| 682 |
-
A = Matrix([[1, x], [0, 1]])
|
| 683 |
-
f = lambdify((x), A.T, modules="numpy")
|
| 684 |
-
numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]]))
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
def test_numpy_dotproduct():
|
| 688 |
-
if not numpy:
|
| 689 |
-
skip("numpy not installed")
|
| 690 |
-
A = Matrix([x, y, z])
|
| 691 |
-
f1 = lambdify([x, y, z], DotProduct(A, A), modules='numpy')
|
| 692 |
-
f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
|
| 693 |
-
f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='numpy')
|
| 694 |
-
f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
|
| 695 |
-
|
| 696 |
-
assert f1(1, 2, 3) == \
|
| 697 |
-
f2(1, 2, 3) == \
|
| 698 |
-
f3(1, 2, 3) == \
|
| 699 |
-
f4(1, 2, 3) == \
|
| 700 |
-
numpy.array([14])
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
def test_numpy_inverse():
|
| 704 |
-
if not numpy:
|
| 705 |
-
skip("numpy not installed.")
|
| 706 |
-
A = Matrix([[1, x], [0, 1]])
|
| 707 |
-
f = lambdify((x), A**-1, modules="numpy")
|
| 708 |
-
numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]]))
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
def test_numpy_old_matrix():
|
| 712 |
-
if not numpy:
|
| 713 |
-
skip("numpy not installed.")
|
| 714 |
-
A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
|
| 715 |
-
sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
|
| 716 |
-
f = lambdify((x, y, z), A, [{'ImmutableDenseMatrix': numpy.matrix}, 'numpy'])
|
| 717 |
-
with ignore_warnings(PendingDeprecationWarning):
|
| 718 |
-
numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
|
| 719 |
-
assert isinstance(f(1, 2, 3), numpy.matrix)
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
def test_scipy_sparse_matrix():
|
| 723 |
-
if not scipy:
|
| 724 |
-
skip("scipy not installed.")
|
| 725 |
-
A = SparseMatrix([[x, 0], [0, y]])
|
| 726 |
-
f = lambdify((x, y), A, modules="scipy")
|
| 727 |
-
B = f(1, 2)
|
| 728 |
-
assert isinstance(B, scipy.sparse.coo_matrix)
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
def test_python_div_zero_issue_11306():
|
| 732 |
-
if not numpy:
|
| 733 |
-
skip("numpy not installed.")
|
| 734 |
-
p = Piecewise((1 / x, y < -1), (x, y < 1), (1 / x, True))
|
| 735 |
-
f = lambdify([x, y], p, modules='numpy')
|
| 736 |
-
with numpy.errstate(divide='ignore'):
|
| 737 |
-
assert float(f(numpy.array(0), numpy.array(0.5))) == 0
|
| 738 |
-
assert float(f(numpy.array(0), numpy.array(1))) == float('inf')
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
def test_issue9474():
|
| 742 |
-
mods = [None, 'math']
|
| 743 |
-
if numpy:
|
| 744 |
-
mods.append('numpy')
|
| 745 |
-
if mpmath:
|
| 746 |
-
mods.append('mpmath')
|
| 747 |
-
for mod in mods:
|
| 748 |
-
f = lambdify(x, S.One/x, modules=mod)
|
| 749 |
-
assert f(2) == 0.5
|
| 750 |
-
f = lambdify(x, floor(S.One/x), modules=mod)
|
| 751 |
-
assert f(2) == 0
|
| 752 |
-
|
| 753 |
-
for absfunc, modules in product([Abs, abs], mods):
|
| 754 |
-
f = lambdify(x, absfunc(x), modules=modules)
|
| 755 |
-
assert f(-1) == 1
|
| 756 |
-
assert f(1) == 1
|
| 757 |
-
assert f(3+4j) == 5
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
def test_issue_9871():
|
| 761 |
-
if not numexpr:
|
| 762 |
-
skip("numexpr not installed.")
|
| 763 |
-
if not numpy:
|
| 764 |
-
skip("numpy not installed.")
|
| 765 |
-
|
| 766 |
-
r = sqrt(x**2 + y**2)
|
| 767 |
-
expr = diff(1/r, x)
|
| 768 |
-
|
| 769 |
-
xn = yn = numpy.linspace(1, 10, 16)
|
| 770 |
-
# expr(xn, xn) = -xn/(sqrt(2)*xn)^3
|
| 771 |
-
fv_exact = -numpy.sqrt(2.)**-3 * xn**-2
|
| 772 |
-
|
| 773 |
-
fv_numpy = lambdify((x, y), expr, modules='numpy')(xn, yn)
|
| 774 |
-
fv_numexpr = lambdify((x, y), expr, modules='numexpr')(xn, yn)
|
| 775 |
-
numpy.testing.assert_allclose(fv_numpy, fv_exact, rtol=1e-10)
|
| 776 |
-
numpy.testing.assert_allclose(fv_numexpr, fv_exact, rtol=1e-10)
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
def test_numpy_piecewise():
|
| 780 |
-
if not numpy:
|
| 781 |
-
skip("numpy not installed.")
|
| 782 |
-
pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True))
|
| 783 |
-
f = lambdify(x, pieces, modules="numpy")
|
| 784 |
-
numpy.testing.assert_array_equal(f(numpy.arange(10)),
|
| 785 |
-
numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81]))
|
| 786 |
-
# If we evaluate somewhere all conditions are False, we should get back NaN
|
| 787 |
-
nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0)))
|
| 788 |
-
numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])),
|
| 789 |
-
numpy.array([1, numpy.nan, 1]))
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
def test_numpy_logical_ops():
|
| 793 |
-
if not numpy:
|
| 794 |
-
skip("numpy not installed.")
|
| 795 |
-
and_func = lambdify((x, y), And(x, y), modules="numpy")
|
| 796 |
-
and_func_3 = lambdify((x, y, z), And(x, y, z), modules="numpy")
|
| 797 |
-
or_func = lambdify((x, y), Or(x, y), modules="numpy")
|
| 798 |
-
or_func_3 = lambdify((x, y, z), Or(x, y, z), modules="numpy")
|
| 799 |
-
not_func = lambdify((x), Not(x), modules="numpy")
|
| 800 |
-
arr1 = numpy.array([True, True])
|
| 801 |
-
arr2 = numpy.array([False, True])
|
| 802 |
-
arr3 = numpy.array([True, False])
|
| 803 |
-
numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True]))
|
| 804 |
-
numpy.testing.assert_array_equal(and_func_3(arr1, arr2, arr3), numpy.array([False, False]))
|
| 805 |
-
numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True]))
|
| 806 |
-
numpy.testing.assert_array_equal(or_func_3(arr1, arr2, arr3), numpy.array([True, True]))
|
| 807 |
-
numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False]))
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
def test_numpy_matmul():
|
| 811 |
-
if not numpy:
|
| 812 |
-
skip("numpy not installed.")
|
| 813 |
-
xmat = Matrix([[x, y], [z, 1+z]])
|
| 814 |
-
ymat = Matrix([[x**2], [Abs(x)]])
|
| 815 |
-
mat_func = lambdify((x, y, z), xmat*ymat, modules="numpy")
|
| 816 |
-
numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]]))
|
| 817 |
-
numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]]))
|
| 818 |
-
# Multiple matrices chained together in multiplication
|
| 819 |
-
f = lambdify((x, y, z), xmat*xmat*xmat, modules="numpy")
|
| 820 |
-
numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25],
|
| 821 |
-
[159, 251]]))
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
def test_numpy_numexpr():
|
| 825 |
-
if not numpy:
|
| 826 |
-
skip("numpy not installed.")
|
| 827 |
-
if not numexpr:
|
| 828 |
-
skip("numexpr not installed.")
|
| 829 |
-
a, b, c = numpy.random.randn(3, 128, 128)
|
| 830 |
-
# ensure that numpy and numexpr return same value for complicated expression
|
| 831 |
-
expr = sin(x) + cos(y) + tan(z)**2 + Abs(z-y)*acos(sin(y*z)) + \
|
| 832 |
-
Abs(y-z)*acosh(2+exp(y-x))- sqrt(x**2+I*y**2)
|
| 833 |
-
npfunc = lambdify((x, y, z), expr, modules='numpy')
|
| 834 |
-
nefunc = lambdify((x, y, z), expr, modules='numexpr')
|
| 835 |
-
assert numpy.allclose(npfunc(a, b, c), nefunc(a, b, c))
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
def test_numexpr_userfunctions():
|
| 839 |
-
if not numpy:
|
| 840 |
-
skip("numpy not installed.")
|
| 841 |
-
if not numexpr:
|
| 842 |
-
skip("numexpr not installed.")
|
| 843 |
-
a, b = numpy.random.randn(2, 10)
|
| 844 |
-
uf = type('uf', (Function, ),
|
| 845 |
-
{'eval' : classmethod(lambda x, y : y**2+1)})
|
| 846 |
-
func = lambdify(x, 1-uf(x), modules='numexpr')
|
| 847 |
-
assert numpy.allclose(func(a), -(a**2))
|
| 848 |
-
|
| 849 |
-
uf = implemented_function(Function('uf'), lambda x, y : 2*x*y+1)
|
| 850 |
-
func = lambdify((x, y), uf(x, y), modules='numexpr')
|
| 851 |
-
assert numpy.allclose(func(a, b), 2*a*b+1)
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
def test_tensorflow_basic_math():
|
| 855 |
-
if not tensorflow:
|
| 856 |
-
skip("tensorflow not installed.")
|
| 857 |
-
expr = Max(sin(x), Abs(1/(x+2)))
|
| 858 |
-
func = lambdify(x, expr, modules="tensorflow")
|
| 859 |
-
|
| 860 |
-
with tensorflow.compat.v1.Session() as s:
|
| 861 |
-
a = tensorflow.constant(0, dtype=tensorflow.float32)
|
| 862 |
-
assert func(a).eval(session=s) == 0.5
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
def test_tensorflow_placeholders():
|
| 866 |
-
if not tensorflow:
|
| 867 |
-
skip("tensorflow not installed.")
|
| 868 |
-
expr = Max(sin(x), Abs(1/(x+2)))
|
| 869 |
-
func = lambdify(x, expr, modules="tensorflow")
|
| 870 |
-
|
| 871 |
-
with tensorflow.compat.v1.Session() as s:
|
| 872 |
-
a = tensorflow.compat.v1.placeholder(dtype=tensorflow.float32)
|
| 873 |
-
assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
def test_tensorflow_variables():
|
| 877 |
-
if not tensorflow:
|
| 878 |
-
skip("tensorflow not installed.")
|
| 879 |
-
expr = Max(sin(x), Abs(1/(x+2)))
|
| 880 |
-
func = lambdify(x, expr, modules="tensorflow")
|
| 881 |
-
|
| 882 |
-
with tensorflow.compat.v1.Session() as s:
|
| 883 |
-
a = tensorflow.Variable(0, dtype=tensorflow.float32)
|
| 884 |
-
s.run(a.initializer)
|
| 885 |
-
assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
def test_tensorflow_logical_operations():
|
| 889 |
-
if not tensorflow:
|
| 890 |
-
skip("tensorflow not installed.")
|
| 891 |
-
expr = Not(And(Or(x, y), y))
|
| 892 |
-
func = lambdify([x, y], expr, modules="tensorflow")
|
| 893 |
-
|
| 894 |
-
with tensorflow.compat.v1.Session() as s:
|
| 895 |
-
assert func(False, True).eval(session=s) == False
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
def test_tensorflow_piecewise():
|
| 899 |
-
if not tensorflow:
|
| 900 |
-
skip("tensorflow not installed.")
|
| 901 |
-
expr = Piecewise((0, Eq(x,0)), (-1, x < 0), (1, x > 0))
|
| 902 |
-
func = lambdify(x, expr, modules="tensorflow")
|
| 903 |
-
|
| 904 |
-
with tensorflow.compat.v1.Session() as s:
|
| 905 |
-
assert func(-1).eval(session=s) == -1
|
| 906 |
-
assert func(0).eval(session=s) == 0
|
| 907 |
-
assert func(1).eval(session=s) == 1
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
def test_tensorflow_multi_max():
|
| 911 |
-
if not tensorflow:
|
| 912 |
-
skip("tensorflow not installed.")
|
| 913 |
-
expr = Max(x, -x, x**2)
|
| 914 |
-
func = lambdify(x, expr, modules="tensorflow")
|
| 915 |
-
|
| 916 |
-
with tensorflow.compat.v1.Session() as s:
|
| 917 |
-
assert func(-2).eval(session=s) == 4
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
def test_tensorflow_multi_min():
|
| 921 |
-
if not tensorflow:
|
| 922 |
-
skip("tensorflow not installed.")
|
| 923 |
-
expr = Min(x, -x, x**2)
|
| 924 |
-
func = lambdify(x, expr, modules="tensorflow")
|
| 925 |
-
|
| 926 |
-
with tensorflow.compat.v1.Session() as s:
|
| 927 |
-
assert func(-2).eval(session=s) == -2
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
def test_tensorflow_relational():
|
| 931 |
-
if not tensorflow:
|
| 932 |
-
skip("tensorflow not installed.")
|
| 933 |
-
expr = x >= 0
|
| 934 |
-
func = lambdify(x, expr, modules="tensorflow")
|
| 935 |
-
|
| 936 |
-
with tensorflow.compat.v1.Session() as s:
|
| 937 |
-
assert func(1).eval(session=s) == True
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
def test_tensorflow_complexes():
|
| 941 |
-
if not tensorflow:
|
| 942 |
-
skip("tensorflow not installed")
|
| 943 |
-
|
| 944 |
-
func1 = lambdify(x, re(x), modules="tensorflow")
|
| 945 |
-
func2 = lambdify(x, im(x), modules="tensorflow")
|
| 946 |
-
func3 = lambdify(x, Abs(x), modules="tensorflow")
|
| 947 |
-
func4 = lambdify(x, arg(x), modules="tensorflow")
|
| 948 |
-
|
| 949 |
-
with tensorflow.compat.v1.Session() as s:
|
| 950 |
-
# For versions before
|
| 951 |
-
# https://github.com/tensorflow/tensorflow/issues/30029
|
| 952 |
-
# resolved, using Python numeric types may not work
|
| 953 |
-
a = tensorflow.constant(1+2j)
|
| 954 |
-
assert func1(a).eval(session=s) == 1
|
| 955 |
-
assert func2(a).eval(session=s) == 2
|
| 956 |
-
|
| 957 |
-
tensorflow_result = func3(a).eval(session=s)
|
| 958 |
-
sympy_result = Abs(1 + 2j).evalf()
|
| 959 |
-
assert abs(tensorflow_result-sympy_result) < 10**-6
|
| 960 |
-
|
| 961 |
-
tensorflow_result = func4(a).eval(session=s)
|
| 962 |
-
sympy_result = arg(1 + 2j).evalf()
|
| 963 |
-
assert abs(tensorflow_result-sympy_result) < 10**-6
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
def test_tensorflow_array_arg():
|
| 967 |
-
# Test for issue 14655 (tensorflow part)
|
| 968 |
-
if not tensorflow:
|
| 969 |
-
skip("tensorflow not installed.")
|
| 970 |
-
|
| 971 |
-
f = lambdify([[x, y]], x*x + y, 'tensorflow')
|
| 972 |
-
|
| 973 |
-
with tensorflow.compat.v1.Session() as s:
|
| 974 |
-
fcall = f(tensorflow.constant([2.0, 1.0]))
|
| 975 |
-
assert fcall.eval(session=s) == 5.0
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
#================== Test symbolic ==================================
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
def test_sym_single_arg():
|
| 982 |
-
f = lambdify(x, x * y)
|
| 983 |
-
assert f(z) == z * y
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
def test_sym_list_args():
|
| 987 |
-
f = lambdify([x, y], x + y + z)
|
| 988 |
-
assert f(1, 2) == 3 + z
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
def test_sym_integral():
|
| 992 |
-
f = Lambda(x, exp(-x**2))
|
| 993 |
-
l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules="sympy")
|
| 994 |
-
assert l(y) == Integral(exp(-y**2), (y, -oo, oo))
|
| 995 |
-
assert l(y).doit() == sqrt(pi)
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
def test_namespace_order():
|
| 999 |
-
# lambdify had a bug, such that module dictionaries or cached module
|
| 1000 |
-
# dictionaries would pull earlier namespaces into themselves.
|
| 1001 |
-
# Because the module dictionaries form the namespace of the
|
| 1002 |
-
# generated lambda, this meant that the behavior of a previously
|
| 1003 |
-
# generated lambda function could change as a result of later calls
|
| 1004 |
-
# to lambdify.
|
| 1005 |
-
n1 = {'f': lambda x: 'first f'}
|
| 1006 |
-
n2 = {'f': lambda x: 'second f',
|
| 1007 |
-
'g': lambda x: 'function g'}
|
| 1008 |
-
f = sympy.Function('f')
|
| 1009 |
-
g = sympy.Function('g')
|
| 1010 |
-
if1 = lambdify(x, f(x), modules=(n1, "sympy"))
|
| 1011 |
-
assert if1(1) == 'first f'
|
| 1012 |
-
if2 = lambdify(x, g(x), modules=(n2, "sympy"))
|
| 1013 |
-
# previously gave 'second f'
|
| 1014 |
-
assert if1(1) == 'first f'
|
| 1015 |
-
|
| 1016 |
-
assert if2(1) == 'function g'
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
def test_imps():
|
| 1020 |
-
# Here we check if the default returned functions are anonymous - in
|
| 1021 |
-
# the sense that we can have more than one function with the same name
|
| 1022 |
-
f = implemented_function('f', lambda x: 2*x)
|
| 1023 |
-
g = implemented_function('f', lambda x: math.sqrt(x))
|
| 1024 |
-
l1 = lambdify(x, f(x))
|
| 1025 |
-
l2 = lambdify(x, g(x))
|
| 1026 |
-
assert str(f(x)) == str(g(x))
|
| 1027 |
-
assert l1(3) == 6
|
| 1028 |
-
assert l2(3) == math.sqrt(3)
|
| 1029 |
-
# check that we can pass in a Function as input
|
| 1030 |
-
func = sympy.Function('myfunc')
|
| 1031 |
-
assert not hasattr(func, '_imp_')
|
| 1032 |
-
my_f = implemented_function(func, lambda x: 2*x)
|
| 1033 |
-
assert hasattr(my_f, '_imp_')
|
| 1034 |
-
# Error for functions with same name and different implementation
|
| 1035 |
-
f2 = implemented_function("f", lambda x: x + 101)
|
| 1036 |
-
raises(ValueError, lambda: lambdify(x, f(f2(x))))
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
def test_imps_errors():
|
| 1040 |
-
# Test errors that implemented functions can return, and still be able to
|
| 1041 |
-
# form expressions.
|
| 1042 |
-
# See: https://github.com/sympy/sympy/issues/10810
|
| 1043 |
-
#
|
| 1044 |
-
# XXX: Removed AttributeError here. This test was added due to issue 10810
|
| 1045 |
-
# but that issue was about ValueError. It doesn't seem reasonable to
|
| 1046 |
-
# "support" catching AttributeError in the same context...
|
| 1047 |
-
for val, error_class in product((0, 0., 2, 2.0), (TypeError, ValueError)):
|
| 1048 |
-
|
| 1049 |
-
def myfunc(a):
|
| 1050 |
-
if a == 0:
|
| 1051 |
-
raise error_class
|
| 1052 |
-
return 1
|
| 1053 |
-
|
| 1054 |
-
f = implemented_function('f', myfunc)
|
| 1055 |
-
expr = f(val)
|
| 1056 |
-
assert expr == f(val)
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
def test_imps_wrong_args():
|
| 1060 |
-
raises(ValueError, lambda: implemented_function(sin, lambda x: x))
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
def test_lambdify_imps():
|
| 1064 |
-
# Test lambdify with implemented functions
|
| 1065 |
-
# first test basic (sympy) lambdify
|
| 1066 |
-
f = sympy.cos
|
| 1067 |
-
assert lambdify(x, f(x))(0) == 1
|
| 1068 |
-
assert lambdify(x, 1 + f(x))(0) == 2
|
| 1069 |
-
assert lambdify((x, y), y + f(x))(0, 1) == 2
|
| 1070 |
-
# make an implemented function and test
|
| 1071 |
-
f = implemented_function("f", lambda x: x + 100)
|
| 1072 |
-
assert lambdify(x, f(x))(0) == 100
|
| 1073 |
-
assert lambdify(x, 1 + f(x))(0) == 101
|
| 1074 |
-
assert lambdify((x, y), y + f(x))(0, 1) == 101
|
| 1075 |
-
# Can also handle tuples, lists, dicts as expressions
|
| 1076 |
-
lam = lambdify(x, (f(x), x))
|
| 1077 |
-
assert lam(3) == (103, 3)
|
| 1078 |
-
lam = lambdify(x, [f(x), x])
|
| 1079 |
-
assert lam(3) == [103, 3]
|
| 1080 |
-
lam = lambdify(x, [f(x), (f(x), x)])
|
| 1081 |
-
assert lam(3) == [103, (103, 3)]
|
| 1082 |
-
lam = lambdify(x, {f(x): x})
|
| 1083 |
-
assert lam(3) == {103: 3}
|
| 1084 |
-
lam = lambdify(x, {f(x): x})
|
| 1085 |
-
assert lam(3) == {103: 3}
|
| 1086 |
-
lam = lambdify(x, {x: f(x)})
|
| 1087 |
-
assert lam(3) == {3: 103}
|
| 1088 |
-
# Check that imp preferred to other namespaces by default
|
| 1089 |
-
d = {'f': lambda x: x + 99}
|
| 1090 |
-
lam = lambdify(x, f(x), d)
|
| 1091 |
-
assert lam(3) == 103
|
| 1092 |
-
# Unless flag passed
|
| 1093 |
-
lam = lambdify(x, f(x), d, use_imps=False)
|
| 1094 |
-
assert lam(3) == 102
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
def test_dummification():
|
| 1098 |
-
t = symbols('t')
|
| 1099 |
-
F = Function('F')
|
| 1100 |
-
G = Function('G')
|
| 1101 |
-
#"\alpha" is not a valid Python variable name
|
| 1102 |
-
#lambdify should sub in a dummy for it, and return
|
| 1103 |
-
#without a syntax error
|
| 1104 |
-
alpha = symbols(r'\alpha')
|
| 1105 |
-
some_expr = 2 * F(t)**2 / G(t)
|
| 1106 |
-
lam = lambdify((F(t), G(t)), some_expr)
|
| 1107 |
-
assert lam(3, 9) == 2
|
| 1108 |
-
lam = lambdify(sin(t), 2 * sin(t)**2)
|
| 1109 |
-
assert lam(F(t)) == 2 * F(t)**2
|
| 1110 |
-
#Test that \alpha was properly dummified
|
| 1111 |
-
lam = lambdify((alpha, t), 2*alpha + t)
|
| 1112 |
-
assert lam(2, 1) == 5
|
| 1113 |
-
raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5))
|
| 1114 |
-
raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5))
|
| 1115 |
-
raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5))
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
def test_lambdify__arguments_with_invalid_python_identifiers():
|
| 1119 |
-
# see sympy/sympy#26690
|
| 1120 |
-
N = CoordSys3D('N')
|
| 1121 |
-
xn, yn, zn = N.base_scalars()
|
| 1122 |
-
expr = xn + yn
|
| 1123 |
-
f = lambdify([xn, yn], expr)
|
| 1124 |
-
res = f(0.2, 0.3)
|
| 1125 |
-
ref = 0.2 + 0.3
|
| 1126 |
-
assert abs(res-ref) < 1e-15
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
def test_curly_matrix_symbol():
|
| 1130 |
-
# Issue #15009
|
| 1131 |
-
curlyv = sympy.MatrixSymbol("{v}", 2, 1)
|
| 1132 |
-
lam = lambdify(curlyv, curlyv)
|
| 1133 |
-
assert lam(1)==1
|
| 1134 |
-
lam = lambdify(curlyv, curlyv, dummify=True)
|
| 1135 |
-
assert lam(1)==1
|
| 1136 |
-
|
| 1137 |
-
|
| 1138 |
-
def test_python_keywords():
|
| 1139 |
-
# Test for issue 7452. The automatic dummification should ensure use of
|
| 1140 |
-
# Python reserved keywords as symbol names will create valid lambda
|
| 1141 |
-
# functions. This is an additional regression test.
|
| 1142 |
-
python_if = symbols('if')
|
| 1143 |
-
expr = python_if / 2
|
| 1144 |
-
f = lambdify(python_if, expr)
|
| 1145 |
-
assert f(4.0) == 2.0
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
def test_lambdify_docstring():
|
| 1149 |
-
func = lambdify((w, x, y, z), w + x + y + z)
|
| 1150 |
-
ref = (
|
| 1151 |
-
"Created with lambdify. Signature:\n\n"
|
| 1152 |
-
"func(w, x, y, z)\n\n"
|
| 1153 |
-
"Expression:\n\n"
|
| 1154 |
-
"w + x + y + z"
|
| 1155 |
-
).splitlines()
|
| 1156 |
-
assert func.__doc__.splitlines()[:len(ref)] == ref
|
| 1157 |
-
syms = symbols('a1:26')
|
| 1158 |
-
func = lambdify(syms, sum(syms))
|
| 1159 |
-
ref = (
|
| 1160 |
-
"Created with lambdify. Signature:\n\n"
|
| 1161 |
-
"func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n"
|
| 1162 |
-
" a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n"
|
| 1163 |
-
"Expression:\n\n"
|
| 1164 |
-
"a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +..."
|
| 1165 |
-
).splitlines()
|
| 1166 |
-
assert func.__doc__.splitlines()[:len(ref)] == ref
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
def test_lambdify_linecache():
|
| 1170 |
-
func = lambdify(x, x + 1)
|
| 1171 |
-
source = 'def _lambdifygenerated(x):\n return x + 1\n'
|
| 1172 |
-
assert inspect.getsource(func) == source
|
| 1173 |
-
filename = inspect.getsourcefile(func)
|
| 1174 |
-
assert filename.startswith('<lambdifygenerated-')
|
| 1175 |
-
assert filename in linecache.cache
|
| 1176 |
-
assert linecache.cache[filename] == (len(source), None, source.splitlines(True), filename)
|
| 1177 |
-
del func
|
| 1178 |
-
gc.collect()
|
| 1179 |
-
assert filename not in linecache.cache
|
| 1180 |
-
|
| 1181 |
-
#================== Test special printers ==========================
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
def test_special_printers():
|
| 1185 |
-
from sympy.printing.lambdarepr import IntervalPrinter
|
| 1186 |
-
|
| 1187 |
-
def intervalrepr(expr):
|
| 1188 |
-
return IntervalPrinter().doprint(expr)
|
| 1189 |
-
|
| 1190 |
-
expr = sqrt(sqrt(2) + sqrt(3)) + S.Half
|
| 1191 |
-
|
| 1192 |
-
func0 = lambdify((), expr, modules="mpmath", printer=intervalrepr)
|
| 1193 |
-
func1 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter)
|
| 1194 |
-
func2 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter())
|
| 1195 |
-
|
| 1196 |
-
mpi = type(mpmath.mpi(1, 2))
|
| 1197 |
-
|
| 1198 |
-
assert isinstance(func0(), mpi)
|
| 1199 |
-
assert isinstance(func1(), mpi)
|
| 1200 |
-
assert isinstance(func2(), mpi)
|
| 1201 |
-
|
| 1202 |
-
# To check Is lambdify loggamma works for mpmath or not
|
| 1203 |
-
exp1 = lambdify(x, loggamma(x), 'mpmath')(5)
|
| 1204 |
-
exp2 = lambdify(x, loggamma(x), 'mpmath')(1.8)
|
| 1205 |
-
exp3 = lambdify(x, loggamma(x), 'mpmath')(15)
|
| 1206 |
-
exp_ls = [exp1, exp2, exp3]
|
| 1207 |
-
|
| 1208 |
-
sol1 = mpmath.loggamma(5)
|
| 1209 |
-
sol2 = mpmath.loggamma(1.8)
|
| 1210 |
-
sol3 = mpmath.loggamma(15)
|
| 1211 |
-
sol_ls = [sol1, sol2, sol3]
|
| 1212 |
-
|
| 1213 |
-
assert exp_ls == sol_ls
|
| 1214 |
-
|
| 1215 |
-
|
| 1216 |
-
def test_true_false():
|
| 1217 |
-
# We want exact is comparison here, not just ==
|
| 1218 |
-
assert lambdify([], true)() is True
|
| 1219 |
-
assert lambdify([], false)() is False
|
| 1220 |
-
|
| 1221 |
-
|
| 1222 |
-
def test_issue_2790():
|
| 1223 |
-
assert lambdify((x, (y, z)), x + y)(1, (2, 4)) == 3
|
| 1224 |
-
assert lambdify((x, (y, (w, z))), w + x + y + z)(1, (2, (3, 4))) == 10
|
| 1225 |
-
assert lambdify(x, x + 1, dummify=False)(1) == 2
|
| 1226 |
-
|
| 1227 |
-
|
| 1228 |
-
def test_issue_12092():
|
| 1229 |
-
f = implemented_function('f', lambda x: x**2)
|
| 1230 |
-
assert f(f(2)).evalf() == Float(16)
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
def test_issue_14911():
|
| 1234 |
-
class Variable(sympy.Symbol):
|
| 1235 |
-
def _sympystr(self, printer):
|
| 1236 |
-
return printer.doprint(self.name)
|
| 1237 |
-
|
| 1238 |
-
_lambdacode = _sympystr
|
| 1239 |
-
_numpycode = _sympystr
|
| 1240 |
-
|
| 1241 |
-
x = Variable('x')
|
| 1242 |
-
y = 2 * x
|
| 1243 |
-
code = LambdaPrinter().doprint(y)
|
| 1244 |
-
assert code.replace(' ', '') == '2*x'
|
| 1245 |
-
|
| 1246 |
-
|
| 1247 |
-
def test_ITE():
|
| 1248 |
-
assert lambdify((x, y, z), ITE(x, y, z))(True, 5, 3) == 5
|
| 1249 |
-
assert lambdify((x, y, z), ITE(x, y, z))(False, 5, 3) == 3
|
| 1250 |
-
|
| 1251 |
-
|
| 1252 |
-
def test_Min_Max():
|
| 1253 |
-
# see gh-10375
|
| 1254 |
-
assert lambdify((x, y, z), Min(x, y, z))(1, 2, 3) == 1
|
| 1255 |
-
assert lambdify((x, y, z), Max(x, y, z))(1, 2, 3) == 3
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
-
def test_amin_amax_minimum_maximum():
|
| 1259 |
-
if not numpy:
|
| 1260 |
-
skip("numpy not installed")
|
| 1261 |
-
|
| 1262 |
-
a234 = numpy.array([2, 3, 4])
|
| 1263 |
-
a152 = numpy.array([1, 5, 2])
|
| 1264 |
-
|
| 1265 |
-
a254 = numpy.array([2, 5, 4])
|
| 1266 |
-
a132 = numpy.array([1, 3, 2])
|
| 1267 |
-
# 2 args
|
| 1268 |
-
assert numpy.all(lambdify((x, y), maximum(x, y))(a234, a152) == a254)
|
| 1269 |
-
assert numpy.all(lambdify((x, y), minimum(x, y))(a234, a152) == a132)
|
| 1270 |
-
|
| 1271 |
-
# 3 args
|
| 1272 |
-
assert numpy.all(lambdify((x, y, z), maximum(x, y, z))(a234, a152, a234) == a254)
|
| 1273 |
-
assert numpy.all(lambdify((x, y, z), minimum(x, y, z))(a234, a152, a234) == a132)
|
| 1274 |
-
|
| 1275 |
-
# 1 arg
|
| 1276 |
-
assert numpy.all(lambdify((x,), maximum(x))(a234) == a234)
|
| 1277 |
-
assert numpy.all(lambdify((x,), minimum(x))(a234) == a234)
|
| 1278 |
-
|
| 1279 |
-
# 4 args, mixed length
|
| 1280 |
-
assert numpy.all(lambdify((x, y, z, w), maximum(x, y, z, w))(a234, a152, a234, 3) == [3, 5, 4])
|
| 1281 |
-
assert numpy.all(lambdify((x, y, z, w), minimum(x, y, z, w))(a234, a152, a234, 2) == [1, 2, 2])
|
| 1282 |
-
|
| 1283 |
-
# amin & amax
|
| 1284 |
-
assert lambdify((x, y), [amin(x), amax(y)])(a234, a152) == [2, 5]
|
| 1285 |
-
A = numpy.array([
|
| 1286 |
-
[0, 4, 8],
|
| 1287 |
-
[1, 5, 9],
|
| 1288 |
-
[2, 6, 10],
|
| 1289 |
-
])
|
| 1290 |
-
min_, max_ = lambdify((x,), [amin(x, axis=0), amax(x, axis=1)])(A)
|
| 1291 |
-
assert numpy.all(min_ == numpy.amin(A, axis=0))
|
| 1292 |
-
assert numpy.all(max_ == numpy.amax(A, axis=1))
|
| 1293 |
-
|
| 1294 |
-
# see gh-25659
|
| 1295 |
-
assert numpy.all(lambdify((x, y), Max(x, y))([1, 2, 3], [3, 2, 1]) == [3, 2, 3])
|
| 1296 |
-
assert numpy.all(lambdify((x), Min(2, x))([1, 2, 3]) == [1, 2, 2])
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
|
| 1300 |
-
def test_Indexed():
|
| 1301 |
-
# Issue #10934
|
| 1302 |
-
if not numpy:
|
| 1303 |
-
skip("numpy not installed")
|
| 1304 |
-
|
| 1305 |
-
a = IndexedBase('a')
|
| 1306 |
-
i, j = symbols('i j')
|
| 1307 |
-
b = numpy.array([[1, 2], [3, 4]])
|
| 1308 |
-
assert lambdify(a, Sum(a[x, y], (x, 0, 1), (y, 0, 1)))(b) == 10
|
| 1309 |
-
|
| 1310 |
-
def test_Sum():
|
| 1311 |
-
e = Sum(z, (y, 0, x), (x, 0, 10))
|
| 1312 |
-
ref = 66*z
|
| 1313 |
-
assert e.doit() == ref
|
| 1314 |
-
assert lambdify([z], e)(7) == ref.subs(z, 7)
|
| 1315 |
-
|
| 1316 |
-
def test_Idx():
|
| 1317 |
-
# Issue 26888
|
| 1318 |
-
a = IndexedBase('a')
|
| 1319 |
-
i = Idx('i')
|
| 1320 |
-
b = [1,2,3]
|
| 1321 |
-
assert lambdify([a, i], a[i])(b, 2) == 3
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
def test_issue_12173():
|
| 1325 |
-
#test for issue 12173
|
| 1326 |
-
expr1 = lambdify((x, y), uppergamma(x, y),"mpmath")(1, 2)
|
| 1327 |
-
expr2 = lambdify((x, y), lowergamma(x, y),"mpmath")(1, 2)
|
| 1328 |
-
assert expr1 == uppergamma(1, 2).evalf()
|
| 1329 |
-
assert expr2 == lowergamma(1, 2).evalf()
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
-
def test_issue_13642():
|
| 1333 |
-
if not numpy:
|
| 1334 |
-
skip("numpy not installed")
|
| 1335 |
-
f = lambdify(x, sinc(x))
|
| 1336 |
-
assert Abs(f(1) - sinc(1)).n() < 1e-15
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
def test_sinc_mpmath():
|
| 1340 |
-
f = lambdify(x, sinc(x), "mpmath")
|
| 1341 |
-
assert Abs(f(1) - sinc(1)).n() < 1e-15
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
def test_lambdify_dummy_arg():
|
| 1345 |
-
d1 = Dummy()
|
| 1346 |
-
f1 = lambdify(d1, d1 + 1, dummify=False)
|
| 1347 |
-
assert f1(2) == 3
|
| 1348 |
-
f1b = lambdify(d1, d1 + 1)
|
| 1349 |
-
assert f1b(2) == 3
|
| 1350 |
-
d2 = Dummy('x')
|
| 1351 |
-
f2 = lambdify(d2, d2 + 1)
|
| 1352 |
-
assert f2(2) == 3
|
| 1353 |
-
f3 = lambdify([[d2]], d2 + 1)
|
| 1354 |
-
assert f3([2]) == 3
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
def test_lambdify_mixed_symbol_dummy_args():
|
| 1358 |
-
d = Dummy()
|
| 1359 |
-
# Contrived example of name clash
|
| 1360 |
-
dsym = symbols(str(d))
|
| 1361 |
-
f = lambdify([d, dsym], d - dsym)
|
| 1362 |
-
assert f(4, 1) == 3
|
| 1363 |
-
|
| 1364 |
-
|
| 1365 |
-
def test_numpy_array_arg():
|
| 1366 |
-
# Test for issue 14655 (numpy part)
|
| 1367 |
-
if not numpy:
|
| 1368 |
-
skip("numpy not installed")
|
| 1369 |
-
|
| 1370 |
-
f = lambdify([[x, y]], x*x + y, 'numpy')
|
| 1371 |
-
|
| 1372 |
-
assert f(numpy.array([2.0, 1.0])) == 5
|
| 1373 |
-
|
| 1374 |
-
|
| 1375 |
-
def test_scipy_fns():
|
| 1376 |
-
if not scipy:
|
| 1377 |
-
skip("scipy not installed")
|
| 1378 |
-
|
| 1379 |
-
single_arg_sympy_fns = [Ei, erf, erfc, factorial, gamma, loggamma, digamma, Si, Ci]
|
| 1380 |
-
single_arg_scipy_fns = [scipy.special.expi, scipy.special.erf, scipy.special.erfc,
|
| 1381 |
-
scipy.special.factorial, scipy.special.gamma, scipy.special.gammaln,
|
| 1382 |
-
scipy.special.psi, scipy.special.sici, scipy.special.sici]
|
| 1383 |
-
numpy.random.seed(0)
|
| 1384 |
-
for (sympy_fn, scipy_fn) in zip(single_arg_sympy_fns, single_arg_scipy_fns):
|
| 1385 |
-
f = lambdify(x, sympy_fn(x), modules="scipy")
|
| 1386 |
-
for i in range(20):
|
| 1387 |
-
tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
|
| 1388 |
-
# SciPy thinks that factorial(z) is 0 when re(z) < 0 and
|
| 1389 |
-
# does not support complex numbers.
|
| 1390 |
-
# SymPy does not think so.
|
| 1391 |
-
if sympy_fn == factorial:
|
| 1392 |
-
tv = numpy.abs(tv)
|
| 1393 |
-
# SciPy supports gammaln for real arguments only,
|
| 1394 |
-
# and there is also a branch cut along the negative real axis
|
| 1395 |
-
if sympy_fn == loggamma:
|
| 1396 |
-
tv = numpy.abs(tv)
|
| 1397 |
-
# SymPy's digamma evaluates as polygamma(0, z)
|
| 1398 |
-
# which SciPy supports for real arguments only
|
| 1399 |
-
if sympy_fn == digamma:
|
| 1400 |
-
tv = numpy.real(tv)
|
| 1401 |
-
sympy_result = sympy_fn(tv).evalf()
|
| 1402 |
-
scipy_result = scipy_fn(tv)
|
| 1403 |
-
# SciPy's sici returns a tuple with both Si and Ci present in it
|
| 1404 |
-
# which needs to be unpacked
|
| 1405 |
-
if sympy_fn == Si:
|
| 1406 |
-
scipy_result = scipy_fn(tv)[0]
|
| 1407 |
-
if sympy_fn == Ci:
|
| 1408 |
-
scipy_result = scipy_fn(tv)[1]
|
| 1409 |
-
assert abs(f(tv) - sympy_result) < 1e-13*(1 + abs(sympy_result))
|
| 1410 |
-
assert abs(f(tv) - scipy_result) < 1e-13*(1 + abs(sympy_result))
|
| 1411 |
-
|
| 1412 |
-
double_arg_sympy_fns = [RisingFactorial, besselj, bessely, besseli,
|
| 1413 |
-
besselk, polygamma]
|
| 1414 |
-
double_arg_scipy_fns = [scipy.special.poch, scipy.special.jv,
|
| 1415 |
-
scipy.special.yv, scipy.special.iv, scipy.special.kv, scipy.special.polygamma]
|
| 1416 |
-
for (sympy_fn, scipy_fn) in zip(double_arg_sympy_fns, double_arg_scipy_fns):
|
| 1417 |
-
f = lambdify((x, y), sympy_fn(x, y), modules="scipy")
|
| 1418 |
-
for i in range(20):
|
| 1419 |
-
# SciPy supports only real orders of Bessel functions
|
| 1420 |
-
tv1 = numpy.random.uniform(-10, 10)
|
| 1421 |
-
tv2 = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
|
| 1422 |
-
# SciPy requires a real valued 2nd argument for: poch, polygamma
|
| 1423 |
-
if sympy_fn in (RisingFactorial, polygamma):
|
| 1424 |
-
tv2 = numpy.real(tv2)
|
| 1425 |
-
if sympy_fn == polygamma:
|
| 1426 |
-
tv1 = abs(int(tv1)) # first argument to polygamma must be a non-negative integer.
|
| 1427 |
-
sympy_result = sympy_fn(tv1, tv2).evalf()
|
| 1428 |
-
assert abs(f(tv1, tv2) - sympy_result) < 1e-13*(1 + abs(sympy_result))
|
| 1429 |
-
assert abs(f(tv1, tv2) - scipy_fn(tv1, tv2)) < 1e-13*(1 + abs(sympy_result))
|
| 1430 |
-
|
| 1431 |
-
|
| 1432 |
-
def test_scipy_polys():
|
| 1433 |
-
if not scipy:
|
| 1434 |
-
skip("scipy not installed")
|
| 1435 |
-
numpy.random.seed(0)
|
| 1436 |
-
|
| 1437 |
-
params = symbols('n k a b')
|
| 1438 |
-
# list polynomials with the number of parameters
|
| 1439 |
-
polys = [
|
| 1440 |
-
(chebyshevt, 1),
|
| 1441 |
-
(chebyshevu, 1),
|
| 1442 |
-
(legendre, 1),
|
| 1443 |
-
(hermite, 1),
|
| 1444 |
-
(laguerre, 1),
|
| 1445 |
-
(gegenbauer, 2),
|
| 1446 |
-
(assoc_legendre, 2),
|
| 1447 |
-
(assoc_laguerre, 2),
|
| 1448 |
-
(jacobi, 3)
|
| 1449 |
-
]
|
| 1450 |
-
|
| 1451 |
-
msg = \
|
| 1452 |
-
"The random test of the function {func} with the arguments " \
|
| 1453 |
-
"{args} had failed because the SymPy result {sympy_result} " \
|
| 1454 |
-
"and SciPy result {scipy_result} had failed to converge " \
|
| 1455 |
-
"within the tolerance {tol} " \
|
| 1456 |
-
"(Actual absolute difference : {diff})"
|
| 1457 |
-
|
| 1458 |
-
for sympy_fn, num_params in polys:
|
| 1459 |
-
args = params[:num_params] + (x,)
|
| 1460 |
-
f = lambdify(args, sympy_fn(*args))
|
| 1461 |
-
for _ in range(10):
|
| 1462 |
-
tn = numpy.random.randint(3, 10)
|
| 1463 |
-
tparams = tuple(numpy.random.uniform(0, 5, size=num_params-1))
|
| 1464 |
-
tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
|
| 1465 |
-
# SciPy supports hermite for real arguments only
|
| 1466 |
-
if sympy_fn == hermite:
|
| 1467 |
-
tv = numpy.real(tv)
|
| 1468 |
-
# assoc_legendre needs x in (-1, 1) and integer param at most n
|
| 1469 |
-
if sympy_fn == assoc_legendre:
|
| 1470 |
-
tv = numpy.random.uniform(-1, 1)
|
| 1471 |
-
tparams = tuple(numpy.random.randint(1, tn, size=1))
|
| 1472 |
-
|
| 1473 |
-
vals = (tn,) + tparams + (tv,)
|
| 1474 |
-
scipy_result = f(*vals)
|
| 1475 |
-
sympy_result = sympy_fn(*vals).evalf()
|
| 1476 |
-
atol = 1e-9*(1 + abs(sympy_result))
|
| 1477 |
-
diff = abs(scipy_result - sympy_result)
|
| 1478 |
-
try:
|
| 1479 |
-
assert diff < atol
|
| 1480 |
-
except TypeError:
|
| 1481 |
-
raise AssertionError(
|
| 1482 |
-
msg.format(
|
| 1483 |
-
func=repr(sympy_fn),
|
| 1484 |
-
args=repr(vals),
|
| 1485 |
-
sympy_result=repr(sympy_result),
|
| 1486 |
-
scipy_result=repr(scipy_result),
|
| 1487 |
-
diff=diff,
|
| 1488 |
-
tol=atol)
|
| 1489 |
-
)
|
| 1490 |
-
|
| 1491 |
-
|
| 1492 |
-
def test_lambdify_inspect():
|
| 1493 |
-
f = lambdify(x, x**2)
|
| 1494 |
-
# Test that inspect.getsource works but don't hard-code implementation
|
| 1495 |
-
# details
|
| 1496 |
-
assert 'x**2' in inspect.getsource(f)
|
| 1497 |
-
|
| 1498 |
-
|
| 1499 |
-
def test_issue_14941():
|
| 1500 |
-
x, y = Dummy(), Dummy()
|
| 1501 |
-
|
| 1502 |
-
# test dict
|
| 1503 |
-
f1 = lambdify([x, y], {x: 3, y: 3}, 'sympy')
|
| 1504 |
-
assert f1(2, 3) == {2: 3, 3: 3}
|
| 1505 |
-
|
| 1506 |
-
# test tuple
|
| 1507 |
-
f2 = lambdify([x, y], (y, x), 'sympy')
|
| 1508 |
-
assert f2(2, 3) == (3, 2)
|
| 1509 |
-
f2b = lambdify([], (1,)) # gh-23224
|
| 1510 |
-
assert f2b() == (1,)
|
| 1511 |
-
|
| 1512 |
-
# test list
|
| 1513 |
-
f3 = lambdify([x, y], [y, x], 'sympy')
|
| 1514 |
-
assert f3(2, 3) == [3, 2]
|
| 1515 |
-
|
| 1516 |
-
|
| 1517 |
-
def test_lambdify_Derivative_arg_issue_16468():
|
| 1518 |
-
f = Function('f')(x)
|
| 1519 |
-
fx = f.diff()
|
| 1520 |
-
assert lambdify((f, fx), f + fx)(10, 5) == 15
|
| 1521 |
-
assert eval(lambdastr((f, fx), f/fx))(10, 5) == 2
|
| 1522 |
-
raises(Exception, lambda:
|
| 1523 |
-
eval(lambdastr((f, fx), f/fx, dummify=False)))
|
| 1524 |
-
assert eval(lambdastr((f, fx), f/fx, dummify=True))(10, 5) == 2
|
| 1525 |
-
assert eval(lambdastr((fx, f), f/fx, dummify=True))(S(10), 5) == S.Half
|
| 1526 |
-
assert lambdify(fx, 1 + fx)(41) == 42
|
| 1527 |
-
assert eval(lambdastr(fx, 1 + fx, dummify=True))(41) == 42
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
-
def test_lambdify_Derivative_zeta():
|
| 1531 |
-
# This is related to gh-11802 (and to lesser extent gh-26663)
|
| 1532 |
-
expr1 = zeta(x).diff(x, evaluate=False)
|
| 1533 |
-
f1 = lambdify(x, expr1, modules=['mpmath'])
|
| 1534 |
-
ans1 = f1(2)
|
| 1535 |
-
ref1 = (zeta(2+1e-8).evalf()-zeta(2).evalf())/1e-8
|
| 1536 |
-
assert abs(ans1 - ref1)/abs(ref1) < 1e-7
|
| 1537 |
-
|
| 1538 |
-
expr2 = zeta(x**2).diff(x)
|
| 1539 |
-
f2 = lambdify(x, expr2, modules=['mpmath'])
|
| 1540 |
-
ans2 = f2(2**0.5)
|
| 1541 |
-
ref2 = 2*2**0.5*ref1
|
| 1542 |
-
assert abs(ans2-ref2)/abs(ref2) < 1e-7
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
def test_lambdify_Derivative_custom_printer():
|
| 1546 |
-
func1 = Function('func1')
|
| 1547 |
-
func2 = Function('func2')
|
| 1548 |
-
|
| 1549 |
-
class MyPrinter(NumPyPrinter):
|
| 1550 |
-
|
| 1551 |
-
def _print_Derivative_func1(self, args, seq_orders):
|
| 1552 |
-
arg, = args
|
| 1553 |
-
order, = seq_orders
|
| 1554 |
-
return '42'
|
| 1555 |
-
|
| 1556 |
-
expr1 = func1(x).diff(x)
|
| 1557 |
-
raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr1))
|
| 1558 |
-
f1 = lambdify([x], expr1, printer=MyPrinter)
|
| 1559 |
-
assert f1(7) == 42
|
| 1560 |
-
|
| 1561 |
-
expr2 = func2(x).diff(x)
|
| 1562 |
-
raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr2, printer=MyPrinter))
|
| 1563 |
-
|
| 1564 |
-
|
| 1565 |
-
def test_lambdify_derivative_and_functions_as_arguments():
|
| 1566 |
-
# see: https://github.com/sympy/sympy/issues/26663#issuecomment-2157179517
|
| 1567 |
-
t, a, b = symbols('t, a, b')
|
| 1568 |
-
f = Function('f')(t)
|
| 1569 |
-
args = f.diff(t, 2), f.diff(t), f, a, b
|
| 1570 |
-
expr1 = a*f.diff(t, 2) + b*f.diff(t) + a*b*f + a**2
|
| 1571 |
-
num_args = 2.0, 3.0, 4.0, 5.0, 6.0
|
| 1572 |
-
ref1 = 5*2 + 6*3 + 5*6*4 + 5**2
|
| 1573 |
-
|
| 1574 |
-
expr2 = a*f.diff(t, 2) + b*f.diff(t) - a*b*f + b**2 - a**2
|
| 1575 |
-
ref2 = 5*2 + 6*3 - 5*6*4 + 6**2 - 5**2
|
| 1576 |
-
|
| 1577 |
-
for dummify, _cse in product([False, None, True], [False, True]):
|
| 1578 |
-
func1 = lambdify(args, expr1, cse=_cse, dummify=dummify)
|
| 1579 |
-
res1 = func1(*num_args)
|
| 1580 |
-
assert abs(res1 - ref1) < 1e-12
|
| 1581 |
-
|
| 1582 |
-
func12 = lambdify(args, [expr1, expr2], cse=_cse, dummify=dummify)
|
| 1583 |
-
res12 = func12(*num_args)
|
| 1584 |
-
assert len(res12) == 2
|
| 1585 |
-
assert abs(res12[0] - ref1) < 1e-12
|
| 1586 |
-
assert abs(res12[1] - ref2) < 1e-12
|
| 1587 |
-
|
| 1588 |
-
|
| 1589 |
-
def test_imag_real():
|
| 1590 |
-
f_re = lambdify([z], sympy.re(z))
|
| 1591 |
-
val = 3+2j
|
| 1592 |
-
assert f_re(val) == val.real
|
| 1593 |
-
|
| 1594 |
-
f_im = lambdify([z], sympy.im(z)) # see #15400
|
| 1595 |
-
assert f_im(val) == val.imag
|
| 1596 |
-
|
| 1597 |
-
|
| 1598 |
-
def test_MatrixSymbol_issue_15578():
|
| 1599 |
-
if not numpy:
|
| 1600 |
-
skip("numpy not installed")
|
| 1601 |
-
A = MatrixSymbol('A', 2, 2)
|
| 1602 |
-
A0 = numpy.array([[1, 2], [3, 4]])
|
| 1603 |
-
f = lambdify(A, A**(-1))
|
| 1604 |
-
assert numpy.allclose(f(A0), numpy.array([[-2., 1.], [1.5, -0.5]]))
|
| 1605 |
-
g = lambdify(A, A**3)
|
| 1606 |
-
assert numpy.allclose(g(A0), numpy.array([[37, 54], [81, 118]]))
|
| 1607 |
-
|
| 1608 |
-
|
| 1609 |
-
def test_issue_15654():
|
| 1610 |
-
if not scipy:
|
| 1611 |
-
skip("scipy not installed")
|
| 1612 |
-
from sympy.abc import n, l, r, Z
|
| 1613 |
-
from sympy.physics import hydrogen
|
| 1614 |
-
nv, lv, rv, Zv = 1, 0, 3, 1
|
| 1615 |
-
sympy_value = hydrogen.R_nl(nv, lv, rv, Zv).evalf()
|
| 1616 |
-
f = lambdify((n, l, r, Z), hydrogen.R_nl(n, l, r, Z))
|
| 1617 |
-
scipy_value = f(nv, lv, rv, Zv)
|
| 1618 |
-
assert abs(sympy_value - scipy_value) < 1e-15
|
| 1619 |
-
|
| 1620 |
-
|
| 1621 |
-
def test_issue_15827():
|
| 1622 |
-
if not numpy:
|
| 1623 |
-
skip("numpy not installed")
|
| 1624 |
-
A = MatrixSymbol("A", 3, 3)
|
| 1625 |
-
B = MatrixSymbol("B", 2, 3)
|
| 1626 |
-
C = MatrixSymbol("C", 3, 4)
|
| 1627 |
-
D = MatrixSymbol("D", 4, 5)
|
| 1628 |
-
k=symbols("k")
|
| 1629 |
-
f = lambdify(A, (2*k)*A)
|
| 1630 |
-
g = lambdify(A, (2+k)*A)
|
| 1631 |
-
h = lambdify(A, 2*A)
|
| 1632 |
-
i = lambdify((B, C, D), 2*B*C*D)
|
| 1633 |
-
assert numpy.array_equal(f(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
|
| 1634 |
-
numpy.array([[2*k, 4*k, 6*k], [2*k, 4*k, 6*k], [2*k, 4*k, 6*k]], dtype=object))
|
| 1635 |
-
|
| 1636 |
-
assert numpy.array_equal(g(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
|
| 1637 |
-
numpy.array([[k + 2, 2*k + 4, 3*k + 6], [k + 2, 2*k + 4, 3*k + 6], \
|
| 1638 |
-
[k + 2, 2*k + 4, 3*k + 6]], dtype=object))
|
| 1639 |
-
|
| 1640 |
-
assert numpy.array_equal(h(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
|
| 1641 |
-
numpy.array([[2, 4, 6], [2, 4, 6], [2, 4, 6]]))
|
| 1642 |
-
|
| 1643 |
-
assert numpy.array_equal(i(numpy.array([[1, 2, 3], [1, 2, 3]]), numpy.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]), \
|
| 1644 |
-
numpy.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])), numpy.array([[ 120, 240, 360, 480, 600], \
|
| 1645 |
-
[ 120, 240, 360, 480, 600]]))
|
| 1646 |
-
|
| 1647 |
-
|
| 1648 |
-
def test_issue_16930():
|
| 1649 |
-
if not scipy:
|
| 1650 |
-
skip("scipy not installed")
|
| 1651 |
-
|
| 1652 |
-
x = symbols("x")
|
| 1653 |
-
f = lambda x: S.GoldenRatio * x**2
|
| 1654 |
-
f_ = lambdify(x, f(x), modules='scipy')
|
| 1655 |
-
assert f_(1) == scipy.constants.golden_ratio
|
| 1656 |
-
|
| 1657 |
-
def test_issue_17898():
|
| 1658 |
-
if not scipy:
|
| 1659 |
-
skip("scipy not installed")
|
| 1660 |
-
x = symbols("x")
|
| 1661 |
-
f_ = lambdify([x], sympy.LambertW(x,-1), modules='scipy')
|
| 1662 |
-
assert f_(0.1) == mpmath.lambertw(0.1, -1)
|
| 1663 |
-
|
| 1664 |
-
def test_issue_13167_21411():
|
| 1665 |
-
if not numpy:
|
| 1666 |
-
skip("numpy not installed")
|
| 1667 |
-
f1 = lambdify(x, sympy.Heaviside(x))
|
| 1668 |
-
f2 = lambdify(x, sympy.Heaviside(x, 1))
|
| 1669 |
-
res1 = f1([-1, 0, 1])
|
| 1670 |
-
res2 = f2([-1, 0, 1])
|
| 1671 |
-
assert Abs(res1[0]).n() < 1e-15 # First functionality: only one argument passed
|
| 1672 |
-
assert Abs(res1[1] - 1/2).n() < 1e-15
|
| 1673 |
-
assert Abs(res1[2] - 1).n() < 1e-15
|
| 1674 |
-
assert Abs(res2[0]).n() < 1e-15 # Second functionality: two arguments passed
|
| 1675 |
-
assert Abs(res2[1] - 1).n() < 1e-15
|
| 1676 |
-
assert Abs(res2[2] - 1).n() < 1e-15
|
| 1677 |
-
|
| 1678 |
-
def test_single_e():
|
| 1679 |
-
f = lambdify(x, E)
|
| 1680 |
-
assert f(23) == exp(1.0)
|
| 1681 |
-
|
| 1682 |
-
def test_issue_16536():
|
| 1683 |
-
if not scipy:
|
| 1684 |
-
skip("scipy not installed")
|
| 1685 |
-
|
| 1686 |
-
a = symbols('a')
|
| 1687 |
-
f1 = lowergamma(a, x)
|
| 1688 |
-
F = lambdify((a, x), f1, modules='scipy')
|
| 1689 |
-
assert abs(lowergamma(1, 3) - F(1, 3)) <= 1e-10
|
| 1690 |
-
|
| 1691 |
-
f2 = uppergamma(a, x)
|
| 1692 |
-
F = lambdify((a, x), f2, modules='scipy')
|
| 1693 |
-
assert abs(uppergamma(1, 3) - F(1, 3)) <= 1e-10
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
def test_issue_22726():
|
| 1697 |
-
if not numpy:
|
| 1698 |
-
skip("numpy not installed")
|
| 1699 |
-
|
| 1700 |
-
x1, x2 = symbols('x1 x2')
|
| 1701 |
-
f = Max(S.Zero, Min(x1, x2))
|
| 1702 |
-
g = derive_by_array(f, (x1, x2))
|
| 1703 |
-
G = lambdify((x1, x2), g, modules='numpy')
|
| 1704 |
-
point = {x1: 1, x2: 2}
|
| 1705 |
-
assert (abs(g.subs(point) - G(*point.values())) <= 1e-10).all()
|
| 1706 |
-
|
| 1707 |
-
|
| 1708 |
-
def test_issue_22739():
|
| 1709 |
-
if not numpy:
|
| 1710 |
-
skip("numpy not installed")
|
| 1711 |
-
|
| 1712 |
-
x1, x2 = symbols('x1 x2')
|
| 1713 |
-
f = Heaviside(Min(x1, x2))
|
| 1714 |
-
F = lambdify((x1, x2), f, modules='numpy')
|
| 1715 |
-
point = {x1: 1, x2: 2}
|
| 1716 |
-
assert abs(f.subs(point) - F(*point.values())) <= 1e-10
|
| 1717 |
-
|
| 1718 |
-
|
| 1719 |
-
def test_issue_22992():
|
| 1720 |
-
if not numpy:
|
| 1721 |
-
skip("numpy not installed")
|
| 1722 |
-
|
| 1723 |
-
a, t = symbols('a t')
|
| 1724 |
-
expr = a*(log(cot(t/2)) - cos(t))
|
| 1725 |
-
F = lambdify([a, t], expr, 'numpy')
|
| 1726 |
-
|
| 1727 |
-
point = {a: 10, t: 2}
|
| 1728 |
-
|
| 1729 |
-
assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
|
| 1730 |
-
|
| 1731 |
-
# Standard math
|
| 1732 |
-
F = lambdify([a, t], expr)
|
| 1733 |
-
|
| 1734 |
-
assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
|
| 1735 |
-
|
| 1736 |
-
|
| 1737 |
-
def test_issue_19764():
|
| 1738 |
-
if not numpy:
|
| 1739 |
-
skip("numpy not installed")
|
| 1740 |
-
|
| 1741 |
-
expr = Array([x, x**2])
|
| 1742 |
-
f = lambdify(x, expr, 'numpy')
|
| 1743 |
-
|
| 1744 |
-
assert f(1).__class__ == numpy.ndarray
|
| 1745 |
-
|
| 1746 |
-
def test_issue_20070():
|
| 1747 |
-
if not numba:
|
| 1748 |
-
skip("numba not installed")
|
| 1749 |
-
|
| 1750 |
-
f = lambdify(x, sin(x), 'numpy')
|
| 1751 |
-
assert numba.jit(f, nopython=True)(1)==0.8414709848078965
|
| 1752 |
-
|
| 1753 |
-
|
| 1754 |
-
def test_fresnel_integrals_scipy():
|
| 1755 |
-
if not scipy:
|
| 1756 |
-
skip("scipy not installed")
|
| 1757 |
-
|
| 1758 |
-
f1 = fresnelc(x)
|
| 1759 |
-
f2 = fresnels(x)
|
| 1760 |
-
F1 = lambdify(x, f1, modules='scipy')
|
| 1761 |
-
F2 = lambdify(x, f2, modules='scipy')
|
| 1762 |
-
|
| 1763 |
-
assert abs(fresnelc(1.3) - F1(1.3)) <= 1e-10
|
| 1764 |
-
assert abs(fresnels(1.3) - F2(1.3)) <= 1e-10
|
| 1765 |
-
|
| 1766 |
-
|
| 1767 |
-
def test_beta_scipy():
|
| 1768 |
-
if not scipy:
|
| 1769 |
-
skip("scipy not installed")
|
| 1770 |
-
|
| 1771 |
-
f = beta(x, y)
|
| 1772 |
-
F = lambdify((x, y), f, modules='scipy')
|
| 1773 |
-
|
| 1774 |
-
assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
|
| 1775 |
-
|
| 1776 |
-
|
| 1777 |
-
def test_beta_math():
|
| 1778 |
-
f = beta(x, y)
|
| 1779 |
-
F = lambdify((x, y), f, modules='math')
|
| 1780 |
-
|
| 1781 |
-
assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
|
| 1782 |
-
|
| 1783 |
-
|
| 1784 |
-
def test_betainc_scipy():
|
| 1785 |
-
if not scipy:
|
| 1786 |
-
skip("scipy not installed")
|
| 1787 |
-
|
| 1788 |
-
f = betainc(w, x, y, z)
|
| 1789 |
-
F = lambdify((w, x, y, z), f, modules='scipy')
|
| 1790 |
-
|
| 1791 |
-
assert abs(betainc(1.4, 3.1, 0.1, 0.5) - F(1.4, 3.1, 0.1, 0.5)) <= 1e-10
|
| 1792 |
-
|
| 1793 |
-
|
| 1794 |
-
def test_betainc_regularized_scipy():
|
| 1795 |
-
if not scipy:
|
| 1796 |
-
skip("scipy not installed")
|
| 1797 |
-
|
| 1798 |
-
f = betainc_regularized(w, x, y, z)
|
| 1799 |
-
F = lambdify((w, x, y, z), f, modules='scipy')
|
| 1800 |
-
|
| 1801 |
-
assert abs(betainc_regularized(0.2, 3.5, 0.1, 1) - F(0.2, 3.5, 0.1, 1)) <= 1e-10
|
| 1802 |
-
|
| 1803 |
-
|
| 1804 |
-
def test_numpy_special_math():
|
| 1805 |
-
if not numpy:
|
| 1806 |
-
skip("numpy not installed")
|
| 1807 |
-
|
| 1808 |
-
funcs = [expm1, log1p, exp2, log2, log10, hypot, logaddexp, logaddexp2]
|
| 1809 |
-
for func in funcs:
|
| 1810 |
-
if 2 in func.nargs:
|
| 1811 |
-
expr = func(x, y)
|
| 1812 |
-
args = (x, y)
|
| 1813 |
-
num_args = (0.3, 0.4)
|
| 1814 |
-
elif 1 in func.nargs:
|
| 1815 |
-
expr = func(x)
|
| 1816 |
-
args = (x,)
|
| 1817 |
-
num_args = (0.3,)
|
| 1818 |
-
else:
|
| 1819 |
-
raise NotImplementedError("Need to handle other than unary & binary functions in test")
|
| 1820 |
-
f = lambdify(args, expr)
|
| 1821 |
-
result = f(*num_args)
|
| 1822 |
-
reference = expr.subs(dict(zip(args, num_args))).evalf()
|
| 1823 |
-
assert numpy.allclose(result, float(reference))
|
| 1824 |
-
|
| 1825 |
-
lae2 = lambdify((x, y), logaddexp2(log2(x), log2(y)))
|
| 1826 |
-
assert abs(2.0**lae2(1e-50, 2.5e-50) - 3.5e-50) < 1e-62 # from NumPy's docstring
|
| 1827 |
-
|
| 1828 |
-
|
| 1829 |
-
def test_scipy_special_math():
|
| 1830 |
-
if not scipy:
|
| 1831 |
-
skip("scipy not installed")
|
| 1832 |
-
|
| 1833 |
-
cm1 = lambdify((x,), cosm1(x), modules='scipy')
|
| 1834 |
-
assert abs(cm1(1e-20) + 5e-41) < 1e-200
|
| 1835 |
-
|
| 1836 |
-
have_scipy_1_10plus = tuple(map(int, scipy.version.version.split('.')[:2])) >= (1, 10)
|
| 1837 |
-
|
| 1838 |
-
if have_scipy_1_10plus:
|
| 1839 |
-
cm2 = lambdify((x, y), powm1(x, y), modules='scipy')
|
| 1840 |
-
assert abs(cm2(1.2, 1e-9) - 1.82321557e-10) < 1e-17
|
| 1841 |
-
|
| 1842 |
-
|
| 1843 |
-
def test_scipy_bernoulli():
|
| 1844 |
-
if not scipy:
|
| 1845 |
-
skip("scipy not installed")
|
| 1846 |
-
|
| 1847 |
-
bern = lambdify((x,), bernoulli(x), modules='scipy')
|
| 1848 |
-
assert bern(1) == 0.5
|
| 1849 |
-
|
| 1850 |
-
|
| 1851 |
-
def test_scipy_harmonic():
|
| 1852 |
-
if not scipy:
|
| 1853 |
-
skip("scipy not installed")
|
| 1854 |
-
|
| 1855 |
-
hn = lambdify((x,), harmonic(x), modules='scipy')
|
| 1856 |
-
assert hn(2) == 1.5
|
| 1857 |
-
hnm = lambdify((x, y), harmonic(x, y), modules='scipy')
|
| 1858 |
-
assert hnm(2, 2) == 1.25
|
| 1859 |
-
|
| 1860 |
-
|
| 1861 |
-
def test_cupy_array_arg():
|
| 1862 |
-
if not cupy:
|
| 1863 |
-
skip("CuPy not installed")
|
| 1864 |
-
|
| 1865 |
-
f = lambdify([[x, y]], x*x + y, 'cupy')
|
| 1866 |
-
result = f(cupy.array([2.0, 1.0]))
|
| 1867 |
-
assert result == 5
|
| 1868 |
-
assert "cupy" in str(type(result))
|
| 1869 |
-
|
| 1870 |
-
|
| 1871 |
-
def test_cupy_array_arg_using_numpy():
|
| 1872 |
-
# numpy functions can be run on cupy arrays
|
| 1873 |
-
# unclear if we can "officially" support this,
|
| 1874 |
-
# depends on numpy __array_function__ support
|
| 1875 |
-
if not cupy:
|
| 1876 |
-
skip("CuPy not installed")
|
| 1877 |
-
|
| 1878 |
-
f = lambdify([[x, y]], x*x + y, 'numpy')
|
| 1879 |
-
result = f(cupy.array([2.0, 1.0]))
|
| 1880 |
-
assert result == 5
|
| 1881 |
-
assert "cupy" in str(type(result))
|
| 1882 |
-
|
| 1883 |
-
def test_cupy_dotproduct():
|
| 1884 |
-
if not cupy:
|
| 1885 |
-
skip("CuPy not installed")
|
| 1886 |
-
|
| 1887 |
-
A = Matrix([x, y, z])
|
| 1888 |
-
f1 = lambdify([x, y, z], DotProduct(A, A), modules='cupy')
|
| 1889 |
-
f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
|
| 1890 |
-
f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='cupy')
|
| 1891 |
-
f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
|
| 1892 |
-
|
| 1893 |
-
assert f1(1, 2, 3) == \
|
| 1894 |
-
f2(1, 2, 3) == \
|
| 1895 |
-
f3(1, 2, 3) == \
|
| 1896 |
-
f4(1, 2, 3) == \
|
| 1897 |
-
cupy.array([14])
|
| 1898 |
-
|
| 1899 |
-
|
| 1900 |
-
def test_jax_array_arg():
|
| 1901 |
-
if not jax:
|
| 1902 |
-
skip("JAX not installed")
|
| 1903 |
-
|
| 1904 |
-
f = lambdify([[x, y]], x*x + y, 'jax')
|
| 1905 |
-
result = f(jax.numpy.array([2.0, 1.0]))
|
| 1906 |
-
assert result == 5
|
| 1907 |
-
assert "jax" in str(type(result))
|
| 1908 |
-
|
| 1909 |
-
|
| 1910 |
-
def test_jax_array_arg_using_numpy():
|
| 1911 |
-
if not jax:
|
| 1912 |
-
skip("JAX not installed")
|
| 1913 |
-
|
| 1914 |
-
f = lambdify([[x, y]], x*x + y, 'numpy')
|
| 1915 |
-
result = f(jax.numpy.array([2.0, 1.0]))
|
| 1916 |
-
assert result == 5
|
| 1917 |
-
assert "jax" in str(type(result))
|
| 1918 |
-
|
| 1919 |
-
|
| 1920 |
-
def test_jax_dotproduct():
|
| 1921 |
-
if not jax:
|
| 1922 |
-
skip("JAX not installed")
|
| 1923 |
-
|
| 1924 |
-
A = Matrix([x, y, z])
|
| 1925 |
-
f1 = lambdify([x, y, z], DotProduct(A, A), modules='jax')
|
| 1926 |
-
f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
|
| 1927 |
-
f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='jax')
|
| 1928 |
-
f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
|
| 1929 |
-
|
| 1930 |
-
assert f1(1, 2, 3) == \
|
| 1931 |
-
f2(1, 2, 3) == \
|
| 1932 |
-
f3(1, 2, 3) == \
|
| 1933 |
-
f4(1, 2, 3) == \
|
| 1934 |
-
jax.numpy.array([14])
|
| 1935 |
-
|
| 1936 |
-
|
| 1937 |
-
def test_lambdify_cse():
|
| 1938 |
-
def no_op_cse(exprs):
|
| 1939 |
-
return (), exprs
|
| 1940 |
-
|
| 1941 |
-
def dummy_cse(exprs):
|
| 1942 |
-
from sympy.simplify.cse_main import cse
|
| 1943 |
-
return cse(exprs, symbols=numbered_symbols(cls=Dummy))
|
| 1944 |
-
|
| 1945 |
-
def minmem(exprs):
|
| 1946 |
-
from sympy.simplify.cse_main import cse_release_variables, cse
|
| 1947 |
-
return cse(exprs, postprocess=cse_release_variables)
|
| 1948 |
-
|
| 1949 |
-
class Case:
|
| 1950 |
-
def __init__(self, *, args, exprs, num_args, requires_numpy=False):
|
| 1951 |
-
self.args = args
|
| 1952 |
-
self.exprs = exprs
|
| 1953 |
-
self.num_args = num_args
|
| 1954 |
-
subs_dict = dict(zip(self.args, self.num_args))
|
| 1955 |
-
self.ref = [e.subs(subs_dict).evalf() for e in exprs]
|
| 1956 |
-
self.requires_numpy = requires_numpy
|
| 1957 |
-
|
| 1958 |
-
def lambdify(self, *, cse):
|
| 1959 |
-
return lambdify(self.args, self.exprs, cse=cse)
|
| 1960 |
-
|
| 1961 |
-
def assertAllClose(self, result, *, abstol=1e-15, reltol=1e-15):
|
| 1962 |
-
if self.requires_numpy:
|
| 1963 |
-
assert all(numpy.allclose(result[i], numpy.asarray(r, dtype=float),
|
| 1964 |
-
rtol=reltol, atol=abstol)
|
| 1965 |
-
for i, r in enumerate(self.ref))
|
| 1966 |
-
return
|
| 1967 |
-
|
| 1968 |
-
for i, r in enumerate(self.ref):
|
| 1969 |
-
abs_err = abs(result[i] - r)
|
| 1970 |
-
if r == 0:
|
| 1971 |
-
assert abs_err < abstol
|
| 1972 |
-
else:
|
| 1973 |
-
assert abs_err/abs(r) < reltol
|
| 1974 |
-
|
| 1975 |
-
cases = [
|
| 1976 |
-
Case(
|
| 1977 |
-
args=(x, y, z),
|
| 1978 |
-
exprs=[
|
| 1979 |
-
x + y + z,
|
| 1980 |
-
x + y - z,
|
| 1981 |
-
2*x + 2*y - z,
|
| 1982 |
-
(x+y)**2 + (y+z)**2,
|
| 1983 |
-
],
|
| 1984 |
-
num_args=(2., 3., 4.)
|
| 1985 |
-
),
|
| 1986 |
-
Case(
|
| 1987 |
-
args=(x, y, z),
|
| 1988 |
-
exprs=[
|
| 1989 |
-
x + sympy.Heaviside(x),
|
| 1990 |
-
y + sympy.Heaviside(x),
|
| 1991 |
-
z + sympy.Heaviside(x, 1),
|
| 1992 |
-
z/sympy.Heaviside(x, 1)
|
| 1993 |
-
],
|
| 1994 |
-
num_args=(0., 3., 4.)
|
| 1995 |
-
),
|
| 1996 |
-
Case(
|
| 1997 |
-
args=(x, y, z),
|
| 1998 |
-
exprs=[
|
| 1999 |
-
x + sinc(y),
|
| 2000 |
-
y + sinc(y),
|
| 2001 |
-
z - sinc(y)
|
| 2002 |
-
],
|
| 2003 |
-
num_args=(0.1, 0.2, 0.3)
|
| 2004 |
-
),
|
| 2005 |
-
Case(
|
| 2006 |
-
args=(x, y, z),
|
| 2007 |
-
exprs=[
|
| 2008 |
-
Matrix([[x, x*y], [sin(z) + 4, x**z]]),
|
| 2009 |
-
x*y+sin(z)-x**z,
|
| 2010 |
-
Matrix([x*x, sin(z), x**z])
|
| 2011 |
-
],
|
| 2012 |
-
num_args=(1.,2.,3.),
|
| 2013 |
-
requires_numpy=True
|
| 2014 |
-
),
|
| 2015 |
-
Case(
|
| 2016 |
-
args=(x, y),
|
| 2017 |
-
exprs=[(x + y - 1)**2, x, x + y,
|
| 2018 |
-
(x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)],
|
| 2019 |
-
num_args=(1,2)
|
| 2020 |
-
)
|
| 2021 |
-
]
|
| 2022 |
-
for case in cases:
|
| 2023 |
-
if not numpy and case.requires_numpy:
|
| 2024 |
-
continue
|
| 2025 |
-
for _cse in [False, True, minmem, no_op_cse, dummy_cse]:
|
| 2026 |
-
f = case.lambdify(cse=_cse)
|
| 2027 |
-
result = f(*case.num_args)
|
| 2028 |
-
case.assertAllClose(result)
|
| 2029 |
-
|
| 2030 |
-
def test_issue_25288():
|
| 2031 |
-
syms = numbered_symbols(cls=Dummy)
|
| 2032 |
-
ok = lambdify(x, [x**2, sin(x**2)], cse=lambda e: cse(e, symbols=syms))(2)
|
| 2033 |
-
assert ok
|
| 2034 |
-
|
| 2035 |
-
|
| 2036 |
-
def test_deprecated_set():
|
| 2037 |
-
with warns_deprecated_sympy():
|
| 2038 |
-
lambdify({x, y}, x + y)
|
| 2039 |
-
|
| 2040 |
-
def test_issue_13881():
|
| 2041 |
-
if not numpy:
|
| 2042 |
-
skip("numpy not installed.")
|
| 2043 |
-
|
| 2044 |
-
X = MatrixSymbol('X', 3, 1)
|
| 2045 |
-
|
| 2046 |
-
f = lambdify(X, X.T*X, 'numpy')
|
| 2047 |
-
assert f(numpy.array([1, 2, 3])) == 14
|
| 2048 |
-
assert f(numpy.array([3, 2, 1])) == 14
|
| 2049 |
-
|
| 2050 |
-
f = lambdify(X, X*X.T, 'numpy')
|
| 2051 |
-
assert f(numpy.array([1, 2, 3])) == 14
|
| 2052 |
-
assert f(numpy.array([3, 2, 1])) == 14
|
| 2053 |
-
|
| 2054 |
-
f = lambdify(X, (X*X.T)*X, 'numpy')
|
| 2055 |
-
arr1 = numpy.array([[1], [2], [3]])
|
| 2056 |
-
arr2 = numpy.array([[14],[28],[42]])
|
| 2057 |
-
|
| 2058 |
-
assert numpy.array_equal(f(arr1), arr2)
|
| 2059 |
-
|
| 2060 |
-
|
| 2061 |
-
def test_23536_lambdify_cse_dummy():
|
| 2062 |
-
|
| 2063 |
-
f = Function('x')(y)
|
| 2064 |
-
g = Function('w')(y)
|
| 2065 |
-
expr = z + (f**4 + g**5)*(f**3 + (g*f)**3)
|
| 2066 |
-
expr = expr.expand()
|
| 2067 |
-
eval_expr = lambdify(((f, g), z), expr, cse=True)
|
| 2068 |
-
ans = eval_expr((1.0, 2.0), 3.0) # shouldn't raise NameError
|
| 2069 |
-
assert ans == 300.0 # not a list and value is 300
|
| 2070 |
-
|
| 2071 |
-
|
| 2072 |
-
class LambdifyDocstringTestCase:
|
| 2073 |
-
SIGNATURE = None
|
| 2074 |
-
EXPR = None
|
| 2075 |
-
SRC = None
|
| 2076 |
-
|
| 2077 |
-
def __init__(self, docstring_limit, expected_redacted):
|
| 2078 |
-
self.docstring_limit = docstring_limit
|
| 2079 |
-
self.expected_redacted = expected_redacted
|
| 2080 |
-
|
| 2081 |
-
@property
|
| 2082 |
-
def expected_expr(self):
|
| 2083 |
-
expr_redacted_msg = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
|
| 2084 |
-
return self.EXPR if not self.expected_redacted else expr_redacted_msg
|
| 2085 |
-
|
| 2086 |
-
@property
|
| 2087 |
-
def expected_src(self):
|
| 2088 |
-
src_redacted_msg = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
|
| 2089 |
-
return self.SRC if not self.expected_redacted else src_redacted_msg
|
| 2090 |
-
|
| 2091 |
-
@property
|
| 2092 |
-
def expected_docstring(self):
|
| 2093 |
-
expected_docstring = (
|
| 2094 |
-
f'Created with lambdify. Signature:\n\n'
|
| 2095 |
-
f'func({self.SIGNATURE})\n\n'
|
| 2096 |
-
f'Expression:\n\n'
|
| 2097 |
-
f'{self.expected_expr}\n\n'
|
| 2098 |
-
f'Source code:\n\n'
|
| 2099 |
-
f'{self.expected_src}\n\n'
|
| 2100 |
-
f'Imported modules:\n\n'
|
| 2101 |
-
)
|
| 2102 |
-
return expected_docstring
|
| 2103 |
-
|
| 2104 |
-
def __len__(self):
|
| 2105 |
-
return len(self.expected_docstring)
|
| 2106 |
-
|
| 2107 |
-
def __repr__(self):
|
| 2108 |
-
return (
|
| 2109 |
-
f'{self.__class__.__name__}('
|
| 2110 |
-
f'docstring_limit={self.docstring_limit}, '
|
| 2111 |
-
f'expected_redacted={self.expected_redacted})'
|
| 2112 |
-
)
|
| 2113 |
-
|
| 2114 |
-
|
| 2115 |
-
def test_lambdify_docstring_size_limit_simple_symbol():
|
| 2116 |
-
|
| 2117 |
-
class SimpleSymbolTestCase(LambdifyDocstringTestCase):
|
| 2118 |
-
SIGNATURE = 'x'
|
| 2119 |
-
EXPR = 'x'
|
| 2120 |
-
SRC = (
|
| 2121 |
-
'def _lambdifygenerated(x):\n'
|
| 2122 |
-
' return x\n'
|
| 2123 |
-
)
|
| 2124 |
-
|
| 2125 |
-
x = symbols('x')
|
| 2126 |
-
|
| 2127 |
-
test_cases = (
|
| 2128 |
-
SimpleSymbolTestCase(docstring_limit=None, expected_redacted=False),
|
| 2129 |
-
SimpleSymbolTestCase(docstring_limit=100, expected_redacted=False),
|
| 2130 |
-
SimpleSymbolTestCase(docstring_limit=1, expected_redacted=False),
|
| 2131 |
-
SimpleSymbolTestCase(docstring_limit=0, expected_redacted=True),
|
| 2132 |
-
SimpleSymbolTestCase(docstring_limit=-1, expected_redacted=True),
|
| 2133 |
-
)
|
| 2134 |
-
for test_case in test_cases:
|
| 2135 |
-
lambdified_expr = lambdify(
|
| 2136 |
-
[x],
|
| 2137 |
-
x,
|
| 2138 |
-
'sympy',
|
| 2139 |
-
docstring_limit=test_case.docstring_limit,
|
| 2140 |
-
)
|
| 2141 |
-
assert lambdified_expr.__doc__ == test_case.expected_docstring
|
| 2142 |
-
|
| 2143 |
-
|
| 2144 |
-
def test_lambdify_docstring_size_limit_nested_expr():
|
| 2145 |
-
|
| 2146 |
-
class ExprListTestCase(LambdifyDocstringTestCase):
|
| 2147 |
-
SIGNATURE = 'x, y, z'
|
| 2148 |
-
EXPR = (
|
| 2149 |
-
'[x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z '
|
| 2150 |
-
'+ 3*x*z**2 +...'
|
| 2151 |
-
)
|
| 2152 |
-
SRC = (
|
| 2153 |
-
'def _lambdifygenerated(x, y, z):\n'
|
| 2154 |
-
' return [x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
|
| 2155 |
-
'+ 6*x*y*z + 3*x*z**2 + y**3 + 3*y**2*z + 3*y*z**2 + z**3]\n'
|
| 2156 |
-
)
|
| 2157 |
-
|
| 2158 |
-
x, y, z = symbols('x, y, z')
|
| 2159 |
-
expr = [x, [y], z, ((x + y + z)**3).expand()]
|
| 2160 |
-
|
| 2161 |
-
test_cases = (
|
| 2162 |
-
ExprListTestCase(docstring_limit=None, expected_redacted=False),
|
| 2163 |
-
ExprListTestCase(docstring_limit=200, expected_redacted=False),
|
| 2164 |
-
ExprListTestCase(docstring_limit=50, expected_redacted=True),
|
| 2165 |
-
ExprListTestCase(docstring_limit=0, expected_redacted=True),
|
| 2166 |
-
ExprListTestCase(docstring_limit=-1, expected_redacted=True),
|
| 2167 |
-
)
|
| 2168 |
-
for test_case in test_cases:
|
| 2169 |
-
lambdified_expr = lambdify(
|
| 2170 |
-
[x, y, z],
|
| 2171 |
-
expr,
|
| 2172 |
-
'sympy',
|
| 2173 |
-
docstring_limit=test_case.docstring_limit,
|
| 2174 |
-
)
|
| 2175 |
-
assert lambdified_expr.__doc__ == test_case.expected_docstring
|
| 2176 |
-
|
| 2177 |
-
|
| 2178 |
-
def test_lambdify_docstring_size_limit_matrix():
|
| 2179 |
-
|
| 2180 |
-
class MatrixTestCase(LambdifyDocstringTestCase):
|
| 2181 |
-
SIGNATURE = 'x, y, z'
|
| 2182 |
-
EXPR = (
|
| 2183 |
-
'Matrix([[0, x], [x + y + z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
|
| 2184 |
-
'+ 6*x*y*z...'
|
| 2185 |
-
)
|
| 2186 |
-
SRC = (
|
| 2187 |
-
'def _lambdifygenerated(x, y, z):\n'
|
| 2188 |
-
' return ImmutableDenseMatrix([[0, x], [x + y + z, x**3 '
|
| 2189 |
-
'+ 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z + 3*x*z**2 + y**3 '
|
| 2190 |
-
'+ 3*y**2*z + 3*y*z**2 + z**3]])\n'
|
| 2191 |
-
)
|
| 2192 |
-
|
| 2193 |
-
x, y, z = symbols('x, y, z')
|
| 2194 |
-
expr = Matrix([[S.Zero, x], [x + y + z, ((x + y + z)**3).expand()]])
|
| 2195 |
-
|
| 2196 |
-
test_cases = (
|
| 2197 |
-
MatrixTestCase(docstring_limit=None, expected_redacted=False),
|
| 2198 |
-
MatrixTestCase(docstring_limit=200, expected_redacted=False),
|
| 2199 |
-
MatrixTestCase(docstring_limit=50, expected_redacted=True),
|
| 2200 |
-
MatrixTestCase(docstring_limit=0, expected_redacted=True),
|
| 2201 |
-
MatrixTestCase(docstring_limit=-1, expected_redacted=True),
|
| 2202 |
-
)
|
| 2203 |
-
for test_case in test_cases:
|
| 2204 |
-
lambdified_expr = lambdify(
|
| 2205 |
-
[x, y, z],
|
| 2206 |
-
expr,
|
| 2207 |
-
'sympy',
|
| 2208 |
-
docstring_limit=test_case.docstring_limit,
|
| 2209 |
-
)
|
| 2210 |
-
assert lambdified_expr.__doc__ == test_case.expected_docstring
|
| 2211 |
-
|
| 2212 |
-
|
| 2213 |
-
def test_lambdify_empty_tuple():
|
| 2214 |
-
a = symbols("a")
|
| 2215 |
-
expr = ((), (a,))
|
| 2216 |
-
f = lambdify(a, expr)
|
| 2217 |
-
result = f(1)
|
| 2218 |
-
assert result == ((), (1,)), "Lambdify did not handle the empty tuple correctly."
|
| 2219 |
-
|
| 2220 |
-
|
| 2221 |
-
def test_assoc_legendre_numerical_evaluation():
|
| 2222 |
-
|
| 2223 |
-
tol = 1e-10
|
| 2224 |
-
|
| 2225 |
-
sympy_result_integer = assoc_legendre(1, 1/2, 0.1).evalf()
|
| 2226 |
-
sympy_result_complex = assoc_legendre(2, 1, 3).evalf()
|
| 2227 |
-
mpmath_result_integer = -0.474572528387641
|
| 2228 |
-
mpmath_result_complex = -25.45584412271571*I
|
| 2229 |
-
|
| 2230 |
-
assert all_close(sympy_result_integer, mpmath_result_integer, tol)
|
| 2231 |
-
assert all_close(sympy_result_complex, mpmath_result_complex, tol)
|
| 2232 |
-
|
| 2233 |
-
|
| 2234 |
-
def test_Piecewise():
|
| 2235 |
-
|
| 2236 |
-
modules = [math]
|
| 2237 |
-
if numpy:
|
| 2238 |
-
modules.append('numpy')
|
| 2239 |
-
|
| 2240 |
-
for mod in modules:
|
| 2241 |
-
# test isinf
|
| 2242 |
-
f = lambdify(x, Piecewise((7.0, isinf(x)), (3.0, True)), mod)
|
| 2243 |
-
assert f(+float('inf')) == +7.0
|
| 2244 |
-
assert f(-float('inf')) == +7.0
|
| 2245 |
-
assert f(42.) == 3.0
|
| 2246 |
-
|
| 2247 |
-
f2 = lambdify(x, Piecewise((7.0*sign(x), isinf(x)), (3.0, True)), mod)
|
| 2248 |
-
assert f2(+float('inf')) == +7.0
|
| 2249 |
-
assert f2(-float('inf')) == -7.0
|
| 2250 |
-
assert f2(42.) == 3.0
|
| 2251 |
-
|
| 2252 |
-
# test isnan (gh-26784)
|
| 2253 |
-
g = lambdify(x, Piecewise((7.0, isnan(x)), (3.0, True)), mod)
|
| 2254 |
-
assert g(float('nan')) == 7.0
|
| 2255 |
-
assert g(42.) == 3.0
|
| 2256 |
-
|
| 2257 |
-
|
| 2258 |
-
def test_array_symbol():
|
| 2259 |
-
if not numpy:
|
| 2260 |
-
skip("numpy not installed.")
|
| 2261 |
-
a = ArraySymbol('a', (3,))
|
| 2262 |
-
f = lambdify((a), a)
|
| 2263 |
-
assert numpy.all(f(numpy.array([1,2,3])) == numpy.array([1,2,3]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_matchpy_connector.py
DELETED
|
@@ -1,164 +0,0 @@
|
|
| 1 |
-
import pickle
|
| 2 |
-
|
| 3 |
-
from sympy.core.relational import (Eq, Ne)
|
| 4 |
-
from sympy.core.singleton import S
|
| 5 |
-
from sympy.core.symbol import symbols
|
| 6 |
-
from sympy.functions.elementary.miscellaneous import sqrt
|
| 7 |
-
from sympy.functions.elementary.trigonometric import (cos, sin)
|
| 8 |
-
from sympy.external import import_module
|
| 9 |
-
from sympy.testing.pytest import skip
|
| 10 |
-
from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer
|
| 11 |
-
|
| 12 |
-
matchpy = import_module("matchpy")
|
| 13 |
-
|
| 14 |
-
x, y, z = symbols("x y z")
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def _get_first_match(expr, pattern):
|
| 18 |
-
from matchpy import ManyToOneMatcher, Pattern
|
| 19 |
-
|
| 20 |
-
matcher = ManyToOneMatcher()
|
| 21 |
-
matcher.add(Pattern(pattern))
|
| 22 |
-
return next(iter(matcher.match(expr)))
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def test_matchpy_connector():
|
| 26 |
-
if matchpy is None:
|
| 27 |
-
skip("matchpy not installed")
|
| 28 |
-
|
| 29 |
-
from multiset import Multiset
|
| 30 |
-
from matchpy import Pattern, Substitution
|
| 31 |
-
|
| 32 |
-
w_ = WildDot("w_")
|
| 33 |
-
w__ = WildPlus("w__")
|
| 34 |
-
w___ = WildStar("w___")
|
| 35 |
-
|
| 36 |
-
expr = x + y
|
| 37 |
-
pattern = x + w_
|
| 38 |
-
p, subst = _get_first_match(expr, pattern)
|
| 39 |
-
assert p == Pattern(pattern)
|
| 40 |
-
assert subst == Substitution({'w_': y})
|
| 41 |
-
|
| 42 |
-
expr = x + y + z
|
| 43 |
-
pattern = x + w__
|
| 44 |
-
p, subst = _get_first_match(expr, pattern)
|
| 45 |
-
assert p == Pattern(pattern)
|
| 46 |
-
assert subst == Substitution({'w__': Multiset([y, z])})
|
| 47 |
-
|
| 48 |
-
expr = x + y + z
|
| 49 |
-
pattern = x + y + z + w___
|
| 50 |
-
p, subst = _get_first_match(expr, pattern)
|
| 51 |
-
assert p == Pattern(pattern)
|
| 52 |
-
assert subst == Substitution({'w___': Multiset()})
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def test_matchpy_optional():
|
| 56 |
-
if matchpy is None:
|
| 57 |
-
skip("matchpy not installed")
|
| 58 |
-
|
| 59 |
-
from matchpy import Pattern, Substitution
|
| 60 |
-
from matchpy import ManyToOneReplacer, ReplacementRule
|
| 61 |
-
|
| 62 |
-
p = WildDot("p", optional=1)
|
| 63 |
-
q = WildDot("q", optional=0)
|
| 64 |
-
|
| 65 |
-
pattern = p*x + q
|
| 66 |
-
|
| 67 |
-
expr1 = 2*x
|
| 68 |
-
pa, subst = _get_first_match(expr1, pattern)
|
| 69 |
-
assert pa == Pattern(pattern)
|
| 70 |
-
assert subst == Substitution({'p': 2, 'q': 0})
|
| 71 |
-
|
| 72 |
-
expr2 = x + 3
|
| 73 |
-
pa, subst = _get_first_match(expr2, pattern)
|
| 74 |
-
assert pa == Pattern(pattern)
|
| 75 |
-
assert subst == Substitution({'p': 1, 'q': 3})
|
| 76 |
-
|
| 77 |
-
expr3 = x
|
| 78 |
-
pa, subst = _get_first_match(expr3, pattern)
|
| 79 |
-
assert pa == Pattern(pattern)
|
| 80 |
-
assert subst == Substitution({'p': 1, 'q': 0})
|
| 81 |
-
|
| 82 |
-
expr4 = x*y + z
|
| 83 |
-
pa, subst = _get_first_match(expr4, pattern)
|
| 84 |
-
assert pa == Pattern(pattern)
|
| 85 |
-
assert subst == Substitution({'p': y, 'q': z})
|
| 86 |
-
|
| 87 |
-
replacer = ManyToOneReplacer()
|
| 88 |
-
replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q)))
|
| 89 |
-
assert replacer.replace(expr1) == sin(2)*cos(0)
|
| 90 |
-
assert replacer.replace(expr2) == sin(1)*cos(3)
|
| 91 |
-
assert replacer.replace(expr3) == sin(1)*cos(0)
|
| 92 |
-
assert replacer.replace(expr4) == sin(y)*cos(z)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def test_replacer():
|
| 96 |
-
if matchpy is None:
|
| 97 |
-
skip("matchpy not installed")
|
| 98 |
-
|
| 99 |
-
for info in [True, False]:
|
| 100 |
-
for lambdify in [True, False]:
|
| 101 |
-
_perform_test_replacer(info, lambdify)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def _perform_test_replacer(info, lambdify):
|
| 105 |
-
|
| 106 |
-
x1_ = WildDot("x1_")
|
| 107 |
-
x2_ = WildDot("x2_")
|
| 108 |
-
|
| 109 |
-
a_ = WildDot("a_", optional=S.One)
|
| 110 |
-
b_ = WildDot("b_", optional=S.One)
|
| 111 |
-
c_ = WildDot("c_", optional=S.Zero)
|
| 112 |
-
|
| 113 |
-
replacer = Replacer(common_constraints=[
|
| 114 |
-
matchpy.CustomConstraint(lambda a_: not a_.has(x)),
|
| 115 |
-
matchpy.CustomConstraint(lambda b_: not b_.has(x)),
|
| 116 |
-
matchpy.CustomConstraint(lambda c_: not c_.has(x)),
|
| 117 |
-
], lambdify=lambdify, info=info)
|
| 118 |
-
|
| 119 |
-
# Rewrite the equation into implicit form, unless it's already solved:
|
| 120 |
-
replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)], info=1)
|
| 121 |
-
|
| 122 |
-
# Simple equation solver for real numbers:
|
| 123 |
-
replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_), info=2)
|
| 124 |
-
disc = b_**2 - 4*a_*c_
|
| 125 |
-
replacer.add(
|
| 126 |
-
Eq(a_*x**2 + b_*x + c_, 0),
|
| 127 |
-
Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)),
|
| 128 |
-
conditions_nonfalse=[disc >= 0],
|
| 129 |
-
info=3
|
| 130 |
-
)
|
| 131 |
-
replacer.add(
|
| 132 |
-
Eq(a_*x**2 + c_, 0),
|
| 133 |
-
Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)),
|
| 134 |
-
conditions_nonfalse=[-c_*a_ > 0],
|
| 135 |
-
info=4
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
g = lambda expr, infos: (expr, infos) if info else expr
|
| 139 |
-
|
| 140 |
-
assert replacer.replace(Eq(3*x, y)) == g(Eq(x, y/3), [1, 2])
|
| 141 |
-
assert replacer.replace(Eq(x**2 + 1, 0)) == g(Eq(x**2 + 1, 0), [])
|
| 142 |
-
assert replacer.replace(Eq(x**2, 4)) == g((Eq(x, 2) | Eq(x, -2)), [1, 4])
|
| 143 |
-
assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == g(Eq(x, -2*y), [3])
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def test_matchpy_object_pickle():
|
| 147 |
-
if matchpy is None:
|
| 148 |
-
return
|
| 149 |
-
|
| 150 |
-
a1 = WildDot("a")
|
| 151 |
-
a2 = pickle.loads(pickle.dumps(a1))
|
| 152 |
-
assert a1 == a2
|
| 153 |
-
|
| 154 |
-
a1 = WildDot("a", S(1))
|
| 155 |
-
a2 = pickle.loads(pickle.dumps(a1))
|
| 156 |
-
assert a1 == a2
|
| 157 |
-
|
| 158 |
-
a1 = WildPlus("a", S(1))
|
| 159 |
-
a2 = pickle.loads(pickle.dumps(a1))
|
| 160 |
-
assert a1 == a2
|
| 161 |
-
|
| 162 |
-
a1 = WildStar("a", S(1))
|
| 163 |
-
a2 = pickle.loads(pickle.dumps(a1))
|
| 164 |
-
assert a1 == a2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_mathml.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from textwrap import dedent
|
| 3 |
-
from sympy.external import import_module
|
| 4 |
-
from sympy.testing.pytest import skip
|
| 5 |
-
from sympy.utilities.mathml import apply_xsl
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
lxml = import_module('lxml')
|
| 10 |
-
|
| 11 |
-
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_xxe.py"))
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def test_xxe():
|
| 15 |
-
assert os.path.isfile(path)
|
| 16 |
-
if not lxml:
|
| 17 |
-
skip("lxml not installed.")
|
| 18 |
-
|
| 19 |
-
mml = dedent(
|
| 20 |
-
rf"""
|
| 21 |
-
<!--?xml version="1.0" ?-->
|
| 22 |
-
<!DOCTYPE replace [<!ENTITY ent SYSTEM "file://{path}"> ]>
|
| 23 |
-
<userInfo>
|
| 24 |
-
<firstName>John</firstName>
|
| 25 |
-
<lastName>&ent;</lastName>
|
| 26 |
-
</userInfo>
|
| 27 |
-
"""
|
| 28 |
-
)
|
| 29 |
-
xsl = 'mathml/data/simple_mmlctop.xsl'
|
| 30 |
-
|
| 31 |
-
res = apply_xsl(mml, xsl)
|
| 32 |
-
assert res == \
|
| 33 |
-
'<?xml version="1.0"?>\n<userInfo>\n<firstName>John</firstName>\n<lastName/>\n</userInfo>\n'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_misc.py
DELETED
|
@@ -1,148 +0,0 @@
|
|
| 1 |
-
from textwrap import dedent
|
| 2 |
-
import sys
|
| 3 |
-
from subprocess import Popen, PIPE
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
from sympy.core.singleton import S
|
| 7 |
-
from sympy.testing.pytest import (raises, warns_deprecated_sympy,
|
| 8 |
-
skip_under_pyodide)
|
| 9 |
-
from sympy.utilities.misc import (translate, replace, ordinal, rawlines,
|
| 10 |
-
strlines, as_int, find_executable)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def test_translate():
|
| 14 |
-
abc = 'abc'
|
| 15 |
-
assert translate(abc, None, 'a') == 'bc'
|
| 16 |
-
assert translate(abc, None, '') == 'abc'
|
| 17 |
-
assert translate(abc, {'a': 'x'}, 'c') == 'xb'
|
| 18 |
-
assert translate(abc, {'a': 'bc'}, 'c') == 'bcb'
|
| 19 |
-
assert translate(abc, {'ab': 'x'}, 'c') == 'x'
|
| 20 |
-
assert translate(abc, {'ab': ''}, 'c') == ''
|
| 21 |
-
assert translate(abc, {'bc': 'x'}, 'c') == 'ab'
|
| 22 |
-
assert translate(abc, {'abc': 'x', 'a': 'y'}) == 'x'
|
| 23 |
-
u = chr(4096)
|
| 24 |
-
assert translate(abc, 'a', 'x', u) == 'xbc'
|
| 25 |
-
assert (u in translate(abc, 'a', u, u)) is True
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def test_replace():
|
| 29 |
-
assert replace('abc', ('a', 'b')) == 'bbc'
|
| 30 |
-
assert replace('abc', {'a': 'Aa'}) == 'Aabc'
|
| 31 |
-
assert replace('abc', ('a', 'b'), ('c', 'C')) == 'bbC'
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def test_ordinal():
|
| 35 |
-
assert ordinal(-1) == '-1st'
|
| 36 |
-
assert ordinal(0) == '0th'
|
| 37 |
-
assert ordinal(1) == '1st'
|
| 38 |
-
assert ordinal(2) == '2nd'
|
| 39 |
-
assert ordinal(3) == '3rd'
|
| 40 |
-
assert all(ordinal(i).endswith('th') for i in range(4, 21))
|
| 41 |
-
assert ordinal(100) == '100th'
|
| 42 |
-
assert ordinal(101) == '101st'
|
| 43 |
-
assert ordinal(102) == '102nd'
|
| 44 |
-
assert ordinal(103) == '103rd'
|
| 45 |
-
assert ordinal(104) == '104th'
|
| 46 |
-
assert ordinal(200) == '200th'
|
| 47 |
-
assert all(ordinal(i) == str(i) + 'th' for i in range(-220, -203))
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def test_rawlines():
|
| 51 |
-
assert rawlines('a a\na') == "dedent('''\\\n a a\n a''')"
|
| 52 |
-
assert rawlines('a a') == "'a a'"
|
| 53 |
-
assert rawlines(strlines('\\le"ft')) == (
|
| 54 |
-
'(\n'
|
| 55 |
-
" '(\\n'\n"
|
| 56 |
-
' \'r\\\'\\\\le"ft\\\'\\n\'\n'
|
| 57 |
-
" ')'\n"
|
| 58 |
-
')')
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def test_strlines():
|
| 62 |
-
q = 'this quote (") is in the middle'
|
| 63 |
-
# the following assert rhs was prepared with
|
| 64 |
-
# print(rawlines(strlines(q, 10)))
|
| 65 |
-
assert strlines(q, 10) == dedent('''\
|
| 66 |
-
(
|
| 67 |
-
'this quo'
|
| 68 |
-
'te (") i'
|
| 69 |
-
's in the'
|
| 70 |
-
' middle'
|
| 71 |
-
)''')
|
| 72 |
-
assert q == (
|
| 73 |
-
'this quo'
|
| 74 |
-
'te (") i'
|
| 75 |
-
's in the'
|
| 76 |
-
' middle'
|
| 77 |
-
)
|
| 78 |
-
q = "this quote (') is in the middle"
|
| 79 |
-
assert strlines(q, 20) == dedent('''\
|
| 80 |
-
(
|
| 81 |
-
"this quote (') is "
|
| 82 |
-
"in the middle"
|
| 83 |
-
)''')
|
| 84 |
-
assert strlines('\\left') == (
|
| 85 |
-
'(\n'
|
| 86 |
-
"r'\\left'\n"
|
| 87 |
-
')')
|
| 88 |
-
assert strlines('\\left', short=True) == r"r'\left'"
|
| 89 |
-
assert strlines('\\le"ft') == (
|
| 90 |
-
'(\n'
|
| 91 |
-
'r\'\\le"ft\'\n'
|
| 92 |
-
')')
|
| 93 |
-
q = 'this\nother line'
|
| 94 |
-
assert strlines(q) == rawlines(q)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def test_translate_args():
|
| 98 |
-
try:
|
| 99 |
-
translate(None, None, None, 'not_none')
|
| 100 |
-
except ValueError:
|
| 101 |
-
pass # Exception raised successfully
|
| 102 |
-
else:
|
| 103 |
-
assert False
|
| 104 |
-
|
| 105 |
-
assert translate('s', None, None, None) == 's'
|
| 106 |
-
|
| 107 |
-
try:
|
| 108 |
-
translate('s', 'a', 'bc')
|
| 109 |
-
except ValueError:
|
| 110 |
-
pass # Exception raised successfully
|
| 111 |
-
else:
|
| 112 |
-
assert False
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
@skip_under_pyodide("Cannot create subprocess under pyodide.")
|
| 116 |
-
def test_debug_output():
|
| 117 |
-
env = os.environ.copy()
|
| 118 |
-
env['SYMPY_DEBUG'] = 'True'
|
| 119 |
-
cmd = 'from sympy import *; x = Symbol("x"); print(integrate((1-cos(x))/x, x))'
|
| 120 |
-
cmdline = [sys.executable, '-c', cmd]
|
| 121 |
-
proc = Popen(cmdline, env=env, stdout=PIPE, stderr=PIPE)
|
| 122 |
-
out, err = proc.communicate()
|
| 123 |
-
out = out.decode('ascii') # utf-8?
|
| 124 |
-
err = err.decode('ascii')
|
| 125 |
-
expected = 'substituted: -x*(1 - cos(x)), u: 1/x, u_var: _u'
|
| 126 |
-
assert expected in err, err
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def test_as_int():
|
| 130 |
-
raises(ValueError, lambda : as_int(True))
|
| 131 |
-
raises(ValueError, lambda : as_int(1.1))
|
| 132 |
-
raises(ValueError, lambda : as_int([]))
|
| 133 |
-
raises(ValueError, lambda : as_int(S.NaN))
|
| 134 |
-
raises(ValueError, lambda : as_int(S.Infinity))
|
| 135 |
-
raises(ValueError, lambda : as_int(S.NegativeInfinity))
|
| 136 |
-
raises(ValueError, lambda : as_int(S.ComplexInfinity))
|
| 137 |
-
# for the following, limited precision makes int(arg) == arg
|
| 138 |
-
# but the int value is not necessarily what a user might have
|
| 139 |
-
# expected; Q.prime is more nuanced in its response for
|
| 140 |
-
# expressions which might be complex representations of an
|
| 141 |
-
# integer. This is not -- by design -- as_ints role.
|
| 142 |
-
raises(ValueError, lambda : as_int(1e23))
|
| 143 |
-
raises(ValueError, lambda : as_int(S('1.'+'0'*20+'1')))
|
| 144 |
-
assert as_int(True, strict=False) == 1
|
| 145 |
-
|
| 146 |
-
def test_deprecated_find_executable():
|
| 147 |
-
with warns_deprecated_sympy():
|
| 148 |
-
find_executable('python')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_pickling.py
DELETED
|
@@ -1,723 +0,0 @@
|
|
| 1 |
-
import inspect
|
| 2 |
-
import copy
|
| 3 |
-
import pickle
|
| 4 |
-
|
| 5 |
-
from sympy.physics.units import meter
|
| 6 |
-
|
| 7 |
-
from sympy.testing.pytest import XFAIL, raises, ignore_warnings
|
| 8 |
-
|
| 9 |
-
from sympy.core.basic import Atom, Basic
|
| 10 |
-
from sympy.core.singleton import SingletonRegistry
|
| 11 |
-
from sympy.core.symbol import Str, Dummy, Symbol, Wild
|
| 12 |
-
from sympy.core.numbers import (E, I, pi, oo, zoo, nan, Integer,
|
| 13 |
-
Rational, Float, AlgebraicNumber)
|
| 14 |
-
from sympy.core.relational import (Equality, GreaterThan, LessThan, Relational,
|
| 15 |
-
StrictGreaterThan, StrictLessThan, Unequality)
|
| 16 |
-
from sympy.core.add import Add
|
| 17 |
-
from sympy.core.mul import Mul
|
| 18 |
-
from sympy.core.power import Pow
|
| 19 |
-
from sympy.core.function import Derivative, Function, FunctionClass, Lambda, \
|
| 20 |
-
WildFunction
|
| 21 |
-
from sympy.sets.sets import Interval
|
| 22 |
-
from sympy.core.multidimensional import vectorize
|
| 23 |
-
|
| 24 |
-
from sympy.external.gmpy import gmpy as _gmpy
|
| 25 |
-
from sympy.utilities.exceptions import SymPyDeprecationWarning
|
| 26 |
-
|
| 27 |
-
from sympy.core.singleton import S
|
| 28 |
-
from sympy.core.symbol import symbols
|
| 29 |
-
|
| 30 |
-
from sympy.external import import_module
|
| 31 |
-
cloudpickle = import_module('cloudpickle')
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
not_equal_attrs = {
|
| 35 |
-
'_assumptions', # This is a local cache that isn't automatically filled on creation
|
| 36 |
-
'_mhash', # Cached after __hash__ is called but set to None after creation
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
deprecated_attrs = {
|
| 41 |
-
'is_EmptySet', # Deprecated from SymPy 1.5. This can be removed when is_EmptySet is removed.
|
| 42 |
-
'expr_free_symbols', # Deprecated from SymPy 1.9. This can be removed when exr_free_symbols is removed.
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
dont_check_attrs = {
|
| 46 |
-
'_sage_', # Fails because Sage is not installed
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def check(a, exclude=[], check_attr=True, deprecated=()):
|
| 51 |
-
""" Check that pickling and copying round-trips.
|
| 52 |
-
"""
|
| 53 |
-
# Pickling with protocols 0 and 1 is disabled for Basic instances:
|
| 54 |
-
if isinstance(a, Basic):
|
| 55 |
-
for protocol in [0, 1]:
|
| 56 |
-
raises(NotImplementedError, lambda: pickle.dumps(a, protocol))
|
| 57 |
-
|
| 58 |
-
protocols = [2, copy.copy, copy.deepcopy, 3, 4]
|
| 59 |
-
if cloudpickle:
|
| 60 |
-
protocols.extend([cloudpickle])
|
| 61 |
-
|
| 62 |
-
for protocol in protocols:
|
| 63 |
-
if protocol in exclude:
|
| 64 |
-
continue
|
| 65 |
-
|
| 66 |
-
if callable(protocol):
|
| 67 |
-
if isinstance(a, type):
|
| 68 |
-
# Classes can't be copied, but that's okay.
|
| 69 |
-
continue
|
| 70 |
-
b = protocol(a)
|
| 71 |
-
elif inspect.ismodule(protocol):
|
| 72 |
-
b = protocol.loads(protocol.dumps(a))
|
| 73 |
-
else:
|
| 74 |
-
b = pickle.loads(pickle.dumps(a, protocol))
|
| 75 |
-
|
| 76 |
-
d1 = dir(a)
|
| 77 |
-
d2 = dir(b)
|
| 78 |
-
assert set(d1) == set(d2)
|
| 79 |
-
|
| 80 |
-
if not check_attr:
|
| 81 |
-
continue
|
| 82 |
-
|
| 83 |
-
def c(a, b, d):
|
| 84 |
-
for i in d:
|
| 85 |
-
if i in dont_check_attrs:
|
| 86 |
-
continue
|
| 87 |
-
elif i in not_equal_attrs:
|
| 88 |
-
if hasattr(a, i):
|
| 89 |
-
assert hasattr(b, i), i
|
| 90 |
-
elif i in deprecated_attrs or i in deprecated:
|
| 91 |
-
with ignore_warnings(SymPyDeprecationWarning):
|
| 92 |
-
assert getattr(a, i) == getattr(b, i), i
|
| 93 |
-
elif not hasattr(a, i):
|
| 94 |
-
continue
|
| 95 |
-
else:
|
| 96 |
-
attr = getattr(a, i)
|
| 97 |
-
if not hasattr(attr, "__call__"):
|
| 98 |
-
assert hasattr(b, i), i
|
| 99 |
-
assert getattr(b, i) == attr, "%s != %s, protocol: %s" % (getattr(b, i), attr, protocol)
|
| 100 |
-
|
| 101 |
-
c(a, b, d1)
|
| 102 |
-
c(b, a, d2)
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
#================== core =========================
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def test_core_basic():
|
| 110 |
-
for c in (Atom, Atom(), Basic, Basic(), SingletonRegistry, S):
|
| 111 |
-
check(c)
|
| 112 |
-
|
| 113 |
-
def test_core_Str():
|
| 114 |
-
check(Str('x'))
|
| 115 |
-
|
| 116 |
-
def test_core_symbol():
|
| 117 |
-
# make the Symbol a unique name that doesn't class with any other
|
| 118 |
-
# testing variable in this file since after this test the symbol
|
| 119 |
-
# having the same name will be cached as noncommutative
|
| 120 |
-
for c in (Dummy, Dummy("x", commutative=False), Symbol,
|
| 121 |
-
Symbol("_issue_3130", commutative=False), Wild, Wild("x")):
|
| 122 |
-
check(c)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def test_core_numbers():
|
| 126 |
-
for c in (Integer(2), Rational(2, 3), Float("1.2")):
|
| 127 |
-
check(c)
|
| 128 |
-
for c in (AlgebraicNumber, AlgebraicNumber(sqrt(3))):
|
| 129 |
-
check(c, check_attr=False)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def test_core_float_copy():
|
| 133 |
-
# See gh-7457
|
| 134 |
-
y = Symbol("x") + 1.0
|
| 135 |
-
check(y) # does not raise TypeError ("argument is not an mpz")
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def test_core_relational():
|
| 139 |
-
x = Symbol("x")
|
| 140 |
-
y = Symbol("y")
|
| 141 |
-
for c in (Equality, Equality(x, y), GreaterThan, GreaterThan(x, y),
|
| 142 |
-
LessThan, LessThan(x, y), Relational, Relational(x, y),
|
| 143 |
-
StrictGreaterThan, StrictGreaterThan(x, y), StrictLessThan,
|
| 144 |
-
StrictLessThan(x, y), Unequality, Unequality(x, y)):
|
| 145 |
-
check(c)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def test_core_add():
|
| 149 |
-
x = Symbol("x")
|
| 150 |
-
for c in (Add, Add(x, 4)):
|
| 151 |
-
check(c)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def test_core_mul():
|
| 155 |
-
x = Symbol("x")
|
| 156 |
-
for c in (Mul, Mul(x, 4)):
|
| 157 |
-
check(c)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def test_core_power():
|
| 161 |
-
x = Symbol("x")
|
| 162 |
-
for c in (Pow, Pow(x, 4)):
|
| 163 |
-
check(c)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def test_core_function():
|
| 167 |
-
x = Symbol("x")
|
| 168 |
-
for f in (Derivative, Derivative(x), Function, FunctionClass, Lambda,
|
| 169 |
-
WildFunction):
|
| 170 |
-
check(f)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def test_core_undefinedfunctions():
|
| 174 |
-
f = Function("f")
|
| 175 |
-
check(f)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def test_core_appliedundef():
|
| 179 |
-
x = Symbol("_long_unique_name_1")
|
| 180 |
-
f = Function("_long_unique_name_2")
|
| 181 |
-
check(f(x))
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def test_core_interval():
|
| 185 |
-
for c in (Interval, Interval(0, 2)):
|
| 186 |
-
check(c)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def test_core_multidimensional():
|
| 190 |
-
for c in (vectorize, vectorize(0)):
|
| 191 |
-
check(c)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def test_Singletons():
|
| 195 |
-
protocols = [0, 1, 2, 3, 4]
|
| 196 |
-
copiers = [copy.copy, copy.deepcopy]
|
| 197 |
-
copiers += [lambda x: pickle.loads(pickle.dumps(x, proto))
|
| 198 |
-
for proto in protocols]
|
| 199 |
-
if cloudpickle:
|
| 200 |
-
copiers += [lambda x: cloudpickle.loads(cloudpickle.dumps(x))]
|
| 201 |
-
|
| 202 |
-
for obj in (Integer(-1), Integer(0), Integer(1), Rational(1, 2), pi, E, I,
|
| 203 |
-
oo, -oo, zoo, nan, S.GoldenRatio, S.TribonacciConstant,
|
| 204 |
-
S.EulerGamma, S.Catalan, S.EmptySet, S.IdentityFunction):
|
| 205 |
-
for func in copiers:
|
| 206 |
-
assert func(obj) is obj
|
| 207 |
-
|
| 208 |
-
#================== combinatorics ===================
|
| 209 |
-
from sympy.combinatorics.free_groups import FreeGroup
|
| 210 |
-
|
| 211 |
-
def test_free_group():
|
| 212 |
-
check(FreeGroup("x, y, z"), check_attr=False)
|
| 213 |
-
|
| 214 |
-
#================== functions ===================
|
| 215 |
-
from sympy.functions import (Piecewise, lowergamma, acosh, chebyshevu,
|
| 216 |
-
chebyshevt, ln, chebyshevt_root, legendre, Heaviside, bernoulli, coth,
|
| 217 |
-
tanh, assoc_legendre, sign, arg, asin, DiracDelta, re, rf, Abs,
|
| 218 |
-
uppergamma, binomial, sinh, cos, cot, acos, acot, gamma, bell,
|
| 219 |
-
hermite, harmonic, LambertW, zeta, log, factorial, asinh, acoth, cosh,
|
| 220 |
-
dirichlet_eta, Eijk, loggamma, erf, ceiling, im, fibonacci,
|
| 221 |
-
tribonacci, conjugate, tan, chebyshevu_root, floor, atanh, sqrt, sin,
|
| 222 |
-
atan, ff, lucas, atan2, polygamma, exp)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
def test_functions():
|
| 226 |
-
one_var = (acosh, ln, Heaviside, factorial, bernoulli, coth, tanh,
|
| 227 |
-
sign, arg, asin, DiracDelta, re, Abs, sinh, cos, cot, acos, acot,
|
| 228 |
-
gamma, bell, harmonic, LambertW, zeta, log, factorial, asinh,
|
| 229 |
-
acoth, cosh, dirichlet_eta, loggamma, erf, ceiling, im, fibonacci,
|
| 230 |
-
tribonacci, conjugate, tan, floor, atanh, sin, atan, lucas, exp)
|
| 231 |
-
two_var = (rf, ff, lowergamma, chebyshevu, chebyshevt, binomial,
|
| 232 |
-
atan2, polygamma, hermite, legendre, uppergamma)
|
| 233 |
-
x, y, z = symbols("x,y,z")
|
| 234 |
-
others = (chebyshevt_root, chebyshevu_root, Eijk(x, y, z),
|
| 235 |
-
Piecewise( (0, x < -1), (x**2, x <= 1), (x**3, True)),
|
| 236 |
-
assoc_legendre)
|
| 237 |
-
for cls in one_var:
|
| 238 |
-
check(cls)
|
| 239 |
-
c = cls(x)
|
| 240 |
-
check(c)
|
| 241 |
-
for cls in two_var:
|
| 242 |
-
check(cls)
|
| 243 |
-
c = cls(x, y)
|
| 244 |
-
check(c)
|
| 245 |
-
for cls in others:
|
| 246 |
-
check(cls)
|
| 247 |
-
|
| 248 |
-
#================== geometry ====================
|
| 249 |
-
from sympy.geometry.entity import GeometryEntity
|
| 250 |
-
from sympy.geometry.point import Point
|
| 251 |
-
from sympy.geometry.ellipse import Circle, Ellipse
|
| 252 |
-
from sympy.geometry.line import Line, LinearEntity, Ray, Segment
|
| 253 |
-
from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def test_geometry():
|
| 257 |
-
p1 = Point(1, 2)
|
| 258 |
-
p2 = Point(2, 3)
|
| 259 |
-
p3 = Point(0, 0)
|
| 260 |
-
p4 = Point(0, 1)
|
| 261 |
-
for c in (
|
| 262 |
-
GeometryEntity, GeometryEntity(), Point, p1, Circle, Circle(p1, 2),
|
| 263 |
-
Ellipse, Ellipse(p1, 3, 4), Line, Line(p1, p2), LinearEntity,
|
| 264 |
-
LinearEntity(p1, p2), Ray, Ray(p1, p2), Segment, Segment(p1, p2),
|
| 265 |
-
Polygon, Polygon(p1, p2, p3, p4), RegularPolygon,
|
| 266 |
-
RegularPolygon(p1, 4, 5), Triangle, Triangle(p1, p2, p3)):
|
| 267 |
-
check(c, check_attr=False)
|
| 268 |
-
|
| 269 |
-
#================== integrals ====================
|
| 270 |
-
from sympy.integrals.integrals import Integral
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
def test_integrals():
|
| 274 |
-
x = Symbol("x")
|
| 275 |
-
for c in (Integral, Integral(x)):
|
| 276 |
-
check(c)
|
| 277 |
-
|
| 278 |
-
#==================== logic =====================
|
| 279 |
-
from sympy.core.logic import Logic
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def test_logic():
|
| 283 |
-
for c in (Logic, Logic(1)):
|
| 284 |
-
check(c)
|
| 285 |
-
|
| 286 |
-
#================== matrices ====================
|
| 287 |
-
from sympy.matrices import Matrix, SparseMatrix
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
def test_matrices():
|
| 291 |
-
for c in (Matrix, Matrix([1, 2, 3]), SparseMatrix, SparseMatrix([[1, 2], [3, 4]])):
|
| 292 |
-
check(c, deprecated=['_smat', '_mat'])
|
| 293 |
-
|
| 294 |
-
#================== ntheory =====================
|
| 295 |
-
from sympy.ntheory.generate import Sieve
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def test_ntheory():
|
| 299 |
-
for c in (Sieve, Sieve()):
|
| 300 |
-
check(c)
|
| 301 |
-
|
| 302 |
-
#================== physics =====================
|
| 303 |
-
from sympy.physics.paulialgebra import Pauli
|
| 304 |
-
from sympy.physics.units import Unit
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
def test_physics():
|
| 308 |
-
for c in (Unit, meter, Pauli, Pauli(1)):
|
| 309 |
-
check(c)
|
| 310 |
-
|
| 311 |
-
#================== plotting ====================
|
| 312 |
-
# XXX: These tests are not complete, so XFAIL them
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
@XFAIL
|
| 316 |
-
def test_plotting():
|
| 317 |
-
from sympy.plotting.pygletplot.color_scheme import ColorGradient, ColorScheme
|
| 318 |
-
from sympy.plotting.pygletplot.managed_window import ManagedWindow
|
| 319 |
-
from sympy.plotting.plot import Plot, ScreenShot
|
| 320 |
-
from sympy.plotting.pygletplot.plot_axes import PlotAxes, PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
|
| 321 |
-
from sympy.plotting.pygletplot.plot_camera import PlotCamera
|
| 322 |
-
from sympy.plotting.pygletplot.plot_controller import PlotController
|
| 323 |
-
from sympy.plotting.pygletplot.plot_curve import PlotCurve
|
| 324 |
-
from sympy.plotting.pygletplot.plot_interval import PlotInterval
|
| 325 |
-
from sympy.plotting.pygletplot.plot_mode import PlotMode
|
| 326 |
-
from sympy.plotting.pygletplot.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \
|
| 327 |
-
ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
|
| 328 |
-
from sympy.plotting.pygletplot.plot_object import PlotObject
|
| 329 |
-
from sympy.plotting.pygletplot.plot_surface import PlotSurface
|
| 330 |
-
from sympy.plotting.pygletplot.plot_window import PlotWindow
|
| 331 |
-
for c in (
|
| 332 |
-
ColorGradient, ColorGradient(0.2, 0.4), ColorScheme, ManagedWindow,
|
| 333 |
-
ManagedWindow, Plot, ScreenShot, PlotAxes, PlotAxesBase,
|
| 334 |
-
PlotAxesFrame, PlotAxesOrdinate, PlotCamera, PlotController,
|
| 335 |
-
PlotCurve, PlotInterval, PlotMode, Cartesian2D, Cartesian3D,
|
| 336 |
-
Cylindrical, ParametricCurve2D, ParametricCurve3D,
|
| 337 |
-
ParametricSurface, Polar, Spherical, PlotObject, PlotSurface,
|
| 338 |
-
PlotWindow):
|
| 339 |
-
check(c)
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
@XFAIL
|
| 343 |
-
def test_plotting2():
|
| 344 |
-
#from sympy.plotting.color_scheme import ColorGradient
|
| 345 |
-
from sympy.plotting.pygletplot.color_scheme import ColorScheme
|
| 346 |
-
#from sympy.plotting.managed_window import ManagedWindow
|
| 347 |
-
from sympy.plotting.plot import Plot
|
| 348 |
-
#from sympy.plotting.plot import ScreenShot
|
| 349 |
-
from sympy.plotting.pygletplot.plot_axes import PlotAxes
|
| 350 |
-
#from sympy.plotting.plot_axes import PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
|
| 351 |
-
#from sympy.plotting.plot_camera import PlotCamera
|
| 352 |
-
#from sympy.plotting.plot_controller import PlotController
|
| 353 |
-
#from sympy.plotting.plot_curve import PlotCurve
|
| 354 |
-
#from sympy.plotting.plot_interval import PlotInterval
|
| 355 |
-
#from sympy.plotting.plot_mode import PlotMode
|
| 356 |
-
#from sympy.plotting.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \
|
| 357 |
-
# ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
|
| 358 |
-
#from sympy.plotting.plot_object import PlotObject
|
| 359 |
-
#from sympy.plotting.plot_surface import PlotSurface
|
| 360 |
-
# from sympy.plotting.plot_window import PlotWindow
|
| 361 |
-
check(ColorScheme("rainbow"))
|
| 362 |
-
check(Plot(1, visible=False))
|
| 363 |
-
check(PlotAxes())
|
| 364 |
-
|
| 365 |
-
#================== polys =======================
|
| 366 |
-
from sympy.polys.domains.integerring import ZZ
|
| 367 |
-
from sympy.polys.domains.rationalfield import QQ
|
| 368 |
-
from sympy.polys.orderings import lex
|
| 369 |
-
from sympy.polys.polytools import Poly
|
| 370 |
-
|
| 371 |
-
def test_pickling_polys_polytools():
|
| 372 |
-
from sympy.polys.polytools import PurePoly
|
| 373 |
-
# from sympy.polys.polytools import GroebnerBasis
|
| 374 |
-
x = Symbol('x')
|
| 375 |
-
|
| 376 |
-
for c in (Poly, Poly(x, x)):
|
| 377 |
-
check(c)
|
| 378 |
-
|
| 379 |
-
for c in (PurePoly, PurePoly(x)):
|
| 380 |
-
check(c)
|
| 381 |
-
|
| 382 |
-
# TODO: fix pickling of Options class (see GroebnerBasis._options)
|
| 383 |
-
# for c in (GroebnerBasis, GroebnerBasis([x**2 - 1], x, order=lex)):
|
| 384 |
-
# check(c)
|
| 385 |
-
|
| 386 |
-
def test_pickling_polys_polyclasses():
|
| 387 |
-
from sympy.polys.polyclasses import DMP, DMF, ANP
|
| 388 |
-
|
| 389 |
-
for c in (DMP, DMP([[ZZ(1)], [ZZ(2)], [ZZ(3)]], ZZ)):
|
| 390 |
-
check(c, deprecated=['rep'])
|
| 391 |
-
for c in (DMF, DMF(([ZZ(1), ZZ(2)], [ZZ(1), ZZ(3)]), ZZ)):
|
| 392 |
-
check(c)
|
| 393 |
-
for c in (ANP, ANP([QQ(1), QQ(2)], [QQ(1), QQ(2), QQ(3)], QQ)):
|
| 394 |
-
check(c)
|
| 395 |
-
|
| 396 |
-
@XFAIL
|
| 397 |
-
def test_pickling_polys_rings():
|
| 398 |
-
# NOTE: can't use protocols < 2 because we have to execute __new__ to
|
| 399 |
-
# make sure caching of rings works properly.
|
| 400 |
-
|
| 401 |
-
from sympy.polys.rings import PolyRing
|
| 402 |
-
|
| 403 |
-
ring = PolyRing("x,y,z", ZZ, lex)
|
| 404 |
-
|
| 405 |
-
for c in (PolyRing, ring):
|
| 406 |
-
check(c, exclude=[0, 1])
|
| 407 |
-
|
| 408 |
-
for c in (ring.dtype, ring.one):
|
| 409 |
-
check(c, exclude=[0, 1], check_attr=False) # TODO: Py3k
|
| 410 |
-
|
| 411 |
-
def test_pickling_polys_fields():
|
| 412 |
-
pass
|
| 413 |
-
# NOTE: can't use protocols < 2 because we have to execute __new__ to
|
| 414 |
-
# make sure caching of fields works properly.
|
| 415 |
-
|
| 416 |
-
# from sympy.polys.fields import FracField
|
| 417 |
-
|
| 418 |
-
# field = FracField("x,y,z", ZZ, lex)
|
| 419 |
-
|
| 420 |
-
# TODO: AssertionError: assert id(obj) not in self.memo
|
| 421 |
-
# for c in (FracField, field):
|
| 422 |
-
# check(c, exclude=[0, 1])
|
| 423 |
-
|
| 424 |
-
# TODO: AssertionError: assert id(obj) not in self.memo
|
| 425 |
-
# for c in (field.dtype, field.one):
|
| 426 |
-
# check(c, exclude=[0, 1])
|
| 427 |
-
|
| 428 |
-
def test_pickling_polys_elements():
|
| 429 |
-
from sympy.polys.domains.pythonrational import PythonRational
|
| 430 |
-
#from sympy.polys.domains.pythonfinitefield import PythonFiniteField
|
| 431 |
-
#from sympy.polys.domains.mpelements import MPContext
|
| 432 |
-
|
| 433 |
-
for c in (PythonRational, PythonRational(1, 7)):
|
| 434 |
-
check(c)
|
| 435 |
-
|
| 436 |
-
#gf = PythonFiniteField(17)
|
| 437 |
-
|
| 438 |
-
# TODO: fix pickling of ModularInteger
|
| 439 |
-
# for c in (gf.dtype, gf(5)):
|
| 440 |
-
# check(c)
|
| 441 |
-
|
| 442 |
-
#mp = MPContext()
|
| 443 |
-
|
| 444 |
-
# TODO: fix pickling of RealElement
|
| 445 |
-
# for c in (mp.mpf, mp.mpf(1.0)):
|
| 446 |
-
# check(c)
|
| 447 |
-
|
| 448 |
-
# TODO: fix pickling of ComplexElement
|
| 449 |
-
# for c in (mp.mpc, mp.mpc(1.0, -1.5)):
|
| 450 |
-
# check(c)
|
| 451 |
-
|
| 452 |
-
def test_pickling_polys_domains():
|
| 453 |
-
# from sympy.polys.domains.pythonfinitefield import PythonFiniteField
|
| 454 |
-
from sympy.polys.domains.pythonintegerring import PythonIntegerRing
|
| 455 |
-
from sympy.polys.domains.pythonrationalfield import PythonRationalField
|
| 456 |
-
|
| 457 |
-
# TODO: fix pickling of ModularInteger
|
| 458 |
-
# for c in (PythonFiniteField, PythonFiniteField(17)):
|
| 459 |
-
# check(c)
|
| 460 |
-
|
| 461 |
-
for c in (PythonIntegerRing, PythonIntegerRing()):
|
| 462 |
-
check(c, check_attr=False)
|
| 463 |
-
|
| 464 |
-
for c in (PythonRationalField, PythonRationalField()):
|
| 465 |
-
check(c, check_attr=False)
|
| 466 |
-
|
| 467 |
-
if _gmpy is not None:
|
| 468 |
-
# from sympy.polys.domains.gmpyfinitefield import GMPYFiniteField
|
| 469 |
-
from sympy.polys.domains.gmpyintegerring import GMPYIntegerRing
|
| 470 |
-
from sympy.polys.domains.gmpyrationalfield import GMPYRationalField
|
| 471 |
-
|
| 472 |
-
# TODO: fix pickling of ModularInteger
|
| 473 |
-
# for c in (GMPYFiniteField, GMPYFiniteField(17)):
|
| 474 |
-
# check(c)
|
| 475 |
-
|
| 476 |
-
for c in (GMPYIntegerRing, GMPYIntegerRing()):
|
| 477 |
-
check(c, check_attr=False)
|
| 478 |
-
|
| 479 |
-
for c in (GMPYRationalField, GMPYRationalField()):
|
| 480 |
-
check(c, check_attr=False)
|
| 481 |
-
|
| 482 |
-
#from sympy.polys.domains.realfield import RealField
|
| 483 |
-
#from sympy.polys.domains.complexfield import ComplexField
|
| 484 |
-
from sympy.polys.domains.algebraicfield import AlgebraicField
|
| 485 |
-
#from sympy.polys.domains.polynomialring import PolynomialRing
|
| 486 |
-
#from sympy.polys.domains.fractionfield import FractionField
|
| 487 |
-
from sympy.polys.domains.expressiondomain import ExpressionDomain
|
| 488 |
-
|
| 489 |
-
# TODO: fix pickling of RealElement
|
| 490 |
-
# for c in (RealField, RealField(100)):
|
| 491 |
-
# check(c)
|
| 492 |
-
|
| 493 |
-
# TODO: fix pickling of ComplexElement
|
| 494 |
-
# for c in (ComplexField, ComplexField(100)):
|
| 495 |
-
# check(c)
|
| 496 |
-
|
| 497 |
-
for c in (AlgebraicField, AlgebraicField(QQ, sqrt(3))):
|
| 498 |
-
check(c, check_attr=False)
|
| 499 |
-
|
| 500 |
-
# TODO: AssertionError
|
| 501 |
-
# for c in (PolynomialRing, PolynomialRing(ZZ, "x,y,z")):
|
| 502 |
-
# check(c)
|
| 503 |
-
|
| 504 |
-
# TODO: AttributeError: 'PolyElement' object has no attribute 'ring'
|
| 505 |
-
# for c in (FractionField, FractionField(ZZ, "x,y,z")):
|
| 506 |
-
# check(c)
|
| 507 |
-
|
| 508 |
-
for c in (ExpressionDomain, ExpressionDomain()):
|
| 509 |
-
check(c, check_attr=False)
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
def test_pickling_polys_orderings():
|
| 513 |
-
from sympy.polys.orderings import (LexOrder, GradedLexOrder,
|
| 514 |
-
ReversedGradedLexOrder, InverseOrder)
|
| 515 |
-
# from sympy.polys.orderings import ProductOrder
|
| 516 |
-
|
| 517 |
-
for c in (LexOrder, LexOrder()):
|
| 518 |
-
check(c)
|
| 519 |
-
|
| 520 |
-
for c in (GradedLexOrder, GradedLexOrder()):
|
| 521 |
-
check(c)
|
| 522 |
-
|
| 523 |
-
for c in (ReversedGradedLexOrder, ReversedGradedLexOrder()):
|
| 524 |
-
check(c)
|
| 525 |
-
|
| 526 |
-
# TODO: Argh, Python is so naive. No lambdas nor inner function support in
|
| 527 |
-
# pickling module. Maybe someone could figure out what to do with this.
|
| 528 |
-
#
|
| 529 |
-
# for c in (ProductOrder, ProductOrder((LexOrder(), lambda m: m[:2]),
|
| 530 |
-
# (GradedLexOrder(), lambda m: m[2:]))):
|
| 531 |
-
# check(c)
|
| 532 |
-
|
| 533 |
-
for c in (InverseOrder, InverseOrder(LexOrder())):
|
| 534 |
-
check(c)
|
| 535 |
-
|
| 536 |
-
def test_pickling_polys_monomials():
|
| 537 |
-
from sympy.polys.monomials import MonomialOps, Monomial
|
| 538 |
-
x, y, z = symbols("x,y,z")
|
| 539 |
-
|
| 540 |
-
for c in (MonomialOps, MonomialOps(3)):
|
| 541 |
-
check(c)
|
| 542 |
-
|
| 543 |
-
for c in (Monomial, Monomial((1, 2, 3), (x, y, z))):
|
| 544 |
-
check(c)
|
| 545 |
-
|
| 546 |
-
def test_pickling_polys_errors():
|
| 547 |
-
from sympy.polys.polyerrors import (HeuristicGCDFailed,
|
| 548 |
-
HomomorphismFailed, IsomorphismFailed, ExtraneousFactors,
|
| 549 |
-
EvaluationFailed, RefinementFailed, CoercionFailed, NotInvertible,
|
| 550 |
-
NotReversible, NotAlgebraic, DomainError, PolynomialError,
|
| 551 |
-
UnificationFailed, GeneratorsError, GeneratorsNeeded,
|
| 552 |
-
UnivariatePolynomialError, MultivariatePolynomialError, OptionError,
|
| 553 |
-
FlagError)
|
| 554 |
-
# from sympy.polys.polyerrors import (ExactQuotientFailed,
|
| 555 |
-
# OperationNotSupported, ComputationFailed, PolificationFailed)
|
| 556 |
-
|
| 557 |
-
# x = Symbol('x')
|
| 558 |
-
|
| 559 |
-
# TODO: TypeError: __init__() takes at least 3 arguments (1 given)
|
| 560 |
-
# for c in (ExactQuotientFailed, ExactQuotientFailed(x, 3*x, ZZ)):
|
| 561 |
-
# check(c)
|
| 562 |
-
|
| 563 |
-
# TODO: TypeError: can't pickle instancemethod objects
|
| 564 |
-
# for c in (OperationNotSupported, OperationNotSupported(Poly(x), Poly.gcd)):
|
| 565 |
-
# check(c)
|
| 566 |
-
|
| 567 |
-
for c in (HeuristicGCDFailed, HeuristicGCDFailed()):
|
| 568 |
-
check(c)
|
| 569 |
-
|
| 570 |
-
for c in (HomomorphismFailed, HomomorphismFailed()):
|
| 571 |
-
check(c)
|
| 572 |
-
|
| 573 |
-
for c in (IsomorphismFailed, IsomorphismFailed()):
|
| 574 |
-
check(c)
|
| 575 |
-
|
| 576 |
-
for c in (ExtraneousFactors, ExtraneousFactors()):
|
| 577 |
-
check(c)
|
| 578 |
-
|
| 579 |
-
for c in (EvaluationFailed, EvaluationFailed()):
|
| 580 |
-
check(c)
|
| 581 |
-
|
| 582 |
-
for c in (RefinementFailed, RefinementFailed()):
|
| 583 |
-
check(c)
|
| 584 |
-
|
| 585 |
-
for c in (CoercionFailed, CoercionFailed()):
|
| 586 |
-
check(c)
|
| 587 |
-
|
| 588 |
-
for c in (NotInvertible, NotInvertible()):
|
| 589 |
-
check(c)
|
| 590 |
-
|
| 591 |
-
for c in (NotReversible, NotReversible()):
|
| 592 |
-
check(c)
|
| 593 |
-
|
| 594 |
-
for c in (NotAlgebraic, NotAlgebraic()):
|
| 595 |
-
check(c)
|
| 596 |
-
|
| 597 |
-
for c in (DomainError, DomainError()):
|
| 598 |
-
check(c)
|
| 599 |
-
|
| 600 |
-
for c in (PolynomialError, PolynomialError()):
|
| 601 |
-
check(c)
|
| 602 |
-
|
| 603 |
-
for c in (UnificationFailed, UnificationFailed()):
|
| 604 |
-
check(c)
|
| 605 |
-
|
| 606 |
-
for c in (GeneratorsError, GeneratorsError()):
|
| 607 |
-
check(c)
|
| 608 |
-
|
| 609 |
-
for c in (GeneratorsNeeded, GeneratorsNeeded()):
|
| 610 |
-
check(c)
|
| 611 |
-
|
| 612 |
-
# TODO: PicklingError: Can't pickle <function <lambda> at 0x38578c0>: it's not found as __main__.<lambda>
|
| 613 |
-
# for c in (ComputationFailed, ComputationFailed(lambda t: t, 3, None)):
|
| 614 |
-
# check(c)
|
| 615 |
-
|
| 616 |
-
for c in (UnivariatePolynomialError, UnivariatePolynomialError()):
|
| 617 |
-
check(c)
|
| 618 |
-
|
| 619 |
-
for c in (MultivariatePolynomialError, MultivariatePolynomialError()):
|
| 620 |
-
check(c)
|
| 621 |
-
|
| 622 |
-
# TODO: TypeError: __init__() takes at least 3 arguments (1 given)
|
| 623 |
-
# for c in (PolificationFailed, PolificationFailed({}, x, x, False)):
|
| 624 |
-
# check(c)
|
| 625 |
-
|
| 626 |
-
for c in (OptionError, OptionError()):
|
| 627 |
-
check(c)
|
| 628 |
-
|
| 629 |
-
for c in (FlagError, FlagError()):
|
| 630 |
-
check(c)
|
| 631 |
-
|
| 632 |
-
#def test_pickling_polys_options():
|
| 633 |
-
#from sympy.polys.polyoptions import Options
|
| 634 |
-
|
| 635 |
-
# TODO: fix pickling of `symbols' flag
|
| 636 |
-
# for c in (Options, Options((), dict(domain='ZZ', polys=False))):
|
| 637 |
-
# check(c)
|
| 638 |
-
|
| 639 |
-
# TODO: def test_pickling_polys_rootisolation():
|
| 640 |
-
# RealInterval
|
| 641 |
-
# ComplexInterval
|
| 642 |
-
|
| 643 |
-
def test_pickling_polys_rootoftools():
|
| 644 |
-
from sympy.polys.rootoftools import CRootOf, RootSum
|
| 645 |
-
|
| 646 |
-
x = Symbol('x')
|
| 647 |
-
f = x**3 + x + 3
|
| 648 |
-
|
| 649 |
-
for c in (CRootOf, CRootOf(f, 0)):
|
| 650 |
-
check(c)
|
| 651 |
-
|
| 652 |
-
for c in (RootSum, RootSum(f, exp)):
|
| 653 |
-
check(c)
|
| 654 |
-
|
| 655 |
-
#================== printing ====================
|
| 656 |
-
from sympy.printing.latex import LatexPrinter
|
| 657 |
-
from sympy.printing.mathml import MathMLContentPrinter, MathMLPresentationPrinter
|
| 658 |
-
from sympy.printing.pretty.pretty import PrettyPrinter
|
| 659 |
-
from sympy.printing.pretty.stringpict import prettyForm, stringPict
|
| 660 |
-
from sympy.printing.printer import Printer
|
| 661 |
-
from sympy.printing.python import PythonPrinter
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
def test_printing():
|
| 665 |
-
for c in (LatexPrinter, LatexPrinter(), MathMLContentPrinter,
|
| 666 |
-
MathMLPresentationPrinter, PrettyPrinter, prettyForm, stringPict,
|
| 667 |
-
stringPict("a"), Printer, Printer(), PythonPrinter,
|
| 668 |
-
PythonPrinter()):
|
| 669 |
-
check(c)
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
@XFAIL
|
| 673 |
-
def test_printing1():
|
| 674 |
-
check(MathMLContentPrinter())
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
@XFAIL
|
| 678 |
-
def test_printing2():
|
| 679 |
-
check(MathMLPresentationPrinter())
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
@XFAIL
|
| 683 |
-
def test_printing3():
|
| 684 |
-
check(PrettyPrinter())
|
| 685 |
-
|
| 686 |
-
#================== series ======================
|
| 687 |
-
from sympy.series.limits import Limit
|
| 688 |
-
from sympy.series.order import Order
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
def test_series():
|
| 692 |
-
e = Symbol("e")
|
| 693 |
-
x = Symbol("x")
|
| 694 |
-
for c in (Limit, Limit(e, x, 1), Order, Order(e)):
|
| 695 |
-
check(c)
|
| 696 |
-
|
| 697 |
-
#================== concrete ==================
|
| 698 |
-
from sympy.concrete.products import Product
|
| 699 |
-
from sympy.concrete.summations import Sum
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
def test_concrete():
|
| 703 |
-
x = Symbol("x")
|
| 704 |
-
for c in (Product, Product(x, (x, 2, 4)), Sum, Sum(x, (x, 2, 4))):
|
| 705 |
-
check(c)
|
| 706 |
-
|
| 707 |
-
def test_deprecation_warning():
|
| 708 |
-
w = SymPyDeprecationWarning("message", deprecated_since_version='1.0', active_deprecations_target="active-deprecations")
|
| 709 |
-
check(w)
|
| 710 |
-
|
| 711 |
-
def test_issue_18438():
|
| 712 |
-
assert pickle.loads(pickle.dumps(S.Half)) == S.Half
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
#================= old pickles =================
|
| 716 |
-
def test_unpickle_from_older_versions():
|
| 717 |
-
data = (
|
| 718 |
-
b'\x80\x04\x95^\x00\x00\x00\x00\x00\x00\x00\x8c\x10sympy.core.power'
|
| 719 |
-
b'\x94\x8c\x03Pow\x94\x93\x94\x8c\x12sympy.core.numbers\x94\x8c'
|
| 720 |
-
b'\x07Integer\x94\x93\x94K\x02\x85\x94R\x94}\x94bh\x03\x8c\x04Half'
|
| 721 |
-
b'\x94\x93\x94)R\x94}\x94b\x86\x94R\x94}\x94b.'
|
| 722 |
-
)
|
| 723 |
-
assert pickle.loads(data) == sqrt(2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_source.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
from sympy.utilities.source import get_mod_func, get_class
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def test_get_mod_func():
|
| 5 |
-
assert get_mod_func(
|
| 6 |
-
'sympy.core.basic.Basic') == ('sympy.core.basic', 'Basic')
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def test_get_class():
|
| 10 |
-
_basic = get_class('sympy.core.basic.Basic')
|
| 11 |
-
assert _basic.__name__ == 'Basic'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|