MTerryJack commited on
Commit
140ec32
·
verified ·
1 Parent(s): a59738c

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. .venv/lib/python3.13/site-packages/sympy/testing/tests/__init__.py +0 -0
  2. .venv/lib/python3.13/site-packages/sympy/testing/tests/test_code_quality.py +510 -0
  3. .venv/lib/python3.13/site-packages/sympy/testing/tests/test_deprecated.py +5 -0
  4. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/__init__.py +122 -0
  5. .venv/lib/python3.13/site-packages/sympy/utilities/tests/__init__.py +0 -0
  6. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_autowrap.py +467 -0
  7. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_octave.py +589 -0
  8. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_rust.py +401 -0
  9. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_deprecated.py +13 -0
  10. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_exceptions.py +12 -0
  11. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_iterables.py +945 -0
  12. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_lambdify.py +2263 -0
  13. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_matchpy_connector.py +164 -0
  14. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_mathml.py +33 -0
  15. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_misc.py +148 -0
  16. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_pickling.py +723 -0
  17. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_source.py +11 -0
  18. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_timeutils.py +10 -0
  19. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_wester.py +3104 -0
  20. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_xxe.py +3 -0
.venv/lib/python3.13/site-packages/sympy/testing/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/testing/tests/test_code_quality.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from os import walk, sep, pardir
3
+ from os.path import split, join, abspath, exists, isfile
4
+ from glob import glob
5
+ import re
6
+ import random
7
+ import ast
8
+
9
+ from sympy.testing.pytest import raises
10
+ from sympy.testing.quality_unicode import _test_this_file_encoding
11
+
12
+ # System path separator (usually slash or backslash) to be
13
+ # used with excluded files, e.g.
14
+ # exclude = set([
15
+ # "%(sep)smpmath%(sep)s" % sepd,
16
+ # ])
17
+ sepd = {"sep": sep}
18
+
19
+ # path and sympy_path
20
+ SYMPY_PATH = abspath(join(split(__file__)[0], pardir, pardir)) # go to sympy/
21
+ assert exists(SYMPY_PATH)
22
+
23
+ TOP_PATH = abspath(join(SYMPY_PATH, pardir))
24
+ BIN_PATH = join(TOP_PATH, "bin")
25
+ EXAMPLES_PATH = join(TOP_PATH, "examples")
26
+
27
+ # Error messages
28
+ message_space = "File contains trailing whitespace: %s, line %s."
29
+ message_implicit = "File contains an implicit import: %s, line %s."
30
+ message_tabs = "File contains tabs instead of spaces: %s, line %s."
31
+ message_carriage = "File contains carriage returns at end of line: %s, line %s"
32
+ message_str_raise = "File contains string exception: %s, line %s"
33
+ message_gen_raise = "File contains generic exception: %s, line %s"
34
+ message_old_raise = "File contains old-style raise statement: %s, line %s, \"%s\""
35
+ message_eof = "File does not end with a newline: %s, line %s"
36
+ message_multi_eof = "File ends with more than 1 newline: %s, line %s"
37
+ message_test_suite_def = "Function should start with 'test_' or '_': %s, line %s"
38
+ message_duplicate_test = "This is a duplicate test function: %s, line %s"
39
+ message_self_assignments = "File contains assignments to self/cls: %s, line %s."
40
+ message_func_is = "File contains '.func is': %s, line %s."
41
+ message_bare_expr = "File contains bare expression: %s, line %s."
42
+
43
+ implicit_test_re = re.compile(r'^\s*(>>> )?(\.\.\. )?from .* import .*\*')
44
+ str_raise_re = re.compile(
45
+ r'^\s*(>>> )?(\.\.\. )?raise(\s+(\'|\")|\s*(\(\s*)+(\'|\"))')
46
+ gen_raise_re = re.compile(
47
+ r'^\s*(>>> )?(\.\.\. )?raise(\s+Exception|\s*(\(\s*)+Exception)')
48
+ old_raise_re = re.compile(r'^\s*(>>> )?(\.\.\. )?raise((\s*\(\s*)|\s+)\w+\s*,')
49
+ test_suite_def_re = re.compile(r'^def\s+(?!(_|test))[^(]*\(\s*\)\s*:$')
50
+ test_ok_def_re = re.compile(r'^def\s+test_.*:$')
51
+ test_file_re = re.compile(r'.*[/\\]test_.*\.py$')
52
+ func_is_re = re.compile(r'\.\s*func\s+is')
53
+
54
+
55
+ def tab_in_leading(s):
56
+ """Returns True if there are tabs in the leading whitespace of a line,
57
+ including the whitespace of docstring code samples."""
58
+ n = len(s) - len(s.lstrip())
59
+ if not s[n:n + 3] in ['...', '>>>']:
60
+ check = s[:n]
61
+ else:
62
+ smore = s[n + 3:]
63
+ check = s[:n] + smore[:len(smore) - len(smore.lstrip())]
64
+ return not (check.expandtabs() == check)
65
+
66
+
67
+ def find_self_assignments(s):
68
+ """Returns a list of "bad" assignments: if there are instances
69
+ of assigning to the first argument of the class method (except
70
+ for staticmethod's).
71
+ """
72
+ t = [n for n in ast.parse(s).body if isinstance(n, ast.ClassDef)]
73
+
74
+ bad = []
75
+ for c in t:
76
+ for n in c.body:
77
+ if not isinstance(n, ast.FunctionDef):
78
+ continue
79
+ if any(d.id == 'staticmethod'
80
+ for d in n.decorator_list if isinstance(d, ast.Name)):
81
+ continue
82
+ if n.name == '__new__':
83
+ continue
84
+ if not n.args.args:
85
+ continue
86
+ first_arg = n.args.args[0].arg
87
+
88
+ for m in ast.walk(n):
89
+ if isinstance(m, ast.Assign):
90
+ for a in m.targets:
91
+ if isinstance(a, ast.Name) and a.id == first_arg:
92
+ bad.append(m)
93
+ elif (isinstance(a, ast.Tuple) and
94
+ any(q.id == first_arg for q in a.elts
95
+ if isinstance(q, ast.Name))):
96
+ bad.append(m)
97
+
98
+ return bad
99
+
100
+
101
+ def check_directory_tree(base_path, file_check, exclusions=set(), pattern="*.py"):
102
+ """
103
+ Checks all files in the directory tree (with base_path as starting point)
104
+ with the file_check function provided, skipping files that contain
105
+ any of the strings in the set provided by exclusions.
106
+ """
107
+ if not base_path:
108
+ return
109
+ for root, dirs, files in walk(base_path):
110
+ check_files(glob(join(root, pattern)), file_check, exclusions)
111
+
112
+
113
+ def check_files(files, file_check, exclusions=set(), pattern=None):
114
+ """
115
+ Checks all files with the file_check function provided, skipping files
116
+ that contain any of the strings in the set provided by exclusions.
117
+ """
118
+ if not files:
119
+ return
120
+ for fname in files:
121
+ if not exists(fname) or not isfile(fname):
122
+ continue
123
+ if any(ex in fname for ex in exclusions):
124
+ continue
125
+ if pattern is None or re.match(pattern, fname):
126
+ file_check(fname)
127
+
128
+
129
+ class _Visit(ast.NodeVisitor):
130
+ """return the line number corresponding to the
131
+ line on which a bare expression appears if it is a binary op
132
+ or a comparison that is not in a with block.
133
+
134
+ EXAMPLES
135
+ ========
136
+
137
+ >>> import ast
138
+ >>> class _Visit(ast.NodeVisitor):
139
+ ... def visit_Expr(self, node):
140
+ ... if isinstance(node.value, (ast.BinOp, ast.Compare)):
141
+ ... print(node.lineno)
142
+ ... def visit_With(self, node):
143
+ ... pass # no checking there
144
+ ...
145
+ >>> code='''x = 1 # line 1
146
+ ... for i in range(3):
147
+ ... x == 2 # <-- 3
148
+ ... if x == 2:
149
+ ... x == 3 # <-- 5
150
+ ... x + 1 # <-- 6
151
+ ... x = 1
152
+ ... if x == 1:
153
+ ... print(1)
154
+ ... while x != 1:
155
+ ... x == 1 # <-- 11
156
+ ... with raises(TypeError):
157
+ ... c == 1
158
+ ... raise TypeError
159
+ ... assert x == 1
160
+ ... '''
161
+ >>> _Visit().visit(ast.parse(code))
162
+ 3
163
+ 5
164
+ 6
165
+ 11
166
+ """
167
+ def visit_Expr(self, node):
168
+ if isinstance(node.value, (ast.BinOp, ast.Compare)):
169
+ assert None, message_bare_expr % ('', node.lineno)
170
+ def visit_With(self, node):
171
+ pass
172
+
173
+
174
+ BareExpr = _Visit()
175
+
176
+
177
+ def line_with_bare_expr(code):
178
+ """return None or else 0-based line number of code on which
179
+ a bare expression appeared.
180
+ """
181
+ tree = ast.parse(code)
182
+ try:
183
+ BareExpr.visit(tree)
184
+ except AssertionError as msg:
185
+ assert msg.args
186
+ msg = msg.args[0]
187
+ assert msg.startswith(message_bare_expr.split(':', 1)[0])
188
+ return int(msg.rsplit(' ', 1)[1].rstrip('.')) # the line number
189
+
190
+
191
+ def test_files():
192
+ """
193
+ This test tests all files in SymPy and checks that:
194
+ o no lines contains a trailing whitespace
195
+ o no lines end with \r\n
196
+ o no line uses tabs instead of spaces
197
+ o that the file ends with a single newline
198
+ o there are no general or string exceptions
199
+ o there are no old style raise statements
200
+ o name of arg-less test suite functions start with _ or test_
201
+ o no duplicate function names that start with test_
202
+ o no assignments to self variable in class methods
203
+ o no lines contain ".func is" except in the test suite
204
+ o there is no do-nothing expression like `a == b` or `x + 1`
205
+ """
206
+
207
+ def test(fname):
208
+ with open(fname, encoding="utf8") as test_file:
209
+ test_this_file(fname, test_file)
210
+ with open(fname, encoding='utf8') as test_file:
211
+ _test_this_file_encoding(fname, test_file)
212
+
213
+ def test_this_file(fname, test_file):
214
+ idx = None
215
+ code = test_file.read()
216
+ test_file.seek(0) # restore reader to head
217
+ py = fname if sep not in fname else fname.rsplit(sep, 1)[-1]
218
+ if py.startswith('test_'):
219
+ idx = line_with_bare_expr(code)
220
+ if idx is not None:
221
+ assert False, message_bare_expr % (fname, idx + 1)
222
+
223
+ line = None # to flag the case where there were no lines in file
224
+ tests = 0
225
+ test_set = set()
226
+ for idx, line in enumerate(test_file):
227
+ if test_file_re.match(fname):
228
+ if test_suite_def_re.match(line):
229
+ assert False, message_test_suite_def % (fname, idx + 1)
230
+ if test_ok_def_re.match(line):
231
+ tests += 1
232
+ test_set.add(line[3:].split('(')[0].strip())
233
+ if len(test_set) != tests:
234
+ assert False, message_duplicate_test % (fname, idx + 1)
235
+ if line.endswith((" \n", "\t\n")):
236
+ assert False, message_space % (fname, idx + 1)
237
+ if line.endswith("\r\n"):
238
+ assert False, message_carriage % (fname, idx + 1)
239
+ if tab_in_leading(line):
240
+ assert False, message_tabs % (fname, idx + 1)
241
+ if str_raise_re.search(line):
242
+ assert False, message_str_raise % (fname, idx + 1)
243
+ if gen_raise_re.search(line):
244
+ assert False, message_gen_raise % (fname, idx + 1)
245
+ if (implicit_test_re.search(line) and
246
+ not list(filter(lambda ex: ex in fname, import_exclude))):
247
+ assert False, message_implicit % (fname, idx + 1)
248
+ if func_is_re.search(line) and not test_file_re.search(fname):
249
+ assert False, message_func_is % (fname, idx + 1)
250
+
251
+ result = old_raise_re.search(line)
252
+
253
+ if result is not None:
254
+ assert False, message_old_raise % (
255
+ fname, idx + 1, result.group(2))
256
+
257
+ if line is not None:
258
+ if line == '\n' and idx > 0:
259
+ assert False, message_multi_eof % (fname, idx + 1)
260
+ elif not line.endswith('\n'):
261
+ # eof newline check
262
+ assert False, message_eof % (fname, idx + 1)
263
+
264
+
265
+ # Files to test at top level
266
+ top_level_files = [join(TOP_PATH, file) for file in [
267
+ "isympy.py",
268
+ "build.py",
269
+ "setup.py",
270
+ ]]
271
+ # Files to exclude from all tests
272
+ exclude = {
273
+ "%(sep)ssympy%(sep)sparsing%(sep)sautolev%(sep)s_antlr%(sep)sautolevparser.py" % sepd,
274
+ "%(sep)ssympy%(sep)sparsing%(sep)sautolev%(sep)s_antlr%(sep)sautolevlexer.py" % sepd,
275
+ "%(sep)ssympy%(sep)sparsing%(sep)sautolev%(sep)s_antlr%(sep)sautolevlistener.py" % sepd,
276
+ "%(sep)ssympy%(sep)sparsing%(sep)slatex%(sep)s_antlr%(sep)slatexparser.py" % sepd,
277
+ "%(sep)ssympy%(sep)sparsing%(sep)slatex%(sep)s_antlr%(sep)slatexlexer.py" % sepd,
278
+ }
279
+ # Files to exclude from the implicit import test
280
+ import_exclude = {
281
+ # glob imports are allowed in top-level __init__.py:
282
+ "%(sep)ssympy%(sep)s__init__.py" % sepd,
283
+ # these __init__.py should be fixed:
284
+ # XXX: not really, they use useful import pattern (DRY)
285
+ "%(sep)svector%(sep)s__init__.py" % sepd,
286
+ "%(sep)smechanics%(sep)s__init__.py" % sepd,
287
+ "%(sep)squantum%(sep)s__init__.py" % sepd,
288
+ "%(sep)spolys%(sep)s__init__.py" % sepd,
289
+ "%(sep)spolys%(sep)sdomains%(sep)s__init__.py" % sepd,
290
+ # interactive SymPy executes ``from sympy import *``:
291
+ "%(sep)sinteractive%(sep)ssession.py" % sepd,
292
+ # isympy.py executes ``from sympy import *``:
293
+ "%(sep)sisympy.py" % sepd,
294
+ # these two are import timing tests:
295
+ "%(sep)sbin%(sep)ssympy_time.py" % sepd,
296
+ "%(sep)sbin%(sep)ssympy_time_cache.py" % sepd,
297
+ # Taken from Python stdlib:
298
+ "%(sep)sparsing%(sep)ssympy_tokenize.py" % sepd,
299
+ # this one should be fixed:
300
+ "%(sep)splotting%(sep)spygletplot%(sep)s" % sepd,
301
+ # False positive in the docstring
302
+ "%(sep)sbin%(sep)stest_external_imports.py" % sepd,
303
+ "%(sep)sbin%(sep)stest_submodule_imports.py" % sepd,
304
+ # These are deprecated stubs that can be removed at some point:
305
+ "%(sep)sutilities%(sep)sruntests.py" % sepd,
306
+ "%(sep)sutilities%(sep)spytest.py" % sepd,
307
+ "%(sep)sutilities%(sep)srandtest.py" % sepd,
308
+ "%(sep)sutilities%(sep)stmpfiles.py" % sepd,
309
+ "%(sep)sutilities%(sep)squality_unicode.py" % sepd,
310
+ }
311
+ check_files(top_level_files, test)
312
+ check_directory_tree(BIN_PATH, test, {"~", ".pyc", ".sh"}, "*")
313
+ check_directory_tree(SYMPY_PATH, test, exclude)
314
+ check_directory_tree(EXAMPLES_PATH, test, exclude)
315
+
316
+
317
+ def _with_space(c):
318
+ # return c with a random amount of leading space
319
+ return random.randint(0, 10)*' ' + c
320
+
321
+
322
+ def test_raise_statement_regular_expression():
323
+ candidates_ok = [
324
+ "some text # raise Exception, 'text'",
325
+ "raise ValueError('text') # raise Exception, 'text'",
326
+ "raise ValueError('text')",
327
+ "raise ValueError",
328
+ "raise ValueError('text')",
329
+ "raise ValueError('text') #,",
330
+ # Talking about an exception in a docstring
331
+ ''''"""This function will raise ValueError, except when it doesn't"""''',
332
+ "raise (ValueError('text')",
333
+ ]
334
+ str_candidates_fail = [
335
+ "raise 'exception'",
336
+ "raise 'Exception'",
337
+ 'raise "exception"',
338
+ 'raise "Exception"',
339
+ "raise 'ValueError'",
340
+ ]
341
+ gen_candidates_fail = [
342
+ "raise Exception('text') # raise Exception, 'text'",
343
+ "raise Exception('text')",
344
+ "raise Exception",
345
+ "raise Exception('text')",
346
+ "raise Exception('text') #,",
347
+ "raise Exception, 'text'",
348
+ "raise Exception, 'text' # raise Exception('text')",
349
+ "raise Exception, 'text' # raise Exception, 'text'",
350
+ ">>> raise Exception, 'text'",
351
+ ">>> raise Exception, 'text' # raise Exception('text')",
352
+ ">>> raise Exception, 'text' # raise Exception, 'text'",
353
+ ]
354
+ old_candidates_fail = [
355
+ "raise Exception, 'text'",
356
+ "raise Exception, 'text' # raise Exception('text')",
357
+ "raise Exception, 'text' # raise Exception, 'text'",
358
+ ">>> raise Exception, 'text'",
359
+ ">>> raise Exception, 'text' # raise Exception('text')",
360
+ ">>> raise Exception, 'text' # raise Exception, 'text'",
361
+ "raise ValueError, 'text'",
362
+ "raise ValueError, 'text' # raise Exception('text')",
363
+ "raise ValueError, 'text' # raise Exception, 'text'",
364
+ ">>> raise ValueError, 'text'",
365
+ ">>> raise ValueError, 'text' # raise Exception('text')",
366
+ ">>> raise ValueError, 'text' # raise Exception, 'text'",
367
+ "raise(ValueError,",
368
+ "raise (ValueError,",
369
+ "raise( ValueError,",
370
+ "raise ( ValueError,",
371
+ "raise(ValueError ,",
372
+ "raise (ValueError ,",
373
+ "raise( ValueError ,",
374
+ "raise ( ValueError ,",
375
+ ]
376
+
377
+ for c in candidates_ok:
378
+ assert str_raise_re.search(_with_space(c)) is None, c
379
+ assert gen_raise_re.search(_with_space(c)) is None, c
380
+ assert old_raise_re.search(_with_space(c)) is None, c
381
+ for c in str_candidates_fail:
382
+ assert str_raise_re.search(_with_space(c)) is not None, c
383
+ for c in gen_candidates_fail:
384
+ assert gen_raise_re.search(_with_space(c)) is not None, c
385
+ for c in old_candidates_fail:
386
+ assert old_raise_re.search(_with_space(c)) is not None, c
387
+
388
+
389
+ def test_implicit_imports_regular_expression():
390
+ candidates_ok = [
391
+ "from sympy import something",
392
+ ">>> from sympy import something",
393
+ "from sympy.somewhere import something",
394
+ ">>> from sympy.somewhere import something",
395
+ "import sympy",
396
+ ">>> import sympy",
397
+ "import sympy.something.something",
398
+ "... import sympy",
399
+ "... import sympy.something.something",
400
+ "... from sympy import something",
401
+ "... from sympy.somewhere import something",
402
+ ">> from sympy import *", # To allow 'fake' docstrings
403
+ "# from sympy import *",
404
+ "some text # from sympy import *",
405
+ ]
406
+ candidates_fail = [
407
+ "from sympy import *",
408
+ ">>> from sympy import *",
409
+ "from sympy.somewhere import *",
410
+ ">>> from sympy.somewhere import *",
411
+ "... from sympy import *",
412
+ "... from sympy.somewhere import *",
413
+ ]
414
+ for c in candidates_ok:
415
+ assert implicit_test_re.search(_with_space(c)) is None, c
416
+ for c in candidates_fail:
417
+ assert implicit_test_re.search(_with_space(c)) is not None, c
418
+
419
+
420
+ def test_test_suite_defs():
421
+ candidates_ok = [
422
+ " def foo():\n",
423
+ "def foo(arg):\n",
424
+ "def _foo():\n",
425
+ "def test_foo():\n",
426
+ ]
427
+ candidates_fail = [
428
+ "def foo():\n",
429
+ "def foo() :\n",
430
+ "def foo( ):\n",
431
+ "def foo():\n",
432
+ ]
433
+ for c in candidates_ok:
434
+ assert test_suite_def_re.search(c) is None, c
435
+ for c in candidates_fail:
436
+ assert test_suite_def_re.search(c) is not None, c
437
+
438
+
439
+ def test_test_duplicate_defs():
440
+ candidates_ok = [
441
+ "def foo():\ndef foo():\n",
442
+ "def test():\ndef test_():\n",
443
+ "def test_():\ndef test__():\n",
444
+ ]
445
+ candidates_fail = [
446
+ "def test_():\ndef test_ ():\n",
447
+ "def test_1():\ndef test_1():\n",
448
+ ]
449
+ ok = (None, 'check')
450
+ def check(file):
451
+ tests = 0
452
+ test_set = set()
453
+ for idx, line in enumerate(file.splitlines()):
454
+ if test_ok_def_re.match(line):
455
+ tests += 1
456
+ test_set.add(line[3:].split('(')[0].strip())
457
+ if len(test_set) != tests:
458
+ return False, message_duplicate_test % ('check', idx + 1)
459
+ return None, 'check'
460
+ for c in candidates_ok:
461
+ assert check(c) == ok
462
+ for c in candidates_fail:
463
+ assert check(c) != ok
464
+
465
+
466
+ def test_find_self_assignments():
467
+ candidates_ok = [
468
+ "class A(object):\n def foo(self, arg): arg = self\n",
469
+ "class A(object):\n def foo(self, arg): self.prop = arg\n",
470
+ "class A(object):\n def foo(self, arg): obj, obj2 = arg, self\n",
471
+ "class A(object):\n @classmethod\n def bar(cls, arg): arg = cls\n",
472
+ "class A(object):\n def foo(var, arg): arg = var\n",
473
+ ]
474
+ candidates_fail = [
475
+ "class A(object):\n def foo(self, arg): self = arg\n",
476
+ "class A(object):\n def foo(self, arg): obj, self = arg, arg\n",
477
+ "class A(object):\n def foo(self, arg):\n if arg: self = arg",
478
+ "class A(object):\n @classmethod\n def foo(cls, arg): cls = arg\n",
479
+ "class A(object):\n def foo(var, arg): var = arg\n",
480
+ ]
481
+
482
+ for c in candidates_ok:
483
+ assert find_self_assignments(c) == []
484
+ for c in candidates_fail:
485
+ assert find_self_assignments(c) != []
486
+
487
+
488
+ def test_test_unicode_encoding():
489
+ unicode_whitelist = ['foo']
490
+ unicode_strict_whitelist = ['bar']
491
+
492
+ fname = 'abc'
493
+ test_file = ['α']
494
+ raises(AssertionError, lambda: _test_this_file_encoding(
495
+ fname, test_file, unicode_whitelist, unicode_strict_whitelist))
496
+
497
+ fname = 'abc'
498
+ test_file = ['abc']
499
+ _test_this_file_encoding(
500
+ fname, test_file, unicode_whitelist, unicode_strict_whitelist)
501
+
502
+ fname = 'foo'
503
+ test_file = ['abc']
504
+ raises(AssertionError, lambda: _test_this_file_encoding(
505
+ fname, test_file, unicode_whitelist, unicode_strict_whitelist))
506
+
507
+ fname = 'bar'
508
+ test_file = ['abc']
509
+ _test_this_file_encoding(
510
+ fname, test_file, unicode_whitelist, unicode_strict_whitelist)
.venv/lib/python3.13/site-packages/sympy/testing/tests/test_deprecated.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from sympy.testing.pytest import warns_deprecated_sympy
2
+
3
+ def test_deprecated_testing_randtest():
4
+ with warns_deprecated_sympy():
5
+ import sympy.testing.randtest # noqa:F401
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/__init__.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module with some functions for MathML, like transforming MathML
2
+ content in MathML presentation.
3
+
4
+ To use this module, you will need lxml.
5
+ """
6
+
7
+ from pathlib import Path
8
+
9
+ from sympy.utilities.decorator import doctest_depends_on
10
+
11
+
12
+ __doctest_requires__ = {('apply_xsl', 'c2p'): ['lxml']}
13
+
14
+
15
+ def add_mathml_headers(s):
16
+ return """<math xmlns:mml="http://www.w3.org/1998/Math/MathML"
17
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
18
+ xsi:schemaLocation="http://www.w3.org/1998/Math/MathML
19
+ http://www.w3.org/Math/XMLSchema/mathml2/mathml2.xsd">""" + s + "</math>"
20
+
21
+
22
+ def _read_binary(pkgname, filename):
23
+ import sys
24
+
25
+ if sys.version_info >= (3, 10):
26
+ # files was added in Python 3.9 but only seems to work here in 3.10+
27
+ from importlib.resources import files
28
+ return files(pkgname).joinpath(filename).read_bytes()
29
+ else:
30
+ # read_binary was deprecated in Python 3.11
31
+ from importlib.resources import read_binary
32
+ return read_binary(pkgname, filename)
33
+
34
+
35
+ def _read_xsl(xsl):
36
+ # Previously these values were allowed:
37
+ if xsl == 'mathml/data/simple_mmlctop.xsl':
38
+ xsl = 'simple_mmlctop.xsl'
39
+ elif xsl == 'mathml/data/mmlctop.xsl':
40
+ xsl = 'mmlctop.xsl'
41
+ elif xsl == 'mathml/data/mmltex.xsl':
42
+ xsl = 'mmltex.xsl'
43
+
44
+ if xsl in ['simple_mmlctop.xsl', 'mmlctop.xsl', 'mmltex.xsl']:
45
+ xslbytes = _read_binary('sympy.utilities.mathml.data', xsl)
46
+ else:
47
+ xslbytes = Path(xsl).read_bytes()
48
+
49
+ return xslbytes
50
+
51
+
52
+ @doctest_depends_on(modules=('lxml',))
53
+ def apply_xsl(mml, xsl):
54
+ """Apply a xsl to a MathML string.
55
+
56
+ Parameters
57
+ ==========
58
+
59
+ mml
60
+ A string with MathML code.
61
+ xsl
62
+ A string giving the name of an xsl (xml stylesheet) file which can be
63
+ found in sympy/utilities/mathml/data. The following files are supplied
64
+ with SymPy:
65
+
66
+ - mmlctop.xsl
67
+ - mmltex.xsl
68
+ - simple_mmlctop.xsl
69
+
70
+ Alternatively, a full path to an xsl file can be given.
71
+
72
+ Examples
73
+ ========
74
+
75
+ >>> from sympy.utilities.mathml import apply_xsl
76
+ >>> xsl = 'simple_mmlctop.xsl'
77
+ >>> mml = '<apply> <plus/> <ci>a</ci> <ci>b</ci> </apply>'
78
+ >>> res = apply_xsl(mml,xsl)
79
+ >>> print(res)
80
+ <?xml version="1.0"?>
81
+ <mrow xmlns="http://www.w3.org/1998/Math/MathML">
82
+ <mi>a</mi>
83
+ <mo> + </mo>
84
+ <mi>b</mi>
85
+ </mrow>
86
+ """
87
+ from lxml import etree
88
+
89
+ parser = etree.XMLParser(resolve_entities=False)
90
+ ac = etree.XSLTAccessControl.DENY_ALL
91
+
92
+ s = etree.XML(_read_xsl(xsl), parser=parser)
93
+ transform = etree.XSLT(s, access_control=ac)
94
+ doc = etree.XML(mml, parser=parser)
95
+ result = transform(doc)
96
+ s = str(result)
97
+ return s
98
+
99
+
100
+ @doctest_depends_on(modules=('lxml',))
101
+ def c2p(mml, simple=False):
102
+ """Transforms a document in MathML content (like the one that sympy produces)
103
+ in one document in MathML presentation, more suitable for printing, and more
104
+ widely accepted
105
+
106
+ Examples
107
+ ========
108
+
109
+ >>> from sympy.utilities.mathml import c2p
110
+ >>> mml = '<apply> <exp/> <cn>2</cn> </apply>'
111
+ >>> c2p(mml,simple=True) != c2p(mml,simple=False)
112
+ True
113
+
114
+ """
115
+
116
+ if not mml.startswith('<math'):
117
+ mml = add_mathml_headers(mml)
118
+
119
+ if simple:
120
+ return apply_xsl(mml, 'mathml/data/simple_mmlctop.xsl')
121
+
122
+ return apply_xsl(mml, 'mathml/data/mmlctop.xsl')
.venv/lib/python3.13/site-packages/sympy/utilities/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_autowrap.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tests that require installed backends go into
2
+ # sympy/test_external/test_autowrap
3
+
4
+ import os
5
+ import tempfile
6
+ import shutil
7
+ from io import StringIO
8
+ from pathlib import Path
9
+
10
+ from sympy.core import symbols, Eq
11
+ from sympy.utilities.autowrap import (autowrap, binary_function,
12
+ CythonCodeWrapper, UfuncifyCodeWrapper, CodeWrapper)
13
+ from sympy.utilities.codegen import (
14
+ CCodeGen, C99CodeGen, CodeGenArgumentListError, make_routine
15
+ )
16
+ from sympy.testing.pytest import raises
17
+ from sympy.testing.tmpfiles import TmpFileManager
18
+
19
+
20
+ def get_string(dump_fn, routines, prefix="file", **kwargs):
21
+ """Wrapper for dump_fn. dump_fn writes its results to a stream object and
22
+ this wrapper returns the contents of that stream as a string. This
23
+ auxiliary function is used by many tests below.
24
+
25
+ The header and the empty lines are not generator to facilitate the
26
+ testing of the output.
27
+ """
28
+ output = StringIO()
29
+ dump_fn(routines, output, prefix, **kwargs)
30
+ source = output.getvalue()
31
+ output.close()
32
+ return source
33
+
34
+
35
+ def test_cython_wrapper_scalar_function():
36
+ x, y, z = symbols('x,y,z')
37
+ expr = (x + y)*z
38
+ routine = make_routine("test", expr)
39
+ code_gen = CythonCodeWrapper(CCodeGen())
40
+ source = get_string(code_gen.dump_pyx, [routine])
41
+
42
+ expected = (
43
+ "cdef extern from 'file.h':\n"
44
+ " double test(double x, double y, double z)\n"
45
+ "\n"
46
+ "def test_c(double x, double y, double z):\n"
47
+ "\n"
48
+ " return test(x, y, z)")
49
+ assert source == expected
50
+
51
+
52
+ def test_cython_wrapper_outarg():
53
+ from sympy.core.relational import Equality
54
+ x, y, z = symbols('x,y,z')
55
+ code_gen = CythonCodeWrapper(C99CodeGen())
56
+
57
+ routine = make_routine("test", Equality(z, x + y))
58
+ source = get_string(code_gen.dump_pyx, [routine])
59
+ expected = (
60
+ "cdef extern from 'file.h':\n"
61
+ " void test(double x, double y, double *z)\n"
62
+ "\n"
63
+ "def test_c(double x, double y):\n"
64
+ "\n"
65
+ " cdef double z = 0\n"
66
+ " test(x, y, &z)\n"
67
+ " return z")
68
+ assert source == expected
69
+
70
+
71
+ def test_cython_wrapper_inoutarg():
72
+ from sympy.core.relational import Equality
73
+ x, y, z = symbols('x,y,z')
74
+ code_gen = CythonCodeWrapper(C99CodeGen())
75
+ routine = make_routine("test", Equality(z, x + y + z))
76
+ source = get_string(code_gen.dump_pyx, [routine])
77
+ expected = (
78
+ "cdef extern from 'file.h':\n"
79
+ " void test(double x, double y, double *z)\n"
80
+ "\n"
81
+ "def test_c(double x, double y, double z):\n"
82
+ "\n"
83
+ " test(x, y, &z)\n"
84
+ " return z")
85
+ assert source == expected
86
+
87
+
88
+ def test_cython_wrapper_compile_flags():
89
+ from sympy.core.relational import Equality
90
+ x, y, z = symbols('x,y,z')
91
+ routine = make_routine("test", Equality(z, x + y))
92
+
93
+ code_gen = CythonCodeWrapper(CCodeGen())
94
+
95
+ expected = """\
96
+ from setuptools import setup
97
+ from setuptools import Extension
98
+ from Cython.Build import cythonize
99
+ cy_opts = {'compiler_directives': {'language_level': '3'}}
100
+
101
+ ext_mods = [Extension(
102
+ 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
103
+ include_dirs=[],
104
+ library_dirs=[],
105
+ libraries=[],
106
+ extra_compile_args=['-std=c99'],
107
+ extra_link_args=[]
108
+ )]
109
+ setup(ext_modules=cythonize(ext_mods, **cy_opts))
110
+ """ % {'num': CodeWrapper._module_counter}
111
+
112
+ temp_dir = tempfile.mkdtemp()
113
+ TmpFileManager.tmp_folder(temp_dir)
114
+ setup_file_path = os.path.join(temp_dir, 'setup.py')
115
+
116
+ code_gen._prepare_files(routine, build_dir=temp_dir)
117
+ setup_text = Path(setup_file_path).read_text()
118
+ assert setup_text == expected
119
+
120
+ code_gen = CythonCodeWrapper(CCodeGen(),
121
+ include_dirs=['/usr/local/include', '/opt/booger/include'],
122
+ library_dirs=['/user/local/lib'],
123
+ libraries=['thelib', 'nilib'],
124
+ extra_compile_args=['-slow-math'],
125
+ extra_link_args=['-lswamp', '-ltrident'],
126
+ cythonize_options={'compiler_directives': {'boundscheck': False}}
127
+ )
128
+ expected = """\
129
+ from setuptools import setup
130
+ from setuptools import Extension
131
+ from Cython.Build import cythonize
132
+ cy_opts = {'compiler_directives': {'boundscheck': False}}
133
+
134
+ ext_mods = [Extension(
135
+ 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
136
+ include_dirs=['/usr/local/include', '/opt/booger/include'],
137
+ library_dirs=['/user/local/lib'],
138
+ libraries=['thelib', 'nilib'],
139
+ extra_compile_args=['-slow-math', '-std=c99'],
140
+ extra_link_args=['-lswamp', '-ltrident']
141
+ )]
142
+ setup(ext_modules=cythonize(ext_mods, **cy_opts))
143
+ """ % {'num': CodeWrapper._module_counter}
144
+
145
+ code_gen._prepare_files(routine, build_dir=temp_dir)
146
+ setup_text = Path(setup_file_path).read_text()
147
+ assert setup_text == expected
148
+
149
+ expected = """\
150
+ from setuptools import setup
151
+ from setuptools import Extension
152
+ from Cython.Build import cythonize
153
+ cy_opts = {'compiler_directives': {'boundscheck': False}}
154
+ import numpy as np
155
+
156
+ ext_mods = [Extension(
157
+ 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
158
+ include_dirs=['/usr/local/include', '/opt/booger/include', np.get_include()],
159
+ library_dirs=['/user/local/lib'],
160
+ libraries=['thelib', 'nilib'],
161
+ extra_compile_args=['-slow-math', '-std=c99'],
162
+ extra_link_args=['-lswamp', '-ltrident']
163
+ )]
164
+ setup(ext_modules=cythonize(ext_mods, **cy_opts))
165
+ """ % {'num': CodeWrapper._module_counter}
166
+
167
+ code_gen._need_numpy = True
168
+ code_gen._prepare_files(routine, build_dir=temp_dir)
169
+ setup_text = Path(setup_file_path).read_text()
170
+ assert setup_text == expected
171
+
172
+ TmpFileManager.cleanup()
173
+
174
+ def test_cython_wrapper_unique_dummyvars():
175
+ from sympy.core.relational import Equality
176
+ from sympy.core.symbol import Dummy
177
+ x, y, z = Dummy('x'), Dummy('y'), Dummy('z')
178
+ x_id, y_id, z_id = [str(d.dummy_index) for d in [x, y, z]]
179
+ expr = Equality(z, x + y)
180
+ routine = make_routine("test", expr)
181
+ code_gen = CythonCodeWrapper(CCodeGen())
182
+ source = get_string(code_gen.dump_pyx, [routine])
183
+ expected_template = (
184
+ "cdef extern from 'file.h':\n"
185
+ " void test(double x_{x_id}, double y_{y_id}, double *z_{z_id})\n"
186
+ "\n"
187
+ "def test_c(double x_{x_id}, double y_{y_id}):\n"
188
+ "\n"
189
+ " cdef double z_{z_id} = 0\n"
190
+ " test(x_{x_id}, y_{y_id}, &z_{z_id})\n"
191
+ " return z_{z_id}")
192
+ expected = expected_template.format(x_id=x_id, y_id=y_id, z_id=z_id)
193
+ assert source == expected
194
+
195
+ def test_autowrap_dummy():
196
+ x, y, z = symbols('x y z')
197
+
198
+ # Uses DummyWrapper to test that codegen works as expected
199
+
200
+ f = autowrap(x + y, backend='dummy')
201
+ assert f() == str(x + y)
202
+ assert f.args == "x, y"
203
+ assert f.returns == "nameless"
204
+ f = autowrap(Eq(z, x + y), backend='dummy')
205
+ assert f() == str(x + y)
206
+ assert f.args == "x, y"
207
+ assert f.returns == "z"
208
+ f = autowrap(Eq(z, x + y + z), backend='dummy')
209
+ assert f() == str(x + y + z)
210
+ assert f.args == "x, y, z"
211
+ assert f.returns == "z"
212
+
213
+
214
+ def test_autowrap_args():
215
+ x, y, z = symbols('x y z')
216
+
217
+ raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y),
218
+ backend='dummy', args=[x]))
219
+ f = autowrap(Eq(z, x + y), backend='dummy', args=[y, x])
220
+ assert f() == str(x + y)
221
+ assert f.args == "y, x"
222
+ assert f.returns == "z"
223
+
224
+ raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y + z),
225
+ backend='dummy', args=[x, y]))
226
+ f = autowrap(Eq(z, x + y + z), backend='dummy', args=[y, x, z])
227
+ assert f() == str(x + y + z)
228
+ assert f.args == "y, x, z"
229
+ assert f.returns == "z"
230
+
231
+ f = autowrap(Eq(z, x + y + z), backend='dummy', args=(y, x, z))
232
+ assert f() == str(x + y + z)
233
+ assert f.args == "y, x, z"
234
+ assert f.returns == "z"
235
+
236
+ def test_autowrap_store_files():
237
+ x, y = symbols('x y')
238
+ tmp = tempfile.mkdtemp()
239
+ TmpFileManager.tmp_folder(tmp)
240
+
241
+ f = autowrap(x + y, backend='dummy', tempdir=tmp)
242
+ assert f() == str(x + y)
243
+ assert os.access(tmp, os.F_OK)
244
+
245
+ TmpFileManager.cleanup()
246
+
247
+ def test_autowrap_store_files_issue_gh12939():
248
+ x, y = symbols('x y')
249
+ tmp = './tmp'
250
+ saved_cwd = os.getcwd()
251
+ temp_cwd = tempfile.mkdtemp()
252
+ try:
253
+ os.chdir(temp_cwd)
254
+ f = autowrap(x + y, backend='dummy', tempdir=tmp)
255
+ assert f() == str(x + y)
256
+ assert os.access(tmp, os.F_OK)
257
+ finally:
258
+ os.chdir(saved_cwd)
259
+ shutil.rmtree(temp_cwd)
260
+
261
+
262
+ def test_binary_function():
263
+ x, y = symbols('x y')
264
+ f = binary_function('f', x + y, backend='dummy')
265
+ assert f._imp_() == str(x + y)
266
+
267
+
268
+ def test_ufuncify_source():
269
+ x, y, z = symbols('x,y,z')
270
+ code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
271
+ routine = make_routine("test", x + y + z)
272
+ source = get_string(code_wrapper.dump_c, [routine])
273
+ expected = """\
274
+ #include "Python.h"
275
+ #include "math.h"
276
+ #include "numpy/ndarraytypes.h"
277
+ #include "numpy/ufuncobject.h"
278
+ #include "numpy/halffloat.h"
279
+ #include "file.h"
280
+
281
+ static PyMethodDef wrapper_module_%(num)sMethods[] = {
282
+ {NULL, NULL, 0, NULL}
283
+ };
284
+
285
+ #ifdef NPY_1_19_API_VERSION
286
+ static void test_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
287
+ #else
288
+ static void test_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
289
+ #endif
290
+ {
291
+ npy_intp i;
292
+ npy_intp n = dimensions[0];
293
+ char *in0 = args[0];
294
+ char *in1 = args[1];
295
+ char *in2 = args[2];
296
+ char *out0 = args[3];
297
+ npy_intp in0_step = steps[0];
298
+ npy_intp in1_step = steps[1];
299
+ npy_intp in2_step = steps[2];
300
+ npy_intp out0_step = steps[3];
301
+ for (i = 0; i < n; i++) {
302
+ *((double *)out0) = test(*(double *)in0, *(double *)in1, *(double *)in2);
303
+ in0 += in0_step;
304
+ in1 += in1_step;
305
+ in2 += in2_step;
306
+ out0 += out0_step;
307
+ }
308
+ }
309
+ PyUFuncGenericFunction test_funcs[1] = {&test_ufunc};
310
+ static char test_types[4] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
311
+ static void *test_data[1] = {NULL};
312
+
313
+ #if PY_VERSION_HEX >= 0x03000000
314
+ static struct PyModuleDef moduledef = {
315
+ PyModuleDef_HEAD_INIT,
316
+ "wrapper_module_%(num)s",
317
+ NULL,
318
+ -1,
319
+ wrapper_module_%(num)sMethods,
320
+ NULL,
321
+ NULL,
322
+ NULL,
323
+ NULL
324
+ };
325
+
326
+ PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
327
+ {
328
+ PyObject *m, *d;
329
+ PyObject *ufunc0;
330
+ m = PyModule_Create(&moduledef);
331
+ if (!m) {
332
+ return NULL;
333
+ }
334
+ import_array();
335
+ import_umath();
336
+ d = PyModule_GetDict(m);
337
+ ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
338
+ PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
339
+ PyDict_SetItemString(d, "test", ufunc0);
340
+ Py_DECREF(ufunc0);
341
+ return m;
342
+ }
343
+ #else
344
+ PyMODINIT_FUNC initwrapper_module_%(num)s(void)
345
+ {
346
+ PyObject *m, *d;
347
+ PyObject *ufunc0;
348
+ m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
349
+ if (m == NULL) {
350
+ return;
351
+ }
352
+ import_array();
353
+ import_umath();
354
+ d = PyModule_GetDict(m);
355
+ ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
356
+ PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
357
+ PyDict_SetItemString(d, "test", ufunc0);
358
+ Py_DECREF(ufunc0);
359
+ }
360
+ #endif""" % {'num': CodeWrapper._module_counter}
361
+ assert source == expected
362
+
363
+
364
+ def test_ufuncify_source_multioutput():
365
+ x, y, z = symbols('x,y,z')
366
+ var_symbols = (x, y, z)
367
+ expr = x + y**3 + 10*z**2
368
+ code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
369
+ routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))]
370
+ source = get_string(code_wrapper.dump_c, routines, funcname='multitest')
371
+ expected = """\
372
+ #include "Python.h"
373
+ #include "math.h"
374
+ #include "numpy/ndarraytypes.h"
375
+ #include "numpy/ufuncobject.h"
376
+ #include "numpy/halffloat.h"
377
+ #include "file.h"
378
+
379
+ static PyMethodDef wrapper_module_%(num)sMethods[] = {
380
+ {NULL, NULL, 0, NULL}
381
+ };
382
+
383
+ #ifdef NPY_1_19_API_VERSION
384
+ static void multitest_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
385
+ #else
386
+ static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
387
+ #endif
388
+ {
389
+ npy_intp i;
390
+ npy_intp n = dimensions[0];
391
+ char *in0 = args[0];
392
+ char *in1 = args[1];
393
+ char *in2 = args[2];
394
+ char *out0 = args[3];
395
+ char *out1 = args[4];
396
+ char *out2 = args[5];
397
+ npy_intp in0_step = steps[0];
398
+ npy_intp in1_step = steps[1];
399
+ npy_intp in2_step = steps[2];
400
+ npy_intp out0_step = steps[3];
401
+ npy_intp out1_step = steps[4];
402
+ npy_intp out2_step = steps[5];
403
+ for (i = 0; i < n; i++) {
404
+ *((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2);
405
+ *((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2);
406
+ *((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2);
407
+ in0 += in0_step;
408
+ in1 += in1_step;
409
+ in2 += in2_step;
410
+ out0 += out0_step;
411
+ out1 += out1_step;
412
+ out2 += out2_step;
413
+ }
414
+ }
415
+ PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc};
416
+ static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
417
+ static void *multitest_data[1] = {NULL};
418
+
419
+ #if PY_VERSION_HEX >= 0x03000000
420
+ static struct PyModuleDef moduledef = {
421
+ PyModuleDef_HEAD_INIT,
422
+ "wrapper_module_%(num)s",
423
+ NULL,
424
+ -1,
425
+ wrapper_module_%(num)sMethods,
426
+ NULL,
427
+ NULL,
428
+ NULL,
429
+ NULL
430
+ };
431
+
432
+ PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
433
+ {
434
+ PyObject *m, *d;
435
+ PyObject *ufunc0;
436
+ m = PyModule_Create(&moduledef);
437
+ if (!m) {
438
+ return NULL;
439
+ }
440
+ import_array();
441
+ import_umath();
442
+ d = PyModule_GetDict(m);
443
+ ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
444
+ PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
445
+ PyDict_SetItemString(d, "multitest", ufunc0);
446
+ Py_DECREF(ufunc0);
447
+ return m;
448
+ }
449
+ #else
450
+ PyMODINIT_FUNC initwrapper_module_%(num)s(void)
451
+ {
452
+ PyObject *m, *d;
453
+ PyObject *ufunc0;
454
+ m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
455
+ if (m == NULL) {
456
+ return;
457
+ }
458
+ import_array();
459
+ import_umath();
460
+ d = PyModule_GetDict(m);
461
+ ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
462
+ PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
463
+ PyDict_SetItemString(d, "multitest", ufunc0);
464
+ Py_DECREF(ufunc0);
465
+ }
466
+ #endif""" % {'num': CodeWrapper._module_counter}
467
+ assert source == expected
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_octave.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+
3
+ from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
4
+ from sympy.core.relational import Equality
5
+ from sympy.functions.elementary.piecewise import Piecewise
6
+ from sympy.matrices import Matrix, MatrixSymbol
7
+ from sympy.utilities.codegen import OctaveCodeGen, codegen, make_routine
8
+ from sympy.testing.pytest import raises
9
+ from sympy.testing.pytest import XFAIL
10
+ import sympy
11
+
12
+
13
+ x, y, z = symbols('x,y,z')
14
+
15
+
16
+ def test_empty_m_code():
17
+ code_gen = OctaveCodeGen()
18
+ output = StringIO()
19
+ code_gen.dump_m([], output, "file", header=False, empty=False)
20
+ source = output.getvalue()
21
+ assert source == ""
22
+
23
+
24
+ def test_m_simple_code():
25
+ name_expr = ("test", (x + y)*z)
26
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
27
+ assert result[0] == "test.m"
28
+ source = result[1]
29
+ expected = (
30
+ "function out1 = test(x, y, z)\n"
31
+ " out1 = z.*(x + y);\n"
32
+ "end\n"
33
+ )
34
+ assert source == expected
35
+
36
+
37
+ def test_m_simple_code_with_header():
38
+ name_expr = ("test", (x + y)*z)
39
+ result, = codegen(name_expr, "Octave", header=True, empty=False)
40
+ assert result[0] == "test.m"
41
+ source = result[1]
42
+ expected = (
43
+ "function out1 = test(x, y, z)\n"
44
+ " %TEST Autogenerated by SymPy\n"
45
+ " % Code generated with SymPy " + sympy.__version__ + "\n"
46
+ " %\n"
47
+ " % See http://www.sympy.org/ for more information.\n"
48
+ " %\n"
49
+ " % This file is part of 'project'\n"
50
+ " out1 = z.*(x + y);\n"
51
+ "end\n"
52
+ )
53
+ assert source == expected
54
+
55
+
56
+ def test_m_simple_code_nameout():
57
+ expr = Equality(z, (x + y))
58
+ name_expr = ("test", expr)
59
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
60
+ source = result[1]
61
+ expected = (
62
+ "function z = test(x, y)\n"
63
+ " z = x + y;\n"
64
+ "end\n"
65
+ )
66
+ assert source == expected
67
+
68
+
69
+ def test_m_numbersymbol():
70
+ name_expr = ("test", pi**Catalan)
71
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
72
+ source = result[1]
73
+ expected = (
74
+ "function out1 = test()\n"
75
+ " out1 = pi^%s;\n"
76
+ "end\n"
77
+ ) % Catalan.evalf(17)
78
+ assert source == expected
79
+
80
+
81
+ @XFAIL
82
+ def test_m_numbersymbol_no_inline():
83
+ # FIXME: how to pass inline=False to the OctaveCodePrinter?
84
+ name_expr = ("test", [pi**Catalan, EulerGamma])
85
+ result, = codegen(name_expr, "Octave", header=False,
86
+ empty=False, inline=False)
87
+ source = result[1]
88
+ expected = (
89
+ "function [out1, out2] = test()\n"
90
+ " Catalan = 0.915965594177219; % constant\n"
91
+ " EulerGamma = 0.5772156649015329; % constant\n"
92
+ " out1 = pi^Catalan;\n"
93
+ " out2 = EulerGamma;\n"
94
+ "end\n"
95
+ )
96
+ assert source == expected
97
+
98
+
99
+ def test_m_code_argument_order():
100
+ expr = x + y
101
+ routine = make_routine("test", expr, argument_sequence=[z, x, y], language="octave")
102
+ code_gen = OctaveCodeGen()
103
+ output = StringIO()
104
+ code_gen.dump_m([routine], output, "test", header=False, empty=False)
105
+ source = output.getvalue()
106
+ expected = (
107
+ "function out1 = test(z, x, y)\n"
108
+ " out1 = x + y;\n"
109
+ "end\n"
110
+ )
111
+ assert source == expected
112
+
113
+
114
+ def test_multiple_results_m():
115
+ # Here the output order is the input order
116
+ expr1 = (x + y)*z
117
+ expr2 = (x - y)*z
118
+ name_expr = ("test", [expr1, expr2])
119
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
120
+ source = result[1]
121
+ expected = (
122
+ "function [out1, out2] = test(x, y, z)\n"
123
+ " out1 = z.*(x + y);\n"
124
+ " out2 = z.*(x - y);\n"
125
+ "end\n"
126
+ )
127
+ assert source == expected
128
+
129
+
130
+ def test_results_named_unordered():
131
+ # Here output order is based on name_expr
132
+ A, B, C = symbols('A,B,C')
133
+ expr1 = Equality(C, (x + y)*z)
134
+ expr2 = Equality(A, (x - y)*z)
135
+ expr3 = Equality(B, 2*x)
136
+ name_expr = ("test", [expr1, expr2, expr3])
137
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
138
+ source = result[1]
139
+ expected = (
140
+ "function [C, A, B] = test(x, y, z)\n"
141
+ " C = z.*(x + y);\n"
142
+ " A = z.*(x - y);\n"
143
+ " B = 2*x;\n"
144
+ "end\n"
145
+ )
146
+ assert source == expected
147
+
148
+
149
+ def test_results_named_ordered():
150
+ A, B, C = symbols('A,B,C')
151
+ expr1 = Equality(C, (x + y)*z)
152
+ expr2 = Equality(A, (x - y)*z)
153
+ expr3 = Equality(B, 2*x)
154
+ name_expr = ("test", [expr1, expr2, expr3])
155
+ result = codegen(name_expr, "Octave", header=False, empty=False,
156
+ argument_sequence=(x, z, y))
157
+ assert result[0][0] == "test.m"
158
+ source = result[0][1]
159
+ expected = (
160
+ "function [C, A, B] = test(x, z, y)\n"
161
+ " C = z.*(x + y);\n"
162
+ " A = z.*(x - y);\n"
163
+ " B = 2*x;\n"
164
+ "end\n"
165
+ )
166
+ assert source == expected
167
+
168
+
169
+ def test_complicated_m_codegen():
170
+ from sympy.functions.elementary.trigonometric import (cos, sin, tan)
171
+ name_expr = ("testlong",
172
+ [ ((sin(x) + cos(y) + tan(z))**3).expand(),
173
+ cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
174
+ ])
175
+ result = codegen(name_expr, "Octave", header=False, empty=False)
176
+ assert result[0][0] == "testlong.m"
177
+ source = result[0][1]
178
+ expected = (
179
+ "function [out1, out2] = testlong(x, y, z)\n"
180
+ " out1 = sin(x).^3 + 3*sin(x).^2.*cos(y) + 3*sin(x).^2.*tan(z)"
181
+ " + 3*sin(x).*cos(y).^2 + 6*sin(x).*cos(y).*tan(z) + 3*sin(x).*tan(z).^2"
182
+ " + cos(y).^3 + 3*cos(y).^2.*tan(z) + 3*cos(y).*tan(z).^2 + tan(z).^3;\n"
183
+ " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n"
184
+ "end\n"
185
+ )
186
+ assert source == expected
187
+
188
+
189
+ def test_m_output_arg_mixed_unordered():
190
+ # named outputs are alphabetical, unnamed output appear in the given order
191
+ from sympy.functions.elementary.trigonometric import (cos, sin)
192
+ a = symbols("a")
193
+ name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
194
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
195
+ assert result[0] == "foo.m"
196
+ source = result[1]
197
+ expected = (
198
+ 'function [out1, y, out3, a] = foo(x)\n'
199
+ ' out1 = cos(2*x);\n'
200
+ ' y = sin(x);\n'
201
+ ' out3 = cos(x);\n'
202
+ ' a = sin(2*x);\n'
203
+ 'end\n'
204
+ )
205
+ assert source == expected
206
+
207
+
208
+ def test_m_piecewise_():
209
+ pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
210
+ name_expr = ("pwtest", pw)
211
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
212
+ source = result[1]
213
+ expected = (
214
+ "function out1 = pwtest(x)\n"
215
+ " out1 = ((x < -1).*(0) + (~(x < -1)).*( ...\n"
216
+ " (x <= 1).*(x.^2) + (~(x <= 1)).*( ...\n"
217
+ " (x > 1).*(2 - x) + (~(x > 1)).*(1))));\n"
218
+ "end\n"
219
+ )
220
+ assert source == expected
221
+
222
+
223
+ @XFAIL
224
+ def test_m_piecewise_no_inline():
225
+ # FIXME: how to pass inline=False to the OctaveCodePrinter?
226
+ pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
227
+ name_expr = ("pwtest", pw)
228
+ result, = codegen(name_expr, "Octave", header=False, empty=False,
229
+ inline=False)
230
+ source = result[1]
231
+ expected = (
232
+ "function out1 = pwtest(x)\n"
233
+ " if (x < -1)\n"
234
+ " out1 = 0;\n"
235
+ " elseif (x <= 1)\n"
236
+ " out1 = x.^2;\n"
237
+ " elseif (x > 1)\n"
238
+ " out1 = -x + 2;\n"
239
+ " else\n"
240
+ " out1 = 1;\n"
241
+ " end\n"
242
+ "end\n"
243
+ )
244
+ assert source == expected
245
+
246
+
247
+ def test_m_multifcns_per_file():
248
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
249
+ result = codegen(name_expr, "Octave", header=False, empty=False)
250
+ assert result[0][0] == "foo.m"
251
+ source = result[0][1]
252
+ expected = (
253
+ "function [out1, out2] = foo(x, y)\n"
254
+ " out1 = 2*x;\n"
255
+ " out2 = 3*y;\n"
256
+ "end\n"
257
+ "function [out1, out2] = bar(y)\n"
258
+ " out1 = y.^2;\n"
259
+ " out2 = 4*y;\n"
260
+ "end\n"
261
+ )
262
+ assert source == expected
263
+
264
+
265
+ def test_m_multifcns_per_file_w_header():
266
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
267
+ result = codegen(name_expr, "Octave", header=True, empty=False)
268
+ assert result[0][0] == "foo.m"
269
+ source = result[0][1]
270
+ expected = (
271
+ "function [out1, out2] = foo(x, y)\n"
272
+ " %FOO Autogenerated by SymPy\n"
273
+ " % Code generated with SymPy " + sympy.__version__ + "\n"
274
+ " %\n"
275
+ " % See http://www.sympy.org/ for more information.\n"
276
+ " %\n"
277
+ " % This file is part of 'project'\n"
278
+ " out1 = 2*x;\n"
279
+ " out2 = 3*y;\n"
280
+ "end\n"
281
+ "function [out1, out2] = bar(y)\n"
282
+ " out1 = y.^2;\n"
283
+ " out2 = 4*y;\n"
284
+ "end\n"
285
+ )
286
+ assert source == expected
287
+
288
+
289
+ def test_m_filename_match_first_fcn():
290
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
291
+ raises(ValueError, lambda: codegen(name_expr,
292
+ "Octave", prefix="bar", header=False, empty=False))
293
+
294
+
295
+ def test_m_matrix_named():
296
+ e2 = Matrix([[x, 2*y, pi*z]])
297
+ name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
298
+ result = codegen(name_expr, "Octave", header=False, empty=False)
299
+ assert result[0][0] == "test.m"
300
+ source = result[0][1]
301
+ expected = (
302
+ "function myout1 = test(x, y, z)\n"
303
+ " myout1 = [x 2*y pi*z];\n"
304
+ "end\n"
305
+ )
306
+ assert source == expected
307
+
308
+
309
+ def test_m_matrix_named_matsym():
310
+ myout1 = MatrixSymbol('myout1', 1, 3)
311
+ e2 = Matrix([[x, 2*y, pi*z]])
312
+ name_expr = ("test", Equality(myout1, e2, evaluate=False))
313
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
314
+ source = result[1]
315
+ expected = (
316
+ "function myout1 = test(x, y, z)\n"
317
+ " myout1 = [x 2*y pi*z];\n"
318
+ "end\n"
319
+ )
320
+ assert source == expected
321
+
322
+
323
+ def test_m_matrix_output_autoname():
324
+ expr = Matrix([[x, x+y, 3]])
325
+ name_expr = ("test", expr)
326
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
327
+ source = result[1]
328
+ expected = (
329
+ "function out1 = test(x, y)\n"
330
+ " out1 = [x x + y 3];\n"
331
+ "end\n"
332
+ )
333
+ assert source == expected
334
+
335
+
336
+ def test_m_matrix_output_autoname_2():
337
+ e1 = (x + y)
338
+ e2 = Matrix([[2*x, 2*y, 2*z]])
339
+ e3 = Matrix([[x], [y], [z]])
340
+ e4 = Matrix([[x, y], [z, 16]])
341
+ name_expr = ("test", (e1, e2, e3, e4))
342
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
343
+ source = result[1]
344
+ expected = (
345
+ "function [out1, out2, out3, out4] = test(x, y, z)\n"
346
+ " out1 = x + y;\n"
347
+ " out2 = [2*x 2*y 2*z];\n"
348
+ " out3 = [x; y; z];\n"
349
+ " out4 = [x y; z 16];\n"
350
+ "end\n"
351
+ )
352
+ assert source == expected
353
+
354
+
355
+ def test_m_results_matrix_named_ordered():
356
+ B, C = symbols('B,C')
357
+ A = MatrixSymbol('A', 1, 3)
358
+ expr1 = Equality(C, (x + y)*z)
359
+ expr2 = Equality(A, Matrix([[1, 2, x]]))
360
+ expr3 = Equality(B, 2*x)
361
+ name_expr = ("test", [expr1, expr2, expr3])
362
+ result, = codegen(name_expr, "Octave", header=False, empty=False,
363
+ argument_sequence=(x, z, y))
364
+ source = result[1]
365
+ expected = (
366
+ "function [C, A, B] = test(x, z, y)\n"
367
+ " C = z.*(x + y);\n"
368
+ " A = [1 2 x];\n"
369
+ " B = 2*x;\n"
370
+ "end\n"
371
+ )
372
+ assert source == expected
373
+
374
+
375
+ def test_m_matrixsymbol_slice():
376
+ A = MatrixSymbol('A', 2, 3)
377
+ B = MatrixSymbol('B', 1, 3)
378
+ C = MatrixSymbol('C', 1, 3)
379
+ D = MatrixSymbol('D', 2, 1)
380
+ name_expr = ("test", [Equality(B, A[0, :]),
381
+ Equality(C, A[1, :]),
382
+ Equality(D, A[:, 2])])
383
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
384
+ source = result[1]
385
+ expected = (
386
+ "function [B, C, D] = test(A)\n"
387
+ " B = A(1, :);\n"
388
+ " C = A(2, :);\n"
389
+ " D = A(:, 3);\n"
390
+ "end\n"
391
+ )
392
+ assert source == expected
393
+
394
+
395
+ def test_m_matrixsymbol_slice2():
396
+ A = MatrixSymbol('A', 3, 4)
397
+ B = MatrixSymbol('B', 2, 2)
398
+ C = MatrixSymbol('C', 2, 2)
399
+ name_expr = ("test", [Equality(B, A[0:2, 0:2]),
400
+ Equality(C, A[0:2, 1:3])])
401
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
402
+ source = result[1]
403
+ expected = (
404
+ "function [B, C] = test(A)\n"
405
+ " B = A(1:2, 1:2);\n"
406
+ " C = A(1:2, 2:3);\n"
407
+ "end\n"
408
+ )
409
+ assert source == expected
410
+
411
+
412
+ def test_m_matrixsymbol_slice3():
413
+ A = MatrixSymbol('A', 8, 7)
414
+ B = MatrixSymbol('B', 2, 2)
415
+ C = MatrixSymbol('C', 4, 2)
416
+ name_expr = ("test", [Equality(B, A[6:, 1::3]),
417
+ Equality(C, A[::2, ::3])])
418
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
419
+ source = result[1]
420
+ expected = (
421
+ "function [B, C] = test(A)\n"
422
+ " B = A(7:end, 2:3:end);\n"
423
+ " C = A(1:2:end, 1:3:end);\n"
424
+ "end\n"
425
+ )
426
+ assert source == expected
427
+
428
+
429
+ def test_m_matrixsymbol_slice_autoname():
430
+ A = MatrixSymbol('A', 2, 3)
431
+ B = MatrixSymbol('B', 1, 3)
432
+ name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
433
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
434
+ source = result[1]
435
+ expected = (
436
+ "function [B, out2, out3, out4] = test(A)\n"
437
+ " B = A(1, :);\n"
438
+ " out2 = A(2, :);\n"
439
+ " out3 = A(:, 1);\n"
440
+ " out4 = A(:, 2);\n"
441
+ "end\n"
442
+ )
443
+ assert source == expected
444
+
445
+
446
+ def test_m_loops():
447
+ # Note: an Octave programmer would probably vectorize this across one or
448
+ # more dimensions. Also, size(A) would be used rather than passing in m
449
+ # and n. Perhaps users would expect us to vectorize automatically here?
450
+ # Or is it possible to represent such things using IndexedBase?
451
+ from sympy.tensor import IndexedBase, Idx
452
+ from sympy.core.symbol import symbols
453
+ n, m = symbols('n m', integer=True)
454
+ A = IndexedBase('A')
455
+ x = IndexedBase('x')
456
+ y = IndexedBase('y')
457
+ i = Idx('i', m)
458
+ j = Idx('j', n)
459
+ result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Octave",
460
+ header=False, empty=False)
461
+ source = result[1]
462
+ expected = (
463
+ 'function y = mat_vec_mult(A, m, n, x)\n'
464
+ ' for i = 1:m\n'
465
+ ' y(i) = 0;\n'
466
+ ' end\n'
467
+ ' for i = 1:m\n'
468
+ ' for j = 1:n\n'
469
+ ' y(i) = %(rhs)s + y(i);\n'
470
+ ' end\n'
471
+ ' end\n'
472
+ 'end\n'
473
+ )
474
+ assert (source == expected % {'rhs': 'A(%s, %s).*x(j)' % (i, j)} or
475
+ source == expected % {'rhs': 'x(j).*A(%s, %s)' % (i, j)})
476
+
477
+
478
+ def test_m_tensor_loops_multiple_contractions():
479
+ # see comments in previous test about vectorizing
480
+ from sympy.tensor import IndexedBase, Idx
481
+ from sympy.core.symbol import symbols
482
+ n, m, o, p = symbols('n m o p', integer=True)
483
+ A = IndexedBase('A')
484
+ B = IndexedBase('B')
485
+ y = IndexedBase('y')
486
+ i = Idx('i', m)
487
+ j = Idx('j', n)
488
+ k = Idx('k', o)
489
+ l = Idx('l', p)
490
+ result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
491
+ "Octave", header=False, empty=False)
492
+ source = result[1]
493
+ expected = (
494
+ 'function y = tensorthing(A, B, m, n, o, p)\n'
495
+ ' for i = 1:m\n'
496
+ ' y(i) = 0;\n'
497
+ ' end\n'
498
+ ' for i = 1:m\n'
499
+ ' for j = 1:n\n'
500
+ ' for k = 1:o\n'
501
+ ' for l = 1:p\n'
502
+ ' y(i) = A(i, j, k, l).*B(j, k, l) + y(i);\n'
503
+ ' end\n'
504
+ ' end\n'
505
+ ' end\n'
506
+ ' end\n'
507
+ 'end\n'
508
+ )
509
+ assert source == expected
510
+
511
+
512
+ def test_m_InOutArgument():
513
+ expr = Equality(x, x**2)
514
+ name_expr = ("mysqr", expr)
515
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
516
+ source = result[1]
517
+ expected = (
518
+ "function x = mysqr(x)\n"
519
+ " x = x.^2;\n"
520
+ "end\n"
521
+ )
522
+ assert source == expected
523
+
524
+
525
+ def test_m_InOutArgument_order():
526
+ # can specify the order as (x, y)
527
+ expr = Equality(x, x**2 + y)
528
+ name_expr = ("test", expr)
529
+ result, = codegen(name_expr, "Octave", header=False,
530
+ empty=False, argument_sequence=(x,y))
531
+ source = result[1]
532
+ expected = (
533
+ "function x = test(x, y)\n"
534
+ " x = x.^2 + y;\n"
535
+ "end\n"
536
+ )
537
+ assert source == expected
538
+ # make sure it gives (x, y) not (y, x)
539
+ expr = Equality(x, x**2 + y)
540
+ name_expr = ("test", expr)
541
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
542
+ source = result[1]
543
+ expected = (
544
+ "function x = test(x, y)\n"
545
+ " x = x.^2 + y;\n"
546
+ "end\n"
547
+ )
548
+ assert source == expected
549
+
550
+
551
+ def test_m_not_supported():
552
+ f = Function('f')
553
+ name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
554
+ result, = codegen(name_expr, "Octave", header=False, empty=False)
555
+ source = result[1]
556
+ expected = (
557
+ "function [out1, out2] = test(x)\n"
558
+ " % unsupported: Derivative(f(x), x)\n"
559
+ " % unsupported: zoo\n"
560
+ " out1 = Derivative(f(x), x);\n"
561
+ " out2 = zoo;\n"
562
+ "end\n"
563
+ )
564
+ assert source == expected
565
+
566
+
567
+ def test_global_vars_octave():
568
+ x, y, z, t = symbols("x y z t")
569
+ result = codegen(('f', x*y), "Octave", header=False, empty=False,
570
+ global_vars=(y,))
571
+ source = result[0][1]
572
+ expected = (
573
+ "function out1 = f(x)\n"
574
+ " global y\n"
575
+ " out1 = x.*y;\n"
576
+ "end\n"
577
+ )
578
+ assert source == expected
579
+
580
+ result = codegen(('f', x*y+z), "Octave", header=False, empty=False,
581
+ argument_sequence=(x, y), global_vars=(z, t))
582
+ source = result[0][1]
583
+ expected = (
584
+ "function out1 = f(x, y)\n"
585
+ " global t z\n"
586
+ " out1 = x.*y + z;\n"
587
+ "end\n"
588
+ )
589
+ assert source == expected
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_rust.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+
3
+ from sympy.core import S, symbols, pi, Catalan, EulerGamma, Function
4
+ from sympy.core.relational import Equality
5
+ from sympy.functions.elementary.piecewise import Piecewise
6
+ from sympy.utilities.codegen import RustCodeGen, codegen, make_routine
7
+ from sympy.testing.pytest import XFAIL
8
+ import sympy
9
+
10
+
11
+ x, y, z = symbols('x,y,z')
12
+
13
+
14
+ def test_empty_rust_code():
15
+ code_gen = RustCodeGen()
16
+ output = StringIO()
17
+ code_gen.dump_rs([], output, "file", header=False, empty=False)
18
+ source = output.getvalue()
19
+ assert source == ""
20
+
21
+
22
+ def test_simple_rust_code():
23
+ name_expr = ("test", (x + y)*z)
24
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
25
+ assert result[0] == "test.rs"
26
+ source = result[1]
27
+ expected = (
28
+ "fn test(x: f64, y: f64, z: f64) -> f64 {\n"
29
+ " let out1 = z*(x + y);\n"
30
+ " out1\n"
31
+ "}\n"
32
+ )
33
+ assert source == expected
34
+
35
+
36
+ def test_simple_code_with_header():
37
+ name_expr = ("test", (x + y)*z)
38
+ result, = codegen(name_expr, "Rust", header=True, empty=False)
39
+ assert result[0] == "test.rs"
40
+ source = result[1]
41
+ version_str = "Code generated with SymPy %s" % sympy.__version__
42
+ version_line = version_str.center(76).rstrip()
43
+ expected = (
44
+ "/*\n"
45
+ " *%(version_line)s\n"
46
+ " *\n"
47
+ " * See http://www.sympy.org/ for more information.\n"
48
+ " *\n"
49
+ " * This file is part of 'project'\n"
50
+ " */\n"
51
+ "fn test(x: f64, y: f64, z: f64) -> f64 {\n"
52
+ " let out1 = z*(x + y);\n"
53
+ " out1\n"
54
+ "}\n"
55
+ ) % {'version_line': version_line}
56
+ assert source == expected
57
+
58
+
59
+ def test_simple_code_nameout():
60
+ expr = Equality(z, (x + y))
61
+ name_expr = ("test", expr)
62
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
63
+ source = result[1]
64
+ expected = (
65
+ "fn test(x: f64, y: f64) -> f64 {\n"
66
+ " let z = x + y;\n"
67
+ " z\n"
68
+ "}\n"
69
+ )
70
+ assert source == expected
71
+
72
+
73
+ def test_numbersymbol():
74
+ name_expr = ("test", pi**Catalan)
75
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
76
+ source = result[1]
77
+ expected = (
78
+ "fn test() -> f64 {\n"
79
+ " const Catalan: f64 = %s;\n"
80
+ " let out1 = PI.powf(Catalan);\n"
81
+ " out1\n"
82
+ "}\n"
83
+ ) % Catalan.evalf(17)
84
+ assert source == expected
85
+
86
+
87
+ @XFAIL
88
+ def test_numbersymbol_inline():
89
+ # FIXME: how to pass inline to the RustCodePrinter?
90
+ name_expr = ("test", [pi**Catalan, EulerGamma])
91
+ result, = codegen(name_expr, "Rust", header=False,
92
+ empty=False, inline=True)
93
+ source = result[1]
94
+ expected = (
95
+ "fn test() -> (f64, f64) {\n"
96
+ " const Catalan: f64 = %s;\n"
97
+ " const EulerGamma: f64 = %s;\n"
98
+ " let out1 = PI.powf(Catalan);\n"
99
+ " let out2 = EulerGamma);\n"
100
+ " (out1, out2)\n"
101
+ "}\n"
102
+ ) % (Catalan.evalf(17), EulerGamma.evalf(17))
103
+ assert source == expected
104
+
105
+
106
+ def test_argument_order():
107
+ expr = x + y
108
+ routine = make_routine("test", expr, argument_sequence=[z, x, y], language="rust")
109
+ code_gen = RustCodeGen()
110
+ output = StringIO()
111
+ code_gen.dump_rs([routine], output, "test", header=False, empty=False)
112
+ source = output.getvalue()
113
+ expected = (
114
+ "fn test(z: f64, x: f64, y: f64) -> f64 {\n"
115
+ " let out1 = x + y;\n"
116
+ " out1\n"
117
+ "}\n"
118
+ )
119
+ assert source == expected
120
+
121
+
122
+ def test_multiple_results_rust():
123
+ # Here the output order is the input order
124
+ expr1 = (x + y)*z
125
+ expr2 = (x - y)*z
126
+ name_expr = ("test", [expr1, expr2])
127
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
128
+ source = result[1]
129
+ expected = (
130
+ "fn test(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
131
+ " let out1 = z*(x + y);\n"
132
+ " let out2 = z*(x - y);\n"
133
+ " (out1, out2)\n"
134
+ "}\n"
135
+ )
136
+ assert source == expected
137
+
138
+
139
+ def test_results_named_unordered():
140
+ # Here output order is based on name_expr
141
+ A, B, C = symbols('A,B,C')
142
+ expr1 = Equality(C, (x + y)*z)
143
+ expr2 = Equality(A, (x - y)*z)
144
+ expr3 = Equality(B, 2*x)
145
+ name_expr = ("test", [expr1, expr2, expr3])
146
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
147
+ source = result[1]
148
+ expected = (
149
+ "fn test(x: f64, y: f64, z: f64) -> (f64, f64, f64) {\n"
150
+ " let C = z*(x + y);\n"
151
+ " let A = z*(x - y);\n"
152
+ " let B = 2*x;\n"
153
+ " (C, A, B)\n"
154
+ "}\n"
155
+ )
156
+ assert source == expected
157
+
158
+
159
+ def test_results_named_ordered():
160
+ A, B, C = symbols('A,B,C')
161
+ expr1 = Equality(C, (x + y)*z)
162
+ expr2 = Equality(A, (x - y)*z)
163
+ expr3 = Equality(B, 2*x)
164
+ name_expr = ("test", [expr1, expr2, expr3])
165
+ result = codegen(name_expr, "Rust", header=False, empty=False,
166
+ argument_sequence=(x, z, y))
167
+ assert result[0][0] == "test.rs"
168
+ source = result[0][1]
169
+ expected = (
170
+ "fn test(x: f64, z: f64, y: f64) -> (f64, f64, f64) {\n"
171
+ " let C = z*(x + y);\n"
172
+ " let A = z*(x - y);\n"
173
+ " let B = 2*x;\n"
174
+ " (C, A, B)\n"
175
+ "}\n"
176
+ )
177
+ assert source == expected
178
+
179
+
180
+ def test_complicated_rs_codegen():
181
+ from sympy.functions.elementary.trigonometric import (cos, sin, tan)
182
+ name_expr = ("testlong",
183
+ [ ((sin(x) + cos(y) + tan(z))**3).expand(),
184
+ cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
185
+ ])
186
+ result = codegen(name_expr, "Rust", header=False, empty=False)
187
+ assert result[0][0] == "testlong.rs"
188
+ source = result[0][1]
189
+ expected = (
190
+ "fn testlong(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
191
+ " let out1 = x.sin().powi(3) + 3*x.sin().powi(2)*y.cos()"
192
+ " + 3*x.sin().powi(2)*z.tan() + 3*x.sin()*y.cos().powi(2)"
193
+ " + 6*x.sin()*y.cos()*z.tan() + 3*x.sin()*z.tan().powi(2)"
194
+ " + y.cos().powi(3) + 3*y.cos().powi(2)*z.tan()"
195
+ " + 3*y.cos()*z.tan().powi(2) + z.tan().powi(3);\n"
196
+ " let out2 = (x + y + z).cos().cos().cos().cos()"
197
+ ".cos().cos().cos().cos();\n"
198
+ " (out1, out2)\n"
199
+ "}\n"
200
+ )
201
+ assert source == expected
202
+
203
+
204
+ def test_output_arg_mixed_unordered():
205
+ # named outputs are alphabetical, unnamed output appear in the given order
206
+ from sympy.functions.elementary.trigonometric import (cos, sin)
207
+ a = symbols("a")
208
+ name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
209
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
210
+ assert result[0] == "foo.rs"
211
+ source = result[1]
212
+ expected = (
213
+ "fn foo(x: f64) -> (f64, f64, f64, f64) {\n"
214
+ " let out1 = (2*x).cos();\n"
215
+ " let y = x.sin();\n"
216
+ " let out3 = x.cos();\n"
217
+ " let a = (2*x).sin();\n"
218
+ " (out1, y, out3, a)\n"
219
+ "}\n"
220
+ )
221
+ assert source == expected
222
+
223
+
224
+ def test_piecewise_():
225
+ pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
226
+ name_expr = ("pwtest", pw)
227
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
228
+ source = result[1]
229
+ expected = (
230
+ "fn pwtest(x: f64) -> f64 {\n"
231
+ " let out1 = if (x < -1.0) {\n"
232
+ " 0\n"
233
+ " } else if (x <= 1.0) {\n"
234
+ " x.powi(2)\n"
235
+ " } else if (x > 1.0) {\n"
236
+ " 2 - x\n"
237
+ " } else {\n"
238
+ " 1\n"
239
+ " };\n"
240
+ " out1\n"
241
+ "}\n"
242
+ )
243
+ assert source == expected
244
+
245
+
246
+ @XFAIL
247
+ def test_piecewise_inline():
248
+ # FIXME: how to pass inline to the RustCodePrinter?
249
+ pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
250
+ name_expr = ("pwtest", pw)
251
+ result, = codegen(name_expr, "Rust", header=False, empty=False,
252
+ inline=True)
253
+ source = result[1]
254
+ expected = (
255
+ "fn pwtest(x: f64) -> f64 {\n"
256
+ " let out1 = if (x < -1) { 0 } else if (x <= 1) { x.powi(2) }"
257
+ " else if (x > 1) { -x + 2 } else { 1 };\n"
258
+ " out1\n"
259
+ "}\n"
260
+ )
261
+ assert source == expected
262
+
263
+
264
+ def test_multifcns_per_file():
265
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
266
+ result = codegen(name_expr, "Rust", header=False, empty=False)
267
+ assert result[0][0] == "foo.rs"
268
+ source = result[0][1]
269
+ expected = (
270
+ "fn foo(x: f64, y: f64) -> (f64, f64) {\n"
271
+ " let out1 = 2*x;\n"
272
+ " let out2 = 3*y;\n"
273
+ " (out1, out2)\n"
274
+ "}\n"
275
+ "fn bar(y: f64) -> (f64, f64) {\n"
276
+ " let out1 = y.powi(2);\n"
277
+ " let out2 = 4*y;\n"
278
+ " (out1, out2)\n"
279
+ "}\n"
280
+ )
281
+ assert source == expected
282
+
283
+
284
+ def test_multifcns_per_file_w_header():
285
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
286
+ result = codegen(name_expr, "Rust", header=True, empty=False)
287
+ assert result[0][0] == "foo.rs"
288
+ source = result[0][1]
289
+ version_str = "Code generated with SymPy %s" % sympy.__version__
290
+ version_line = version_str.center(76).rstrip()
291
+ expected = (
292
+ "/*\n"
293
+ " *%(version_line)s\n"
294
+ " *\n"
295
+ " * See http://www.sympy.org/ for more information.\n"
296
+ " *\n"
297
+ " * This file is part of 'project'\n"
298
+ " */\n"
299
+ "fn foo(x: f64, y: f64) -> (f64, f64) {\n"
300
+ " let out1 = 2*x;\n"
301
+ " let out2 = 3*y;\n"
302
+ " (out1, out2)\n"
303
+ "}\n"
304
+ "fn bar(y: f64) -> (f64, f64) {\n"
305
+ " let out1 = y.powi(2);\n"
306
+ " let out2 = 4*y;\n"
307
+ " (out1, out2)\n"
308
+ "}\n"
309
+ ) % {'version_line': version_line}
310
+ assert source == expected
311
+
312
+
313
+ def test_filename_match_prefix():
314
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
315
+ result, = codegen(name_expr, "Rust", prefix="baz", header=False,
316
+ empty=False)
317
+ assert result[0] == "baz.rs"
318
+
319
+
320
+ def test_InOutArgument():
321
+ expr = Equality(x, x**2)
322
+ name_expr = ("mysqr", expr)
323
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
324
+ source = result[1]
325
+ expected = (
326
+ "fn mysqr(x: f64) -> f64 {\n"
327
+ " let x = x.powi(2);\n"
328
+ " x\n"
329
+ "}\n"
330
+ )
331
+ assert source == expected
332
+
333
+
334
+ def test_InOutArgument_order():
335
+ # can specify the order as (x, y)
336
+ expr = Equality(x, x**2 + y)
337
+ name_expr = ("test", expr)
338
+ result, = codegen(name_expr, "Rust", header=False,
339
+ empty=False, argument_sequence=(x,y))
340
+ source = result[1]
341
+ expected = (
342
+ "fn test(x: f64, y: f64) -> f64 {\n"
343
+ " let x = x.powi(2) + y;\n"
344
+ " x\n"
345
+ "}\n"
346
+ )
347
+ assert source == expected
348
+ # make sure it gives (x, y) not (y, x)
349
+ expr = Equality(x, x**2 + y)
350
+ name_expr = ("test", expr)
351
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
352
+ source = result[1]
353
+ expected = (
354
+ "fn test(x: f64, y: f64) -> f64 {\n"
355
+ " let x = x.powi(2) + y;\n"
356
+ " x\n"
357
+ "}\n"
358
+ )
359
+ assert source == expected
360
+
361
+
362
+ def test_not_supported():
363
+ f = Function('f')
364
+ name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
365
+ result, = codegen(name_expr, "Rust", header=False, empty=False)
366
+ source = result[1]
367
+ expected = (
368
+ "fn test(x: f64) -> (f64, f64) {\n"
369
+ " // unsupported: Derivative(f(x), x)\n"
370
+ " // unsupported: zoo\n"
371
+ " let out1 = Derivative(f(x), x);\n"
372
+ " let out2 = zoo;\n"
373
+ " (out1, out2)\n"
374
+ "}\n"
375
+ )
376
+ assert source == expected
377
+
378
+
379
+ def test_global_vars_rust():
380
+ x, y, z, t = symbols("x y z t")
381
+ result = codegen(('f', x*y), "Rust", header=False, empty=False,
382
+ global_vars=(y,))
383
+ source = result[0][1]
384
+ expected = (
385
+ "fn f(x: f64) -> f64 {\n"
386
+ " let out1 = x*y;\n"
387
+ " out1\n"
388
+ "}\n"
389
+ )
390
+ assert source == expected
391
+
392
+ result = codegen(('f', x*y+z), "Rust", header=False, empty=False,
393
+ argument_sequence=(x, y), global_vars=(z, t))
394
+ source = result[0][1]
395
+ expected = (
396
+ "fn f(x: f64, y: f64) -> f64 {\n"
397
+ " let out1 = x*y + z;\n"
398
+ " out1\n"
399
+ "}\n"
400
+ )
401
+ assert source == expected
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_deprecated.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.testing.pytest import warns_deprecated_sympy
2
+
3
+ # See https://github.com/sympy/sympy/pull/18095
4
+
5
+ def test_deprecated_utilities():
6
+ with warns_deprecated_sympy():
7
+ import sympy.utilities.pytest # noqa:F401
8
+ with warns_deprecated_sympy():
9
+ import sympy.utilities.runtests # noqa:F401
10
+ with warns_deprecated_sympy():
11
+ import sympy.utilities.randtest # noqa:F401
12
+ with warns_deprecated_sympy():
13
+ import sympy.utilities.tmpfiles # noqa:F401
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_exceptions.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.testing.pytest import raises
2
+ from sympy.utilities.exceptions import sympy_deprecation_warning
3
+
4
+ # Only test exceptions here because the other cases are tested in the
5
+ # warns_deprecated_sympy tests
6
+ def test_sympy_deprecation_warning():
7
+ raises(TypeError, lambda: sympy_deprecation_warning('test',
8
+ deprecated_since_version=1.10,
9
+ active_deprecations_target='active-deprecations'))
10
+
11
+ raises(ValueError, lambda: sympy_deprecation_warning('test',
12
+ deprecated_since_version="1.10", active_deprecations_target='(active-deprecations)='))
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_iterables.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from textwrap import dedent
2
+ from itertools import islice, product
3
+
4
+ from sympy.core.basic import Basic
5
+ from sympy.core.numbers import Integer
6
+ from sympy.core.sorting import ordered
7
+ from sympy.core.symbol import (Dummy, symbols)
8
+ from sympy.functions.combinatorial.factorials import factorial
9
+ from sympy.matrices.dense import Matrix
10
+ from sympy.combinatorics import RGS_enum, RGS_unrank, Permutation
11
+ from sympy.utilities.iterables import (
12
+ _partition, _set_partitions, binary_partitions, bracelets, capture,
13
+ cartes, common_prefix, common_suffix, connected_components, dict_merge,
14
+ filter_symbols, flatten, generate_bell, generate_derangements,
15
+ generate_involutions, generate_oriented_forest, group, has_dups, ibin,
16
+ iproduct, kbins, minlex, multiset, multiset_combinations,
17
+ multiset_partitions, multiset_permutations, necklaces, numbered_symbols,
18
+ partitions, permutations, postfixes,
19
+ prefixes, reshape, rotate_left, rotate_right, runs, sift,
20
+ strongly_connected_components, subsets, take, topological_sort, unflatten,
21
+ uniq, variations, ordered_partitions, rotations, is_palindromic, iterable,
22
+ NotIterable, multiset_derangements, signed_permutations,
23
+ sequence_partitions, sequence_partitions_empty)
24
+ from sympy.utilities.enumerative import (
25
+ factoring_visitor, multiset_partitions_taocp )
26
+
27
+ from sympy.core.singleton import S
28
+ from sympy.testing.pytest import raises, warns_deprecated_sympy
29
+
30
+ w, x, y, z = symbols('w,x,y,z')
31
+
32
+
33
+ def test_deprecated_iterables():
34
+ from sympy.utilities.iterables import default_sort_key, ordered
35
+ with warns_deprecated_sympy():
36
+ assert list(ordered([y, x])) == [x, y]
37
+ with warns_deprecated_sympy():
38
+ assert sorted([y, x], key=default_sort_key) == [x, y]
39
+
40
+
41
+ def test_is_palindromic():
42
+ assert is_palindromic('')
43
+ assert is_palindromic('x')
44
+ assert is_palindromic('xx')
45
+ assert is_palindromic('xyx')
46
+ assert not is_palindromic('xy')
47
+ assert not is_palindromic('xyzx')
48
+ assert is_palindromic('xxyzzyx', 1)
49
+ assert not is_palindromic('xxyzzyx', 2)
50
+ assert is_palindromic('xxyzzyx', 2, -1)
51
+ assert is_palindromic('xxyzzyx', 2, 6)
52
+ assert is_palindromic('xxyzyx', 1)
53
+ assert not is_palindromic('xxyzyx', 2)
54
+ assert is_palindromic('xxyzyx', 2, 2 + 3)
55
+
56
+
57
+ def test_flatten():
58
+ assert flatten((1, (1,))) == [1, 1]
59
+ assert flatten((x, (x,))) == [x, x]
60
+
61
+ ls = [[(-2, -1), (1, 2)], [(0, 0)]]
62
+
63
+ assert flatten(ls, levels=0) == ls
64
+ assert flatten(ls, levels=1) == [(-2, -1), (1, 2), (0, 0)]
65
+ assert flatten(ls, levels=2) == [-2, -1, 1, 2, 0, 0]
66
+ assert flatten(ls, levels=3) == [-2, -1, 1, 2, 0, 0]
67
+
68
+ raises(ValueError, lambda: flatten(ls, levels=-1))
69
+
70
+ class MyOp(Basic):
71
+ pass
72
+
73
+ assert flatten([MyOp(x, y), z]) == [MyOp(x, y), z]
74
+ assert flatten([MyOp(x, y), z], cls=MyOp) == [x, y, z]
75
+
76
+ assert flatten({1, 11, 2}) == list({1, 11, 2})
77
+
78
+
79
+ def test_iproduct():
80
+ assert list(iproduct()) == [()]
81
+ assert list(iproduct([])) == []
82
+ assert list(iproduct([1,2,3])) == [(1,),(2,),(3,)]
83
+ assert sorted(iproduct([1, 2], [3, 4, 5])) == [
84
+ (1,3),(1,4),(1,5),(2,3),(2,4),(2,5)]
85
+ assert sorted(iproduct([0,1],[0,1],[0,1])) == [
86
+ (0,0,0),(0,0,1),(0,1,0),(0,1,1),(1,0,0),(1,0,1),(1,1,0),(1,1,1)]
87
+ assert iterable(iproduct(S.Integers)) is True
88
+ assert iterable(iproduct(S.Integers, S.Integers)) is True
89
+ assert (3,) in iproduct(S.Integers)
90
+ assert (4, 5) in iproduct(S.Integers, S.Integers)
91
+ assert (1, 2, 3) in iproduct(S.Integers, S.Integers, S.Integers)
92
+ triples = set(islice(iproduct(S.Integers, S.Integers, S.Integers), 1000))
93
+ for n1, n2, n3 in triples:
94
+ assert isinstance(n1, Integer)
95
+ assert isinstance(n2, Integer)
96
+ assert isinstance(n3, Integer)
97
+ for t in set(product(*([range(-2, 3)]*3))):
98
+ assert t in iproduct(S.Integers, S.Integers, S.Integers)
99
+
100
+
101
+ def test_group():
102
+ assert group([]) == []
103
+ assert group([], multiple=False) == []
104
+
105
+ assert group([1]) == [[1]]
106
+ assert group([1], multiple=False) == [(1, 1)]
107
+
108
+ assert group([1, 1]) == [[1, 1]]
109
+ assert group([1, 1], multiple=False) == [(1, 2)]
110
+
111
+ assert group([1, 1, 1]) == [[1, 1, 1]]
112
+ assert group([1, 1, 1], multiple=False) == [(1, 3)]
113
+
114
+ assert group([1, 2, 1]) == [[1], [2], [1]]
115
+ assert group([1, 2, 1], multiple=False) == [(1, 1), (2, 1), (1, 1)]
116
+
117
+ assert group([1, 1, 2, 2, 2, 1, 3, 3]) == [[1, 1], [2, 2, 2], [1], [3, 3]]
118
+ assert group([1, 1, 2, 2, 2, 1, 3, 3], multiple=False) == [(1, 2),
119
+ (2, 3), (1, 1), (3, 2)]
120
+
121
+
122
+ def test_subsets():
123
+ # combinations
124
+ assert list(subsets([1, 2, 3], 0)) == [()]
125
+ assert list(subsets([1, 2, 3], 1)) == [(1,), (2,), (3,)]
126
+ assert list(subsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)]
127
+ assert list(subsets([1, 2, 3], 3)) == [(1, 2, 3)]
128
+ l = list(range(4))
129
+ assert list(subsets(l, 0, repetition=True)) == [()]
130
+ assert list(subsets(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
131
+ assert list(subsets(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
132
+ (0, 3), (1, 1), (1, 2),
133
+ (1, 3), (2, 2), (2, 3),
134
+ (3, 3)]
135
+ assert list(subsets(l, 3, repetition=True)) == [(0, 0, 0), (0, 0, 1),
136
+ (0, 0, 2), (0, 0, 3),
137
+ (0, 1, 1), (0, 1, 2),
138
+ (0, 1, 3), (0, 2, 2),
139
+ (0, 2, 3), (0, 3, 3),
140
+ (1, 1, 1), (1, 1, 2),
141
+ (1, 1, 3), (1, 2, 2),
142
+ (1, 2, 3), (1, 3, 3),
143
+ (2, 2, 2), (2, 2, 3),
144
+ (2, 3, 3), (3, 3, 3)]
145
+ assert len(list(subsets(l, 4, repetition=True))) == 35
146
+
147
+ assert list(subsets(l[:2], 3, repetition=False)) == []
148
+ assert list(subsets(l[:2], 3, repetition=True)) == [(0, 0, 0),
149
+ (0, 0, 1),
150
+ (0, 1, 1),
151
+ (1, 1, 1)]
152
+ assert list(subsets([1, 2], repetition=True)) == \
153
+ [(), (1,), (2,), (1, 1), (1, 2), (2, 2)]
154
+ assert list(subsets([1, 2], repetition=False)) == \
155
+ [(), (1,), (2,), (1, 2)]
156
+ assert list(subsets([1, 2, 3], 2)) == \
157
+ [(1, 2), (1, 3), (2, 3)]
158
+ assert list(subsets([1, 2, 3], 2, repetition=True)) == \
159
+ [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
160
+
161
+
162
+ def test_variations():
163
+ # permutations
164
+ l = list(range(4))
165
+ assert list(variations(l, 0, repetition=False)) == [()]
166
+ assert list(variations(l, 1, repetition=False)) == [(0,), (1,), (2,), (3,)]
167
+ assert list(variations(l, 2, repetition=False)) == [(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)]
168
+ assert list(variations(l, 3, repetition=False)) == [(0, 1, 2), (0, 1, 3), (0, 2, 1), (0, 2, 3), (0, 3, 1), (0, 3, 2), (1, 0, 2), (1, 0, 3), (1, 2, 0), (1, 2, 3), (1, 3, 0), (1, 3, 2), (2, 0, 1), (2, 0, 3), (2, 1, 0), (2, 1, 3), (2, 3, 0), (2, 3, 1), (3, 0, 1), (3, 0, 2), (3, 1, 0), (3, 1, 2), (3, 2, 0), (3, 2, 1)]
169
+ assert list(variations(l, 0, repetition=True)) == [()]
170
+ assert list(variations(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
171
+ assert list(variations(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
172
+ (0, 3), (1, 0), (1, 1),
173
+ (1, 2), (1, 3), (2, 0),
174
+ (2, 1), (2, 2), (2, 3),
175
+ (3, 0), (3, 1), (3, 2),
176
+ (3, 3)]
177
+ assert len(list(variations(l, 3, repetition=True))) == 64
178
+ assert len(list(variations(l, 4, repetition=True))) == 256
179
+ assert list(variations(l[:2], 3, repetition=False)) == []
180
+ assert list(variations(l[:2], 3, repetition=True)) == [
181
+ (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1),
182
+ (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)
183
+ ]
184
+
185
+
186
+ def test_cartes():
187
+ assert list(cartes([1, 2], [3, 4, 5])) == \
188
+ [(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5)]
189
+ assert list(cartes()) == [()]
190
+ assert list(cartes('a')) == [('a',)]
191
+ assert list(cartes('a', repeat=2)) == [('a', 'a')]
192
+ assert list(cartes(list(range(2)))) == [(0,), (1,)]
193
+
194
+
195
+ def test_filter_symbols():
196
+ s = numbered_symbols()
197
+ filtered = filter_symbols(s, symbols("x0 x2 x3"))
198
+ assert take(filtered, 3) == list(symbols("x1 x4 x5"))
199
+
200
+
201
+ def test_numbered_symbols():
202
+ s = numbered_symbols(cls=Dummy)
203
+ assert isinstance(next(s), Dummy)
204
+ assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
205
+ symbols('C2')
206
+
207
+
208
+ def test_sift():
209
+ assert sift(list(range(5)), lambda _: _ % 2) == {1: [1, 3], 0: [0, 2, 4]}
210
+ assert sift([x, y], lambda _: _.has(x)) == {False: [y], True: [x]}
211
+ assert sift([S.One], lambda _: _.has(x)) == {False: [1]}
212
+ assert sift([0, 1, 2, 3], lambda x: x % 2, binary=True) == (
213
+ [1, 3], [0, 2])
214
+ assert sift([0, 1, 2, 3], lambda x: x % 3 == 1, binary=True) == (
215
+ [1], [0, 2, 3])
216
+ raises(ValueError, lambda:
217
+ sift([0, 1, 2, 3], lambda x: x % 3, binary=True))
218
+
219
+
220
+ def test_take():
221
+ X = numbered_symbols()
222
+
223
+ assert take(X, 5) == list(symbols('x0:5'))
224
+ assert take(X, 5) == list(symbols('x5:10'))
225
+
226
+ assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
227
+
228
+
229
+ def test_dict_merge():
230
+ assert dict_merge({}, {1: x, y: z}) == {1: x, y: z}
231
+ assert dict_merge({1: x, y: z}, {}) == {1: x, y: z}
232
+
233
+ assert dict_merge({2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
234
+ assert dict_merge({1: x, y: z}, {2: z}) == {1: x, 2: z, y: z}
235
+
236
+ assert dict_merge({1: y, 2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
237
+ assert dict_merge({1: x, y: z}, {1: y, 2: z}) == {1: y, 2: z, y: z}
238
+
239
+
240
+ def test_prefixes():
241
+ assert list(prefixes([])) == []
242
+ assert list(prefixes([1])) == [[1]]
243
+ assert list(prefixes([1, 2])) == [[1], [1, 2]]
244
+
245
+ assert list(prefixes([1, 2, 3, 4, 5])) == \
246
+ [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]
247
+
248
+
249
+ def test_postfixes():
250
+ assert list(postfixes([])) == []
251
+ assert list(postfixes([1])) == [[1]]
252
+ assert list(postfixes([1, 2])) == [[2], [1, 2]]
253
+
254
+ assert list(postfixes([1, 2, 3, 4, 5])) == \
255
+ [[5], [4, 5], [3, 4, 5], [2, 3, 4, 5], [1, 2, 3, 4, 5]]
256
+
257
+
258
+ def test_topological_sort():
259
+ V = [2, 3, 5, 7, 8, 9, 10, 11]
260
+ E = [(7, 11), (7, 8), (5, 11),
261
+ (3, 8), (3, 10), (11, 2),
262
+ (11, 9), (11, 10), (8, 9)]
263
+
264
+ assert topological_sort((V, E)) == [3, 5, 7, 8, 11, 2, 9, 10]
265
+ assert topological_sort((V, E), key=lambda v: -v) == \
266
+ [7, 5, 11, 3, 10, 8, 9, 2]
267
+
268
+ raises(ValueError, lambda: topological_sort((V, E + [(10, 7)])))
269
+
270
+
271
+ def test_strongly_connected_components():
272
+ assert strongly_connected_components(([], [])) == []
273
+ assert strongly_connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
274
+
275
+ V = [1, 2, 3]
276
+ E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
277
+ assert strongly_connected_components((V, E)) == [[1, 2, 3]]
278
+
279
+ V = [1, 2, 3, 4]
280
+ E = [(1, 2), (2, 3), (3, 2), (3, 4)]
281
+ assert strongly_connected_components((V, E)) == [[4], [2, 3], [1]]
282
+
283
+ V = [1, 2, 3, 4]
284
+ E = [(1, 2), (2, 1), (3, 4), (4, 3)]
285
+ assert strongly_connected_components((V, E)) == [[1, 2], [3, 4]]
286
+
287
+
288
+ def test_connected_components():
289
+ assert connected_components(([], [])) == []
290
+ assert connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
291
+
292
+ V = [1, 2, 3]
293
+ E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
294
+ assert connected_components((V, E)) == [[1, 2, 3]]
295
+
296
+ V = [1, 2, 3, 4]
297
+ E = [(1, 2), (2, 3), (3, 2), (3, 4)]
298
+ assert connected_components((V, E)) == [[1, 2, 3, 4]]
299
+
300
+ V = [1, 2, 3, 4]
301
+ E = [(1, 2), (3, 4)]
302
+ assert connected_components((V, E)) == [[1, 2], [3, 4]]
303
+
304
+
305
+ def test_rotate():
306
+ A = [0, 1, 2, 3, 4]
307
+
308
+ assert rotate_left(A, 2) == [2, 3, 4, 0, 1]
309
+ assert rotate_right(A, 1) == [4, 0, 1, 2, 3]
310
+ A = []
311
+ B = rotate_right(A, 1)
312
+ assert B == []
313
+ B.append(1)
314
+ assert A == []
315
+ B = rotate_left(A, 1)
316
+ assert B == []
317
+ B.append(1)
318
+ assert A == []
319
+
320
+
321
+ def test_multiset_partitions():
322
+ A = [0, 1, 2, 3, 4]
323
+
324
+ assert list(multiset_partitions(A, 5)) == [[[0], [1], [2], [3], [4]]]
325
+ assert len(list(multiset_partitions(A, 4))) == 10
326
+ assert len(list(multiset_partitions(A, 3))) == 25
327
+
328
+ assert list(multiset_partitions([1, 1, 1, 2, 2], 2)) == [
329
+ [[1, 1, 1, 2], [2]], [[1, 1, 1], [2, 2]], [[1, 1, 2, 2], [1]],
330
+ [[1, 1, 2], [1, 2]], [[1, 1], [1, 2, 2]]]
331
+
332
+ assert list(multiset_partitions([1, 1, 2, 2], 2)) == [
333
+ [[1, 1, 2], [2]], [[1, 1], [2, 2]], [[1, 2, 2], [1]],
334
+ [[1, 2], [1, 2]]]
335
+
336
+ assert list(multiset_partitions([1, 2, 3, 4], 2)) == [
337
+ [[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]],
338
+ [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]],
339
+ [[1], [2, 3, 4]]]
340
+
341
+ assert list(multiset_partitions([1, 2, 2], 2)) == [
342
+ [[1, 2], [2]], [[1], [2, 2]]]
343
+
344
+ assert list(multiset_partitions(3)) == [
345
+ [[0, 1, 2]], [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]],
346
+ [[0], [1], [2]]]
347
+ assert list(multiset_partitions(3, 2)) == [
348
+ [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]]]
349
+ assert list(multiset_partitions([1] * 3, 2)) == [[[1], [1, 1]]]
350
+ assert list(multiset_partitions([1] * 3)) == [
351
+ [[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]]
352
+ a = [3, 2, 1]
353
+ assert list(multiset_partitions(a)) == \
354
+ list(multiset_partitions(sorted(a)))
355
+ assert list(multiset_partitions(a, 5)) == []
356
+ assert list(multiset_partitions(a, 1)) == [[[1, 2, 3]]]
357
+ assert list(multiset_partitions(a + [4], 5)) == []
358
+ assert list(multiset_partitions(a + [4], 1)) == [[[1, 2, 3, 4]]]
359
+ assert list(multiset_partitions(2, 5)) == []
360
+ assert list(multiset_partitions(2, 1)) == [[[0, 1]]]
361
+ assert list(multiset_partitions('a')) == [[['a']]]
362
+ assert list(multiset_partitions('a', 2)) == []
363
+ assert list(multiset_partitions('ab')) == [[['a', 'b']], [['a'], ['b']]]
364
+ assert list(multiset_partitions('ab', 1)) == [[['a', 'b']]]
365
+ assert list(multiset_partitions('aaa', 1)) == [['aaa']]
366
+ assert list(multiset_partitions([1, 1], 1)) == [[[1, 1]]]
367
+ ans = [('mpsyy',), ('mpsy', 'y'), ('mps', 'yy'), ('mps', 'y', 'y'),
368
+ ('mpyy', 's'), ('mpy', 'sy'), ('mpy', 's', 'y'), ('mp', 'syy'),
369
+ ('mp', 'sy', 'y'), ('mp', 's', 'yy'), ('mp', 's', 'y', 'y'),
370
+ ('msyy', 'p'), ('msy', 'py'), ('msy', 'p', 'y'), ('ms', 'pyy'),
371
+ ('ms', 'py', 'y'), ('ms', 'p', 'yy'), ('ms', 'p', 'y', 'y'),
372
+ ('myy', 'ps'), ('myy', 'p', 's'), ('my', 'psy'), ('my', 'ps', 'y'),
373
+ ('my', 'py', 's'), ('my', 'p', 'sy'), ('my', 'p', 's', 'y'),
374
+ ('m', 'psyy'), ('m', 'psy', 'y'), ('m', 'ps', 'yy'),
375
+ ('m', 'ps', 'y', 'y'), ('m', 'pyy', 's'), ('m', 'py', 'sy'),
376
+ ('m', 'py', 's', 'y'), ('m', 'p', 'syy'),
377
+ ('m', 'p', 'sy', 'y'), ('m', 'p', 's', 'yy'),
378
+ ('m', 'p', 's', 'y', 'y')]
379
+ assert [tuple("".join(part) for part in p)
380
+ for p in multiset_partitions('sympy')] == ans
381
+ factorings = [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3],
382
+ [6, 2, 2], [2, 2, 2, 3]]
383
+ assert [factoring_visitor(p, [2,3]) for
384
+ p in multiset_partitions_taocp([3, 1])] == factorings
385
+
386
+
387
+ def test_multiset_combinations():
388
+ ans = ['iii', 'iim', 'iip', 'iis', 'imp', 'ims', 'ipp', 'ips',
389
+ 'iss', 'mpp', 'mps', 'mss', 'pps', 'pss', 'sss']
390
+ assert [''.join(i) for i in
391
+ list(multiset_combinations('mississippi', 3))] == ans
392
+ M = multiset('mississippi')
393
+ assert [''.join(i) for i in
394
+ list(multiset_combinations(M, 3))] == ans
395
+ assert [''.join(i) for i in multiset_combinations(M, 30)] == []
396
+ assert list(multiset_combinations([[1], [2, 3]], 2)) == [[[1], [2, 3]]]
397
+ assert len(list(multiset_combinations('a', 3))) == 0
398
+ assert len(list(multiset_combinations('a', 0))) == 1
399
+ assert list(multiset_combinations('abc', 1)) == [['a'], ['b'], ['c']]
400
+ raises(ValueError, lambda: list(multiset_combinations({0: 3, 1: -1}, 2)))
401
+
402
+
403
+ def test_multiset_permutations():
404
+ ans = ['abby', 'abyb', 'aybb', 'baby', 'bayb', 'bbay', 'bbya', 'byab',
405
+ 'byba', 'yabb', 'ybab', 'ybba']
406
+ assert [''.join(i) for i in multiset_permutations('baby')] == ans
407
+ assert [''.join(i) for i in multiset_permutations(multiset('baby'))] == ans
408
+ assert list(multiset_permutations([0, 0, 0], 2)) == [[0, 0]]
409
+ assert list(multiset_permutations([0, 2, 1], 2)) == [
410
+ [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]
411
+ assert len(list(multiset_permutations('a', 0))) == 1
412
+ assert len(list(multiset_permutations('a', 3))) == 0
413
+ for nul in ([], {}, ''):
414
+ assert list(multiset_permutations(nul)) == [[]]
415
+ assert list(multiset_permutations(nul, 0)) == [[]]
416
+ # impossible requests give no result
417
+ assert list(multiset_permutations(nul, 1)) == []
418
+ assert list(multiset_permutations(nul, -1)) == []
419
+
420
+ def test():
421
+ for i in range(1, 7):
422
+ print(i)
423
+ for p in multiset_permutations([0, 0, 1, 0, 1], i):
424
+ print(p)
425
+ assert capture(lambda: test()) == dedent('''\
426
+ 1
427
+ [0]
428
+ [1]
429
+ 2
430
+ [0, 0]
431
+ [0, 1]
432
+ [1, 0]
433
+ [1, 1]
434
+ 3
435
+ [0, 0, 0]
436
+ [0, 0, 1]
437
+ [0, 1, 0]
438
+ [0, 1, 1]
439
+ [1, 0, 0]
440
+ [1, 0, 1]
441
+ [1, 1, 0]
442
+ 4
443
+ [0, 0, 0, 1]
444
+ [0, 0, 1, 0]
445
+ [0, 0, 1, 1]
446
+ [0, 1, 0, 0]
447
+ [0, 1, 0, 1]
448
+ [0, 1, 1, 0]
449
+ [1, 0, 0, 0]
450
+ [1, 0, 0, 1]
451
+ [1, 0, 1, 0]
452
+ [1, 1, 0, 0]
453
+ 5
454
+ [0, 0, 0, 1, 1]
455
+ [0, 0, 1, 0, 1]
456
+ [0, 0, 1, 1, 0]
457
+ [0, 1, 0, 0, 1]
458
+ [0, 1, 0, 1, 0]
459
+ [0, 1, 1, 0, 0]
460
+ [1, 0, 0, 0, 1]
461
+ [1, 0, 0, 1, 0]
462
+ [1, 0, 1, 0, 0]
463
+ [1, 1, 0, 0, 0]
464
+ 6\n''')
465
+ raises(ValueError, lambda: list(multiset_permutations({0: 3, 1: -1})))
466
+
467
+
468
+ def test_partitions():
469
+ ans = [[{}], [(0, {})]]
470
+ for i in range(2):
471
+ assert list(partitions(0, size=i)) == ans[i]
472
+ assert list(partitions(1, 0, size=i)) == ans[i]
473
+ assert list(partitions(6, 2, 2, size=i)) == ans[i]
474
+ assert list(partitions(6, 2, None, size=i)) != ans[i]
475
+ assert list(partitions(6, None, 2, size=i)) != ans[i]
476
+ assert list(partitions(6, 2, 0, size=i)) == ans[i]
477
+
478
+ assert list(partitions(6, k=2)) == [
479
+ {2: 3}, {1: 2, 2: 2}, {1: 4, 2: 1}, {1: 6}]
480
+
481
+ assert list(partitions(6, k=3)) == [
482
+ {3: 2}, {1: 1, 2: 1, 3: 1}, {1: 3, 3: 1}, {2: 3}, {1: 2, 2: 2},
483
+ {1: 4, 2: 1}, {1: 6}]
484
+
485
+ assert list(partitions(8, k=4, m=3)) == [
486
+ {4: 2}, {1: 1, 3: 1, 4: 1}, {2: 2, 4: 1}, {2: 1, 3: 2}] == [
487
+ i for i in partitions(8, k=4, m=3) if all(k <= 4 for k in i)
488
+ and sum(i.values()) <=3]
489
+
490
+ assert list(partitions(S(3), m=2)) == [
491
+ {3: 1}, {1: 1, 2: 1}]
492
+
493
+ assert list(partitions(4, k=3)) == [
494
+ {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] == [
495
+ i for i in partitions(4) if all(k <= 3 for k in i)]
496
+
497
+
498
+ # Consistency check on output of _partitions and RGS_unrank.
499
+ # This provides a sanity test on both routines. Also verifies that
500
+ # the total number of partitions is the same in each case.
501
+ # (from pkrathmann2)
502
+
503
+ for n in range(2, 6):
504
+ i = 0
505
+ for m, q in _set_partitions(n):
506
+ assert q == RGS_unrank(i, n)
507
+ i += 1
508
+ assert i == RGS_enum(n)
509
+
510
+
511
+ def test_binary_partitions():
512
+ assert [i[:] for i in binary_partitions(10)] == [[8, 2], [8, 1, 1],
513
+ [4, 4, 2], [4, 4, 1, 1], [4, 2, 2, 2], [4, 2, 2, 1, 1],
514
+ [4, 2, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2],
515
+ [2, 2, 2, 2, 1, 1], [2, 2, 2, 1, 1, 1, 1], [2, 2, 1, 1, 1, 1, 1, 1],
516
+ [2, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
517
+
518
+ assert len([j[:] for j in binary_partitions(16)]) == 36
519
+
520
+
521
+ def test_bell_perm():
522
+ assert [len(set(generate_bell(i))) for i in range(1, 7)] == [
523
+ factorial(i) for i in range(1, 7)]
524
+ assert list(generate_bell(3)) == [
525
+ (0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)]
526
+ # generate_bell and trotterjohnson are advertised to return the same
527
+ # permutations; this is not technically necessary so this test could
528
+ # be removed
529
+ for n in range(1, 5):
530
+ p = Permutation(range(n))
531
+ b = generate_bell(n)
532
+ for bi in b:
533
+ assert bi == tuple(p.array_form)
534
+ p = p.next_trotterjohnson()
535
+ raises(ValueError, lambda: list(generate_bell(0))) # XXX is this consistent with other permutation algorithms?
536
+
537
+
538
+ def test_involutions():
539
+ lengths = [1, 2, 4, 10, 26, 76]
540
+ for n, N in enumerate(lengths):
541
+ i = list(generate_involutions(n + 1))
542
+ assert len(i) == N
543
+ assert len({Permutation(j)**2 for j in i}) == 1
544
+
545
+
546
+ def test_derangements():
547
+ assert len(list(generate_derangements(list(range(6))))) == 265
548
+ assert ''.join(''.join(i) for i in generate_derangements('abcde')) == (
549
+ 'badecbaecdbcaedbcdeabceadbdaecbdeacbdecabeacdbedacbedcacabedcadebcaebd'
550
+ 'cdaebcdbeacdeabcdebaceabdcebadcedabcedbadabecdaebcdaecbdcaebdcbeadceab'
551
+ 'dcebadeabcdeacbdebacdebcaeabcdeadbceadcbecabdecbadecdabecdbaedabcedacb'
552
+ 'edbacedbca')
553
+ assert list(generate_derangements([0, 1, 2, 3])) == [
554
+ [1, 0, 3, 2], [1, 2, 3, 0], [1, 3, 0, 2], [2, 0, 3, 1],
555
+ [2, 3, 0, 1], [2, 3, 1, 0], [3, 0, 1, 2], [3, 2, 0, 1], [3, 2, 1, 0]]
556
+ assert list(generate_derangements([0, 1, 2, 2])) == [
557
+ [2, 2, 0, 1], [2, 2, 1, 0]]
558
+ assert list(generate_derangements('ba')) == [list('ab')]
559
+ # multiset_derangements
560
+ D = multiset_derangements
561
+ assert list(D('abb')) == []
562
+ assert [''.join(i) for i in D('ab')] == ['ba']
563
+ assert [''.join(i) for i in D('abc')] == ['bca', 'cab']
564
+ assert [''.join(i) for i in D('aabb')] == ['bbaa']
565
+ assert [''.join(i) for i in D('aabbcccc')] == [
566
+ 'ccccaabb', 'ccccabab', 'ccccabba', 'ccccbaab', 'ccccbaba',
567
+ 'ccccbbaa']
568
+ assert [''.join(i) for i in D('aabbccc')] == [
569
+ 'cccabba', 'cccabab', 'cccaabb', 'ccacbba', 'ccacbab',
570
+ 'ccacabb', 'cbccbaa', 'cbccaba', 'cbccaab', 'bcccbaa',
571
+ 'bcccaba', 'bcccaab']
572
+ assert [''.join(i) for i in D('books')] == ['kbsoo', 'ksboo',
573
+ 'sbkoo', 'skboo', 'oksbo', 'oskbo', 'okbso', 'obkso', 'oskob',
574
+ 'oksob', 'osbok', 'obsok']
575
+ assert list(generate_derangements([[3], [2], [2], [1]])) == [
576
+ [[2], [1], [3], [2]], [[2], [3], [1], [2]]]
577
+
578
+
579
+ def test_necklaces():
580
+ def count(n, k, f):
581
+ return len(list(necklaces(n, k, f)))
582
+ m = []
583
+ for i in range(1, 8):
584
+ m.append((
585
+ i, count(i, 2, 0), count(i, 2, 1), count(i, 3, 1)))
586
+ assert Matrix(m) == Matrix([
587
+ [1, 2, 2, 3],
588
+ [2, 3, 3, 6],
589
+ [3, 4, 4, 10],
590
+ [4, 6, 6, 21],
591
+ [5, 8, 8, 39],
592
+ [6, 14, 13, 92],
593
+ [7, 20, 18, 198]])
594
+
595
+
596
+ def test_bracelets():
597
+ bc = list(bracelets(2, 4))
598
+ assert Matrix(bc) == Matrix([
599
+ [0, 0],
600
+ [0, 1],
601
+ [0, 2],
602
+ [0, 3],
603
+ [1, 1],
604
+ [1, 2],
605
+ [1, 3],
606
+ [2, 2],
607
+ [2, 3],
608
+ [3, 3]
609
+ ])
610
+ bc = list(bracelets(4, 2))
611
+ assert Matrix(bc) == Matrix([
612
+ [0, 0, 0, 0],
613
+ [0, 0, 0, 1],
614
+ [0, 0, 1, 1],
615
+ [0, 1, 0, 1],
616
+ [0, 1, 1, 1],
617
+ [1, 1, 1, 1]
618
+ ])
619
+
620
+
621
+ def test_generate_oriented_forest():
622
+ assert list(generate_oriented_forest(5)) == [[0, 1, 2, 3, 4],
623
+ [0, 1, 2, 3, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 0],
624
+ [0, 1, 2, 2, 2], [0, 1, 2, 2, 1], [0, 1, 2, 2, 0], [0, 1, 2, 1, 2],
625
+ [0, 1, 2, 1, 1], [0, 1, 2, 1, 0], [0, 1, 2, 0, 1], [0, 1, 2, 0, 0],
626
+ [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1], [0, 1, 1, 0, 0],
627
+ [0, 1, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]]
628
+ assert len(list(generate_oriented_forest(10))) == 1842
629
+
630
+
631
+ def test_unflatten():
632
+ r = list(range(10))
633
+ assert unflatten(r) == list(zip(r[::2], r[1::2]))
634
+ assert unflatten(r, 5) == [tuple(r[:5]), tuple(r[5:])]
635
+ raises(ValueError, lambda: unflatten(list(range(10)), 3))
636
+ raises(ValueError, lambda: unflatten(list(range(10)), -2))
637
+
638
+
639
+ def test_common_prefix_suffix():
640
+ assert common_prefix([], [1]) == []
641
+ assert common_prefix(list(range(3))) == [0, 1, 2]
642
+ assert common_prefix(list(range(3)), list(range(4))) == [0, 1, 2]
643
+ assert common_prefix([1, 2, 3], [1, 2, 5]) == [1, 2]
644
+ assert common_prefix([1, 2, 3], [1, 3, 5]) == [1]
645
+
646
+ assert common_suffix([], [1]) == []
647
+ assert common_suffix(list(range(3))) == [0, 1, 2]
648
+ assert common_suffix(list(range(3)), list(range(3))) == [0, 1, 2]
649
+ assert common_suffix(list(range(3)), list(range(4))) == []
650
+ assert common_suffix([1, 2, 3], [9, 2, 3]) == [2, 3]
651
+ assert common_suffix([1, 2, 3], [9, 7, 3]) == [3]
652
+
653
+
654
+ def test_minlex():
655
+ assert minlex([1, 2, 0]) == (0, 1, 2)
656
+ assert minlex((1, 2, 0)) == (0, 1, 2)
657
+ assert minlex((1, 0, 2)) == (0, 2, 1)
658
+ assert minlex((1, 0, 2), directed=False) == (0, 1, 2)
659
+ assert minlex('aba') == 'aab'
660
+ assert minlex(('bb', 'aaa', 'c', 'a'), key=len) == ('c', 'a', 'bb', 'aaa')
661
+
662
+
663
+ def test_ordered():
664
+ assert list(ordered((x, y), hash, default=False)) in [[x, y], [y, x]]
665
+ assert list(ordered((x, y), hash, default=False)) == \
666
+ list(ordered((y, x), hash, default=False))
667
+ assert list(ordered((x, y))) == [x, y]
668
+
669
+ seq, keys = [[[1, 2, 1], [0, 3, 1], [1, 1, 3], [2], [1]],
670
+ (lambda x: len(x), lambda x: sum(x))]
671
+ assert list(ordered(seq, keys, default=False, warn=False)) == \
672
+ [[1], [2], [1, 2, 1], [0, 3, 1], [1, 1, 3]]
673
+ raises(ValueError, lambda:
674
+ list(ordered(seq, keys, default=False, warn=True)))
675
+
676
+
677
+ def test_runs():
678
+ assert runs([]) == []
679
+ assert runs([1]) == [[1]]
680
+ assert runs([1, 1]) == [[1], [1]]
681
+ assert runs([1, 1, 2]) == [[1], [1, 2]]
682
+ assert runs([1, 2, 1]) == [[1, 2], [1]]
683
+ assert runs([2, 1, 1]) == [[2], [1], [1]]
684
+ from operator import lt
685
+ assert runs([2, 1, 1], lt) == [[2, 1], [1]]
686
+
687
+
688
+ def test_reshape():
689
+ seq = list(range(1, 9))
690
+ assert reshape(seq, [4]) == \
691
+ [[1, 2, 3, 4], [5, 6, 7, 8]]
692
+ assert reshape(seq, (4,)) == \
693
+ [(1, 2, 3, 4), (5, 6, 7, 8)]
694
+ assert reshape(seq, (2, 2)) == \
695
+ [(1, 2, 3, 4), (5, 6, 7, 8)]
696
+ assert reshape(seq, (2, [2])) == \
697
+ [(1, 2, [3, 4]), (5, 6, [7, 8])]
698
+ assert reshape(seq, ((2,), [2])) == \
699
+ [((1, 2), [3, 4]), ((5, 6), [7, 8])]
700
+ assert reshape(seq, (1, [2], 1)) == \
701
+ [(1, [2, 3], 4), (5, [6, 7], 8)]
702
+ assert reshape(tuple(seq), ([[1], 1, (2,)],)) == \
703
+ (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],))
704
+ assert reshape(tuple(seq), ([1], 1, (2,))) == \
705
+ (([1], 2, (3, 4)), ([5], 6, (7, 8)))
706
+ assert reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) == \
707
+ [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]]
708
+ raises(ValueError, lambda: reshape([0, 1], [-1]))
709
+ raises(ValueError, lambda: reshape([0, 1], [3]))
710
+
711
+
712
+ def test_uniq():
713
+ assert list(uniq(p for p in partitions(4))) == \
714
+ [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}]
715
+ assert list(uniq(x % 2 for x in range(5))) == [0, 1]
716
+ assert list(uniq('a')) == ['a']
717
+ assert list(uniq('ababc')) == list('abc')
718
+ assert list(uniq([[1], [2, 1], [1]])) == [[1], [2, 1]]
719
+ assert list(uniq(permutations(i for i in [[1], 2, 2]))) == \
720
+ [([1], 2, 2), (2, [1], 2), (2, 2, [1])]
721
+ assert list(uniq([2, 3, 2, 4, [2], [1], [2], [3], [1]])) == \
722
+ [2, 3, 4, [2], [1], [3]]
723
+ f = [1]
724
+ raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
725
+ f = [[1]]
726
+ raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
727
+
728
+
729
+ def test_kbins():
730
+ assert len(list(kbins('1123', 2, ordered=1))) == 24
731
+ assert len(list(kbins('1123', 2, ordered=11))) == 36
732
+ assert len(list(kbins('1123', 2, ordered=10))) == 10
733
+ assert len(list(kbins('1123', 2, ordered=0))) == 5
734
+ assert len(list(kbins('1123', 2, ordered=None))) == 3
735
+
736
+ def test1():
737
+ for orderedval in [None, 0, 1, 10, 11]:
738
+ print('ordered =', orderedval)
739
+ for p in kbins([0, 0, 1], 2, ordered=orderedval):
740
+ print(' ', p)
741
+ assert capture(lambda : test1()) == dedent('''\
742
+ ordered = None
743
+ [[0], [0, 1]]
744
+ [[0, 0], [1]]
745
+ ordered = 0
746
+ [[0, 0], [1]]
747
+ [[0, 1], [0]]
748
+ ordered = 1
749
+ [[0], [0, 1]]
750
+ [[0], [1, 0]]
751
+ [[1], [0, 0]]
752
+ ordered = 10
753
+ [[0, 0], [1]]
754
+ [[1], [0, 0]]
755
+ [[0, 1], [0]]
756
+ [[0], [0, 1]]
757
+ ordered = 11
758
+ [[0], [0, 1]]
759
+ [[0, 0], [1]]
760
+ [[0], [1, 0]]
761
+ [[0, 1], [0]]
762
+ [[1], [0, 0]]
763
+ [[1, 0], [0]]\n''')
764
+
765
+ def test2():
766
+ for orderedval in [None, 0, 1, 10, 11]:
767
+ print('ordered =', orderedval)
768
+ for p in kbins(list(range(3)), 2, ordered=orderedval):
769
+ print(' ', p)
770
+ assert capture(lambda : test2()) == dedent('''\
771
+ ordered = None
772
+ [[0], [1, 2]]
773
+ [[0, 1], [2]]
774
+ ordered = 0
775
+ [[0, 1], [2]]
776
+ [[0, 2], [1]]
777
+ [[0], [1, 2]]
778
+ ordered = 1
779
+ [[0], [1, 2]]
780
+ [[0], [2, 1]]
781
+ [[1], [0, 2]]
782
+ [[1], [2, 0]]
783
+ [[2], [0, 1]]
784
+ [[2], [1, 0]]
785
+ ordered = 10
786
+ [[0, 1], [2]]
787
+ [[2], [0, 1]]
788
+ [[0, 2], [1]]
789
+ [[1], [0, 2]]
790
+ [[0], [1, 2]]
791
+ [[1, 2], [0]]
792
+ ordered = 11
793
+ [[0], [1, 2]]
794
+ [[0, 1], [2]]
795
+ [[0], [2, 1]]
796
+ [[0, 2], [1]]
797
+ [[1], [0, 2]]
798
+ [[1, 0], [2]]
799
+ [[1], [2, 0]]
800
+ [[1, 2], [0]]
801
+ [[2], [0, 1]]
802
+ [[2, 0], [1]]
803
+ [[2], [1, 0]]
804
+ [[2, 1], [0]]\n''')
805
+
806
+
807
+ def test_has_dups():
808
+ assert has_dups(set()) is False
809
+ assert has_dups(list(range(3))) is False
810
+ assert has_dups([1, 2, 1]) is True
811
+ assert has_dups([[1], [1]]) is True
812
+ assert has_dups([[1], [2]]) is False
813
+
814
+
815
+ def test__partition():
816
+ assert _partition('abcde', [1, 0, 1, 2, 0]) == [
817
+ ['b', 'e'], ['a', 'c'], ['d']]
818
+ assert _partition('abcde', [1, 0, 1, 2, 0], 3) == [
819
+ ['b', 'e'], ['a', 'c'], ['d']]
820
+ output = (3, [1, 0, 1, 2, 0])
821
+ assert _partition('abcde', *output) == [['b', 'e'], ['a', 'c'], ['d']]
822
+
823
+
824
+ def test_ordered_partitions():
825
+ from sympy.functions.combinatorial.numbers import nT
826
+ f = ordered_partitions
827
+ assert list(f(0, 1)) == [[]]
828
+ assert list(f(1, 0)) == [[]]
829
+ for i in range(1, 7):
830
+ for j in [None] + list(range(1, i)):
831
+ assert (
832
+ sum(1 for p in f(i, j, 1)) ==
833
+ sum(1 for p in f(i, j, 0)) ==
834
+ nT(i, j))
835
+
836
+
837
+ def test_rotations():
838
+ assert list(rotations('ab')) == [['a', 'b'], ['b', 'a']]
839
+ assert list(rotations(range(3))) == [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
840
+ assert list(rotations(range(3), dir=-1)) == [[0, 1, 2], [2, 0, 1], [1, 2, 0]]
841
+
842
+
843
+ def test_ibin():
844
+ assert ibin(3) == [1, 1]
845
+ assert ibin(3, 3) == [0, 1, 1]
846
+ assert ibin(3, str=True) == '11'
847
+ assert ibin(3, 3, str=True) == '011'
848
+ assert list(ibin(2, 'all')) == [(0, 0), (0, 1), (1, 0), (1, 1)]
849
+ assert list(ibin(2, '', str=True)) == ['00', '01', '10', '11']
850
+ raises(ValueError, lambda: ibin(-.5))
851
+ raises(ValueError, lambda: ibin(2, 1))
852
+
853
+
854
+ def test_iterable():
855
+ assert iterable(0) is False
856
+ assert iterable(1) is False
857
+ assert iterable(None) is False
858
+
859
+ class Test1(NotIterable):
860
+ pass
861
+
862
+ assert iterable(Test1()) is False
863
+
864
+ class Test2(NotIterable):
865
+ _iterable = True
866
+
867
+ assert iterable(Test2()) is True
868
+
869
+ class Test3:
870
+ pass
871
+
872
+ assert iterable(Test3()) is False
873
+
874
+ class Test4:
875
+ _iterable = True
876
+
877
+ assert iterable(Test4()) is True
878
+
879
+ class Test5:
880
+ def __iter__(self):
881
+ yield 1
882
+
883
+ assert iterable(Test5()) is True
884
+
885
+ class Test6(Test5):
886
+ _iterable = False
887
+
888
+ assert iterable(Test6()) is False
889
+
890
+
891
+ def test_sequence_partitions():
892
+ assert list(sequence_partitions([1], 1)) == [[[1]]]
893
+ assert list(sequence_partitions([1, 2], 1)) == [[[1, 2]]]
894
+ assert list(sequence_partitions([1, 2], 2)) == [[[1], [2]]]
895
+ assert list(sequence_partitions([1, 2, 3], 1)) == [[[1, 2, 3]]]
896
+ assert list(sequence_partitions([1, 2, 3], 2)) == \
897
+ [[[1], [2, 3]], [[1, 2], [3]]]
898
+ assert list(sequence_partitions([1, 2, 3], 3)) == [[[1], [2], [3]]]
899
+
900
+ # Exceptional cases
901
+ assert list(sequence_partitions([], 0)) == []
902
+ assert list(sequence_partitions([], 1)) == []
903
+ assert list(sequence_partitions([1, 2], 0)) == []
904
+ assert list(sequence_partitions([1, 2], 3)) == []
905
+
906
+
907
+ def test_sequence_partitions_empty():
908
+ assert list(sequence_partitions_empty([], 1)) == [[[]]]
909
+ assert list(sequence_partitions_empty([], 2)) == [[[], []]]
910
+ assert list(sequence_partitions_empty([], 3)) == [[[], [], []]]
911
+ assert list(sequence_partitions_empty([1], 1)) == [[[1]]]
912
+ assert list(sequence_partitions_empty([1], 2)) == [[[], [1]], [[1], []]]
913
+ assert list(sequence_partitions_empty([1], 3)) == \
914
+ [[[], [], [1]], [[], [1], []], [[1], [], []]]
915
+ assert list(sequence_partitions_empty([1, 2], 1)) == [[[1, 2]]]
916
+ assert list(sequence_partitions_empty([1, 2], 2)) == \
917
+ [[[], [1, 2]], [[1], [2]], [[1, 2], []]]
918
+ assert list(sequence_partitions_empty([1, 2], 3)) == [
919
+ [[], [], [1, 2]], [[], [1], [2]], [[], [1, 2], []],
920
+ [[1], [], [2]], [[1], [2], []], [[1, 2], [], []]
921
+ ]
922
+ assert list(sequence_partitions_empty([1, 2, 3], 1)) == [[[1, 2, 3]]]
923
+ assert list(sequence_partitions_empty([1, 2, 3], 2)) == \
924
+ [[[], [1, 2, 3]], [[1], [2, 3]], [[1, 2], [3]], [[1, 2, 3], []]]
925
+ assert list(sequence_partitions_empty([1, 2, 3], 3)) == [
926
+ [[], [], [1, 2, 3]], [[], [1], [2, 3]],
927
+ [[], [1, 2], [3]], [[], [1, 2, 3], []],
928
+ [[1], [], [2, 3]], [[1], [2], [3]],
929
+ [[1], [2, 3], []], [[1, 2], [], [3]],
930
+ [[1, 2], [3], []], [[1, 2, 3], [], []]
931
+ ]
932
+
933
+ # Exceptional cases
934
+ assert list(sequence_partitions([], 0)) == []
935
+ assert list(sequence_partitions([1], 0)) == []
936
+ assert list(sequence_partitions([1, 2], 0)) == []
937
+
938
+
939
+ def test_signed_permutations():
940
+ ans = [(0, 1, 1), (0, -1, 1), (0, 1, -1), (0, -1, -1),
941
+ (1, 0, 1), (-1, 0, 1), (1, 0, -1), (-1, 0, -1),
942
+ (1, 1, 0), (-1, 1, 0), (1, -1, 0), (-1, -1, 0)]
943
+ assert list(signed_permutations((0, 1, 1))) == ans
944
+ assert list(signed_permutations((1, 0, 1))) == ans
945
+ assert list(signed_permutations((1, 1, 0))) == ans
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_lambdify.py ADDED
@@ -0,0 +1,2263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+ import math
3
+ import inspect
4
+ import linecache
5
+ import gc
6
+
7
+ import mpmath
8
+ import cmath
9
+
10
+ from sympy.testing.pytest import raises, warns_deprecated_sympy
11
+ from sympy.concrete.summations import Sum
12
+ from sympy.core.function import (Function, Lambda, diff)
13
+ from sympy.core.numbers import (E, Float, I, Rational, all_close, oo, pi)
14
+ from sympy.core.relational import Eq
15
+ from sympy.core.singleton import S
16
+ from sympy.core.symbol import (Dummy, symbols)
17
+ from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
18
+ from sympy.functions.combinatorial.numbers import bernoulli, harmonic
19
+ from sympy.functions.elementary.complexes import Abs, sign
20
+ from sympy.functions.elementary.exponential import exp, log
21
+ from sympy.functions.elementary.hyperbolic import asinh,acosh,atanh
22
+ from sympy.functions.elementary.integers import floor
23
+ from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt)
24
+ from sympy.functions.elementary.piecewise import Piecewise
25
+ from sympy.functions.elementary.trigonometric import (asin, acos, atan, cos, cot, sin,
26
+ sinc, tan)
27
+ from sympy.functions import sinh,cosh,tanh
28
+ from sympy.functions.special.bessel import (besseli, besselj, besselk, bessely, jn, yn)
29
+ from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized)
30
+ from sympy.functions.special.delta_functions import (Heaviside)
31
+ from sympy.functions.special.error_functions import (Ei, erf, erfc, fresnelc, fresnels, Si, Ci)
32
+ from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, polygamma)
33
+ from sympy.functions.special.zeta_functions import zeta
34
+ from sympy.integrals.integrals import Integral
35
+ from sympy.logic.boolalg import (And, false, ITE, Not, Or, true)
36
+ from sympy.matrices.expressions.dotproduct import DotProduct
37
+ from sympy.simplify.cse_main import cse
38
+ from sympy.tensor.array import derive_by_array, Array
39
+ from sympy.tensor.array.expressions import ArraySymbol
40
+ from sympy.tensor.indexed import IndexedBase, Idx
41
+ from sympy.utilities.lambdify import lambdify
42
+ from sympy.utilities.iterables import numbered_symbols
43
+ from sympy.vector import CoordSys3D
44
+ from sympy.core.expr import UnevaluatedExpr
45
+ from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, log10, hypot, isnan, isinf
46
+ from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, amin, amax, minimum, maximum
47
+ from sympy.codegen.scipy_nodes import cosm1, powm1
48
+ from sympy.functions.elementary.complexes import re, im, arg
49
+ from sympy.functions.special.polynomials import \
50
+ chebyshevt, chebyshevu, legendre, hermite, laguerre, gegenbauer, \
51
+ assoc_legendre, assoc_laguerre, jacobi
52
+ from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
53
+ from sympy.printing.codeprinter import PrintMethodNotImplementedError
54
+ from sympy.printing.lambdarepr import LambdaPrinter
55
+ from sympy.printing.numpy import NumPyPrinter
56
+ from sympy.utilities.lambdify import implemented_function, lambdastr
57
+ from sympy.testing.pytest import skip
58
+ from sympy.utilities.decorator import conserve_mpmath_dps
59
+ from sympy.utilities.exceptions import ignore_warnings
60
+ from sympy.external import import_module
61
+ from sympy.functions.special.gamma_functions import uppergamma, lowergamma
62
+
63
+
64
+ import sympy
65
+
66
+
67
+ MutableDenseMatrix = Matrix
68
+
69
+ numpy = import_module('numpy')
70
+ scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']})
71
+ numexpr = import_module('numexpr')
72
+ tensorflow = import_module('tensorflow')
73
+ cupy = import_module('cupy')
74
+ jax = import_module('jax')
75
+ numba = import_module('numba')
76
+
77
+ if tensorflow:
78
+ # Hide Tensorflow warnings
79
+ import os
80
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
81
+
82
+ w, x, y, z = symbols('w,x,y,z')
83
+
84
+ #================== Test different arguments =======================
85
+
86
+
87
+ def test_no_args():
88
+ f = lambdify([], 1)
89
+ raises(TypeError, lambda: f(-1))
90
+ assert f() == 1
91
+
92
+
93
+ def test_single_arg():
94
+ f = lambdify(x, 2*x)
95
+ assert f(1) == 2
96
+
97
+
98
+ def test_list_args():
99
+ f = lambdify([x, y], x + y)
100
+ assert f(1, 2) == 3
101
+
102
+
103
+ def test_nested_args():
104
+ f1 = lambdify([[w, x]], [w, x])
105
+ assert f1([91, 2]) == [91, 2]
106
+ raises(TypeError, lambda: f1(1, 2))
107
+
108
+ f2 = lambdify([(w, x), (y, z)], [w, x, y, z])
109
+ assert f2((18, 12), (73, 4)) == [18, 12, 73, 4]
110
+ raises(TypeError, lambda: f2(3, 4))
111
+
112
+ f3 = lambdify([w, [[[x]], y], z], [w, x, y, z])
113
+ assert f3(10, [[[52]], 31], 44) == [10, 52, 31, 44]
114
+
115
+
116
+ def test_str_args():
117
+ f = lambdify('x,y,z', 'z,y,x')
118
+ assert f(3, 2, 1) == (1, 2, 3)
119
+ assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
120
+ # make sure correct number of args required
121
+ raises(TypeError, lambda: f(0))
122
+
123
+
124
+ def test_own_namespace_1():
125
+ myfunc = lambda x: 1
126
+ f = lambdify(x, sin(x), {"sin": myfunc})
127
+ assert f(0.1) == 1
128
+ assert f(100) == 1
129
+
130
+
131
+ def test_own_namespace_2():
132
+ def myfunc(x):
133
+ return 1
134
+ f = lambdify(x, sin(x), {'sin': myfunc})
135
+ assert f(0.1) == 1
136
+ assert f(100) == 1
137
+
138
+
139
+ def test_own_module():
140
+ f = lambdify(x, sin(x), math)
141
+ assert f(0) == 0.0
142
+
143
+ p, q, r = symbols("p q r", real=True)
144
+ ae = abs(exp(p+UnevaluatedExpr(q+r)))
145
+ f = lambdify([p, q, r], [ae, ae], modules=math)
146
+ results = f(1.0, 1e18, -1e18)
147
+ refvals = [math.exp(1.0)]*2
148
+ for res, ref in zip(results, refvals):
149
+ assert abs((res-ref)/ref) < 1e-15
150
+
151
+
152
+ def test_bad_args():
153
+ # no vargs given
154
+ raises(TypeError, lambda: lambdify(1))
155
+ # same with vector exprs
156
+ raises(TypeError, lambda: lambdify([1, 2]))
157
+
158
+
159
+ def test_atoms():
160
+ # Non-Symbol atoms should not be pulled out from the expression namespace
161
+ f = lambdify(x, pi + x, {"pi": 3.14})
162
+ assert f(0) == 3.14
163
+ f = lambdify(x, I + x, {"I": 1j})
164
+ assert f(1) == 1 + 1j
165
+
166
+ #================== Test different modules =========================
167
+
168
+ # high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted
169
+
170
+
171
+ @conserve_mpmath_dps
172
+ def test_sympy_lambda():
173
+ mpmath.mp.dps = 50
174
+ sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
175
+ f = lambdify(x, sin(x), "sympy")
176
+ assert f(x) == sin(x)
177
+ prec = 1e-15
178
+ assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec
179
+ # arctan is in numpy module and should not be available
180
+ # The arctan below gives NameError. What is this supposed to test?
181
+ # raises(NameError, lambda: lambdify(x, arctan(x), "sympy"))
182
+
183
+
184
+ @conserve_mpmath_dps
185
+ def test_math_lambda():
186
+ mpmath.mp.dps = 50
187
+ sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
188
+ f = lambdify(x, sin(x), "math")
189
+ prec = 1e-15
190
+ assert -prec < f(0.2) - sin02 < prec
191
+ raises(TypeError, lambda: f(x))
192
+ # if this succeeds, it can't be a Python math function
193
+
194
+
195
+ @conserve_mpmath_dps
196
+ def test_mpmath_lambda():
197
+ mpmath.mp.dps = 50
198
+ sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
199
+ f = lambdify(x, sin(x), "mpmath")
200
+ prec = 1e-49 # mpmath precision is around 50 decimal places
201
+ assert -prec < f(mpmath.mpf("0.2")) - sin02 < prec
202
+ raises(TypeError, lambda: f(x))
203
+ # if this succeeds, it can't be a mpmath function
204
+
205
+ ref2 = (mpmath.mpf("1e-30")
206
+ - mpmath.mpf("1e-45")/2
207
+ + 5*mpmath.mpf("1e-60")/6
208
+ - 3*mpmath.mpf("1e-75")/4
209
+ + 33*mpmath.mpf("1e-90")/40
210
+ )
211
+ f2a = lambdify((x, y), x**y - 1, "mpmath")
212
+ f2b = lambdify((x, y), powm1(x, y), "mpmath")
213
+ f2c = lambdify((x,), expm1(x*log1p(x)), "mpmath")
214
+ ans2a = f2a(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
215
+ ans2b = f2b(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
216
+ ans2c = f2c(mpmath.mpf("1e-15"))
217
+ assert abs(ans2a - ref2) < 1e-51
218
+ assert abs(ans2b - ref2) < 1e-67
219
+ assert abs(ans2c - ref2) < 1e-80
220
+
221
+
222
+ @conserve_mpmath_dps
223
+ def test_number_precision():
224
+ mpmath.mp.dps = 50
225
+ sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
226
+ f = lambdify(x, sin02, "mpmath")
227
+ prec = 1e-49 # mpmath precision is around 50 decimal places
228
+ assert -prec < f(0) - sin02 < prec
229
+
230
+ @conserve_mpmath_dps
231
+ def test_mpmath_precision():
232
+ mpmath.mp.dps = 100
233
+ assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100))
234
+
235
+ #================== Test Translations ==============================
236
+ # We can only check if all translated functions are valid. It has to be checked
237
+ # by hand if they are complete.
238
+
239
+
240
+ def test_math_transl():
241
+ from sympy.utilities.lambdify import MATH_TRANSLATIONS
242
+ for sym, mat in MATH_TRANSLATIONS.items():
243
+ assert sym in sympy.__dict__
244
+ assert mat in math.__dict__
245
+
246
+
247
+ def test_mpmath_transl():
248
+ from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
249
+ for sym, mat in MPMATH_TRANSLATIONS.items():
250
+ assert sym in sympy.__dict__ or sym == 'Matrix'
251
+ assert mat in mpmath.__dict__
252
+
253
+
254
+ def test_numpy_transl():
255
+ if not numpy:
256
+ skip("numpy not installed.")
257
+
258
+ from sympy.utilities.lambdify import NUMPY_TRANSLATIONS
259
+ for sym, nump in NUMPY_TRANSLATIONS.items():
260
+ assert sym in sympy.__dict__
261
+ assert nump in numpy.__dict__
262
+
263
+
264
+ def test_scipy_transl():
265
+ if not scipy:
266
+ skip("scipy not installed.")
267
+
268
+ from sympy.utilities.lambdify import SCIPY_TRANSLATIONS
269
+ for sym, scip in SCIPY_TRANSLATIONS.items():
270
+ assert sym in sympy.__dict__
271
+ assert scip in scipy.__dict__ or scip in scipy.special.__dict__
272
+
273
+
274
+ def test_numpy_translation_abs():
275
+ if not numpy:
276
+ skip("numpy not installed.")
277
+
278
+ f = lambdify(x, Abs(x), "numpy")
279
+ assert f(-1) == 1
280
+ assert f(1) == 1
281
+
282
+
283
+ def test_numexpr_printer():
284
+ if not numexpr:
285
+ skip("numexpr not installed.")
286
+
287
+ # if translation/printing is done incorrectly then evaluating
288
+ # a lambdified numexpr expression will throw an exception
289
+ from sympy.printing.lambdarepr import NumExprPrinter
290
+
291
+ blacklist = ('where', 'complex', 'contains')
292
+ arg_tuple = (x, y, z) # some functions take more than one argument
293
+ for sym in NumExprPrinter._numexpr_functions.keys():
294
+ if sym in blacklist:
295
+ continue
296
+ ssym = S(sym)
297
+ if hasattr(ssym, '_nargs'):
298
+ nargs = ssym._nargs[0]
299
+ else:
300
+ nargs = 1
301
+ args = arg_tuple[:nargs]
302
+ f = lambdify(args, ssym(*args), modules='numexpr')
303
+ assert f(*(1, )*nargs) is not None
304
+
305
+
306
+ def test_cmath_sqrt():
307
+ f = lambdify(x, sqrt(x), "cmath")
308
+ assert f(0) == 0
309
+ assert f(1) == 1
310
+ assert f(4) == 2
311
+ assert abs(f(2) - 1.414) < 0.001
312
+ assert f(-1) == 1j
313
+ assert f(-4) == 2j
314
+
315
+
316
+ def test_cmath_log():
317
+ f = lambdify(x, log(x), "cmath")
318
+ assert abs(f(1) - 0) < 1e-15
319
+ assert abs(f(cmath.e) - 1) < 1e-15
320
+ assert abs(f(-1) - cmath.log(-1)) < 1e-15
321
+
322
+
323
+ def test_cmath_sinh():
324
+ f = lambdify(x, sinh(x), "cmath")
325
+ assert abs(f(0) - cmath.sinh(0)) < 1e-15
326
+ assert abs(f(pi) - cmath.sinh(pi)) < 1e-15
327
+ assert abs(f(-pi) - cmath.sinh(-pi)) < 1e-15
328
+ assert abs(f(1j) - cmath.sinh(1j)) < 1e-15
329
+
330
+
331
+ def test_cmath_cosh():
332
+ f = lambdify(x, cosh(x), "cmath")
333
+ assert abs(f(0) - cmath.cosh(0)) < 1e-15
334
+ assert abs(f(pi) - cmath.cosh(pi)) < 1e-15
335
+ assert abs(f(-pi) - cmath.cosh(-pi)) < 1e-15
336
+ assert abs(f(1j) - cmath.cosh(1j)) < 1e-15
337
+
338
+
339
+ def test_cmath_tanh():
340
+ f = lambdify(x, tanh(x), "cmath")
341
+ assert abs(f(0) - cmath.tanh(0)) < 1e-15
342
+ assert abs(f(pi) - cmath.tanh(pi)) < 1e-15
343
+ assert abs(f(-pi) - cmath.tanh(-pi)) < 1e-15
344
+ assert abs(f(1j) - cmath.tanh(1j)) < 1e-15
345
+
346
+
347
+ def test_cmath_sin():
348
+ f = lambdify(x, sin(x), "cmath")
349
+ assert abs(f(0) - cmath.sin(0)) < 1e-15
350
+ assert abs(f(pi) - cmath.sin(pi)) < 1e-15
351
+ assert abs(f(-pi) - cmath.sin(-pi)) < 1e-15
352
+ assert abs(f(1j) - cmath.sin(1j)) < 1e-15
353
+
354
+
355
+ def test_cmath_cos():
356
+ f = lambdify(x, cos(x), "cmath")
357
+ assert abs(f(0) - cmath.cos(0)) < 1e-15
358
+ assert abs(f(pi) - cmath.cos(pi)) < 1e-15
359
+ assert abs(f(-pi) - cmath.cos(-pi)) < 1e-15
360
+ assert abs(f(1j) - cmath.cos(1j)) < 1e-15
361
+
362
+
363
+ def test_cmath_tan():
364
+ f = lambdify(x, tan(x), "cmath")
365
+ assert abs(f(0) - cmath.tan(0)) < 1e-15
366
+ assert abs(f(1j) - cmath.tan(1j)) < 1e-15
367
+
368
+
369
+ def test_cmath_asin():
370
+ f = lambdify(x, asin(x), "cmath")
371
+ assert abs(f(0) - cmath.asin(0)) < 1e-15
372
+ assert abs(f(1) - cmath.asin(1)) < 1e-15
373
+ assert abs(f(-1) - cmath.asin(-1)) < 1e-15
374
+ assert abs(f(2) - cmath.asin(2)) < 1e-15
375
+ assert abs(f(1j) - cmath.asin(1j)) < 1e-15
376
+
377
+
378
+ def test_cmath_acos():
379
+ f = lambdify(x, acos(x), "cmath")
380
+ assert abs(f(1) - cmath.acos(1)) < 1e-15
381
+ assert abs(f(-1) - cmath.acos(-1)) < 1e-15
382
+ assert abs(f(2) - cmath.acos(2)) < 1e-15
383
+ assert abs(f(1j) - cmath.acos(1j)) < 1e-15
384
+
385
+
386
+ def test_cmath_atan():
387
+ f = lambdify(x, atan(x), "cmath")
388
+ assert abs(f(0) - cmath.atan(0)) < 1e-15
389
+ assert abs(f(1) - cmath.atan(1)) < 1e-15
390
+ assert abs(f(-1) - cmath.atan(-1)) < 1e-15
391
+ assert abs(f(2) - cmath.atan(2)) < 1e-15
392
+ assert abs(f(2j) - cmath.atan(2j)) < 1e-15
393
+
394
+
395
+ def test_cmath_asinh():
396
+ f = lambdify(x, asinh(x), "cmath")
397
+ assert abs(f(0) - cmath.asinh(0)) < 1e-15
398
+ assert abs(f(1) - cmath.asinh(1)) < 1e-15
399
+ assert abs(f(-1) - cmath.asinh(-1)) < 1e-15
400
+ assert abs(f(2) - cmath.asinh(2)) < 1e-15
401
+ assert abs(f(2j) - cmath.asinh(2j)) < 1e-15
402
+
403
+
404
+ def test_cmath_acosh():
405
+ f = lambdify(x, acosh(x), "cmath")
406
+ assert abs(f(1) - cmath.acosh(1)) < 1e-15
407
+ assert abs(f(2) - cmath.acosh(2)) < 1e-15
408
+ assert abs(f(-1) - cmath.acosh(-1)) < 1e-15
409
+ assert abs(f(2j) - cmath.acosh(2j)) < 1e-15
410
+
411
+
412
+ def test_cmath_atanh():
413
+ f = lambdify(x, atanh(x), "cmath")
414
+ assert abs(f(0) - cmath.atanh(0)) < 1e-15
415
+ assert abs(f(0.5) - cmath.atanh(0.5)) < 1e-15
416
+ assert abs(f(-0.5) - cmath.atanh(-0.5)) < 1e-15
417
+ assert abs(f(2) - cmath.atanh(2)) < 1e-15
418
+ assert abs(f(-2) - cmath.atanh(-2)) < 1e-15
419
+ assert abs(f(2j) - cmath.atanh(2j)) < 1e-15
420
+
421
+
422
+ def test_cmath_complex_identities():
423
+ # Define symbol
424
+ z = symbols('z')
425
+
426
+ # Trigonometric identity using re(z) and im(z)
427
+ expr = cos(z) - cos(re(z)) * cosh(im(z)) + I * sin(re(z)) * sinh(im(z))
428
+ func = lambdify([z], expr, modules=["cmath", "math"])
429
+ hpi = math.pi / 2
430
+ assert abs(func(hpi + 1j * hpi)) < 4e-16
431
+
432
+ # Euler's Formula: e^(i*z) = cos(z) + i*sin(z)
433
+ func = lambdify([z], exp(I * z) - (cos(z) + I * sin(z)), modules=["cmath", "math"])
434
+ assert abs(func(hpi)) < 4e-16
435
+
436
+ # Exponential Identity: e^z = e^(Re(z)) * (cos(Im(z)) + i*sin(Im(z)))
437
+ func_exp = lambdify([z], exp(z) - exp(re(z)) * (cos(im(z)) + I * sin(im(z))),
438
+ modules=["cmath", "math"])
439
+ assert abs(func_exp(hpi + 1j * hpi)) < 4e-16
440
+
441
+ # Complex Cosine Identity: cos(z) = cos(Re(z)) * cosh(Im(z)) - i*sin(Re(z)) * sinh(Im(z))
442
+ func_cos = lambdify([z], cos(z) - (cos(re(z)) * cosh(im(z)) - I * sin(re(z)) * sinh(im(z))),
443
+ modules=["cmath", "math"])
444
+ assert abs(func_cos(hpi + 1j * hpi)) < 4e-16
445
+
446
+ # Complex Sine Identity: sin(z) = sin(Re(z)) * cosh(Im(z)) + i*cos(Re(z)) * sinh(Im(z))
447
+ func_sin = lambdify([z], sin(z) - (sin(re(z)) * cosh(im(z)) + I * cos(re(z)) * sinh(im(z))),
448
+ modules=["cmath", "math"])
449
+ assert abs(func_sin(hpi + 1j * hpi)) < 4e-16
450
+
451
+ # Complex Hyperbolic Cosine Identity: cosh(z) = cosh(Re(z)) * cos(Im(z)) + i*sinh(Re(z)) * sin(Im(z))
452
+ func_cosh_1 = lambdify([z], cosh(z) - (cosh(re(z)) * cos(im(z)) + I * sinh(re(z)) * sin(im(z))),
453
+ modules=["cmath", "math"])
454
+ assert abs(func_cosh_1(hpi + 1j * hpi)) < 4e-16
455
+
456
+ # Complex Hyperbolic Sine Identity: sinh(z) = sinh(Re(z)) * cos(Im(z)) + i*cosh(Re(z)) * sin(Im(z))
457
+ func_sinh = lambdify([z], sinh(z) - (sinh(re(z)) * cos(im(z)) + I * cosh(re(z)) * sin(im(z))),
458
+ modules=["cmath", "math"])
459
+ assert abs(func_sinh(hpi + 1j * hpi)) < 4e-16
460
+
461
+ # cosh(z) = (e^z + e^(-z)) / 2
462
+ func_cosh_2 = lambdify([z], cosh(z) - (exp(z) + exp(-z)) / 2, modules=["cmath", "math"])
463
+ assert abs(func_cosh_2(hpi)) < 4e-16
464
+
465
+ # Additional expressions testing log and exp with real and imaginary parts
466
+ expr1 = log(re(z)) + log(im(z)) - log(re(z) * im(z))
467
+ expr2 = exp(re(z)) * exp(im(z) * I) - exp(z)
468
+ expr3 = log(exp(re(z))) - re(z)
469
+ expr4 = exp(log(re(z))) - re(z)
470
+ expr5 = log(exp(re(z) + im(z))) - (re(z) + im(z))
471
+ expr6 = exp(log(re(z) + im(z))) - (re(z) + im(z))
472
+ func1 = lambdify([z], expr1, modules=["cmath", "math"])
473
+ func2 = lambdify([z], expr2, modules=["cmath", "math"])
474
+ func3 = lambdify([z], expr3, modules=["cmath", "math"])
475
+ func4 = lambdify([z], expr4, modules=["cmath", "math"])
476
+ func5 = lambdify([z], expr5, modules=["cmath", "math"])
477
+ func6 = lambdify([z], expr6, modules=["cmath", "math"])
478
+ test_value = 3 + 4j
479
+ assert abs(func1(test_value)) < 4e-16
480
+ assert abs(func2(test_value)) < 4e-16
481
+ assert abs(func3(test_value)) < 4e-16
482
+ assert abs(func4(test_value)) < 4e-16
483
+ assert abs(func5(test_value)) < 4e-16
484
+ assert abs(func6(test_value)) < 4e-16
485
+
486
+
487
+ def test_issue_9334():
488
+ if not numexpr:
489
+ skip("numexpr not installed.")
490
+ if not numpy:
491
+ skip("numpy not installed.")
492
+ expr = S('b*a - sqrt(a**2)')
493
+ a, b = sorted(expr.free_symbols, key=lambda s: s.name)
494
+ func_numexpr = lambdify((a,b), expr, modules=[numexpr], dummify=False)
495
+ foo, bar = numpy.random.random((2, 4))
496
+ func_numexpr(foo, bar)
497
+
498
+
499
+ def test_issue_12984():
500
+ if not numexpr:
501
+ skip("numexpr not installed.")
502
+ func_numexpr = lambdify((x,y,z), Piecewise((y, x >= 0), (z, x > -1)), numexpr)
503
+ with ignore_warnings(RuntimeWarning):
504
+ assert func_numexpr(1, 24, 42) == 24
505
+ assert str(func_numexpr(-1, 24, 42)) == 'nan'
506
+
507
+
508
+ def test_empty_modules():
509
+ x, y = symbols('x y')
510
+ expr = -(x % y)
511
+
512
+ no_modules = lambdify([x, y], expr)
513
+ empty_modules = lambdify([x, y], expr, modules=[])
514
+ assert no_modules(3, 7) == empty_modules(3, 7)
515
+ assert no_modules(3, 7) == -3
516
+
517
+
518
+ def test_exponentiation():
519
+ f = lambdify(x, x**2)
520
+ assert f(-1) == 1
521
+ assert f(0) == 0
522
+ assert f(1) == 1
523
+ assert f(-2) == 4
524
+ assert f(2) == 4
525
+ assert f(2.5) == 6.25
526
+
527
+
528
+ def test_sqrt():
529
+ f = lambdify(x, sqrt(x))
530
+ assert f(0) == 0.0
531
+ assert f(1) == 1.0
532
+ assert f(4) == 2.0
533
+ assert abs(f(2) - 1.414) < 0.001
534
+ assert f(6.25) == 2.5
535
+
536
+
537
+ def test_trig():
538
+ f = lambdify([x], [cos(x), sin(x)], 'math')
539
+ d = f(pi)
540
+ prec = 1e-11
541
+ assert -prec < d[0] + 1 < prec
542
+ assert -prec < d[1] < prec
543
+ d = f(3.14159)
544
+ prec = 1e-5
545
+ assert -prec < d[0] + 1 < prec
546
+ assert -prec < d[1] < prec
547
+
548
+
549
+ def test_integral():
550
+ if numpy and not scipy:
551
+ skip("scipy not installed.")
552
+ f = Lambda(x, exp(-x**2))
553
+ l = lambdify(y, Integral(f(x), (x, y, oo)))
554
+ d = l(-oo)
555
+ assert 1.77245385 < d < 1.772453851
556
+
557
+
558
+ def test_double_integral():
559
+ if numpy and not scipy:
560
+ skip("scipy not installed.")
561
+ # example from http://mpmath.org/doc/current/calculus/integration.html
562
+ i = Integral(1/(1 - x**2*y**2), (x, 0, 1), (y, 0, z))
563
+ l = lambdify([z], i)
564
+ d = l(1)
565
+ assert 1.23370055 < d < 1.233700551
566
+
567
+ def test_spherical_bessel():
568
+ if numpy and not scipy:
569
+ skip("scipy not installed.")
570
+ test_point = 4.2 #randomly selected
571
+ x = symbols("x")
572
+ jtest = jn(2, x)
573
+ assert abs(lambdify(x,jtest)(test_point) -
574
+ jtest.subs(x,test_point).evalf()) < 1e-8
575
+ ytest = yn(2, x)
576
+ assert abs(lambdify(x,ytest)(test_point) -
577
+ ytest.subs(x,test_point).evalf()) < 1e-8
578
+
579
+
580
+ #================== Test vectors ===================================
581
+
582
+
583
+ def test_vector_simple():
584
+ f = lambdify((x, y, z), (z, y, x))
585
+ assert f(3, 2, 1) == (1, 2, 3)
586
+ assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
587
+ # make sure correct number of args required
588
+ raises(TypeError, lambda: f(0))
589
+
590
+
591
+ def test_vector_discontinuous():
592
+ f = lambdify(x, (-1/x, 1/x))
593
+ raises(ZeroDivisionError, lambda: f(0))
594
+ assert f(1) == (-1.0, 1.0)
595
+ assert f(2) == (-0.5, 0.5)
596
+ assert f(-2) == (0.5, -0.5)
597
+
598
+
599
+ def test_trig_symbolic():
600
+ f = lambdify([x], [cos(x), sin(x)], 'math')
601
+ d = f(pi)
602
+ assert abs(d[0] + 1) < 0.0001
603
+ assert abs(d[1] - 0) < 0.0001
604
+
605
+
606
+ def test_trig_float():
607
+ f = lambdify([x], [cos(x), sin(x)])
608
+ d = f(3.14159)
609
+ assert abs(d[0] + 1) < 0.0001
610
+ assert abs(d[1] - 0) < 0.0001
611
+
612
+
613
+ def test_docs():
614
+ f = lambdify(x, x**2)
615
+ assert f(2) == 4
616
+ f = lambdify([x, y, z], [z, y, x])
617
+ assert f(1, 2, 3) == [3, 2, 1]
618
+ f = lambdify(x, sqrt(x))
619
+ assert f(4) == 2.0
620
+ f = lambdify((x, y), sin(x*y)**2)
621
+ assert f(0, 5) == 0
622
+
623
+
624
+ def test_math():
625
+ f = lambdify((x, y), sin(x), modules="math")
626
+ assert f(0, 5) == 0
627
+
628
+
629
+ def test_sin():
630
+ f = lambdify(x, sin(x)**2)
631
+ assert isinstance(f(2), float)
632
+ f = lambdify(x, sin(x)**2, modules="math")
633
+ assert isinstance(f(2), float)
634
+
635
+
636
+ def test_matrix():
637
+ A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
638
+ sol = Matrix([[1, 2], [sin(3) + 4, 1]])
639
+ f = lambdify((x, y, z), A, modules="sympy")
640
+ assert f(1, 2, 3) == sol
641
+ f = lambdify((x, y, z), (A, [A]), modules="sympy")
642
+ assert f(1, 2, 3) == (sol, [sol])
643
+ J = Matrix((x, x + y)).jacobian((x, y))
644
+ v = Matrix((x, y))
645
+ sol = Matrix([[1, 0], [1, 1]])
646
+ assert lambdify(v, J, modules='sympy')(1, 2) == sol
647
+ assert lambdify(v.T, J, modules='sympy')(1, 2) == sol
648
+
649
+
650
+ def test_numpy_matrix():
651
+ if not numpy:
652
+ skip("numpy not installed.")
653
+ A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
654
+ sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
655
+ #Lambdify array first, to ensure return to array as default
656
+ f = lambdify((x, y, z), A, ['numpy'])
657
+ numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
658
+ #Check that the types are arrays and matrices
659
+ assert isinstance(f(1, 2, 3), numpy.ndarray)
660
+
661
+ # gh-15071
662
+ class dot(Function):
663
+ pass
664
+ x_dot_mtx = dot(x, Matrix([[2], [1], [0]]))
665
+ f_dot1 = lambdify(x, x_dot_mtx)
666
+ inp = numpy.zeros((17, 3))
667
+ assert numpy.all(f_dot1(inp) == 0)
668
+
669
+ strict_kw = {"allow_unknown_functions": False, "inline": True, "fully_qualified_modules": False}
670
+ p2 = NumPyPrinter(dict(user_functions={'dot': 'dot'}, **strict_kw))
671
+ f_dot2 = lambdify(x, x_dot_mtx, printer=p2)
672
+ assert numpy.all(f_dot2(inp) == 0)
673
+
674
+ p3 = NumPyPrinter(strict_kw)
675
+ # The line below should probably fail upon construction (before calling with "(inp)"):
676
+ raises(Exception, lambda: lambdify(x, x_dot_mtx, printer=p3)(inp))
677
+
678
+
679
+ def test_numpy_transpose():
680
+ if not numpy:
681
+ skip("numpy not installed.")
682
+ A = Matrix([[1, x], [0, 1]])
683
+ f = lambdify((x), A.T, modules="numpy")
684
+ numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]]))
685
+
686
+
687
+ def test_numpy_dotproduct():
688
+ if not numpy:
689
+ skip("numpy not installed")
690
+ A = Matrix([x, y, z])
691
+ f1 = lambdify([x, y, z], DotProduct(A, A), modules='numpy')
692
+ f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
693
+ f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='numpy')
694
+ f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
695
+
696
+ assert f1(1, 2, 3) == \
697
+ f2(1, 2, 3) == \
698
+ f3(1, 2, 3) == \
699
+ f4(1, 2, 3) == \
700
+ numpy.array([14])
701
+
702
+
703
+ def test_numpy_inverse():
704
+ if not numpy:
705
+ skip("numpy not installed.")
706
+ A = Matrix([[1, x], [0, 1]])
707
+ f = lambdify((x), A**-1, modules="numpy")
708
+ numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]]))
709
+
710
+
711
+ def test_numpy_old_matrix():
712
+ if not numpy:
713
+ skip("numpy not installed.")
714
+ A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
715
+ sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
716
+ f = lambdify((x, y, z), A, [{'ImmutableDenseMatrix': numpy.matrix}, 'numpy'])
717
+ with ignore_warnings(PendingDeprecationWarning):
718
+ numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
719
+ assert isinstance(f(1, 2, 3), numpy.matrix)
720
+
721
+
722
+ def test_scipy_sparse_matrix():
723
+ if not scipy:
724
+ skip("scipy not installed.")
725
+ A = SparseMatrix([[x, 0], [0, y]])
726
+ f = lambdify((x, y), A, modules="scipy")
727
+ B = f(1, 2)
728
+ assert isinstance(B, scipy.sparse.coo_matrix)
729
+
730
+
731
+ def test_python_div_zero_issue_11306():
732
+ if not numpy:
733
+ skip("numpy not installed.")
734
+ p = Piecewise((1 / x, y < -1), (x, y < 1), (1 / x, True))
735
+ f = lambdify([x, y], p, modules='numpy')
736
+ with numpy.errstate(divide='ignore'):
737
+ assert float(f(numpy.array(0), numpy.array(0.5))) == 0
738
+ assert float(f(numpy.array(0), numpy.array(1))) == float('inf')
739
+
740
+
741
+ def test_issue9474():
742
+ mods = [None, 'math']
743
+ if numpy:
744
+ mods.append('numpy')
745
+ if mpmath:
746
+ mods.append('mpmath')
747
+ for mod in mods:
748
+ f = lambdify(x, S.One/x, modules=mod)
749
+ assert f(2) == 0.5
750
+ f = lambdify(x, floor(S.One/x), modules=mod)
751
+ assert f(2) == 0
752
+
753
+ for absfunc, modules in product([Abs, abs], mods):
754
+ f = lambdify(x, absfunc(x), modules=modules)
755
+ assert f(-1) == 1
756
+ assert f(1) == 1
757
+ assert f(3+4j) == 5
758
+
759
+
760
+ def test_issue_9871():
761
+ if not numexpr:
762
+ skip("numexpr not installed.")
763
+ if not numpy:
764
+ skip("numpy not installed.")
765
+
766
+ r = sqrt(x**2 + y**2)
767
+ expr = diff(1/r, x)
768
+
769
+ xn = yn = numpy.linspace(1, 10, 16)
770
+ # expr(xn, xn) = -xn/(sqrt(2)*xn)^3
771
+ fv_exact = -numpy.sqrt(2.)**-3 * xn**-2
772
+
773
+ fv_numpy = lambdify((x, y), expr, modules='numpy')(xn, yn)
774
+ fv_numexpr = lambdify((x, y), expr, modules='numexpr')(xn, yn)
775
+ numpy.testing.assert_allclose(fv_numpy, fv_exact, rtol=1e-10)
776
+ numpy.testing.assert_allclose(fv_numexpr, fv_exact, rtol=1e-10)
777
+
778
+
779
+ def test_numpy_piecewise():
780
+ if not numpy:
781
+ skip("numpy not installed.")
782
+ pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True))
783
+ f = lambdify(x, pieces, modules="numpy")
784
+ numpy.testing.assert_array_equal(f(numpy.arange(10)),
785
+ numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81]))
786
+ # If we evaluate somewhere all conditions are False, we should get back NaN
787
+ nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0)))
788
+ numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])),
789
+ numpy.array([1, numpy.nan, 1]))
790
+
791
+
792
+ def test_numpy_logical_ops():
793
+ if not numpy:
794
+ skip("numpy not installed.")
795
+ and_func = lambdify((x, y), And(x, y), modules="numpy")
796
+ and_func_3 = lambdify((x, y, z), And(x, y, z), modules="numpy")
797
+ or_func = lambdify((x, y), Or(x, y), modules="numpy")
798
+ or_func_3 = lambdify((x, y, z), Or(x, y, z), modules="numpy")
799
+ not_func = lambdify((x), Not(x), modules="numpy")
800
+ arr1 = numpy.array([True, True])
801
+ arr2 = numpy.array([False, True])
802
+ arr3 = numpy.array([True, False])
803
+ numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True]))
804
+ numpy.testing.assert_array_equal(and_func_3(arr1, arr2, arr3), numpy.array([False, False]))
805
+ numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True]))
806
+ numpy.testing.assert_array_equal(or_func_3(arr1, arr2, arr3), numpy.array([True, True]))
807
+ numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False]))
808
+
809
+
810
+ def test_numpy_matmul():
811
+ if not numpy:
812
+ skip("numpy not installed.")
813
+ xmat = Matrix([[x, y], [z, 1+z]])
814
+ ymat = Matrix([[x**2], [Abs(x)]])
815
+ mat_func = lambdify((x, y, z), xmat*ymat, modules="numpy")
816
+ numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]]))
817
+ numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]]))
818
+ # Multiple matrices chained together in multiplication
819
+ f = lambdify((x, y, z), xmat*xmat*xmat, modules="numpy")
820
+ numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25],
821
+ [159, 251]]))
822
+
823
+
824
+ def test_numpy_numexpr():
825
+ if not numpy:
826
+ skip("numpy not installed.")
827
+ if not numexpr:
828
+ skip("numexpr not installed.")
829
+ a, b, c = numpy.random.randn(3, 128, 128)
830
+ # ensure that numpy and numexpr return same value for complicated expression
831
+ expr = sin(x) + cos(y) + tan(z)**2 + Abs(z-y)*acos(sin(y*z)) + \
832
+ Abs(y-z)*acosh(2+exp(y-x))- sqrt(x**2+I*y**2)
833
+ npfunc = lambdify((x, y, z), expr, modules='numpy')
834
+ nefunc = lambdify((x, y, z), expr, modules='numexpr')
835
+ assert numpy.allclose(npfunc(a, b, c), nefunc(a, b, c))
836
+
837
+
838
+ def test_numexpr_userfunctions():
839
+ if not numpy:
840
+ skip("numpy not installed.")
841
+ if not numexpr:
842
+ skip("numexpr not installed.")
843
+ a, b = numpy.random.randn(2, 10)
844
+ uf = type('uf', (Function, ),
845
+ {'eval' : classmethod(lambda x, y : y**2+1)})
846
+ func = lambdify(x, 1-uf(x), modules='numexpr')
847
+ assert numpy.allclose(func(a), -(a**2))
848
+
849
+ uf = implemented_function(Function('uf'), lambda x, y : 2*x*y+1)
850
+ func = lambdify((x, y), uf(x, y), modules='numexpr')
851
+ assert numpy.allclose(func(a, b), 2*a*b+1)
852
+
853
+
854
+ def test_tensorflow_basic_math():
855
+ if not tensorflow:
856
+ skip("tensorflow not installed.")
857
+ expr = Max(sin(x), Abs(1/(x+2)))
858
+ func = lambdify(x, expr, modules="tensorflow")
859
+
860
+ with tensorflow.compat.v1.Session() as s:
861
+ a = tensorflow.constant(0, dtype=tensorflow.float32)
862
+ assert func(a).eval(session=s) == 0.5
863
+
864
+
865
+ def test_tensorflow_placeholders():
866
+ if not tensorflow:
867
+ skip("tensorflow not installed.")
868
+ expr = Max(sin(x), Abs(1/(x+2)))
869
+ func = lambdify(x, expr, modules="tensorflow")
870
+
871
+ with tensorflow.compat.v1.Session() as s:
872
+ a = tensorflow.compat.v1.placeholder(dtype=tensorflow.float32)
873
+ assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
874
+
875
+
876
+ def test_tensorflow_variables():
877
+ if not tensorflow:
878
+ skip("tensorflow not installed.")
879
+ expr = Max(sin(x), Abs(1/(x+2)))
880
+ func = lambdify(x, expr, modules="tensorflow")
881
+
882
+ with tensorflow.compat.v1.Session() as s:
883
+ a = tensorflow.Variable(0, dtype=tensorflow.float32)
884
+ s.run(a.initializer)
885
+ assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
886
+
887
+
888
+ def test_tensorflow_logical_operations():
889
+ if not tensorflow:
890
+ skip("tensorflow not installed.")
891
+ expr = Not(And(Or(x, y), y))
892
+ func = lambdify([x, y], expr, modules="tensorflow")
893
+
894
+ with tensorflow.compat.v1.Session() as s:
895
+ assert func(False, True).eval(session=s) == False
896
+
897
+
898
+ def test_tensorflow_piecewise():
899
+ if not tensorflow:
900
+ skip("tensorflow not installed.")
901
+ expr = Piecewise((0, Eq(x,0)), (-1, x < 0), (1, x > 0))
902
+ func = lambdify(x, expr, modules="tensorflow")
903
+
904
+ with tensorflow.compat.v1.Session() as s:
905
+ assert func(-1).eval(session=s) == -1
906
+ assert func(0).eval(session=s) == 0
907
+ assert func(1).eval(session=s) == 1
908
+
909
+
910
+ def test_tensorflow_multi_max():
911
+ if not tensorflow:
912
+ skip("tensorflow not installed.")
913
+ expr = Max(x, -x, x**2)
914
+ func = lambdify(x, expr, modules="tensorflow")
915
+
916
+ with tensorflow.compat.v1.Session() as s:
917
+ assert func(-2).eval(session=s) == 4
918
+
919
+
920
+ def test_tensorflow_multi_min():
921
+ if not tensorflow:
922
+ skip("tensorflow not installed.")
923
+ expr = Min(x, -x, x**2)
924
+ func = lambdify(x, expr, modules="tensorflow")
925
+
926
+ with tensorflow.compat.v1.Session() as s:
927
+ assert func(-2).eval(session=s) == -2
928
+
929
+
930
+ def test_tensorflow_relational():
931
+ if not tensorflow:
932
+ skip("tensorflow not installed.")
933
+ expr = x >= 0
934
+ func = lambdify(x, expr, modules="tensorflow")
935
+
936
+ with tensorflow.compat.v1.Session() as s:
937
+ assert func(1).eval(session=s) == True
938
+
939
+
940
+ def test_tensorflow_complexes():
941
+ if not tensorflow:
942
+ skip("tensorflow not installed")
943
+
944
+ func1 = lambdify(x, re(x), modules="tensorflow")
945
+ func2 = lambdify(x, im(x), modules="tensorflow")
946
+ func3 = lambdify(x, Abs(x), modules="tensorflow")
947
+ func4 = lambdify(x, arg(x), modules="tensorflow")
948
+
949
+ with tensorflow.compat.v1.Session() as s:
950
+ # For versions before
951
+ # https://github.com/tensorflow/tensorflow/issues/30029
952
+ # resolved, using Python numeric types may not work
953
+ a = tensorflow.constant(1+2j)
954
+ assert func1(a).eval(session=s) == 1
955
+ assert func2(a).eval(session=s) == 2
956
+
957
+ tensorflow_result = func3(a).eval(session=s)
958
+ sympy_result = Abs(1 + 2j).evalf()
959
+ assert abs(tensorflow_result-sympy_result) < 10**-6
960
+
961
+ tensorflow_result = func4(a).eval(session=s)
962
+ sympy_result = arg(1 + 2j).evalf()
963
+ assert abs(tensorflow_result-sympy_result) < 10**-6
964
+
965
+
966
+ def test_tensorflow_array_arg():
967
+ # Test for issue 14655 (tensorflow part)
968
+ if not tensorflow:
969
+ skip("tensorflow not installed.")
970
+
971
+ f = lambdify([[x, y]], x*x + y, 'tensorflow')
972
+
973
+ with tensorflow.compat.v1.Session() as s:
974
+ fcall = f(tensorflow.constant([2.0, 1.0]))
975
+ assert fcall.eval(session=s) == 5.0
976
+
977
+
978
+ #================== Test symbolic ==================================
979
+
980
+
981
+ def test_sym_single_arg():
982
+ f = lambdify(x, x * y)
983
+ assert f(z) == z * y
984
+
985
+
986
+ def test_sym_list_args():
987
+ f = lambdify([x, y], x + y + z)
988
+ assert f(1, 2) == 3 + z
989
+
990
+
991
+ def test_sym_integral():
992
+ f = Lambda(x, exp(-x**2))
993
+ l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules="sympy")
994
+ assert l(y) == Integral(exp(-y**2), (y, -oo, oo))
995
+ assert l(y).doit() == sqrt(pi)
996
+
997
+
998
+ def test_namespace_order():
999
+ # lambdify had a bug, such that module dictionaries or cached module
1000
+ # dictionaries would pull earlier namespaces into themselves.
1001
+ # Because the module dictionaries form the namespace of the
1002
+ # generated lambda, this meant that the behavior of a previously
1003
+ # generated lambda function could change as a result of later calls
1004
+ # to lambdify.
1005
+ n1 = {'f': lambda x: 'first f'}
1006
+ n2 = {'f': lambda x: 'second f',
1007
+ 'g': lambda x: 'function g'}
1008
+ f = sympy.Function('f')
1009
+ g = sympy.Function('g')
1010
+ if1 = lambdify(x, f(x), modules=(n1, "sympy"))
1011
+ assert if1(1) == 'first f'
1012
+ if2 = lambdify(x, g(x), modules=(n2, "sympy"))
1013
+ # previously gave 'second f'
1014
+ assert if1(1) == 'first f'
1015
+
1016
+ assert if2(1) == 'function g'
1017
+
1018
+
1019
+ def test_imps():
1020
+ # Here we check if the default returned functions are anonymous - in
1021
+ # the sense that we can have more than one function with the same name
1022
+ f = implemented_function('f', lambda x: 2*x)
1023
+ g = implemented_function('f', lambda x: math.sqrt(x))
1024
+ l1 = lambdify(x, f(x))
1025
+ l2 = lambdify(x, g(x))
1026
+ assert str(f(x)) == str(g(x))
1027
+ assert l1(3) == 6
1028
+ assert l2(3) == math.sqrt(3)
1029
+ # check that we can pass in a Function as input
1030
+ func = sympy.Function('myfunc')
1031
+ assert not hasattr(func, '_imp_')
1032
+ my_f = implemented_function(func, lambda x: 2*x)
1033
+ assert hasattr(my_f, '_imp_')
1034
+ # Error for functions with same name and different implementation
1035
+ f2 = implemented_function("f", lambda x: x + 101)
1036
+ raises(ValueError, lambda: lambdify(x, f(f2(x))))
1037
+
1038
+
1039
+ def test_imps_errors():
1040
+ # Test errors that implemented functions can return, and still be able to
1041
+ # form expressions.
1042
+ # See: https://github.com/sympy/sympy/issues/10810
1043
+ #
1044
+ # XXX: Removed AttributeError here. This test was added due to issue 10810
1045
+ # but that issue was about ValueError. It doesn't seem reasonable to
1046
+ # "support" catching AttributeError in the same context...
1047
+ for val, error_class in product((0, 0., 2, 2.0), (TypeError, ValueError)):
1048
+
1049
+ def myfunc(a):
1050
+ if a == 0:
1051
+ raise error_class
1052
+ return 1
1053
+
1054
+ f = implemented_function('f', myfunc)
1055
+ expr = f(val)
1056
+ assert expr == f(val)
1057
+
1058
+
1059
+ def test_imps_wrong_args():
1060
+ raises(ValueError, lambda: implemented_function(sin, lambda x: x))
1061
+
1062
+
1063
+ def test_lambdify_imps():
1064
+ # Test lambdify with implemented functions
1065
+ # first test basic (sympy) lambdify
1066
+ f = sympy.cos
1067
+ assert lambdify(x, f(x))(0) == 1
1068
+ assert lambdify(x, 1 + f(x))(0) == 2
1069
+ assert lambdify((x, y), y + f(x))(0, 1) == 2
1070
+ # make an implemented function and test
1071
+ f = implemented_function("f", lambda x: x + 100)
1072
+ assert lambdify(x, f(x))(0) == 100
1073
+ assert lambdify(x, 1 + f(x))(0) == 101
1074
+ assert lambdify((x, y), y + f(x))(0, 1) == 101
1075
+ # Can also handle tuples, lists, dicts as expressions
1076
+ lam = lambdify(x, (f(x), x))
1077
+ assert lam(3) == (103, 3)
1078
+ lam = lambdify(x, [f(x), x])
1079
+ assert lam(3) == [103, 3]
1080
+ lam = lambdify(x, [f(x), (f(x), x)])
1081
+ assert lam(3) == [103, (103, 3)]
1082
+ lam = lambdify(x, {f(x): x})
1083
+ assert lam(3) == {103: 3}
1084
+ lam = lambdify(x, {f(x): x})
1085
+ assert lam(3) == {103: 3}
1086
+ lam = lambdify(x, {x: f(x)})
1087
+ assert lam(3) == {3: 103}
1088
+ # Check that imp preferred to other namespaces by default
1089
+ d = {'f': lambda x: x + 99}
1090
+ lam = lambdify(x, f(x), d)
1091
+ assert lam(3) == 103
1092
+ # Unless flag passed
1093
+ lam = lambdify(x, f(x), d, use_imps=False)
1094
+ assert lam(3) == 102
1095
+
1096
+
1097
+ def test_dummification():
1098
+ t = symbols('t')
1099
+ F = Function('F')
1100
+ G = Function('G')
1101
+ #"\alpha" is not a valid Python variable name
1102
+ #lambdify should sub in a dummy for it, and return
1103
+ #without a syntax error
1104
+ alpha = symbols(r'\alpha')
1105
+ some_expr = 2 * F(t)**2 / G(t)
1106
+ lam = lambdify((F(t), G(t)), some_expr)
1107
+ assert lam(3, 9) == 2
1108
+ lam = lambdify(sin(t), 2 * sin(t)**2)
1109
+ assert lam(F(t)) == 2 * F(t)**2
1110
+ #Test that \alpha was properly dummified
1111
+ lam = lambdify((alpha, t), 2*alpha + t)
1112
+ assert lam(2, 1) == 5
1113
+ raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5))
1114
+ raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5))
1115
+ raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5))
1116
+
1117
+
1118
+ def test_lambdify__arguments_with_invalid_python_identifiers():
1119
+ # see sympy/sympy#26690
1120
+ N = CoordSys3D('N')
1121
+ xn, yn, zn = N.base_scalars()
1122
+ expr = xn + yn
1123
+ f = lambdify([xn, yn], expr)
1124
+ res = f(0.2, 0.3)
1125
+ ref = 0.2 + 0.3
1126
+ assert abs(res-ref) < 1e-15
1127
+
1128
+
1129
+ def test_curly_matrix_symbol():
1130
+ # Issue #15009
1131
+ curlyv = sympy.MatrixSymbol("{v}", 2, 1)
1132
+ lam = lambdify(curlyv, curlyv)
1133
+ assert lam(1)==1
1134
+ lam = lambdify(curlyv, curlyv, dummify=True)
1135
+ assert lam(1)==1
1136
+
1137
+
1138
+ def test_python_keywords():
1139
+ # Test for issue 7452. The automatic dummification should ensure use of
1140
+ # Python reserved keywords as symbol names will create valid lambda
1141
+ # functions. This is an additional regression test.
1142
+ python_if = symbols('if')
1143
+ expr = python_if / 2
1144
+ f = lambdify(python_if, expr)
1145
+ assert f(4.0) == 2.0
1146
+
1147
+
1148
+ def test_lambdify_docstring():
1149
+ func = lambdify((w, x, y, z), w + x + y + z)
1150
+ ref = (
1151
+ "Created with lambdify. Signature:\n\n"
1152
+ "func(w, x, y, z)\n\n"
1153
+ "Expression:\n\n"
1154
+ "w + x + y + z"
1155
+ ).splitlines()
1156
+ assert func.__doc__.splitlines()[:len(ref)] == ref
1157
+ syms = symbols('a1:26')
1158
+ func = lambdify(syms, sum(syms))
1159
+ ref = (
1160
+ "Created with lambdify. Signature:\n\n"
1161
+ "func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n"
1162
+ " a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n"
1163
+ "Expression:\n\n"
1164
+ "a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +..."
1165
+ ).splitlines()
1166
+ assert func.__doc__.splitlines()[:len(ref)] == ref
1167
+
1168
+
1169
+ def test_lambdify_linecache():
1170
+ func = lambdify(x, x + 1)
1171
+ source = 'def _lambdifygenerated(x):\n return x + 1\n'
1172
+ assert inspect.getsource(func) == source
1173
+ filename = inspect.getsourcefile(func)
1174
+ assert filename.startswith('<lambdifygenerated-')
1175
+ assert filename in linecache.cache
1176
+ assert linecache.cache[filename] == (len(source), None, source.splitlines(True), filename)
1177
+ del func
1178
+ gc.collect()
1179
+ assert filename not in linecache.cache
1180
+
1181
+ #================== Test special printers ==========================
1182
+
1183
+
1184
+ def test_special_printers():
1185
+ from sympy.printing.lambdarepr import IntervalPrinter
1186
+
1187
+ def intervalrepr(expr):
1188
+ return IntervalPrinter().doprint(expr)
1189
+
1190
+ expr = sqrt(sqrt(2) + sqrt(3)) + S.Half
1191
+
1192
+ func0 = lambdify((), expr, modules="mpmath", printer=intervalrepr)
1193
+ func1 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter)
1194
+ func2 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter())
1195
+
1196
+ mpi = type(mpmath.mpi(1, 2))
1197
+
1198
+ assert isinstance(func0(), mpi)
1199
+ assert isinstance(func1(), mpi)
1200
+ assert isinstance(func2(), mpi)
1201
+
1202
+ # To check Is lambdify loggamma works for mpmath or not
1203
+ exp1 = lambdify(x, loggamma(x), 'mpmath')(5)
1204
+ exp2 = lambdify(x, loggamma(x), 'mpmath')(1.8)
1205
+ exp3 = lambdify(x, loggamma(x), 'mpmath')(15)
1206
+ exp_ls = [exp1, exp2, exp3]
1207
+
1208
+ sol1 = mpmath.loggamma(5)
1209
+ sol2 = mpmath.loggamma(1.8)
1210
+ sol3 = mpmath.loggamma(15)
1211
+ sol_ls = [sol1, sol2, sol3]
1212
+
1213
+ assert exp_ls == sol_ls
1214
+
1215
+
1216
+ def test_true_false():
1217
+ # We want exact is comparison here, not just ==
1218
+ assert lambdify([], true)() is True
1219
+ assert lambdify([], false)() is False
1220
+
1221
+
1222
+ def test_issue_2790():
1223
+ assert lambdify((x, (y, z)), x + y)(1, (2, 4)) == 3
1224
+ assert lambdify((x, (y, (w, z))), w + x + y + z)(1, (2, (3, 4))) == 10
1225
+ assert lambdify(x, x + 1, dummify=False)(1) == 2
1226
+
1227
+
1228
+ def test_issue_12092():
1229
+ f = implemented_function('f', lambda x: x**2)
1230
+ assert f(f(2)).evalf() == Float(16)
1231
+
1232
+
1233
+ def test_issue_14911():
1234
+ class Variable(sympy.Symbol):
1235
+ def _sympystr(self, printer):
1236
+ return printer.doprint(self.name)
1237
+
1238
+ _lambdacode = _sympystr
1239
+ _numpycode = _sympystr
1240
+
1241
+ x = Variable('x')
1242
+ y = 2 * x
1243
+ code = LambdaPrinter().doprint(y)
1244
+ assert code.replace(' ', '') == '2*x'
1245
+
1246
+
1247
+ def test_ITE():
1248
+ assert lambdify((x, y, z), ITE(x, y, z))(True, 5, 3) == 5
1249
+ assert lambdify((x, y, z), ITE(x, y, z))(False, 5, 3) == 3
1250
+
1251
+
1252
+ def test_Min_Max():
1253
+ # see gh-10375
1254
+ assert lambdify((x, y, z), Min(x, y, z))(1, 2, 3) == 1
1255
+ assert lambdify((x, y, z), Max(x, y, z))(1, 2, 3) == 3
1256
+
1257
+
1258
+ def test_amin_amax_minimum_maximum():
1259
+ if not numpy:
1260
+ skip("numpy not installed")
1261
+
1262
+ a234 = numpy.array([2, 3, 4])
1263
+ a152 = numpy.array([1, 5, 2])
1264
+
1265
+ a254 = numpy.array([2, 5, 4])
1266
+ a132 = numpy.array([1, 3, 2])
1267
+ # 2 args
1268
+ assert numpy.all(lambdify((x, y), maximum(x, y))(a234, a152) == a254)
1269
+ assert numpy.all(lambdify((x, y), minimum(x, y))(a234, a152) == a132)
1270
+
1271
+ # 3 args
1272
+ assert numpy.all(lambdify((x, y, z), maximum(x, y, z))(a234, a152, a234) == a254)
1273
+ assert numpy.all(lambdify((x, y, z), minimum(x, y, z))(a234, a152, a234) == a132)
1274
+
1275
+ # 1 arg
1276
+ assert numpy.all(lambdify((x,), maximum(x))(a234) == a234)
1277
+ assert numpy.all(lambdify((x,), minimum(x))(a234) == a234)
1278
+
1279
+ # 4 args, mixed length
1280
+ assert numpy.all(lambdify((x, y, z, w), maximum(x, y, z, w))(a234, a152, a234, 3) == [3, 5, 4])
1281
+ assert numpy.all(lambdify((x, y, z, w), minimum(x, y, z, w))(a234, a152, a234, 2) == [1, 2, 2])
1282
+
1283
+ # amin & amax
1284
+ assert lambdify((x, y), [amin(x), amax(y)])(a234, a152) == [2, 5]
1285
+ A = numpy.array([
1286
+ [0, 4, 8],
1287
+ [1, 5, 9],
1288
+ [2, 6, 10],
1289
+ ])
1290
+ min_, max_ = lambdify((x,), [amin(x, axis=0), amax(x, axis=1)])(A)
1291
+ assert numpy.all(min_ == numpy.amin(A, axis=0))
1292
+ assert numpy.all(max_ == numpy.amax(A, axis=1))
1293
+
1294
+ # see gh-25659
1295
+ assert numpy.all(lambdify((x, y), Max(x, y))([1, 2, 3], [3, 2, 1]) == [3, 2, 3])
1296
+ assert numpy.all(lambdify((x), Min(2, x))([1, 2, 3]) == [1, 2, 2])
1297
+
1298
+
1299
+
1300
+ def test_Indexed():
1301
+ # Issue #10934
1302
+ if not numpy:
1303
+ skip("numpy not installed")
1304
+
1305
+ a = IndexedBase('a')
1306
+ i, j = symbols('i j')
1307
+ b = numpy.array([[1, 2], [3, 4]])
1308
+ assert lambdify(a, Sum(a[x, y], (x, 0, 1), (y, 0, 1)))(b) == 10
1309
+
1310
+ def test_Sum():
1311
+ e = Sum(z, (y, 0, x), (x, 0, 10))
1312
+ ref = 66*z
1313
+ assert e.doit() == ref
1314
+ assert lambdify([z], e)(7) == ref.subs(z, 7)
1315
+
1316
+ def test_Idx():
1317
+ # Issue 26888
1318
+ a = IndexedBase('a')
1319
+ i = Idx('i')
1320
+ b = [1,2,3]
1321
+ assert lambdify([a, i], a[i])(b, 2) == 3
1322
+
1323
+
1324
+ def test_issue_12173():
1325
+ #test for issue 12173
1326
+ expr1 = lambdify((x, y), uppergamma(x, y),"mpmath")(1, 2)
1327
+ expr2 = lambdify((x, y), lowergamma(x, y),"mpmath")(1, 2)
1328
+ assert expr1 == uppergamma(1, 2).evalf()
1329
+ assert expr2 == lowergamma(1, 2).evalf()
1330
+
1331
+
1332
+ def test_issue_13642():
1333
+ if not numpy:
1334
+ skip("numpy not installed")
1335
+ f = lambdify(x, sinc(x))
1336
+ assert Abs(f(1) - sinc(1)).n() < 1e-15
1337
+
1338
+
1339
+ def test_sinc_mpmath():
1340
+ f = lambdify(x, sinc(x), "mpmath")
1341
+ assert Abs(f(1) - sinc(1)).n() < 1e-15
1342
+
1343
+
1344
+ def test_lambdify_dummy_arg():
1345
+ d1 = Dummy()
1346
+ f1 = lambdify(d1, d1 + 1, dummify=False)
1347
+ assert f1(2) == 3
1348
+ f1b = lambdify(d1, d1 + 1)
1349
+ assert f1b(2) == 3
1350
+ d2 = Dummy('x')
1351
+ f2 = lambdify(d2, d2 + 1)
1352
+ assert f2(2) == 3
1353
+ f3 = lambdify([[d2]], d2 + 1)
1354
+ assert f3([2]) == 3
1355
+
1356
+
1357
+ def test_lambdify_mixed_symbol_dummy_args():
1358
+ d = Dummy()
1359
+ # Contrived example of name clash
1360
+ dsym = symbols(str(d))
1361
+ f = lambdify([d, dsym], d - dsym)
1362
+ assert f(4, 1) == 3
1363
+
1364
+
1365
+ def test_numpy_array_arg():
1366
+ # Test for issue 14655 (numpy part)
1367
+ if not numpy:
1368
+ skip("numpy not installed")
1369
+
1370
+ f = lambdify([[x, y]], x*x + y, 'numpy')
1371
+
1372
+ assert f(numpy.array([2.0, 1.0])) == 5
1373
+
1374
+
1375
+ def test_scipy_fns():
1376
+ if not scipy:
1377
+ skip("scipy not installed")
1378
+
1379
+ single_arg_sympy_fns = [Ei, erf, erfc, factorial, gamma, loggamma, digamma, Si, Ci]
1380
+ single_arg_scipy_fns = [scipy.special.expi, scipy.special.erf, scipy.special.erfc,
1381
+ scipy.special.factorial, scipy.special.gamma, scipy.special.gammaln,
1382
+ scipy.special.psi, scipy.special.sici, scipy.special.sici]
1383
+ numpy.random.seed(0)
1384
+ for (sympy_fn, scipy_fn) in zip(single_arg_sympy_fns, single_arg_scipy_fns):
1385
+ f = lambdify(x, sympy_fn(x), modules="scipy")
1386
+ for i in range(20):
1387
+ tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
1388
+ # SciPy thinks that factorial(z) is 0 when re(z) < 0 and
1389
+ # does not support complex numbers.
1390
+ # SymPy does not think so.
1391
+ if sympy_fn == factorial:
1392
+ tv = numpy.abs(tv)
1393
+ # SciPy supports gammaln for real arguments only,
1394
+ # and there is also a branch cut along the negative real axis
1395
+ if sympy_fn == loggamma:
1396
+ tv = numpy.abs(tv)
1397
+ # SymPy's digamma evaluates as polygamma(0, z)
1398
+ # which SciPy supports for real arguments only
1399
+ if sympy_fn == digamma:
1400
+ tv = numpy.real(tv)
1401
+ sympy_result = sympy_fn(tv).evalf()
1402
+ scipy_result = scipy_fn(tv)
1403
+ # SciPy's sici returns a tuple with both Si and Ci present in it
1404
+ # which needs to be unpacked
1405
+ if sympy_fn == Si:
1406
+ scipy_result = scipy_fn(tv)[0]
1407
+ if sympy_fn == Ci:
1408
+ scipy_result = scipy_fn(tv)[1]
1409
+ assert abs(f(tv) - sympy_result) < 1e-13*(1 + abs(sympy_result))
1410
+ assert abs(f(tv) - scipy_result) < 1e-13*(1 + abs(sympy_result))
1411
+
1412
+ double_arg_sympy_fns = [RisingFactorial, besselj, bessely, besseli,
1413
+ besselk, polygamma]
1414
+ double_arg_scipy_fns = [scipy.special.poch, scipy.special.jv,
1415
+ scipy.special.yv, scipy.special.iv, scipy.special.kv, scipy.special.polygamma]
1416
+ for (sympy_fn, scipy_fn) in zip(double_arg_sympy_fns, double_arg_scipy_fns):
1417
+ f = lambdify((x, y), sympy_fn(x, y), modules="scipy")
1418
+ for i in range(20):
1419
+ # SciPy supports only real orders of Bessel functions
1420
+ tv1 = numpy.random.uniform(-10, 10)
1421
+ tv2 = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
1422
+ # SciPy requires a real valued 2nd argument for: poch, polygamma
1423
+ if sympy_fn in (RisingFactorial, polygamma):
1424
+ tv2 = numpy.real(tv2)
1425
+ if sympy_fn == polygamma:
1426
+ tv1 = abs(int(tv1)) # first argument to polygamma must be a non-negative integer.
1427
+ sympy_result = sympy_fn(tv1, tv2).evalf()
1428
+ assert abs(f(tv1, tv2) - sympy_result) < 1e-13*(1 + abs(sympy_result))
1429
+ assert abs(f(tv1, tv2) - scipy_fn(tv1, tv2)) < 1e-13*(1 + abs(sympy_result))
1430
+
1431
+
1432
+ def test_scipy_polys():
1433
+ if not scipy:
1434
+ skip("scipy not installed")
1435
+ numpy.random.seed(0)
1436
+
1437
+ params = symbols('n k a b')
1438
+ # list polynomials with the number of parameters
1439
+ polys = [
1440
+ (chebyshevt, 1),
1441
+ (chebyshevu, 1),
1442
+ (legendre, 1),
1443
+ (hermite, 1),
1444
+ (laguerre, 1),
1445
+ (gegenbauer, 2),
1446
+ (assoc_legendre, 2),
1447
+ (assoc_laguerre, 2),
1448
+ (jacobi, 3)
1449
+ ]
1450
+
1451
+ msg = \
1452
+ "The random test of the function {func} with the arguments " \
1453
+ "{args} had failed because the SymPy result {sympy_result} " \
1454
+ "and SciPy result {scipy_result} had failed to converge " \
1455
+ "within the tolerance {tol} " \
1456
+ "(Actual absolute difference : {diff})"
1457
+
1458
+ for sympy_fn, num_params in polys:
1459
+ args = params[:num_params] + (x,)
1460
+ f = lambdify(args, sympy_fn(*args))
1461
+ for _ in range(10):
1462
+ tn = numpy.random.randint(3, 10)
1463
+ tparams = tuple(numpy.random.uniform(0, 5, size=num_params-1))
1464
+ tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
1465
+ # SciPy supports hermite for real arguments only
1466
+ if sympy_fn == hermite:
1467
+ tv = numpy.real(tv)
1468
+ # assoc_legendre needs x in (-1, 1) and integer param at most n
1469
+ if sympy_fn == assoc_legendre:
1470
+ tv = numpy.random.uniform(-1, 1)
1471
+ tparams = tuple(numpy.random.randint(1, tn, size=1))
1472
+
1473
+ vals = (tn,) + tparams + (tv,)
1474
+ scipy_result = f(*vals)
1475
+ sympy_result = sympy_fn(*vals).evalf()
1476
+ atol = 1e-9*(1 + abs(sympy_result))
1477
+ diff = abs(scipy_result - sympy_result)
1478
+ try:
1479
+ assert diff < atol
1480
+ except TypeError:
1481
+ raise AssertionError(
1482
+ msg.format(
1483
+ func=repr(sympy_fn),
1484
+ args=repr(vals),
1485
+ sympy_result=repr(sympy_result),
1486
+ scipy_result=repr(scipy_result),
1487
+ diff=diff,
1488
+ tol=atol)
1489
+ )
1490
+
1491
+
1492
+ def test_lambdify_inspect():
1493
+ f = lambdify(x, x**2)
1494
+ # Test that inspect.getsource works but don't hard-code implementation
1495
+ # details
1496
+ assert 'x**2' in inspect.getsource(f)
1497
+
1498
+
1499
+ def test_issue_14941():
1500
+ x, y = Dummy(), Dummy()
1501
+
1502
+ # test dict
1503
+ f1 = lambdify([x, y], {x: 3, y: 3}, 'sympy')
1504
+ assert f1(2, 3) == {2: 3, 3: 3}
1505
+
1506
+ # test tuple
1507
+ f2 = lambdify([x, y], (y, x), 'sympy')
1508
+ assert f2(2, 3) == (3, 2)
1509
+ f2b = lambdify([], (1,)) # gh-23224
1510
+ assert f2b() == (1,)
1511
+
1512
+ # test list
1513
+ f3 = lambdify([x, y], [y, x], 'sympy')
1514
+ assert f3(2, 3) == [3, 2]
1515
+
1516
+
1517
+ def test_lambdify_Derivative_arg_issue_16468():
1518
+ f = Function('f')(x)
1519
+ fx = f.diff()
1520
+ assert lambdify((f, fx), f + fx)(10, 5) == 15
1521
+ assert eval(lambdastr((f, fx), f/fx))(10, 5) == 2
1522
+ raises(Exception, lambda:
1523
+ eval(lambdastr((f, fx), f/fx, dummify=False)))
1524
+ assert eval(lambdastr((f, fx), f/fx, dummify=True))(10, 5) == 2
1525
+ assert eval(lambdastr((fx, f), f/fx, dummify=True))(S(10), 5) == S.Half
1526
+ assert lambdify(fx, 1 + fx)(41) == 42
1527
+ assert eval(lambdastr(fx, 1 + fx, dummify=True))(41) == 42
1528
+
1529
+
1530
+ def test_lambdify_Derivative_zeta():
1531
+ # This is related to gh-11802 (and to lesser extent gh-26663)
1532
+ expr1 = zeta(x).diff(x, evaluate=False)
1533
+ f1 = lambdify(x, expr1, modules=['mpmath'])
1534
+ ans1 = f1(2)
1535
+ ref1 = (zeta(2+1e-8).evalf()-zeta(2).evalf())/1e-8
1536
+ assert abs(ans1 - ref1)/abs(ref1) < 1e-7
1537
+
1538
+ expr2 = zeta(x**2).diff(x)
1539
+ f2 = lambdify(x, expr2, modules=['mpmath'])
1540
+ ans2 = f2(2**0.5)
1541
+ ref2 = 2*2**0.5*ref1
1542
+ assert abs(ans2-ref2)/abs(ref2) < 1e-7
1543
+
1544
+
1545
+ def test_lambdify_Derivative_custom_printer():
1546
+ func1 = Function('func1')
1547
+ func2 = Function('func2')
1548
+
1549
+ class MyPrinter(NumPyPrinter):
1550
+
1551
+ def _print_Derivative_func1(self, args, seq_orders):
1552
+ arg, = args
1553
+ order, = seq_orders
1554
+ return '42'
1555
+
1556
+ expr1 = func1(x).diff(x)
1557
+ raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr1))
1558
+ f1 = lambdify([x], expr1, printer=MyPrinter)
1559
+ assert f1(7) == 42
1560
+
1561
+ expr2 = func2(x).diff(x)
1562
+ raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr2, printer=MyPrinter))
1563
+
1564
+
1565
+ def test_lambdify_derivative_and_functions_as_arguments():
1566
+ # see: https://github.com/sympy/sympy/issues/26663#issuecomment-2157179517
1567
+ t, a, b = symbols('t, a, b')
1568
+ f = Function('f')(t)
1569
+ args = f.diff(t, 2), f.diff(t), f, a, b
1570
+ expr1 = a*f.diff(t, 2) + b*f.diff(t) + a*b*f + a**2
1571
+ num_args = 2.0, 3.0, 4.0, 5.0, 6.0
1572
+ ref1 = 5*2 + 6*3 + 5*6*4 + 5**2
1573
+
1574
+ expr2 = a*f.diff(t, 2) + b*f.diff(t) - a*b*f + b**2 - a**2
1575
+ ref2 = 5*2 + 6*3 - 5*6*4 + 6**2 - 5**2
1576
+
1577
+ for dummify, _cse in product([False, None, True], [False, True]):
1578
+ func1 = lambdify(args, expr1, cse=_cse, dummify=dummify)
1579
+ res1 = func1(*num_args)
1580
+ assert abs(res1 - ref1) < 1e-12
1581
+
1582
+ func12 = lambdify(args, [expr1, expr2], cse=_cse, dummify=dummify)
1583
+ res12 = func12(*num_args)
1584
+ assert len(res12) == 2
1585
+ assert abs(res12[0] - ref1) < 1e-12
1586
+ assert abs(res12[1] - ref2) < 1e-12
1587
+
1588
+
1589
+ def test_imag_real():
1590
+ f_re = lambdify([z], sympy.re(z))
1591
+ val = 3+2j
1592
+ assert f_re(val) == val.real
1593
+
1594
+ f_im = lambdify([z], sympy.im(z)) # see #15400
1595
+ assert f_im(val) == val.imag
1596
+
1597
+
1598
+ def test_MatrixSymbol_issue_15578():
1599
+ if not numpy:
1600
+ skip("numpy not installed")
1601
+ A = MatrixSymbol('A', 2, 2)
1602
+ A0 = numpy.array([[1, 2], [3, 4]])
1603
+ f = lambdify(A, A**(-1))
1604
+ assert numpy.allclose(f(A0), numpy.array([[-2., 1.], [1.5, -0.5]]))
1605
+ g = lambdify(A, A**3)
1606
+ assert numpy.allclose(g(A0), numpy.array([[37, 54], [81, 118]]))
1607
+
1608
+
1609
+ def test_issue_15654():
1610
+ if not scipy:
1611
+ skip("scipy not installed")
1612
+ from sympy.abc import n, l, r, Z
1613
+ from sympy.physics import hydrogen
1614
+ nv, lv, rv, Zv = 1, 0, 3, 1
1615
+ sympy_value = hydrogen.R_nl(nv, lv, rv, Zv).evalf()
1616
+ f = lambdify((n, l, r, Z), hydrogen.R_nl(n, l, r, Z))
1617
+ scipy_value = f(nv, lv, rv, Zv)
1618
+ assert abs(sympy_value - scipy_value) < 1e-15
1619
+
1620
+
1621
+ def test_issue_15827():
1622
+ if not numpy:
1623
+ skip("numpy not installed")
1624
+ A = MatrixSymbol("A", 3, 3)
1625
+ B = MatrixSymbol("B", 2, 3)
1626
+ C = MatrixSymbol("C", 3, 4)
1627
+ D = MatrixSymbol("D", 4, 5)
1628
+ k=symbols("k")
1629
+ f = lambdify(A, (2*k)*A)
1630
+ g = lambdify(A, (2+k)*A)
1631
+ h = lambdify(A, 2*A)
1632
+ i = lambdify((B, C, D), 2*B*C*D)
1633
+ assert numpy.array_equal(f(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
1634
+ numpy.array([[2*k, 4*k, 6*k], [2*k, 4*k, 6*k], [2*k, 4*k, 6*k]], dtype=object))
1635
+
1636
+ assert numpy.array_equal(g(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
1637
+ numpy.array([[k + 2, 2*k + 4, 3*k + 6], [k + 2, 2*k + 4, 3*k + 6], \
1638
+ [k + 2, 2*k + 4, 3*k + 6]], dtype=object))
1639
+
1640
+ assert numpy.array_equal(h(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
1641
+ numpy.array([[2, 4, 6], [2, 4, 6], [2, 4, 6]]))
1642
+
1643
+ assert numpy.array_equal(i(numpy.array([[1, 2, 3], [1, 2, 3]]), numpy.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]), \
1644
+ numpy.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])), numpy.array([[ 120, 240, 360, 480, 600], \
1645
+ [ 120, 240, 360, 480, 600]]))
1646
+
1647
+
1648
+ def test_issue_16930():
1649
+ if not scipy:
1650
+ skip("scipy not installed")
1651
+
1652
+ x = symbols("x")
1653
+ f = lambda x: S.GoldenRatio * x**2
1654
+ f_ = lambdify(x, f(x), modules='scipy')
1655
+ assert f_(1) == scipy.constants.golden_ratio
1656
+
1657
+ def test_issue_17898():
1658
+ if not scipy:
1659
+ skip("scipy not installed")
1660
+ x = symbols("x")
1661
+ f_ = lambdify([x], sympy.LambertW(x,-1), modules='scipy')
1662
+ assert f_(0.1) == mpmath.lambertw(0.1, -1)
1663
+
1664
+ def test_issue_13167_21411():
1665
+ if not numpy:
1666
+ skip("numpy not installed")
1667
+ f1 = lambdify(x, sympy.Heaviside(x))
1668
+ f2 = lambdify(x, sympy.Heaviside(x, 1))
1669
+ res1 = f1([-1, 0, 1])
1670
+ res2 = f2([-1, 0, 1])
1671
+ assert Abs(res1[0]).n() < 1e-15 # First functionality: only one argument passed
1672
+ assert Abs(res1[1] - 1/2).n() < 1e-15
1673
+ assert Abs(res1[2] - 1).n() < 1e-15
1674
+ assert Abs(res2[0]).n() < 1e-15 # Second functionality: two arguments passed
1675
+ assert Abs(res2[1] - 1).n() < 1e-15
1676
+ assert Abs(res2[2] - 1).n() < 1e-15
1677
+
1678
+ def test_single_e():
1679
+ f = lambdify(x, E)
1680
+ assert f(23) == exp(1.0)
1681
+
1682
+ def test_issue_16536():
1683
+ if not scipy:
1684
+ skip("scipy not installed")
1685
+
1686
+ a = symbols('a')
1687
+ f1 = lowergamma(a, x)
1688
+ F = lambdify((a, x), f1, modules='scipy')
1689
+ assert abs(lowergamma(1, 3) - F(1, 3)) <= 1e-10
1690
+
1691
+ f2 = uppergamma(a, x)
1692
+ F = lambdify((a, x), f2, modules='scipy')
1693
+ assert abs(uppergamma(1, 3) - F(1, 3)) <= 1e-10
1694
+
1695
+
1696
+ def test_issue_22726():
1697
+ if not numpy:
1698
+ skip("numpy not installed")
1699
+
1700
+ x1, x2 = symbols('x1 x2')
1701
+ f = Max(S.Zero, Min(x1, x2))
1702
+ g = derive_by_array(f, (x1, x2))
1703
+ G = lambdify((x1, x2), g, modules='numpy')
1704
+ point = {x1: 1, x2: 2}
1705
+ assert (abs(g.subs(point) - G(*point.values())) <= 1e-10).all()
1706
+
1707
+
1708
+ def test_issue_22739():
1709
+ if not numpy:
1710
+ skip("numpy not installed")
1711
+
1712
+ x1, x2 = symbols('x1 x2')
1713
+ f = Heaviside(Min(x1, x2))
1714
+ F = lambdify((x1, x2), f, modules='numpy')
1715
+ point = {x1: 1, x2: 2}
1716
+ assert abs(f.subs(point) - F(*point.values())) <= 1e-10
1717
+
1718
+
1719
+ def test_issue_22992():
1720
+ if not numpy:
1721
+ skip("numpy not installed")
1722
+
1723
+ a, t = symbols('a t')
1724
+ expr = a*(log(cot(t/2)) - cos(t))
1725
+ F = lambdify([a, t], expr, 'numpy')
1726
+
1727
+ point = {a: 10, t: 2}
1728
+
1729
+ assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
1730
+
1731
+ # Standard math
1732
+ F = lambdify([a, t], expr)
1733
+
1734
+ assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
1735
+
1736
+
1737
+ def test_issue_19764():
1738
+ if not numpy:
1739
+ skip("numpy not installed")
1740
+
1741
+ expr = Array([x, x**2])
1742
+ f = lambdify(x, expr, 'numpy')
1743
+
1744
+ assert f(1).__class__ == numpy.ndarray
1745
+
1746
+ def test_issue_20070():
1747
+ if not numba:
1748
+ skip("numba not installed")
1749
+
1750
+ f = lambdify(x, sin(x), 'numpy')
1751
+ assert numba.jit(f, nopython=True)(1)==0.8414709848078965
1752
+
1753
+
1754
+ def test_fresnel_integrals_scipy():
1755
+ if not scipy:
1756
+ skip("scipy not installed")
1757
+
1758
+ f1 = fresnelc(x)
1759
+ f2 = fresnels(x)
1760
+ F1 = lambdify(x, f1, modules='scipy')
1761
+ F2 = lambdify(x, f2, modules='scipy')
1762
+
1763
+ assert abs(fresnelc(1.3) - F1(1.3)) <= 1e-10
1764
+ assert abs(fresnels(1.3) - F2(1.3)) <= 1e-10
1765
+
1766
+
1767
+ def test_beta_scipy():
1768
+ if not scipy:
1769
+ skip("scipy not installed")
1770
+
1771
+ f = beta(x, y)
1772
+ F = lambdify((x, y), f, modules='scipy')
1773
+
1774
+ assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
1775
+
1776
+
1777
+ def test_beta_math():
1778
+ f = beta(x, y)
1779
+ F = lambdify((x, y), f, modules='math')
1780
+
1781
+ assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
1782
+
1783
+
1784
+ def test_betainc_scipy():
1785
+ if not scipy:
1786
+ skip("scipy not installed")
1787
+
1788
+ f = betainc(w, x, y, z)
1789
+ F = lambdify((w, x, y, z), f, modules='scipy')
1790
+
1791
+ assert abs(betainc(1.4, 3.1, 0.1, 0.5) - F(1.4, 3.1, 0.1, 0.5)) <= 1e-10
1792
+
1793
+
1794
+ def test_betainc_regularized_scipy():
1795
+ if not scipy:
1796
+ skip("scipy not installed")
1797
+
1798
+ f = betainc_regularized(w, x, y, z)
1799
+ F = lambdify((w, x, y, z), f, modules='scipy')
1800
+
1801
+ assert abs(betainc_regularized(0.2, 3.5, 0.1, 1) - F(0.2, 3.5, 0.1, 1)) <= 1e-10
1802
+
1803
+
1804
+ def test_numpy_special_math():
1805
+ if not numpy:
1806
+ skip("numpy not installed")
1807
+
1808
+ funcs = [expm1, log1p, exp2, log2, log10, hypot, logaddexp, logaddexp2]
1809
+ for func in funcs:
1810
+ if 2 in func.nargs:
1811
+ expr = func(x, y)
1812
+ args = (x, y)
1813
+ num_args = (0.3, 0.4)
1814
+ elif 1 in func.nargs:
1815
+ expr = func(x)
1816
+ args = (x,)
1817
+ num_args = (0.3,)
1818
+ else:
1819
+ raise NotImplementedError("Need to handle other than unary & binary functions in test")
1820
+ f = lambdify(args, expr)
1821
+ result = f(*num_args)
1822
+ reference = expr.subs(dict(zip(args, num_args))).evalf()
1823
+ assert numpy.allclose(result, float(reference))
1824
+
1825
+ lae2 = lambdify((x, y), logaddexp2(log2(x), log2(y)))
1826
+ assert abs(2.0**lae2(1e-50, 2.5e-50) - 3.5e-50) < 1e-62 # from NumPy's docstring
1827
+
1828
+
1829
+ def test_scipy_special_math():
1830
+ if not scipy:
1831
+ skip("scipy not installed")
1832
+
1833
+ cm1 = lambdify((x,), cosm1(x), modules='scipy')
1834
+ assert abs(cm1(1e-20) + 5e-41) < 1e-200
1835
+
1836
+ have_scipy_1_10plus = tuple(map(int, scipy.version.version.split('.')[:2])) >= (1, 10)
1837
+
1838
+ if have_scipy_1_10plus:
1839
+ cm2 = lambdify((x, y), powm1(x, y), modules='scipy')
1840
+ assert abs(cm2(1.2, 1e-9) - 1.82321557e-10) < 1e-17
1841
+
1842
+
1843
+ def test_scipy_bernoulli():
1844
+ if not scipy:
1845
+ skip("scipy not installed")
1846
+
1847
+ bern = lambdify((x,), bernoulli(x), modules='scipy')
1848
+ assert bern(1) == 0.5
1849
+
1850
+
1851
+ def test_scipy_harmonic():
1852
+ if not scipy:
1853
+ skip("scipy not installed")
1854
+
1855
+ hn = lambdify((x,), harmonic(x), modules='scipy')
1856
+ assert hn(2) == 1.5
1857
+ hnm = lambdify((x, y), harmonic(x, y), modules='scipy')
1858
+ assert hnm(2, 2) == 1.25
1859
+
1860
+
1861
+ def test_cupy_array_arg():
1862
+ if not cupy:
1863
+ skip("CuPy not installed")
1864
+
1865
+ f = lambdify([[x, y]], x*x + y, 'cupy')
1866
+ result = f(cupy.array([2.0, 1.0]))
1867
+ assert result == 5
1868
+ assert "cupy" in str(type(result))
1869
+
1870
+
1871
+ def test_cupy_array_arg_using_numpy():
1872
+ # numpy functions can be run on cupy arrays
1873
+ # unclear if we can "officially" support this,
1874
+ # depends on numpy __array_function__ support
1875
+ if not cupy:
1876
+ skip("CuPy not installed")
1877
+
1878
+ f = lambdify([[x, y]], x*x + y, 'numpy')
1879
+ result = f(cupy.array([2.0, 1.0]))
1880
+ assert result == 5
1881
+ assert "cupy" in str(type(result))
1882
+
1883
+ def test_cupy_dotproduct():
1884
+ if not cupy:
1885
+ skip("CuPy not installed")
1886
+
1887
+ A = Matrix([x, y, z])
1888
+ f1 = lambdify([x, y, z], DotProduct(A, A), modules='cupy')
1889
+ f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
1890
+ f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='cupy')
1891
+ f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
1892
+
1893
+ assert f1(1, 2, 3) == \
1894
+ f2(1, 2, 3) == \
1895
+ f3(1, 2, 3) == \
1896
+ f4(1, 2, 3) == \
1897
+ cupy.array([14])
1898
+
1899
+
1900
+ def test_jax_array_arg():
1901
+ if not jax:
1902
+ skip("JAX not installed")
1903
+
1904
+ f = lambdify([[x, y]], x*x + y, 'jax')
1905
+ result = f(jax.numpy.array([2.0, 1.0]))
1906
+ assert result == 5
1907
+ assert "jax" in str(type(result))
1908
+
1909
+
1910
+ def test_jax_array_arg_using_numpy():
1911
+ if not jax:
1912
+ skip("JAX not installed")
1913
+
1914
+ f = lambdify([[x, y]], x*x + y, 'numpy')
1915
+ result = f(jax.numpy.array([2.0, 1.0]))
1916
+ assert result == 5
1917
+ assert "jax" in str(type(result))
1918
+
1919
+
1920
+ def test_jax_dotproduct():
1921
+ if not jax:
1922
+ skip("JAX not installed")
1923
+
1924
+ A = Matrix([x, y, z])
1925
+ f1 = lambdify([x, y, z], DotProduct(A, A), modules='jax')
1926
+ f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
1927
+ f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='jax')
1928
+ f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
1929
+
1930
+ assert f1(1, 2, 3) == \
1931
+ f2(1, 2, 3) == \
1932
+ f3(1, 2, 3) == \
1933
+ f4(1, 2, 3) == \
1934
+ jax.numpy.array([14])
1935
+
1936
+
1937
+ def test_lambdify_cse():
1938
+ def no_op_cse(exprs):
1939
+ return (), exprs
1940
+
1941
+ def dummy_cse(exprs):
1942
+ from sympy.simplify.cse_main import cse
1943
+ return cse(exprs, symbols=numbered_symbols(cls=Dummy))
1944
+
1945
+ def minmem(exprs):
1946
+ from sympy.simplify.cse_main import cse_release_variables, cse
1947
+ return cse(exprs, postprocess=cse_release_variables)
1948
+
1949
+ class Case:
1950
+ def __init__(self, *, args, exprs, num_args, requires_numpy=False):
1951
+ self.args = args
1952
+ self.exprs = exprs
1953
+ self.num_args = num_args
1954
+ subs_dict = dict(zip(self.args, self.num_args))
1955
+ self.ref = [e.subs(subs_dict).evalf() for e in exprs]
1956
+ self.requires_numpy = requires_numpy
1957
+
1958
+ def lambdify(self, *, cse):
1959
+ return lambdify(self.args, self.exprs, cse=cse)
1960
+
1961
+ def assertAllClose(self, result, *, abstol=1e-15, reltol=1e-15):
1962
+ if self.requires_numpy:
1963
+ assert all(numpy.allclose(result[i], numpy.asarray(r, dtype=float),
1964
+ rtol=reltol, atol=abstol)
1965
+ for i, r in enumerate(self.ref))
1966
+ return
1967
+
1968
+ for i, r in enumerate(self.ref):
1969
+ abs_err = abs(result[i] - r)
1970
+ if r == 0:
1971
+ assert abs_err < abstol
1972
+ else:
1973
+ assert abs_err/abs(r) < reltol
1974
+
1975
+ cases = [
1976
+ Case(
1977
+ args=(x, y, z),
1978
+ exprs=[
1979
+ x + y + z,
1980
+ x + y - z,
1981
+ 2*x + 2*y - z,
1982
+ (x+y)**2 + (y+z)**2,
1983
+ ],
1984
+ num_args=(2., 3., 4.)
1985
+ ),
1986
+ Case(
1987
+ args=(x, y, z),
1988
+ exprs=[
1989
+ x + sympy.Heaviside(x),
1990
+ y + sympy.Heaviside(x),
1991
+ z + sympy.Heaviside(x, 1),
1992
+ z/sympy.Heaviside(x, 1)
1993
+ ],
1994
+ num_args=(0., 3., 4.)
1995
+ ),
1996
+ Case(
1997
+ args=(x, y, z),
1998
+ exprs=[
1999
+ x + sinc(y),
2000
+ y + sinc(y),
2001
+ z - sinc(y)
2002
+ ],
2003
+ num_args=(0.1, 0.2, 0.3)
2004
+ ),
2005
+ Case(
2006
+ args=(x, y, z),
2007
+ exprs=[
2008
+ Matrix([[x, x*y], [sin(z) + 4, x**z]]),
2009
+ x*y+sin(z)-x**z,
2010
+ Matrix([x*x, sin(z), x**z])
2011
+ ],
2012
+ num_args=(1.,2.,3.),
2013
+ requires_numpy=True
2014
+ ),
2015
+ Case(
2016
+ args=(x, y),
2017
+ exprs=[(x + y - 1)**2, x, x + y,
2018
+ (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)],
2019
+ num_args=(1,2)
2020
+ )
2021
+ ]
2022
+ for case in cases:
2023
+ if not numpy and case.requires_numpy:
2024
+ continue
2025
+ for _cse in [False, True, minmem, no_op_cse, dummy_cse]:
2026
+ f = case.lambdify(cse=_cse)
2027
+ result = f(*case.num_args)
2028
+ case.assertAllClose(result)
2029
+
2030
+ def test_issue_25288():
2031
+ syms = numbered_symbols(cls=Dummy)
2032
+ ok = lambdify(x, [x**2, sin(x**2)], cse=lambda e: cse(e, symbols=syms))(2)
2033
+ assert ok
2034
+
2035
+
2036
+ def test_deprecated_set():
2037
+ with warns_deprecated_sympy():
2038
+ lambdify({x, y}, x + y)
2039
+
2040
+ def test_issue_13881():
2041
+ if not numpy:
2042
+ skip("numpy not installed.")
2043
+
2044
+ X = MatrixSymbol('X', 3, 1)
2045
+
2046
+ f = lambdify(X, X.T*X, 'numpy')
2047
+ assert f(numpy.array([1, 2, 3])) == 14
2048
+ assert f(numpy.array([3, 2, 1])) == 14
2049
+
2050
+ f = lambdify(X, X*X.T, 'numpy')
2051
+ assert f(numpy.array([1, 2, 3])) == 14
2052
+ assert f(numpy.array([3, 2, 1])) == 14
2053
+
2054
+ f = lambdify(X, (X*X.T)*X, 'numpy')
2055
+ arr1 = numpy.array([[1], [2], [3]])
2056
+ arr2 = numpy.array([[14],[28],[42]])
2057
+
2058
+ assert numpy.array_equal(f(arr1), arr2)
2059
+
2060
+
2061
+ def test_23536_lambdify_cse_dummy():
2062
+
2063
+ f = Function('x')(y)
2064
+ g = Function('w')(y)
2065
+ expr = z + (f**4 + g**5)*(f**3 + (g*f)**3)
2066
+ expr = expr.expand()
2067
+ eval_expr = lambdify(((f, g), z), expr, cse=True)
2068
+ ans = eval_expr((1.0, 2.0), 3.0) # shouldn't raise NameError
2069
+ assert ans == 300.0 # not a list and value is 300
2070
+
2071
+
2072
+ class LambdifyDocstringTestCase:
2073
+ SIGNATURE = None
2074
+ EXPR = None
2075
+ SRC = None
2076
+
2077
+ def __init__(self, docstring_limit, expected_redacted):
2078
+ self.docstring_limit = docstring_limit
2079
+ self.expected_redacted = expected_redacted
2080
+
2081
+ @property
2082
+ def expected_expr(self):
2083
+ expr_redacted_msg = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
2084
+ return self.EXPR if not self.expected_redacted else expr_redacted_msg
2085
+
2086
+ @property
2087
+ def expected_src(self):
2088
+ src_redacted_msg = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
2089
+ return self.SRC if not self.expected_redacted else src_redacted_msg
2090
+
2091
+ @property
2092
+ def expected_docstring(self):
2093
+ expected_docstring = (
2094
+ f'Created with lambdify. Signature:\n\n'
2095
+ f'func({self.SIGNATURE})\n\n'
2096
+ f'Expression:\n\n'
2097
+ f'{self.expected_expr}\n\n'
2098
+ f'Source code:\n\n'
2099
+ f'{self.expected_src}\n\n'
2100
+ f'Imported modules:\n\n'
2101
+ )
2102
+ return expected_docstring
2103
+
2104
+ def __len__(self):
2105
+ return len(self.expected_docstring)
2106
+
2107
+ def __repr__(self):
2108
+ return (
2109
+ f'{self.__class__.__name__}('
2110
+ f'docstring_limit={self.docstring_limit}, '
2111
+ f'expected_redacted={self.expected_redacted})'
2112
+ )
2113
+
2114
+
2115
+ def test_lambdify_docstring_size_limit_simple_symbol():
2116
+
2117
+ class SimpleSymbolTestCase(LambdifyDocstringTestCase):
2118
+ SIGNATURE = 'x'
2119
+ EXPR = 'x'
2120
+ SRC = (
2121
+ 'def _lambdifygenerated(x):\n'
2122
+ ' return x\n'
2123
+ )
2124
+
2125
+ x = symbols('x')
2126
+
2127
+ test_cases = (
2128
+ SimpleSymbolTestCase(docstring_limit=None, expected_redacted=False),
2129
+ SimpleSymbolTestCase(docstring_limit=100, expected_redacted=False),
2130
+ SimpleSymbolTestCase(docstring_limit=1, expected_redacted=False),
2131
+ SimpleSymbolTestCase(docstring_limit=0, expected_redacted=True),
2132
+ SimpleSymbolTestCase(docstring_limit=-1, expected_redacted=True),
2133
+ )
2134
+ for test_case in test_cases:
2135
+ lambdified_expr = lambdify(
2136
+ [x],
2137
+ x,
2138
+ 'sympy',
2139
+ docstring_limit=test_case.docstring_limit,
2140
+ )
2141
+ assert lambdified_expr.__doc__ == test_case.expected_docstring
2142
+
2143
+
2144
+ def test_lambdify_docstring_size_limit_nested_expr():
2145
+
2146
+ class ExprListTestCase(LambdifyDocstringTestCase):
2147
+ SIGNATURE = 'x, y, z'
2148
+ EXPR = (
2149
+ '[x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z '
2150
+ '+ 3*x*z**2 +...'
2151
+ )
2152
+ SRC = (
2153
+ 'def _lambdifygenerated(x, y, z):\n'
2154
+ ' return [x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
2155
+ '+ 6*x*y*z + 3*x*z**2 + y**3 + 3*y**2*z + 3*y*z**2 + z**3]\n'
2156
+ )
2157
+
2158
+ x, y, z = symbols('x, y, z')
2159
+ expr = [x, [y], z, ((x + y + z)**3).expand()]
2160
+
2161
+ test_cases = (
2162
+ ExprListTestCase(docstring_limit=None, expected_redacted=False),
2163
+ ExprListTestCase(docstring_limit=200, expected_redacted=False),
2164
+ ExprListTestCase(docstring_limit=50, expected_redacted=True),
2165
+ ExprListTestCase(docstring_limit=0, expected_redacted=True),
2166
+ ExprListTestCase(docstring_limit=-1, expected_redacted=True),
2167
+ )
2168
+ for test_case in test_cases:
2169
+ lambdified_expr = lambdify(
2170
+ [x, y, z],
2171
+ expr,
2172
+ 'sympy',
2173
+ docstring_limit=test_case.docstring_limit,
2174
+ )
2175
+ assert lambdified_expr.__doc__ == test_case.expected_docstring
2176
+
2177
+
2178
+ def test_lambdify_docstring_size_limit_matrix():
2179
+
2180
+ class MatrixTestCase(LambdifyDocstringTestCase):
2181
+ SIGNATURE = 'x, y, z'
2182
+ EXPR = (
2183
+ 'Matrix([[0, x], [x + y + z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
2184
+ '+ 6*x*y*z...'
2185
+ )
2186
+ SRC = (
2187
+ 'def _lambdifygenerated(x, y, z):\n'
2188
+ ' return ImmutableDenseMatrix([[0, x], [x + y + z, x**3 '
2189
+ '+ 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z + 3*x*z**2 + y**3 '
2190
+ '+ 3*y**2*z + 3*y*z**2 + z**3]])\n'
2191
+ )
2192
+
2193
+ x, y, z = symbols('x, y, z')
2194
+ expr = Matrix([[S.Zero, x], [x + y + z, ((x + y + z)**3).expand()]])
2195
+
2196
+ test_cases = (
2197
+ MatrixTestCase(docstring_limit=None, expected_redacted=False),
2198
+ MatrixTestCase(docstring_limit=200, expected_redacted=False),
2199
+ MatrixTestCase(docstring_limit=50, expected_redacted=True),
2200
+ MatrixTestCase(docstring_limit=0, expected_redacted=True),
2201
+ MatrixTestCase(docstring_limit=-1, expected_redacted=True),
2202
+ )
2203
+ for test_case in test_cases:
2204
+ lambdified_expr = lambdify(
2205
+ [x, y, z],
2206
+ expr,
2207
+ 'sympy',
2208
+ docstring_limit=test_case.docstring_limit,
2209
+ )
2210
+ assert lambdified_expr.__doc__ == test_case.expected_docstring
2211
+
2212
+
2213
+ def test_lambdify_empty_tuple():
2214
+ a = symbols("a")
2215
+ expr = ((), (a,))
2216
+ f = lambdify(a, expr)
2217
+ result = f(1)
2218
+ assert result == ((), (1,)), "Lambdify did not handle the empty tuple correctly."
2219
+
2220
+
2221
+ def test_assoc_legendre_numerical_evaluation():
2222
+
2223
+ tol = 1e-10
2224
+
2225
+ sympy_result_integer = assoc_legendre(1, 1/2, 0.1).evalf()
2226
+ sympy_result_complex = assoc_legendre(2, 1, 3).evalf()
2227
+ mpmath_result_integer = -0.474572528387641
2228
+ mpmath_result_complex = -25.45584412271571*I
2229
+
2230
+ assert all_close(sympy_result_integer, mpmath_result_integer, tol)
2231
+ assert all_close(sympy_result_complex, mpmath_result_complex, tol)
2232
+
2233
+
2234
+ def test_Piecewise():
2235
+
2236
+ modules = [math]
2237
+ if numpy:
2238
+ modules.append('numpy')
2239
+
2240
+ for mod in modules:
2241
+ # test isinf
2242
+ f = lambdify(x, Piecewise((7.0, isinf(x)), (3.0, True)), mod)
2243
+ assert f(+float('inf')) == +7.0
2244
+ assert f(-float('inf')) == +7.0
2245
+ assert f(42.) == 3.0
2246
+
2247
+ f2 = lambdify(x, Piecewise((7.0*sign(x), isinf(x)), (3.0, True)), mod)
2248
+ assert f2(+float('inf')) == +7.0
2249
+ assert f2(-float('inf')) == -7.0
2250
+ assert f2(42.) == 3.0
2251
+
2252
+ # test isnan (gh-26784)
2253
+ g = lambdify(x, Piecewise((7.0, isnan(x)), (3.0, True)), mod)
2254
+ assert g(float('nan')) == 7.0
2255
+ assert g(42.) == 3.0
2256
+
2257
+
2258
+ def test_array_symbol():
2259
+ if not numpy:
2260
+ skip("numpy not installed.")
2261
+ a = ArraySymbol('a', (3,))
2262
+ f = lambdify((a), a)
2263
+ assert numpy.all(f(numpy.array([1,2,3])) == numpy.array([1,2,3]))
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_matchpy_connector.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ from sympy.core.relational import (Eq, Ne)
4
+ from sympy.core.singleton import S
5
+ from sympy.core.symbol import symbols
6
+ from sympy.functions.elementary.miscellaneous import sqrt
7
+ from sympy.functions.elementary.trigonometric import (cos, sin)
8
+ from sympy.external import import_module
9
+ from sympy.testing.pytest import skip
10
+ from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer
11
+
12
+ matchpy = import_module("matchpy")
13
+
14
+ x, y, z = symbols("x y z")
15
+
16
+
17
+ def _get_first_match(expr, pattern):
18
+ from matchpy import ManyToOneMatcher, Pattern
19
+
20
+ matcher = ManyToOneMatcher()
21
+ matcher.add(Pattern(pattern))
22
+ return next(iter(matcher.match(expr)))
23
+
24
+
25
+ def test_matchpy_connector():
26
+ if matchpy is None:
27
+ skip("matchpy not installed")
28
+
29
+ from multiset import Multiset
30
+ from matchpy import Pattern, Substitution
31
+
32
+ w_ = WildDot("w_")
33
+ w__ = WildPlus("w__")
34
+ w___ = WildStar("w___")
35
+
36
+ expr = x + y
37
+ pattern = x + w_
38
+ p, subst = _get_first_match(expr, pattern)
39
+ assert p == Pattern(pattern)
40
+ assert subst == Substitution({'w_': y})
41
+
42
+ expr = x + y + z
43
+ pattern = x + w__
44
+ p, subst = _get_first_match(expr, pattern)
45
+ assert p == Pattern(pattern)
46
+ assert subst == Substitution({'w__': Multiset([y, z])})
47
+
48
+ expr = x + y + z
49
+ pattern = x + y + z + w___
50
+ p, subst = _get_first_match(expr, pattern)
51
+ assert p == Pattern(pattern)
52
+ assert subst == Substitution({'w___': Multiset()})
53
+
54
+
55
+ def test_matchpy_optional():
56
+ if matchpy is None:
57
+ skip("matchpy not installed")
58
+
59
+ from matchpy import Pattern, Substitution
60
+ from matchpy import ManyToOneReplacer, ReplacementRule
61
+
62
+ p = WildDot("p", optional=1)
63
+ q = WildDot("q", optional=0)
64
+
65
+ pattern = p*x + q
66
+
67
+ expr1 = 2*x
68
+ pa, subst = _get_first_match(expr1, pattern)
69
+ assert pa == Pattern(pattern)
70
+ assert subst == Substitution({'p': 2, 'q': 0})
71
+
72
+ expr2 = x + 3
73
+ pa, subst = _get_first_match(expr2, pattern)
74
+ assert pa == Pattern(pattern)
75
+ assert subst == Substitution({'p': 1, 'q': 3})
76
+
77
+ expr3 = x
78
+ pa, subst = _get_first_match(expr3, pattern)
79
+ assert pa == Pattern(pattern)
80
+ assert subst == Substitution({'p': 1, 'q': 0})
81
+
82
+ expr4 = x*y + z
83
+ pa, subst = _get_first_match(expr4, pattern)
84
+ assert pa == Pattern(pattern)
85
+ assert subst == Substitution({'p': y, 'q': z})
86
+
87
+ replacer = ManyToOneReplacer()
88
+ replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q)))
89
+ assert replacer.replace(expr1) == sin(2)*cos(0)
90
+ assert replacer.replace(expr2) == sin(1)*cos(3)
91
+ assert replacer.replace(expr3) == sin(1)*cos(0)
92
+ assert replacer.replace(expr4) == sin(y)*cos(z)
93
+
94
+
95
+ def test_replacer():
96
+ if matchpy is None:
97
+ skip("matchpy not installed")
98
+
99
+ for info in [True, False]:
100
+ for lambdify in [True, False]:
101
+ _perform_test_replacer(info, lambdify)
102
+
103
+
104
+ def _perform_test_replacer(info, lambdify):
105
+
106
+ x1_ = WildDot("x1_")
107
+ x2_ = WildDot("x2_")
108
+
109
+ a_ = WildDot("a_", optional=S.One)
110
+ b_ = WildDot("b_", optional=S.One)
111
+ c_ = WildDot("c_", optional=S.Zero)
112
+
113
+ replacer = Replacer(common_constraints=[
114
+ matchpy.CustomConstraint(lambda a_: not a_.has(x)),
115
+ matchpy.CustomConstraint(lambda b_: not b_.has(x)),
116
+ matchpy.CustomConstraint(lambda c_: not c_.has(x)),
117
+ ], lambdify=lambdify, info=info)
118
+
119
+ # Rewrite the equation into implicit form, unless it's already solved:
120
+ replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)], info=1)
121
+
122
+ # Simple equation solver for real numbers:
123
+ replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_), info=2)
124
+ disc = b_**2 - 4*a_*c_
125
+ replacer.add(
126
+ Eq(a_*x**2 + b_*x + c_, 0),
127
+ Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)),
128
+ conditions_nonfalse=[disc >= 0],
129
+ info=3
130
+ )
131
+ replacer.add(
132
+ Eq(a_*x**2 + c_, 0),
133
+ Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)),
134
+ conditions_nonfalse=[-c_*a_ > 0],
135
+ info=4
136
+ )
137
+
138
+ g = lambda expr, infos: (expr, infos) if info else expr
139
+
140
+ assert replacer.replace(Eq(3*x, y)) == g(Eq(x, y/3), [1, 2])
141
+ assert replacer.replace(Eq(x**2 + 1, 0)) == g(Eq(x**2 + 1, 0), [])
142
+ assert replacer.replace(Eq(x**2, 4)) == g((Eq(x, 2) | Eq(x, -2)), [1, 4])
143
+ assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == g(Eq(x, -2*y), [3])
144
+
145
+
146
+ def test_matchpy_object_pickle():
147
+ if matchpy is None:
148
+ return
149
+
150
+ a1 = WildDot("a")
151
+ a2 = pickle.loads(pickle.dumps(a1))
152
+ assert a1 == a2
153
+
154
+ a1 = WildDot("a", S(1))
155
+ a2 = pickle.loads(pickle.dumps(a1))
156
+ assert a1 == a2
157
+
158
+ a1 = WildPlus("a", S(1))
159
+ a2 = pickle.loads(pickle.dumps(a1))
160
+ assert a1 == a2
161
+
162
+ a1 = WildStar("a", S(1))
163
+ a2 = pickle.loads(pickle.dumps(a1))
164
+ assert a1 == a2
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_mathml.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from textwrap import dedent
3
+ from sympy.external import import_module
4
+ from sympy.testing.pytest import skip
5
+ from sympy.utilities.mathml import apply_xsl
6
+
7
+
8
+
9
+ lxml = import_module('lxml')
10
+
11
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_xxe.py"))
12
+
13
+
14
+ def test_xxe():
15
+ assert os.path.isfile(path)
16
+ if not lxml:
17
+ skip("lxml not installed.")
18
+
19
+ mml = dedent(
20
+ rf"""
21
+ <!--?xml version="1.0" ?-->
22
+ <!DOCTYPE replace [<!ENTITY ent SYSTEM "file://{path}"> ]>
23
+ <userInfo>
24
+ <firstName>John</firstName>
25
+ <lastName>&ent;</lastName>
26
+ </userInfo>
27
+ """
28
+ )
29
+ xsl = 'mathml/data/simple_mmlctop.xsl'
30
+
31
+ res = apply_xsl(mml, xsl)
32
+ assert res == \
33
+ '<?xml version="1.0"?>\n<userInfo>\n<firstName>John</firstName>\n<lastName/>\n</userInfo>\n'
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_misc.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from textwrap import dedent
2
+ import sys
3
+ from subprocess import Popen, PIPE
4
+ import os
5
+
6
+ from sympy.core.singleton import S
7
+ from sympy.testing.pytest import (raises, warns_deprecated_sympy,
8
+ skip_under_pyodide)
9
+ from sympy.utilities.misc import (translate, replace, ordinal, rawlines,
10
+ strlines, as_int, find_executable)
11
+
12
+
13
+ def test_translate():
14
+ abc = 'abc'
15
+ assert translate(abc, None, 'a') == 'bc'
16
+ assert translate(abc, None, '') == 'abc'
17
+ assert translate(abc, {'a': 'x'}, 'c') == 'xb'
18
+ assert translate(abc, {'a': 'bc'}, 'c') == 'bcb'
19
+ assert translate(abc, {'ab': 'x'}, 'c') == 'x'
20
+ assert translate(abc, {'ab': ''}, 'c') == ''
21
+ assert translate(abc, {'bc': 'x'}, 'c') == 'ab'
22
+ assert translate(abc, {'abc': 'x', 'a': 'y'}) == 'x'
23
+ u = chr(4096)
24
+ assert translate(abc, 'a', 'x', u) == 'xbc'
25
+ assert (u in translate(abc, 'a', u, u)) is True
26
+
27
+
28
+ def test_replace():
29
+ assert replace('abc', ('a', 'b')) == 'bbc'
30
+ assert replace('abc', {'a': 'Aa'}) == 'Aabc'
31
+ assert replace('abc', ('a', 'b'), ('c', 'C')) == 'bbC'
32
+
33
+
34
+ def test_ordinal():
35
+ assert ordinal(-1) == '-1st'
36
+ assert ordinal(0) == '0th'
37
+ assert ordinal(1) == '1st'
38
+ assert ordinal(2) == '2nd'
39
+ assert ordinal(3) == '3rd'
40
+ assert all(ordinal(i).endswith('th') for i in range(4, 21))
41
+ assert ordinal(100) == '100th'
42
+ assert ordinal(101) == '101st'
43
+ assert ordinal(102) == '102nd'
44
+ assert ordinal(103) == '103rd'
45
+ assert ordinal(104) == '104th'
46
+ assert ordinal(200) == '200th'
47
+ assert all(ordinal(i) == str(i) + 'th' for i in range(-220, -203))
48
+
49
+
50
+ def test_rawlines():
51
+ assert rawlines('a a\na') == "dedent('''\\\n a a\n a''')"
52
+ assert rawlines('a a') == "'a a'"
53
+ assert rawlines(strlines('\\le"ft')) == (
54
+ '(\n'
55
+ " '(\\n'\n"
56
+ ' \'r\\\'\\\\le"ft\\\'\\n\'\n'
57
+ " ')'\n"
58
+ ')')
59
+
60
+
61
+ def test_strlines():
62
+ q = 'this quote (") is in the middle'
63
+ # the following assert rhs was prepared with
64
+ # print(rawlines(strlines(q, 10)))
65
+ assert strlines(q, 10) == dedent('''\
66
+ (
67
+ 'this quo'
68
+ 'te (") i'
69
+ 's in the'
70
+ ' middle'
71
+ )''')
72
+ assert q == (
73
+ 'this quo'
74
+ 'te (") i'
75
+ 's in the'
76
+ ' middle'
77
+ )
78
+ q = "this quote (') is in the middle"
79
+ assert strlines(q, 20) == dedent('''\
80
+ (
81
+ "this quote (') is "
82
+ "in the middle"
83
+ )''')
84
+ assert strlines('\\left') == (
85
+ '(\n'
86
+ "r'\\left'\n"
87
+ ')')
88
+ assert strlines('\\left', short=True) == r"r'\left'"
89
+ assert strlines('\\le"ft') == (
90
+ '(\n'
91
+ 'r\'\\le"ft\'\n'
92
+ ')')
93
+ q = 'this\nother line'
94
+ assert strlines(q) == rawlines(q)
95
+
96
+
97
+ def test_translate_args():
98
+ try:
99
+ translate(None, None, None, 'not_none')
100
+ except ValueError:
101
+ pass # Exception raised successfully
102
+ else:
103
+ assert False
104
+
105
+ assert translate('s', None, None, None) == 's'
106
+
107
+ try:
108
+ translate('s', 'a', 'bc')
109
+ except ValueError:
110
+ pass # Exception raised successfully
111
+ else:
112
+ assert False
113
+
114
+
115
+ @skip_under_pyodide("Cannot create subprocess under pyodide.")
116
+ def test_debug_output():
117
+ env = os.environ.copy()
118
+ env['SYMPY_DEBUG'] = 'True'
119
+ cmd = 'from sympy import *; x = Symbol("x"); print(integrate((1-cos(x))/x, x))'
120
+ cmdline = [sys.executable, '-c', cmd]
121
+ proc = Popen(cmdline, env=env, stdout=PIPE, stderr=PIPE)
122
+ out, err = proc.communicate()
123
+ out = out.decode('ascii') # utf-8?
124
+ err = err.decode('ascii')
125
+ expected = 'substituted: -x*(1 - cos(x)), u: 1/x, u_var: _u'
126
+ assert expected in err, err
127
+
128
+
129
+ def test_as_int():
130
+ raises(ValueError, lambda : as_int(True))
131
+ raises(ValueError, lambda : as_int(1.1))
132
+ raises(ValueError, lambda : as_int([]))
133
+ raises(ValueError, lambda : as_int(S.NaN))
134
+ raises(ValueError, lambda : as_int(S.Infinity))
135
+ raises(ValueError, lambda : as_int(S.NegativeInfinity))
136
+ raises(ValueError, lambda : as_int(S.ComplexInfinity))
137
+ # for the following, limited precision makes int(arg) == arg
138
+ # but the int value is not necessarily what a user might have
139
+ # expected; Q.prime is more nuanced in its response for
140
+ # expressions which might be complex representations of an
141
+ # integer. This is not -- by design -- as_ints role.
142
+ raises(ValueError, lambda : as_int(1e23))
143
+ raises(ValueError, lambda : as_int(S('1.'+'0'*20+'1')))
144
+ assert as_int(True, strict=False) == 1
145
+
146
+ def test_deprecated_find_executable():
147
+ with warns_deprecated_sympy():
148
+ find_executable('python')
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_pickling.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import copy
3
+ import pickle
4
+
5
+ from sympy.physics.units import meter
6
+
7
+ from sympy.testing.pytest import XFAIL, raises, ignore_warnings
8
+
9
+ from sympy.core.basic import Atom, Basic
10
+ from sympy.core.singleton import SingletonRegistry
11
+ from sympy.core.symbol import Str, Dummy, Symbol, Wild
12
+ from sympy.core.numbers import (E, I, pi, oo, zoo, nan, Integer,
13
+ Rational, Float, AlgebraicNumber)
14
+ from sympy.core.relational import (Equality, GreaterThan, LessThan, Relational,
15
+ StrictGreaterThan, StrictLessThan, Unequality)
16
+ from sympy.core.add import Add
17
+ from sympy.core.mul import Mul
18
+ from sympy.core.power import Pow
19
+ from sympy.core.function import Derivative, Function, FunctionClass, Lambda, \
20
+ WildFunction
21
+ from sympy.sets.sets import Interval
22
+ from sympy.core.multidimensional import vectorize
23
+
24
+ from sympy.external.gmpy import gmpy as _gmpy
25
+ from sympy.utilities.exceptions import SymPyDeprecationWarning
26
+
27
+ from sympy.core.singleton import S
28
+ from sympy.core.symbol import symbols
29
+
30
+ from sympy.external import import_module
31
+ cloudpickle = import_module('cloudpickle')
32
+
33
+
34
+ not_equal_attrs = {
35
+ '_assumptions', # This is a local cache that isn't automatically filled on creation
36
+ '_mhash', # Cached after __hash__ is called but set to None after creation
37
+ }
38
+
39
+
40
+ deprecated_attrs = {
41
+ 'is_EmptySet', # Deprecated from SymPy 1.5. This can be removed when is_EmptySet is removed.
42
+ 'expr_free_symbols', # Deprecated from SymPy 1.9. This can be removed when exr_free_symbols is removed.
43
+ }
44
+
45
+ dont_check_attrs = {
46
+ '_sage_', # Fails because Sage is not installed
47
+ }
48
+
49
+
50
+ def check(a, exclude=[], check_attr=True, deprecated=()):
51
+ """ Check that pickling and copying round-trips.
52
+ """
53
+ # Pickling with protocols 0 and 1 is disabled for Basic instances:
54
+ if isinstance(a, Basic):
55
+ for protocol in [0, 1]:
56
+ raises(NotImplementedError, lambda: pickle.dumps(a, protocol))
57
+
58
+ protocols = [2, copy.copy, copy.deepcopy, 3, 4]
59
+ if cloudpickle:
60
+ protocols.extend([cloudpickle])
61
+
62
+ for protocol in protocols:
63
+ if protocol in exclude:
64
+ continue
65
+
66
+ if callable(protocol):
67
+ if isinstance(a, type):
68
+ # Classes can't be copied, but that's okay.
69
+ continue
70
+ b = protocol(a)
71
+ elif inspect.ismodule(protocol):
72
+ b = protocol.loads(protocol.dumps(a))
73
+ else:
74
+ b = pickle.loads(pickle.dumps(a, protocol))
75
+
76
+ d1 = dir(a)
77
+ d2 = dir(b)
78
+ assert set(d1) == set(d2)
79
+
80
+ if not check_attr:
81
+ continue
82
+
83
+ def c(a, b, d):
84
+ for i in d:
85
+ if i in dont_check_attrs:
86
+ continue
87
+ elif i in not_equal_attrs:
88
+ if hasattr(a, i):
89
+ assert hasattr(b, i), i
90
+ elif i in deprecated_attrs or i in deprecated:
91
+ with ignore_warnings(SymPyDeprecationWarning):
92
+ assert getattr(a, i) == getattr(b, i), i
93
+ elif not hasattr(a, i):
94
+ continue
95
+ else:
96
+ attr = getattr(a, i)
97
+ if not hasattr(attr, "__call__"):
98
+ assert hasattr(b, i), i
99
+ assert getattr(b, i) == attr, "%s != %s, protocol: %s" % (getattr(b, i), attr, protocol)
100
+
101
+ c(a, b, d1)
102
+ c(b, a, d2)
103
+
104
+
105
+
106
+ #================== core =========================
107
+
108
+
109
+ def test_core_basic():
110
+ for c in (Atom, Atom(), Basic, Basic(), SingletonRegistry, S):
111
+ check(c)
112
+
113
+ def test_core_Str():
114
+ check(Str('x'))
115
+
116
+ def test_core_symbol():
117
+ # make the Symbol a unique name that doesn't class with any other
118
+ # testing variable in this file since after this test the symbol
119
+ # having the same name will be cached as noncommutative
120
+ for c in (Dummy, Dummy("x", commutative=False), Symbol,
121
+ Symbol("_issue_3130", commutative=False), Wild, Wild("x")):
122
+ check(c)
123
+
124
+
125
+ def test_core_numbers():
126
+ for c in (Integer(2), Rational(2, 3), Float("1.2")):
127
+ check(c)
128
+ for c in (AlgebraicNumber, AlgebraicNumber(sqrt(3))):
129
+ check(c, check_attr=False)
130
+
131
+
132
+ def test_core_float_copy():
133
+ # See gh-7457
134
+ y = Symbol("x") + 1.0
135
+ check(y) # does not raise TypeError ("argument is not an mpz")
136
+
137
+
138
+ def test_core_relational():
139
+ x = Symbol("x")
140
+ y = Symbol("y")
141
+ for c in (Equality, Equality(x, y), GreaterThan, GreaterThan(x, y),
142
+ LessThan, LessThan(x, y), Relational, Relational(x, y),
143
+ StrictGreaterThan, StrictGreaterThan(x, y), StrictLessThan,
144
+ StrictLessThan(x, y), Unequality, Unequality(x, y)):
145
+ check(c)
146
+
147
+
148
+ def test_core_add():
149
+ x = Symbol("x")
150
+ for c in (Add, Add(x, 4)):
151
+ check(c)
152
+
153
+
154
+ def test_core_mul():
155
+ x = Symbol("x")
156
+ for c in (Mul, Mul(x, 4)):
157
+ check(c)
158
+
159
+
160
+ def test_core_power():
161
+ x = Symbol("x")
162
+ for c in (Pow, Pow(x, 4)):
163
+ check(c)
164
+
165
+
166
+ def test_core_function():
167
+ x = Symbol("x")
168
+ for f in (Derivative, Derivative(x), Function, FunctionClass, Lambda,
169
+ WildFunction):
170
+ check(f)
171
+
172
+
173
+ def test_core_undefinedfunctions():
174
+ f = Function("f")
175
+ check(f)
176
+
177
+
178
+ def test_core_appliedundef():
179
+ x = Symbol("_long_unique_name_1")
180
+ f = Function("_long_unique_name_2")
181
+ check(f(x))
182
+
183
+
184
+ def test_core_interval():
185
+ for c in (Interval, Interval(0, 2)):
186
+ check(c)
187
+
188
+
189
+ def test_core_multidimensional():
190
+ for c in (vectorize, vectorize(0)):
191
+ check(c)
192
+
193
+
194
+ def test_Singletons():
195
+ protocols = [0, 1, 2, 3, 4]
196
+ copiers = [copy.copy, copy.deepcopy]
197
+ copiers += [lambda x: pickle.loads(pickle.dumps(x, proto))
198
+ for proto in protocols]
199
+ if cloudpickle:
200
+ copiers += [lambda x: cloudpickle.loads(cloudpickle.dumps(x))]
201
+
202
+ for obj in (Integer(-1), Integer(0), Integer(1), Rational(1, 2), pi, E, I,
203
+ oo, -oo, zoo, nan, S.GoldenRatio, S.TribonacciConstant,
204
+ S.EulerGamma, S.Catalan, S.EmptySet, S.IdentityFunction):
205
+ for func in copiers:
206
+ assert func(obj) is obj
207
+
208
+ #================== combinatorics ===================
209
+ from sympy.combinatorics.free_groups import FreeGroup
210
+
211
+ def test_free_group():
212
+ check(FreeGroup("x, y, z"), check_attr=False)
213
+
214
+ #================== functions ===================
215
+ from sympy.functions import (Piecewise, lowergamma, acosh, chebyshevu,
216
+ chebyshevt, ln, chebyshevt_root, legendre, Heaviside, bernoulli, coth,
217
+ tanh, assoc_legendre, sign, arg, asin, DiracDelta, re, rf, Abs,
218
+ uppergamma, binomial, sinh, cos, cot, acos, acot, gamma, bell,
219
+ hermite, harmonic, LambertW, zeta, log, factorial, asinh, acoth, cosh,
220
+ dirichlet_eta, Eijk, loggamma, erf, ceiling, im, fibonacci,
221
+ tribonacci, conjugate, tan, chebyshevu_root, floor, atanh, sqrt, sin,
222
+ atan, ff, lucas, atan2, polygamma, exp)
223
+
224
+
225
+ def test_functions():
226
+ one_var = (acosh, ln, Heaviside, factorial, bernoulli, coth, tanh,
227
+ sign, arg, asin, DiracDelta, re, Abs, sinh, cos, cot, acos, acot,
228
+ gamma, bell, harmonic, LambertW, zeta, log, factorial, asinh,
229
+ acoth, cosh, dirichlet_eta, loggamma, erf, ceiling, im, fibonacci,
230
+ tribonacci, conjugate, tan, floor, atanh, sin, atan, lucas, exp)
231
+ two_var = (rf, ff, lowergamma, chebyshevu, chebyshevt, binomial,
232
+ atan2, polygamma, hermite, legendre, uppergamma)
233
+ x, y, z = symbols("x,y,z")
234
+ others = (chebyshevt_root, chebyshevu_root, Eijk(x, y, z),
235
+ Piecewise( (0, x < -1), (x**2, x <= 1), (x**3, True)),
236
+ assoc_legendre)
237
+ for cls in one_var:
238
+ check(cls)
239
+ c = cls(x)
240
+ check(c)
241
+ for cls in two_var:
242
+ check(cls)
243
+ c = cls(x, y)
244
+ check(c)
245
+ for cls in others:
246
+ check(cls)
247
+
248
+ #================== geometry ====================
249
+ from sympy.geometry.entity import GeometryEntity
250
+ from sympy.geometry.point import Point
251
+ from sympy.geometry.ellipse import Circle, Ellipse
252
+ from sympy.geometry.line import Line, LinearEntity, Ray, Segment
253
+ from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle
254
+
255
+
256
+ def test_geometry():
257
+ p1 = Point(1, 2)
258
+ p2 = Point(2, 3)
259
+ p3 = Point(0, 0)
260
+ p4 = Point(0, 1)
261
+ for c in (
262
+ GeometryEntity, GeometryEntity(), Point, p1, Circle, Circle(p1, 2),
263
+ Ellipse, Ellipse(p1, 3, 4), Line, Line(p1, p2), LinearEntity,
264
+ LinearEntity(p1, p2), Ray, Ray(p1, p2), Segment, Segment(p1, p2),
265
+ Polygon, Polygon(p1, p2, p3, p4), RegularPolygon,
266
+ RegularPolygon(p1, 4, 5), Triangle, Triangle(p1, p2, p3)):
267
+ check(c, check_attr=False)
268
+
269
+ #================== integrals ====================
270
+ from sympy.integrals.integrals import Integral
271
+
272
+
273
+ def test_integrals():
274
+ x = Symbol("x")
275
+ for c in (Integral, Integral(x)):
276
+ check(c)
277
+
278
+ #==================== logic =====================
279
+ from sympy.core.logic import Logic
280
+
281
+
282
+ def test_logic():
283
+ for c in (Logic, Logic(1)):
284
+ check(c)
285
+
286
+ #================== matrices ====================
287
+ from sympy.matrices import Matrix, SparseMatrix
288
+
289
+
290
+ def test_matrices():
291
+ for c in (Matrix, Matrix([1, 2, 3]), SparseMatrix, SparseMatrix([[1, 2], [3, 4]])):
292
+ check(c, deprecated=['_smat', '_mat'])
293
+
294
+ #================== ntheory =====================
295
+ from sympy.ntheory.generate import Sieve
296
+
297
+
298
+ def test_ntheory():
299
+ for c in (Sieve, Sieve()):
300
+ check(c)
301
+
302
+ #================== physics =====================
303
+ from sympy.physics.paulialgebra import Pauli
304
+ from sympy.physics.units import Unit
305
+
306
+
307
+ def test_physics():
308
+ for c in (Unit, meter, Pauli, Pauli(1)):
309
+ check(c)
310
+
311
+ #================== plotting ====================
312
+ # XXX: These tests are not complete, so XFAIL them
313
+
314
+
315
+ @XFAIL
316
+ def test_plotting():
317
+ from sympy.plotting.pygletplot.color_scheme import ColorGradient, ColorScheme
318
+ from sympy.plotting.pygletplot.managed_window import ManagedWindow
319
+ from sympy.plotting.plot import Plot, ScreenShot
320
+ from sympy.plotting.pygletplot.plot_axes import PlotAxes, PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
321
+ from sympy.plotting.pygletplot.plot_camera import PlotCamera
322
+ from sympy.plotting.pygletplot.plot_controller import PlotController
323
+ from sympy.plotting.pygletplot.plot_curve import PlotCurve
324
+ from sympy.plotting.pygletplot.plot_interval import PlotInterval
325
+ from sympy.plotting.pygletplot.plot_mode import PlotMode
326
+ from sympy.plotting.pygletplot.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \
327
+ ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
328
+ from sympy.plotting.pygletplot.plot_object import PlotObject
329
+ from sympy.plotting.pygletplot.plot_surface import PlotSurface
330
+ from sympy.plotting.pygletplot.plot_window import PlotWindow
331
+ for c in (
332
+ ColorGradient, ColorGradient(0.2, 0.4), ColorScheme, ManagedWindow,
333
+ ManagedWindow, Plot, ScreenShot, PlotAxes, PlotAxesBase,
334
+ PlotAxesFrame, PlotAxesOrdinate, PlotCamera, PlotController,
335
+ PlotCurve, PlotInterval, PlotMode, Cartesian2D, Cartesian3D,
336
+ Cylindrical, ParametricCurve2D, ParametricCurve3D,
337
+ ParametricSurface, Polar, Spherical, PlotObject, PlotSurface,
338
+ PlotWindow):
339
+ check(c)
340
+
341
+
342
+ @XFAIL
343
+ def test_plotting2():
344
+ #from sympy.plotting.color_scheme import ColorGradient
345
+ from sympy.plotting.pygletplot.color_scheme import ColorScheme
346
+ #from sympy.plotting.managed_window import ManagedWindow
347
+ from sympy.plotting.plot import Plot
348
+ #from sympy.plotting.plot import ScreenShot
349
+ from sympy.plotting.pygletplot.plot_axes import PlotAxes
350
+ #from sympy.plotting.plot_axes import PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
351
+ #from sympy.plotting.plot_camera import PlotCamera
352
+ #from sympy.plotting.plot_controller import PlotController
353
+ #from sympy.plotting.plot_curve import PlotCurve
354
+ #from sympy.plotting.plot_interval import PlotInterval
355
+ #from sympy.plotting.plot_mode import PlotMode
356
+ #from sympy.plotting.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \
357
+ # ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
358
+ #from sympy.plotting.plot_object import PlotObject
359
+ #from sympy.plotting.plot_surface import PlotSurface
360
+ # from sympy.plotting.plot_window import PlotWindow
361
+ check(ColorScheme("rainbow"))
362
+ check(Plot(1, visible=False))
363
+ check(PlotAxes())
364
+
365
+ #================== polys =======================
366
+ from sympy.polys.domains.integerring import ZZ
367
+ from sympy.polys.domains.rationalfield import QQ
368
+ from sympy.polys.orderings import lex
369
+ from sympy.polys.polytools import Poly
370
+
371
+ def test_pickling_polys_polytools():
372
+ from sympy.polys.polytools import PurePoly
373
+ # from sympy.polys.polytools import GroebnerBasis
374
+ x = Symbol('x')
375
+
376
+ for c in (Poly, Poly(x, x)):
377
+ check(c)
378
+
379
+ for c in (PurePoly, PurePoly(x)):
380
+ check(c)
381
+
382
+ # TODO: fix pickling of Options class (see GroebnerBasis._options)
383
+ # for c in (GroebnerBasis, GroebnerBasis([x**2 - 1], x, order=lex)):
384
+ # check(c)
385
+
386
+ def test_pickling_polys_polyclasses():
387
+ from sympy.polys.polyclasses import DMP, DMF, ANP
388
+
389
+ for c in (DMP, DMP([[ZZ(1)], [ZZ(2)], [ZZ(3)]], ZZ)):
390
+ check(c, deprecated=['rep'])
391
+ for c in (DMF, DMF(([ZZ(1), ZZ(2)], [ZZ(1), ZZ(3)]), ZZ)):
392
+ check(c)
393
+ for c in (ANP, ANP([QQ(1), QQ(2)], [QQ(1), QQ(2), QQ(3)], QQ)):
394
+ check(c)
395
+
396
+ @XFAIL
397
+ def test_pickling_polys_rings():
398
+ # NOTE: can't use protocols < 2 because we have to execute __new__ to
399
+ # make sure caching of rings works properly.
400
+
401
+ from sympy.polys.rings import PolyRing
402
+
403
+ ring = PolyRing("x,y,z", ZZ, lex)
404
+
405
+ for c in (PolyRing, ring):
406
+ check(c, exclude=[0, 1])
407
+
408
+ for c in (ring.dtype, ring.one):
409
+ check(c, exclude=[0, 1], check_attr=False) # TODO: Py3k
410
+
411
+ def test_pickling_polys_fields():
412
+ pass
413
+ # NOTE: can't use protocols < 2 because we have to execute __new__ to
414
+ # make sure caching of fields works properly.
415
+
416
+ # from sympy.polys.fields import FracField
417
+
418
+ # field = FracField("x,y,z", ZZ, lex)
419
+
420
+ # TODO: AssertionError: assert id(obj) not in self.memo
421
+ # for c in (FracField, field):
422
+ # check(c, exclude=[0, 1])
423
+
424
+ # TODO: AssertionError: assert id(obj) not in self.memo
425
+ # for c in (field.dtype, field.one):
426
+ # check(c, exclude=[0, 1])
427
+
428
+ def test_pickling_polys_elements():
429
+ from sympy.polys.domains.pythonrational import PythonRational
430
+ #from sympy.polys.domains.pythonfinitefield import PythonFiniteField
431
+ #from sympy.polys.domains.mpelements import MPContext
432
+
433
+ for c in (PythonRational, PythonRational(1, 7)):
434
+ check(c)
435
+
436
+ #gf = PythonFiniteField(17)
437
+
438
+ # TODO: fix pickling of ModularInteger
439
+ # for c in (gf.dtype, gf(5)):
440
+ # check(c)
441
+
442
+ #mp = MPContext()
443
+
444
+ # TODO: fix pickling of RealElement
445
+ # for c in (mp.mpf, mp.mpf(1.0)):
446
+ # check(c)
447
+
448
+ # TODO: fix pickling of ComplexElement
449
+ # for c in (mp.mpc, mp.mpc(1.0, -1.5)):
450
+ # check(c)
451
+
452
+ def test_pickling_polys_domains():
453
+ # from sympy.polys.domains.pythonfinitefield import PythonFiniteField
454
+ from sympy.polys.domains.pythonintegerring import PythonIntegerRing
455
+ from sympy.polys.domains.pythonrationalfield import PythonRationalField
456
+
457
+ # TODO: fix pickling of ModularInteger
458
+ # for c in (PythonFiniteField, PythonFiniteField(17)):
459
+ # check(c)
460
+
461
+ for c in (PythonIntegerRing, PythonIntegerRing()):
462
+ check(c, check_attr=False)
463
+
464
+ for c in (PythonRationalField, PythonRationalField()):
465
+ check(c, check_attr=False)
466
+
467
+ if _gmpy is not None:
468
+ # from sympy.polys.domains.gmpyfinitefield import GMPYFiniteField
469
+ from sympy.polys.domains.gmpyintegerring import GMPYIntegerRing
470
+ from sympy.polys.domains.gmpyrationalfield import GMPYRationalField
471
+
472
+ # TODO: fix pickling of ModularInteger
473
+ # for c in (GMPYFiniteField, GMPYFiniteField(17)):
474
+ # check(c)
475
+
476
+ for c in (GMPYIntegerRing, GMPYIntegerRing()):
477
+ check(c, check_attr=False)
478
+
479
+ for c in (GMPYRationalField, GMPYRationalField()):
480
+ check(c, check_attr=False)
481
+
482
+ #from sympy.polys.domains.realfield import RealField
483
+ #from sympy.polys.domains.complexfield import ComplexField
484
+ from sympy.polys.domains.algebraicfield import AlgebraicField
485
+ #from sympy.polys.domains.polynomialring import PolynomialRing
486
+ #from sympy.polys.domains.fractionfield import FractionField
487
+ from sympy.polys.domains.expressiondomain import ExpressionDomain
488
+
489
+ # TODO: fix pickling of RealElement
490
+ # for c in (RealField, RealField(100)):
491
+ # check(c)
492
+
493
+ # TODO: fix pickling of ComplexElement
494
+ # for c in (ComplexField, ComplexField(100)):
495
+ # check(c)
496
+
497
+ for c in (AlgebraicField, AlgebraicField(QQ, sqrt(3))):
498
+ check(c, check_attr=False)
499
+
500
+ # TODO: AssertionError
501
+ # for c in (PolynomialRing, PolynomialRing(ZZ, "x,y,z")):
502
+ # check(c)
503
+
504
+ # TODO: AttributeError: 'PolyElement' object has no attribute 'ring'
505
+ # for c in (FractionField, FractionField(ZZ, "x,y,z")):
506
+ # check(c)
507
+
508
+ for c in (ExpressionDomain, ExpressionDomain()):
509
+ check(c, check_attr=False)
510
+
511
+
512
+ def test_pickling_polys_orderings():
513
+ from sympy.polys.orderings import (LexOrder, GradedLexOrder,
514
+ ReversedGradedLexOrder, InverseOrder)
515
+ # from sympy.polys.orderings import ProductOrder
516
+
517
+ for c in (LexOrder, LexOrder()):
518
+ check(c)
519
+
520
+ for c in (GradedLexOrder, GradedLexOrder()):
521
+ check(c)
522
+
523
+ for c in (ReversedGradedLexOrder, ReversedGradedLexOrder()):
524
+ check(c)
525
+
526
+ # TODO: Argh, Python is so naive. No lambdas nor inner function support in
527
+ # pickling module. Maybe someone could figure out what to do with this.
528
+ #
529
+ # for c in (ProductOrder, ProductOrder((LexOrder(), lambda m: m[:2]),
530
+ # (GradedLexOrder(), lambda m: m[2:]))):
531
+ # check(c)
532
+
533
+ for c in (InverseOrder, InverseOrder(LexOrder())):
534
+ check(c)
535
+
536
+ def test_pickling_polys_monomials():
537
+ from sympy.polys.monomials import MonomialOps, Monomial
538
+ x, y, z = symbols("x,y,z")
539
+
540
+ for c in (MonomialOps, MonomialOps(3)):
541
+ check(c)
542
+
543
+ for c in (Monomial, Monomial((1, 2, 3), (x, y, z))):
544
+ check(c)
545
+
546
+ def test_pickling_polys_errors():
547
+ from sympy.polys.polyerrors import (HeuristicGCDFailed,
548
+ HomomorphismFailed, IsomorphismFailed, ExtraneousFactors,
549
+ EvaluationFailed, RefinementFailed, CoercionFailed, NotInvertible,
550
+ NotReversible, NotAlgebraic, DomainError, PolynomialError,
551
+ UnificationFailed, GeneratorsError, GeneratorsNeeded,
552
+ UnivariatePolynomialError, MultivariatePolynomialError, OptionError,
553
+ FlagError)
554
+ # from sympy.polys.polyerrors import (ExactQuotientFailed,
555
+ # OperationNotSupported, ComputationFailed, PolificationFailed)
556
+
557
+ # x = Symbol('x')
558
+
559
+ # TODO: TypeError: __init__() takes at least 3 arguments (1 given)
560
+ # for c in (ExactQuotientFailed, ExactQuotientFailed(x, 3*x, ZZ)):
561
+ # check(c)
562
+
563
+ # TODO: TypeError: can't pickle instancemethod objects
564
+ # for c in (OperationNotSupported, OperationNotSupported(Poly(x), Poly.gcd)):
565
+ # check(c)
566
+
567
+ for c in (HeuristicGCDFailed, HeuristicGCDFailed()):
568
+ check(c)
569
+
570
+ for c in (HomomorphismFailed, HomomorphismFailed()):
571
+ check(c)
572
+
573
+ for c in (IsomorphismFailed, IsomorphismFailed()):
574
+ check(c)
575
+
576
+ for c in (ExtraneousFactors, ExtraneousFactors()):
577
+ check(c)
578
+
579
+ for c in (EvaluationFailed, EvaluationFailed()):
580
+ check(c)
581
+
582
+ for c in (RefinementFailed, RefinementFailed()):
583
+ check(c)
584
+
585
+ for c in (CoercionFailed, CoercionFailed()):
586
+ check(c)
587
+
588
+ for c in (NotInvertible, NotInvertible()):
589
+ check(c)
590
+
591
+ for c in (NotReversible, NotReversible()):
592
+ check(c)
593
+
594
+ for c in (NotAlgebraic, NotAlgebraic()):
595
+ check(c)
596
+
597
+ for c in (DomainError, DomainError()):
598
+ check(c)
599
+
600
+ for c in (PolynomialError, PolynomialError()):
601
+ check(c)
602
+
603
+ for c in (UnificationFailed, UnificationFailed()):
604
+ check(c)
605
+
606
+ for c in (GeneratorsError, GeneratorsError()):
607
+ check(c)
608
+
609
+ for c in (GeneratorsNeeded, GeneratorsNeeded()):
610
+ check(c)
611
+
612
+ # TODO: PicklingError: Can't pickle <function <lambda> at 0x38578c0>: it's not found as __main__.<lambda>
613
+ # for c in (ComputationFailed, ComputationFailed(lambda t: t, 3, None)):
614
+ # check(c)
615
+
616
+ for c in (UnivariatePolynomialError, UnivariatePolynomialError()):
617
+ check(c)
618
+
619
+ for c in (MultivariatePolynomialError, MultivariatePolynomialError()):
620
+ check(c)
621
+
622
+ # TODO: TypeError: __init__() takes at least 3 arguments (1 given)
623
+ # for c in (PolificationFailed, PolificationFailed({}, x, x, False)):
624
+ # check(c)
625
+
626
+ for c in (OptionError, OptionError()):
627
+ check(c)
628
+
629
+ for c in (FlagError, FlagError()):
630
+ check(c)
631
+
632
+ #def test_pickling_polys_options():
633
+ #from sympy.polys.polyoptions import Options
634
+
635
+ # TODO: fix pickling of `symbols' flag
636
+ # for c in (Options, Options((), dict(domain='ZZ', polys=False))):
637
+ # check(c)
638
+
639
+ # TODO: def test_pickling_polys_rootisolation():
640
+ # RealInterval
641
+ # ComplexInterval
642
+
643
+ def test_pickling_polys_rootoftools():
644
+ from sympy.polys.rootoftools import CRootOf, RootSum
645
+
646
+ x = Symbol('x')
647
+ f = x**3 + x + 3
648
+
649
+ for c in (CRootOf, CRootOf(f, 0)):
650
+ check(c)
651
+
652
+ for c in (RootSum, RootSum(f, exp)):
653
+ check(c)
654
+
655
+ #================== printing ====================
656
+ from sympy.printing.latex import LatexPrinter
657
+ from sympy.printing.mathml import MathMLContentPrinter, MathMLPresentationPrinter
658
+ from sympy.printing.pretty.pretty import PrettyPrinter
659
+ from sympy.printing.pretty.stringpict import prettyForm, stringPict
660
+ from sympy.printing.printer import Printer
661
+ from sympy.printing.python import PythonPrinter
662
+
663
+
664
+ def test_printing():
665
+ for c in (LatexPrinter, LatexPrinter(), MathMLContentPrinter,
666
+ MathMLPresentationPrinter, PrettyPrinter, prettyForm, stringPict,
667
+ stringPict("a"), Printer, Printer(), PythonPrinter,
668
+ PythonPrinter()):
669
+ check(c)
670
+
671
+
672
+ @XFAIL
673
+ def test_printing1():
674
+ check(MathMLContentPrinter())
675
+
676
+
677
+ @XFAIL
678
+ def test_printing2():
679
+ check(MathMLPresentationPrinter())
680
+
681
+
682
+ @XFAIL
683
+ def test_printing3():
684
+ check(PrettyPrinter())
685
+
686
+ #================== series ======================
687
+ from sympy.series.limits import Limit
688
+ from sympy.series.order import Order
689
+
690
+
691
+ def test_series():
692
+ e = Symbol("e")
693
+ x = Symbol("x")
694
+ for c in (Limit, Limit(e, x, 1), Order, Order(e)):
695
+ check(c)
696
+
697
+ #================== concrete ==================
698
+ from sympy.concrete.products import Product
699
+ from sympy.concrete.summations import Sum
700
+
701
+
702
+ def test_concrete():
703
+ x = Symbol("x")
704
+ for c in (Product, Product(x, (x, 2, 4)), Sum, Sum(x, (x, 2, 4))):
705
+ check(c)
706
+
707
+ def test_deprecation_warning():
708
+ w = SymPyDeprecationWarning("message", deprecated_since_version='1.0', active_deprecations_target="active-deprecations")
709
+ check(w)
710
+
711
+ def test_issue_18438():
712
+ assert pickle.loads(pickle.dumps(S.Half)) == S.Half
713
+
714
+
715
+ #================= old pickles =================
716
+ def test_unpickle_from_older_versions():
717
+ data = (
718
+ b'\x80\x04\x95^\x00\x00\x00\x00\x00\x00\x00\x8c\x10sympy.core.power'
719
+ b'\x94\x8c\x03Pow\x94\x93\x94\x8c\x12sympy.core.numbers\x94\x8c'
720
+ b'\x07Integer\x94\x93\x94K\x02\x85\x94R\x94}\x94bh\x03\x8c\x04Half'
721
+ b'\x94\x93\x94)R\x94}\x94b\x86\x94R\x94}\x94b.'
722
+ )
723
+ assert pickle.loads(data) == sqrt(2)
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_source.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.utilities.source import get_mod_func, get_class
2
+
3
+
4
+ def test_get_mod_func():
5
+ assert get_mod_func(
6
+ 'sympy.core.basic.Basic') == ('sympy.core.basic', 'Basic')
7
+
8
+
9
+ def test_get_class():
10
+ _basic = get_class('sympy.core.basic.Basic')
11
+ assert _basic.__name__ == 'Basic'
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_timeutils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for simple tools for timing functions' execution. """
2
+
3
+ from sympy.utilities.timeutils import timed
4
+
5
+ def test_timed():
6
+ result = timed(lambda: 1 + 1, limit=100000)
7
+ assert result[0] == 100000 and result[3] == "ns", str(result)
8
+
9
+ result = timed("1 + 1", limit=100000)
10
+ assert result[0] == 100000 and result[3] == "ns"
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_wester.py ADDED
@@ -0,0 +1,3104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Tests from Michael Wester's 1999 paper "Review of CAS mathematical
2
+ capabilities".
3
+
4
+ http://www.math.unm.edu/~wester/cas/book/Wester.pdf
5
+ See also http://math.unm.edu/~wester/cas_review.html for detailed output of
6
+ each tested system.
7
+ """
8
+
9
+ from sympy.assumptions.ask import Q, ask
10
+ from sympy.assumptions.refine import refine
11
+ from sympy.concrete.products import product
12
+ from sympy.core import EulerGamma
13
+ from sympy.core.evalf import N
14
+ from sympy.core.function import (Derivative, Function, Lambda, Subs,
15
+ diff, expand, expand_func)
16
+ from sympy.core.mul import Mul
17
+ from sympy.core.intfunc import igcd
18
+ from sympy.core.numbers import (AlgebraicNumber, E, I, Rational,
19
+ nan, oo, pi, zoo)
20
+ from sympy.core.relational import Eq, Lt
21
+ from sympy.core.singleton import S
22
+ from sympy.core.symbol import Dummy, Symbol, symbols
23
+ from sympy.functions.combinatorial.factorials import (rf, binomial,
24
+ factorial, factorial2)
25
+ from sympy.functions.combinatorial.numbers import bernoulli, fibonacci, totient, partition
26
+ from sympy.functions.elementary.complexes import (conjugate, im, re,
27
+ sign)
28
+ from sympy.functions.elementary.exponential import LambertW, exp, log
29
+ from sympy.functions.elementary.hyperbolic import (asinh, cosh, sinh,
30
+ tanh)
31
+ from sympy.functions.elementary.integers import ceiling, floor
32
+ from sympy.functions.elementary.miscellaneous import Max, Min, sqrt
33
+ from sympy.functions.elementary.piecewise import Piecewise
34
+ from sympy.functions.elementary.trigonometric import (acos, acot, asin,
35
+ atan, cos, cot, csc, sec, sin, tan)
36
+ from sympy.functions.special.bessel import besselj
37
+ from sympy.functions.special.delta_functions import DiracDelta
38
+ from sympy.functions.special.elliptic_integrals import (elliptic_e,
39
+ elliptic_f)
40
+ from sympy.functions.special.gamma_functions import gamma, polygamma
41
+ from sympy.functions.special.hyper import hyper
42
+ from sympy.functions.special.polynomials import (assoc_legendre,
43
+ chebyshevt)
44
+ from sympy.functions.special.zeta_functions import polylog
45
+ from sympy.geometry.util import idiff
46
+ from sympy.logic.boolalg import And
47
+ from sympy.matrices.dense import hessian, wronskian
48
+ from sympy.matrices.expressions.matmul import MatMul
49
+ from sympy.ntheory.continued_fraction import (
50
+ continued_fraction_convergents as cf_c,
51
+ continued_fraction_iterator as cf_i, continued_fraction_periodic as
52
+ cf_p, continued_fraction_reduce as cf_r)
53
+ from sympy.ntheory.factor_ import factorint
54
+ from sympy.ntheory.generate import primerange
55
+ from sympy.polys.domains.integerring import ZZ
56
+ from sympy.polys.orthopolys import legendre_poly
57
+ from sympy.polys.partfrac import apart
58
+ from sympy.polys.polytools import Poly, factor, gcd, resultant
59
+ from sympy.series.limits import limit
60
+ from sympy.series.order import O
61
+ from sympy.series.residues import residue
62
+ from sympy.series.series import series
63
+ from sympy.sets.fancysets import ImageSet
64
+ from sympy.sets.sets import FiniteSet, Intersection, Interval, Union
65
+ from sympy.simplify.combsimp import combsimp
66
+ from sympy.simplify.hyperexpand import hyperexpand
67
+ from sympy.simplify.powsimp import powdenest, powsimp
68
+ from sympy.simplify.radsimp import radsimp
69
+ from sympy.simplify.simplify import logcombine, simplify
70
+ from sympy.simplify.sqrtdenest import sqrtdenest
71
+ from sympy.simplify.trigsimp import trigsimp
72
+ from sympy.solvers.solvers import solve
73
+
74
+ import mpmath
75
+ from sympy.functions.combinatorial.numbers import stirling
76
+ from sympy.functions.special.delta_functions import Heaviside
77
+ from sympy.functions.special.error_functions import Ci, Si, erf
78
+ from sympy.functions.special.zeta_functions import zeta
79
+ from sympy.testing.pytest import (XFAIL, slow, SKIP, tooslow, raises)
80
+ from sympy.utilities.iterables import partitions
81
+ from mpmath import mpi, mpc
82
+ from sympy.matrices import Matrix, GramSchmidt, eye
83
+ from sympy.matrices.expressions.blockmatrix import BlockMatrix, block_collapse
84
+ from sympy.matrices.expressions import MatrixSymbol, ZeroMatrix
85
+ from sympy.physics.quantum import Commutator
86
+ from sympy.polys.rings import PolyRing
87
+ from sympy.polys.fields import FracField
88
+ from sympy.polys.solvers import solve_lin_sys
89
+ from sympy.concrete import Sum
90
+ from sympy.concrete.products import Product
91
+ from sympy.integrals import integrate
92
+ from sympy.integrals.transforms import laplace_transform,\
93
+ inverse_laplace_transform, LaplaceTransform, fourier_transform,\
94
+ mellin_transform, laplace_correspondence, laplace_initial_conds
95
+ from sympy.solvers.recurr import rsolve
96
+ from sympy.solvers.solveset import solveset, solveset_real, linsolve
97
+ from sympy.solvers.ode import dsolve
98
+ from sympy.core.relational import Equality
99
+ from itertools import islice, takewhile
100
+ from sympy.series.formal import fps
101
+ from sympy.series.fourier import fourier_series
102
+ from sympy.calculus.util import minimum
103
+
104
+
105
+ EmptySet = S.EmptySet
106
+ R = Rational
107
+ x, y, z = symbols('x y z')
108
+ i, j, k, l, m, n = symbols('i j k l m n', integer=True)
109
+ f = Function('f')
110
+ g = Function('g')
111
+
112
+ # A. Boolean Logic and Quantifier Elimination
113
+ # Not implemented.
114
+
115
+ # B. Set Theory
116
+
117
+
118
+ def test_B1():
119
+ assert (FiniteSet(i, j, j, k, k, k) | FiniteSet(l, k, j) |
120
+ FiniteSet(j, m, j)) == FiniteSet(i, j, k, l, m)
121
+
122
+
123
+ def test_B2():
124
+ assert (FiniteSet(i, j, j, k, k, k) & FiniteSet(l, k, j) &
125
+ FiniteSet(j, m, j)) == Intersection({j, m}, {i, j, k}, {j, k, l})
126
+ # Previous output below. Not sure why that should be the expected output.
127
+ # There should probably be a way to rewrite Intersections that way but I
128
+ # don't see why an Intersection should evaluate like that:
129
+ #
130
+ # == Union({j}, Intersection({m}, Union({j, k}, Intersection({i}, {l}))))
131
+
132
+
133
+ def test_B3():
134
+ assert (FiniteSet(i, j, k, l, m) - FiniteSet(j) ==
135
+ FiniteSet(i, k, l, m))
136
+
137
+
138
+ def test_B4():
139
+ assert (FiniteSet(*(FiniteSet(i, j)*FiniteSet(k, l))) ==
140
+ FiniteSet((i, k), (i, l), (j, k), (j, l)))
141
+
142
+
143
+ # C. Numbers
144
+
145
+
146
+ def test_C1():
147
+ assert (factorial(50) ==
148
+ 30414093201713378043612608166064768844377641568960512000000000000)
149
+
150
+
151
+ def test_C2():
152
+ assert (factorint(factorial(50)) == {2: 47, 3: 22, 5: 12, 7: 8,
153
+ 11: 4, 13: 3, 17: 2, 19: 2, 23: 2, 29: 1, 31: 1, 37: 1,
154
+ 41: 1, 43: 1, 47: 1})
155
+
156
+
157
+ def test_C3():
158
+ assert (factorial2(10), factorial2(9)) == (3840, 945)
159
+
160
+
161
+ # Base conversions; not really implemented by SymPy
162
+ # Whatever. Take credit!
163
+ def test_C4():
164
+ assert 0xABC == 2748
165
+
166
+
167
+ def test_C5():
168
+ assert 123 == int('234', 7)
169
+
170
+
171
+ def test_C6():
172
+ assert int('677', 8) == int('1BF', 16) == 447
173
+
174
+
175
+ def test_C7():
176
+ assert log(32768, 8) == 5
177
+
178
+
179
+ def test_C8():
180
+ # Modular multiplicative inverse. Would be nice if divmod could do this.
181
+ assert ZZ.invert(5, 7) == 3
182
+ assert ZZ.invert(5, 6) == 5
183
+
184
+
185
+ def test_C9():
186
+ assert igcd(igcd(1776, 1554), 5698) == 74
187
+
188
+
189
+ def test_C10():
190
+ x = 0
191
+ for n in range(2, 11):
192
+ x += R(1, n)
193
+ assert x == R(4861, 2520)
194
+
195
+
196
+ def test_C11():
197
+ assert R(1, 7) == S('0.[142857]')
198
+
199
+
200
+ def test_C12():
201
+ assert R(7, 11) * R(22, 7) == 2
202
+
203
+
204
+ def test_C13():
205
+ test = R(10, 7) * (1 + R(29, 1000)) ** R(1, 3)
206
+ good = 3 ** R(1, 3)
207
+ assert test == good
208
+
209
+
210
+ def test_C14():
211
+ assert sqrtdenest(sqrt(2*sqrt(3) + 4)) == 1 + sqrt(3)
212
+
213
+
214
+ def test_C15():
215
+ test = sqrtdenest(sqrt(14 + 3*sqrt(3 + 2*sqrt(5 - 12*sqrt(3 - 2*sqrt(2))))))
216
+ good = sqrt(2) + 3
217
+ assert test == good
218
+
219
+
220
+ def test_C16():
221
+ test = sqrtdenest(sqrt(10 + 2*sqrt(6) + 2*sqrt(10) + 2*sqrt(15)))
222
+ good = sqrt(2) + sqrt(3) + sqrt(5)
223
+ assert test == good
224
+
225
+
226
+ def test_C17():
227
+ test = radsimp((sqrt(3) + sqrt(2)) / (sqrt(3) - sqrt(2)))
228
+ good = 5 + 2*sqrt(6)
229
+ assert test == good
230
+
231
+
232
+ def test_C18():
233
+ assert simplify((sqrt(-2 + sqrt(-5)) * sqrt(-2 - sqrt(-5))).expand(complex=True)) == 3
234
+
235
+
236
+ @XFAIL
237
+ def test_C19():
238
+ assert radsimp(simplify((90 + 34*sqrt(7)) ** R(1, 3))) == 3 + sqrt(7)
239
+
240
+
241
+ def test_C20():
242
+ inside = (135 + 78*sqrt(3))
243
+ test = AlgebraicNumber((inside**R(2, 3) + 3) * sqrt(3) / inside**R(1, 3))
244
+ assert simplify(test) == AlgebraicNumber(12)
245
+
246
+
247
+ def test_C21():
248
+ assert simplify(AlgebraicNumber((41 + 29*sqrt(2)) ** R(1, 5))) == \
249
+ AlgebraicNumber(1 + sqrt(2))
250
+
251
+
252
+ @XFAIL
253
+ def test_C22():
254
+ test = simplify(((6 - 4*sqrt(2))*log(3 - 2*sqrt(2)) + (3 - 2*sqrt(2))*log(17
255
+ - 12*sqrt(2)) + 32 - 24*sqrt(2)) / (48*sqrt(2) - 72))
256
+ good = sqrt(2)/3 - log(sqrt(2) - 1)/3
257
+ assert test == good
258
+
259
+
260
+ def test_C23():
261
+ assert 2 * oo - 3 is oo
262
+
263
+
264
+ @XFAIL
265
+ def test_C24():
266
+ raise NotImplementedError("2**aleph_null == aleph_1")
267
+
268
+ # D. Numerical Analysis
269
+
270
+
271
+ def test_D1():
272
+ assert 0.0 / sqrt(2) == 0
273
+
274
+
275
+ def test_D2():
276
+ assert str(exp(-1000000).evalf()) == '3.29683147808856e-434295'
277
+
278
+
279
+ def test_D3():
280
+ assert exp(pi*sqrt(163)).evalf(50).num.ae(262537412640768744)
281
+
282
+
283
+ def test_D4():
284
+ assert floor(R(-5, 3)) == -2
285
+ assert ceiling(R(-5, 3)) == -1
286
+
287
+
288
+ @XFAIL
289
+ def test_D5():
290
+ raise NotImplementedError("cubic_spline([1, 2, 4, 5], [1, 4, 2, 3], x)(3) == 27/8")
291
+
292
+
293
+ @XFAIL
294
+ def test_D6():
295
+ raise NotImplementedError("translate sum(a[i]*x**i, (i,1,n)) to FORTRAN")
296
+
297
+
298
+ @XFAIL
299
+ def test_D7():
300
+ raise NotImplementedError("translate sum(a[i]*x**i, (i,1,n)) to C")
301
+
302
+
303
+ @XFAIL
304
+ def test_D8():
305
+ # One way is to cheat by converting the sum to a string,
306
+ # and replacing the '[' and ']' with ''.
307
+ # E.g., horner(S(str(_).replace('[','').replace(']','')))
308
+ raise NotImplementedError("apply Horner's rule to sum(a[i]*x**i, (i,1,5))")
309
+
310
+
311
+ @XFAIL
312
+ def test_D9():
313
+ raise NotImplementedError("translate D8 to FORTRAN")
314
+
315
+
316
+ @XFAIL
317
+ def test_D10():
318
+ raise NotImplementedError("translate D8 to C")
319
+
320
+
321
+ @XFAIL
322
+ def test_D11():
323
+ #Is there a way to use count_ops?
324
+ raise NotImplementedError("flops(sum(product(f[i][k], (i,1,k)), (k,1,n)))")
325
+
326
+
327
+ @XFAIL
328
+ def test_D12():
329
+ assert (mpi(-4, 2) * x + mpi(1, 3)) ** 2 == mpi(-8, 16)*x**2 + mpi(-24, 12)*x + mpi(1, 9)
330
+
331
+
332
+ @XFAIL
333
+ def test_D13():
334
+ raise NotImplementedError("discretize a PDE: diff(f(x,t),t) == diff(diff(f(x,t),x),x)")
335
+
336
+ # E. Statistics
337
+ # See scipy; all of this is numerical.
338
+
339
+ # F. Combinatorial Theory.
340
+
341
+
342
+ def test_F1():
343
+ assert rf(x, 3) == x*(1 + x)*(2 + x)
344
+
345
+
346
+ def test_F2():
347
+ assert expand_func(binomial(n, 3)) == n*(n - 1)*(n - 2)/6
348
+
349
+
350
+ @XFAIL
351
+ def test_F3():
352
+ assert combsimp(2**n * factorial(n) * factorial2(2*n - 1)) == factorial(2*n)
353
+
354
+
355
+ @XFAIL
356
+ def test_F4():
357
+ assert combsimp(2**n * factorial(n) * product(2*k - 1, (k, 1, n))) == factorial(2*n)
358
+
359
+
360
+ @XFAIL
361
+ def test_F5():
362
+ assert gamma(n + R(1, 2)) / sqrt(pi) / factorial(n) == factorial(2*n)/2**(2*n)/factorial(n)**2
363
+
364
+
365
+ def test_F6():
366
+ partTest = [p.copy() for p in partitions(4)]
367
+ partDesired = [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2:1}, {1: 4}]
368
+ assert partTest == partDesired
369
+
370
+
371
+ def test_F7():
372
+ assert partition(4) == 5
373
+
374
+
375
+ def test_F8():
376
+ assert stirling(5, 2, signed=True) == -50 # if signed, then kind=1
377
+
378
+
379
+ def test_F9():
380
+ assert totient(1776) == 576
381
+
382
+ # G. Number Theory
383
+
384
+
385
+ def test_G1():
386
+ assert list(primerange(999983, 1000004)) == [999983, 1000003]
387
+
388
+
389
+ @XFAIL
390
+ def test_G2():
391
+ raise NotImplementedError("find the primitive root of 191 == 19")
392
+
393
+
394
+ @XFAIL
395
+ def test_G3():
396
+ raise NotImplementedError("(a+b)**p mod p == a**p + b**p mod p; p prime")
397
+
398
+ # ... G14 Modular equations are not implemented.
399
+
400
+ def test_G15():
401
+ assert Rational(sqrt(3).evalf()).limit_denominator(15) == R(26, 15)
402
+ assert list(takewhile(lambda x: x.q <= 15, cf_c(cf_i(sqrt(3)))))[-1] == \
403
+ R(26, 15)
404
+
405
+
406
+ def test_G16():
407
+ assert list(islice(cf_i(pi),10)) == [3, 7, 15, 1, 292, 1, 1, 1, 2, 1]
408
+
409
+
410
+ def test_G17():
411
+ assert cf_p(0, 1, 23) == [4, [1, 3, 1, 8]]
412
+
413
+
414
+ def test_G18():
415
+ assert cf_p(1, 2, 5) == [[1]]
416
+ assert cf_r([[1]]).expand() == S.Half + sqrt(5)/2
417
+
418
+
419
+ @XFAIL
420
+ def test_G19():
421
+ s = symbols('s', integer=True, positive=True)
422
+ it = cf_i((exp(1/s) - 1)/(exp(1/s) + 1))
423
+ assert list(islice(it, 5)) == [0, 2*s, 6*s, 10*s, 14*s]
424
+
425
+
426
+ def test_G20():
427
+ s = symbols('s', integer=True, positive=True)
428
+ # Wester erroneously has this as -s + sqrt(s**2 + 1)
429
+ assert cf_r([[2*s]]) == s + sqrt(s**2 + 1)
430
+
431
+
432
+ @XFAIL
433
+ def test_G20b():
434
+ s = symbols('s', integer=True, positive=True)
435
+ assert cf_p(s, 1, s**2 + 1) == [[2*s]]
436
+
437
+
438
+ # H. Algebra
439
+
440
+
441
+ def test_H1():
442
+ assert simplify(2*2**n) == simplify(2**(n + 1))
443
+ assert powdenest(2*2**n) == simplify(2**(n + 1))
444
+
445
+
446
+ def test_H2():
447
+ assert powsimp(4 * 2**n) == 2**(n + 2)
448
+
449
+
450
+ def test_H3():
451
+ assert (-1)**(n*(n + 1)) == 1
452
+
453
+
454
+ def test_H4():
455
+ expr = factor(6*x - 10)
456
+ assert type(expr) is Mul
457
+ assert expr.args[0] == 2
458
+ assert expr.args[1] == 3*x - 5
459
+
460
+ p1 = 64*x**34 - 21*x**47 - 126*x**8 - 46*x**5 - 16*x**60 - 81
461
+ p2 = 72*x**60 - 25*x**25 - 19*x**23 - 22*x**39 - 83*x**52 + 54*x**10 + 81
462
+ q = 34*x**19 - 25*x**16 + 70*x**7 + 20*x**3 - 91*x - 86
463
+
464
+
465
+ def test_H5():
466
+ assert gcd(p1, p2, x) == 1
467
+
468
+
469
+ def test_H6():
470
+ assert gcd(expand(p1 * q), expand(p2 * q)) == q
471
+
472
+
473
+ def test_H7():
474
+ p1 = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5
475
+ p2 = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z
476
+ assert gcd(p1, p2, x, y, z) == 1
477
+
478
+
479
+ def test_H8():
480
+ p1 = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5
481
+ p2 = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z
482
+ q = 11*x**12*y**7*z**13 - 23*x**2*y**8*z**10 + 47*x**17*y**5*z**8
483
+ assert gcd(p1 * q, p2 * q, x, y, z) == q
484
+
485
+
486
+ def test_H9():
487
+ x = Symbol('x', zero=False)
488
+ p1 = 2*x**(n + 4) - x**(n + 2)
489
+ p2 = 4*x**(n + 1) + 3*x**n
490
+ assert gcd(p1, p2) == x**n
491
+
492
+
493
+ def test_H10():
494
+ p1 = 3*x**4 + 3*x**3 + x**2 - x - 2
495
+ p2 = x**3 - 3*x**2 + x + 5
496
+ assert resultant(p1, p2, x) == 0
497
+
498
+
499
+ def test_H11():
500
+ assert resultant(p1 * q, p2 * q, x) == 0
501
+
502
+
503
+ def test_H12():
504
+ num = x**2 - 4
505
+ den = x**2 + 4*x + 4
506
+ assert simplify(num/den) == (x - 2)/(x + 2)
507
+
508
+
509
+ @XFAIL
510
+ def test_H13():
511
+ assert simplify((exp(x) - 1) / (exp(x/2) + 1)) == exp(x/2) - 1
512
+
513
+
514
+ def test_H14():
515
+ p = (x + 1) ** 20
516
+ ep = expand(p)
517
+ assert ep == (1 + 20*x + 190*x**2 + 1140*x**3 + 4845*x**4 + 15504*x**5
518
+ + 38760*x**6 + 77520*x**7 + 125970*x**8 + 167960*x**9 + 184756*x**10
519
+ + 167960*x**11 + 125970*x**12 + 77520*x**13 + 38760*x**14 + 15504*x**15
520
+ + 4845*x**16 + 1140*x**17 + 190*x**18 + 20*x**19 + x**20)
521
+ dep = diff(ep, x)
522
+ assert dep == (20 + 380*x + 3420*x**2 + 19380*x**3 + 77520*x**4
523
+ + 232560*x**5 + 542640*x**6 + 1007760*x**7 + 1511640*x**8 + 1847560*x**9
524
+ + 1847560*x**10 + 1511640*x**11 + 1007760*x**12 + 542640*x**13
525
+ + 232560*x**14 + 77520*x**15 + 19380*x**16 + 3420*x**17 + 380*x**18
526
+ + 20*x**19)
527
+ assert factor(dep) == 20*(1 + x)**19
528
+
529
+
530
+ def test_H15():
531
+ assert simplify(Mul(*[x - r for r in solveset(x**3 + x**2 - 7)])) == x**3 + x**2 - 7
532
+
533
+
534
+ def test_H16():
535
+ assert factor(x**100 - 1) == ((x - 1)*(x + 1)*(x**2 + 1)*(x**4 - x**3
536
+ + x**2 - x + 1)*(x**4 + x**3 + x**2 + x + 1)*(x**8 - x**6 + x**4
537
+ - x**2 + 1)*(x**20 - x**15 + x**10 - x**5 + 1)*(x**20 + x**15 + x**10
538
+ + x**5 + 1)*(x**40 - x**30 + x**20 - x**10 + 1))
539
+
540
+
541
+ def test_H17():
542
+ assert simplify(factor(expand(p1 * p2)) - p1*p2) == 0
543
+
544
+
545
+ @XFAIL
546
+ def test_H18():
547
+ # Factor over complex rationals.
548
+ test = factor(4*x**4 + 8*x**3 + 77*x**2 + 18*x + 153)
549
+ good = (2*x + 3*I)*(2*x - 3*I)*(x + 1 - 4*I)*(x + 1 + 4*I)
550
+ assert test == good
551
+
552
+
553
+ def test_H19():
554
+ a = symbols('a')
555
+ # The idea is to let a**2 == 2, then solve 1/(a-1). Answer is a+1")
556
+ assert Poly(a - 1).invert(Poly(a**2 - 2)) == a + 1
557
+
558
+
559
+ @XFAIL
560
+ def test_H20():
561
+ raise NotImplementedError("let a**2==2; (x**3 + (a-2)*x**2 - "
562
+ + "(2*a+3)*x - 3*a) / (x**2-2) = (x**2 - 2*x - 3) / (x-a)")
563
+
564
+
565
+ @XFAIL
566
+ def test_H21():
567
+ raise NotImplementedError("evaluate (b+c)**4 assuming b**3==2, c**2==3. \
568
+ Answer is 2*b + 8*c + 18*b**2 + 12*b*c + 9")
569
+
570
+
571
+ def test_H22():
572
+ assert factor(x**4 - 3*x**2 + 1, modulus=5) == (x - 2)**2 * (x + 2)**2
573
+
574
+
575
+ def test_H23():
576
+ f = x**11 + x + 1
577
+ g = (x**2 + x + 1) * (x**9 - x**8 + x**6 - x**5 + x**3 - x**2 + 1)
578
+ assert factor(f, modulus=65537) == g
579
+
580
+
581
+ def test_H24():
582
+ phi = AlgebraicNumber(S.GoldenRatio.expand(func=True), alias='phi')
583
+ assert factor(x**4 - 3*x**2 + 1, extension=phi) == \
584
+ (x - phi)*(x + 1 - phi)*(x - 1 + phi)*(x + phi)
585
+
586
+
587
+ def test_H25():
588
+ e = (x - 2*y**2 + 3*z**3) ** 20
589
+ assert factor(expand(e)) == e
590
+
591
+
592
+ def test_H26():
593
+ g = expand((sin(x) - 2*cos(y)**2 + 3*tan(z)**3)**20)
594
+ assert factor(g, expand=False) == (-sin(x) + 2*cos(y)**2 - 3*tan(z)**3)**20
595
+
596
+
597
+ def test_H27():
598
+ f = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5
599
+ g = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z
600
+ h = -2*z*y**7 \
601
+ *(6*x**9*y**9*z**3 + 10*x**7*z**6 + 17*y*x**5*z**12 + 40*y**7) \
602
+ *(3*x**22 + 47*x**17*y**5*z**8 - 6*x**15*y**9*z**2 - 24*x*y**19*z**8 - 5)
603
+ assert factor(expand(f*g)) == h
604
+
605
+
606
+ @XFAIL
607
+ def test_H28():
608
+ raise NotImplementedError("expand ((1 - c**2)**5 * (1 - s**2)**5 * "
609
+ + "(c**2 + s**2)**10) with c**2 + s**2 = 1. Answer is c**10*s**10.")
610
+
611
+
612
+ @XFAIL
613
+ def test_H29():
614
+ assert factor(4*x**2 - 21*x*y + 20*y**2, modulus=3) == (x + y)*(x - y)
615
+
616
+
617
+ def test_H30():
618
+ test = factor(x**3 + y**3, extension=sqrt(-3))
619
+ answer = (x + y)*(x + y*(-R(1, 2) - sqrt(3)/2*I))*(x + y*(-R(1, 2) + sqrt(3)/2*I))
620
+ assert answer == test
621
+
622
+
623
+ def test_H31():
624
+ f = (x**2 + 2*x + 3)/(x**3 + 4*x**2 + 5*x + 2)
625
+ g = 2 / (x + 1)**2 - 2 / (x + 1) + 3 / (x + 2)
626
+ assert apart(f) == g
627
+
628
+
629
+ @XFAIL
630
+ def test_H32(): # issue 6558
631
+ raise NotImplementedError("[A*B*C - (A*B*C)**(-1)]*A*C*B (product \
632
+ of a non-commuting product and its inverse)")
633
+
634
+
635
+ def test_H33():
636
+ A, B, C = symbols('A, B, C', commutative=False)
637
+ assert (Commutator(A, Commutator(B, C))
638
+ + Commutator(B, Commutator(C, A))
639
+ + Commutator(C, Commutator(A, B))).doit().expand() == 0
640
+
641
+
642
+ # I. Trigonometry
643
+
644
+ def test_I1():
645
+ assert tan(pi*R(7, 10)) == -sqrt(1 + 2/sqrt(5))
646
+
647
+
648
+ @XFAIL
649
+ def test_I2():
650
+ assert sqrt((1 + cos(6))/2) == -cos(3)
651
+
652
+
653
+ def test_I3():
654
+ assert cos(n*pi) + sin((4*n - 1)*pi/2) == (-1)**n - 1
655
+
656
+
657
+ def test_I4():
658
+ assert refine(cos(pi*cos(n*pi)) + sin(pi/2*cos(n*pi)), Q.integer(n)) == (-1)**n - 1
659
+
660
+
661
+ @XFAIL
662
+ def test_I5():
663
+ assert sin((n**5/5 + n**4/2 + n**3/3 - n/30) * pi) == 0
664
+
665
+
666
+ @XFAIL
667
+ def test_I6():
668
+ raise NotImplementedError("assuming -3*pi<x<-5*pi/2, abs(cos(x)) == -cos(x), abs(sin(x)) == -sin(x)")
669
+
670
+
671
+ @XFAIL
672
+ def test_I7():
673
+ assert cos(3*x)/cos(x) == cos(x)**2 - 3*sin(x)**2
674
+
675
+
676
+ @XFAIL
677
+ def test_I8():
678
+ assert cos(3*x)/cos(x) == 2*cos(2*x) - 1
679
+
680
+
681
+ @XFAIL
682
+ def test_I9():
683
+ # Supposed to do this with rewrite rules.
684
+ assert cos(3*x)/cos(x) == cos(x)**2 - 3*sin(x)**2
685
+
686
+
687
+ def test_I10():
688
+ assert trigsimp((tan(x)**2 + 1 - cos(x)**-2) / (sin(x)**2 + cos(x)**2 - 1)) is nan
689
+
690
+
691
+ @SKIP("hangs")
692
+ @XFAIL
693
+ def test_I11():
694
+ assert limit((tan(x)**2 + 1 - cos(x)**-2) / (sin(x)**2 + cos(x)**2 - 1), x, 0) != 0
695
+
696
+
697
+ @XFAIL
698
+ def test_I12():
699
+ # This should fail or return nan or something.
700
+ res = diff((tan(x)**2 + 1 - cos(x)**-2) / (sin(x)**2 + cos(x)**2 - 1), x)
701
+ assert res is nan # trigsimp(res) gives nan
702
+
703
+ # J. Special functions.
704
+
705
+
706
+ def test_J1():
707
+ assert bernoulli(16) == R(-3617, 510)
708
+
709
+
710
+ def test_J2():
711
+ assert diff(elliptic_e(x, y**2), y) == (elliptic_e(x, y**2) - elliptic_f(x, y**2))/y
712
+
713
+
714
+ @XFAIL
715
+ def test_J3():
716
+ raise NotImplementedError("Jacobi elliptic functions: diff(dn(u,k), u) == -k**2*sn(u,k)*cn(u,k)")
717
+
718
+
719
+ def test_J4():
720
+ assert gamma(R(-1, 2)) == -2*sqrt(pi)
721
+
722
+
723
+ def test_J5():
724
+ assert polygamma(0, R(1, 3)) == -log(3) - sqrt(3)*pi/6 - EulerGamma - log(sqrt(3))
725
+
726
+
727
+ def test_J6():
728
+ assert mpmath.besselj(2, 1 + 1j).ae(mpc('0.04157988694396212', '0.24739764151330632'))
729
+
730
+
731
+ def test_J7():
732
+ assert simplify(besselj(R(-5,2), pi/2)) == 12/(pi**2)
733
+
734
+
735
+ def test_J8():
736
+ p = besselj(R(3,2), z)
737
+ q = (sin(z)/z - cos(z))/sqrt(pi*z/2)
738
+ assert simplify(expand_func(p) -q) == 0
739
+
740
+
741
+ def test_J9():
742
+ assert besselj(0, z).diff(z) == - besselj(1, z)
743
+
744
+
745
+ def test_J10():
746
+ mu, nu = symbols('mu, nu', integer=True)
747
+ assert assoc_legendre(nu, mu, 0) == 2**mu*sqrt(pi)/gamma((nu - mu)/2 + 1)/gamma((-nu - mu + 1)/2)
748
+
749
+
750
+ def test_J11():
751
+ assert simplify(assoc_legendre(3, 1, x)) == simplify(-R(3, 2)*sqrt(1 - x**2)*(5*x**2 - 1))
752
+
753
+
754
+ @slow
755
+ def test_J12():
756
+ assert simplify(chebyshevt(1008, x) - 2*x*chebyshevt(1007, x) + chebyshevt(1006, x)) == 0
757
+
758
+
759
+ def test_J13():
760
+ a = symbols('a', integer=True, negative=False)
761
+ assert chebyshevt(a, -1) == (-1)**a
762
+
763
+
764
+ def test_J14():
765
+ p = hyper([S.Half, S.Half], [R(3, 2)], z**2)
766
+ assert hyperexpand(p) == asin(z)/z
767
+
768
+
769
+ @XFAIL
770
+ def test_J15():
771
+ raise NotImplementedError("F((n+2)/2,-(n-2)/2,R(3,2),sin(z)**2) == sin(n*z)/(n*sin(z)*cos(z)); F(.) is hypergeometric function")
772
+
773
+
774
+ @XFAIL
775
+ def test_J16():
776
+ raise NotImplementedError("diff(zeta(x), x) @ x=0 == -log(2*pi)/2")
777
+
778
+
779
+ def test_J17():
780
+ assert integrate(f((x + 2)/5)*DiracDelta((x - 2)/3) - g(x)*diff(DiracDelta(x - 1), x), (x, 0, 3)) == 3*f(R(4, 5)) + Subs(Derivative(g(x), x), x, 1)
781
+
782
+
783
+ @XFAIL
784
+ def test_J18():
785
+ raise NotImplementedError("define an antisymmetric function")
786
+
787
+
788
+ # K. The Complex Domain
789
+
790
+ def test_K1():
791
+ z1, z2 = symbols('z1, z2', complex=True)
792
+ assert re(z1 + I*z2) == -im(z2) + re(z1)
793
+ assert im(z1 + I*z2) == im(z1) + re(z2)
794
+
795
+
796
+ def test_K2():
797
+ assert abs(3 - sqrt(7) + I*sqrt(6*sqrt(7) - 15)) == 1
798
+
799
+
800
+ @XFAIL
801
+ def test_K3():
802
+ a, b = symbols('a, b', real=True)
803
+ assert simplify(abs(1/(a + I/a + I*b))) == 1/sqrt(a**2 + (I/a + b)**2)
804
+
805
+
806
+ def test_K4():
807
+ assert log(3 + 4*I).expand(complex=True) == log(5) + I*atan(R(4, 3))
808
+
809
+
810
+ def test_K5():
811
+ x, y = symbols('x, y', real=True)
812
+ assert tan(x + I*y).expand(complex=True) == (sin(2*x)/(cos(2*x) +
813
+ cosh(2*y)) + I*sinh(2*y)/(cos(2*x) + cosh(2*y)))
814
+
815
+
816
+ def test_K6():
817
+ assert sqrt(x*y*abs(z)**2)/(sqrt(x)*abs(z)) == sqrt(x*y)/sqrt(x)
818
+ assert sqrt(x*y*abs(z)**2)/(sqrt(x)*abs(z)) != sqrt(y)
819
+
820
+
821
+ def test_K7():
822
+ y = symbols('y', real=True, negative=False)
823
+ expr = sqrt(x*y*abs(z)**2)/(sqrt(x)*abs(z))
824
+ sexpr = simplify(expr)
825
+ assert sexpr == sqrt(y)
826
+
827
+
828
+ def test_K8():
829
+ z = symbols('z', complex=True)
830
+ assert simplify(sqrt(1/z) - 1/sqrt(z)) != 0 # Passes
831
+ z = symbols('z', complex=True, negative=False)
832
+ assert simplify(sqrt(1/z) - 1/sqrt(z)) == 0 # Fails
833
+
834
+
835
+ def test_K9():
836
+ z = symbols('z', positive=True)
837
+ assert simplify(sqrt(1/z) - 1/sqrt(z)) == 0
838
+
839
+
840
+ def test_K10():
841
+ z = symbols('z', negative=True)
842
+ assert simplify(sqrt(1/z) + 1/sqrt(z)) == 0
843
+
844
+ # This goes up to K25
845
+
846
+ # L. Determining Zero Equivalence
847
+
848
+
849
+ def test_L1():
850
+ assert sqrt(997) - (997**3)**R(1, 6) == 0
851
+
852
+
853
+ def test_L2():
854
+ assert sqrt(999983) - (999983**3)**R(1, 6) == 0
855
+
856
+
857
+ def test_L3():
858
+ assert simplify((2**R(1, 3) + 4**R(1, 3))**3 - 6*(2**R(1, 3) + 4**R(1, 3)) - 6) == 0
859
+
860
+
861
+ def test_L4():
862
+ assert trigsimp(cos(x)**3 + cos(x)*sin(x)**2 - cos(x)) == 0
863
+
864
+
865
+ @XFAIL
866
+ def test_L5():
867
+ assert log(tan(R(1, 2)*x + pi/4)) - asinh(tan(x)) == 0
868
+
869
+
870
+ def test_L6():
871
+ assert (log(tan(x/2 + pi/4)) - asinh(tan(x))).diff(x).subs({x: 0}) == 0
872
+
873
+
874
+ @XFAIL
875
+ def test_L7():
876
+ assert simplify(log((2*sqrt(x) + 1)/(sqrt(4*x + 4*sqrt(x) + 1)))) == 0
877
+
878
+
879
+ @XFAIL
880
+ def test_L8():
881
+ assert simplify((4*x + 4*sqrt(x) + 1)**(sqrt(x)/(2*sqrt(x) + 1)) \
882
+ *(2*sqrt(x) + 1)**(1/(2*sqrt(x) + 1)) - 2*sqrt(x) - 1) == 0
883
+
884
+
885
+ @XFAIL
886
+ def test_L9():
887
+ z = symbols('z', complex=True)
888
+ assert simplify(2**(1 - z)*gamma(z)*zeta(z)*cos(z*pi/2) - pi**2*zeta(1 - z)) == 0
889
+
890
+ # M. Equations
891
+
892
+
893
+ @XFAIL
894
+ def test_M1():
895
+ assert Equality(x, 2)/2 + Equality(1, 1) == Equality(x/2 + 1, 2)
896
+
897
+
898
+ def test_M2():
899
+ # The roots of this equation should all be real. Note that this
900
+ # doesn't test that they are correct.
901
+ sol = solveset(3*x**3 - 18*x**2 + 33*x - 19, x)
902
+ assert all(s.expand(complex=True).is_real for s in sol)
903
+
904
+
905
+ @XFAIL
906
+ def test_M5():
907
+ assert solveset(x**6 - 9*x**4 - 4*x**3 + 27*x**2 - 36*x - 23, x) == FiniteSet(2**(1/3) + sqrt(3), 2**(1/3) - sqrt(3), +sqrt(3) - 1/2**(2/3) + I*sqrt(3)/2**(2/3), +sqrt(3) - 1/2**(2/3) - I*sqrt(3)/2**(2/3), -sqrt(3) - 1/2**(2/3) + I*sqrt(3)/2**(2/3), -sqrt(3) - 1/2**(2/3) - I*sqrt(3)/2**(2/3))
908
+
909
+
910
+ def test_M6():
911
+ assert set(solveset(x**7 - 1, x)) == \
912
+ {cos(n*pi*R(2, 7)) + I*sin(n*pi*R(2, 7)) for n in range(0, 7)}
913
+ # The paper asks for exp terms, but sin's and cos's may be acceptable;
914
+ # if the results are simplified, exp terms appear for all but
915
+ # -sin(pi/14) - I*cos(pi/14) and -sin(pi/14) + I*cos(pi/14) which
916
+ # will simplify if you apply the transformation foo.rewrite(exp).expand()
917
+
918
+
919
+ def test_M7():
920
+ # TODO: Replace solve with solveset, as of now test fails for solveset
921
+ assert set(solve(x**8 - 8*x**7 + 34*x**6 - 92*x**5 + 175*x**4 - 236*x**3 +
922
+ 226*x**2 - 140*x + 46, x)) == {
923
+ 1 - sqrt(2)*I*sqrt(-sqrt(-3 + 4*sqrt(3)) + 3)/2,
924
+ 1 - sqrt(2)*sqrt(-3 + I*sqrt(3 + 4*sqrt(3)))/2,
925
+ 1 - sqrt(2)*I*sqrt(sqrt(-3 + 4*sqrt(3)) + 3)/2,
926
+ 1 - sqrt(2)*sqrt(-3 - I*sqrt(3 + 4*sqrt(3)))/2,
927
+ 1 + sqrt(2)*I*sqrt(sqrt(-3 + 4*sqrt(3)) + 3)/2,
928
+ 1 + sqrt(2)*sqrt(-3 - I*sqrt(3 + 4*sqrt(3)))/2,
929
+ 1 + sqrt(2)*sqrt(-3 + I*sqrt(3 + 4*sqrt(3)))/2,
930
+ 1 + sqrt(2)*I*sqrt(-sqrt(-3 + 4*sqrt(3)) + 3)/2,
931
+ }
932
+
933
+
934
+ @XFAIL # There are an infinite number of solutions.
935
+ def test_M8():
936
+ x = Symbol('x')
937
+ z = symbols('z', complex=True)
938
+ assert solveset(exp(2*x) + 2*exp(x) + 1 - z, x, S.Reals) == \
939
+ FiniteSet(log(1 + z - 2*sqrt(z))/2, log(1 + z + 2*sqrt(z))/2)
940
+ # This one could be simplified better (the 1/2 could be pulled into the log
941
+ # as a sqrt, and the function inside the log can be factored as a square,
942
+ # giving [log(sqrt(z) - 1), log(sqrt(z) + 1)]). Also, there should be an
943
+ # infinite number of solutions.
944
+ # x = {log(sqrt(z) - 1), log(sqrt(z) + 1) + i pi} [+ n 2 pi i, + n 2 pi i]
945
+ # where n is an arbitrary integer. See url of detailed output above.
946
+
947
+
948
+ @XFAIL
949
+ def test_M9():
950
+ # x = symbols('x')
951
+ # solutions are 1/2*(1 +/- sqrt(9 + 8*I*pi*n)) for integer n
952
+ raise NotImplementedError("solveset(exp(2-x**2)-exp(-x),x) has complex solutions.")
953
+
954
+
955
+ def test_M10():
956
+ # TODO: Replace solve with solveset when it gives Lambert solution
957
+ assert solve(exp(x) - x, x) == [-LambertW(-1)]
958
+
959
+
960
+ @XFAIL
961
+ def test_M11():
962
+ assert solveset(x**x - x, x) == FiniteSet(-1, 1)
963
+
964
+
965
+ def test_M12():
966
+ # TODO: x = [-1, 2*(+/-asinh(1)*I + n*pi}, 3*(pi/6 + n*pi/3)]
967
+ # TODO: Replace solve with solveset, as of now test fails for solveset
968
+ assert solve((x + 1)*(sin(x)**2 + 1)**2*cos(3*x)**3, x) == [
969
+ -1, pi/6, pi/2,
970
+ - I*log(1 + sqrt(2)), I*log(1 + sqrt(2)),
971
+ pi - I*log(1 + sqrt(2)), pi + I*log(1 + sqrt(2)),
972
+ ]
973
+
974
+
975
+ @XFAIL
976
+ def test_M13():
977
+ n = Dummy('n')
978
+ assert solveset_real(sin(x) - cos(x), x) == ImageSet(Lambda(n, n*pi - pi*R(7, 4)), S.Integers)
979
+
980
+
981
+ @XFAIL
982
+ def test_M14():
983
+ n = Dummy('n')
984
+ assert solveset_real(tan(x) - 1, x) == ImageSet(Lambda(n, n*pi + pi/4), S.Integers)
985
+
986
+
987
+ def test_M15():
988
+ n = Dummy('n')
989
+ got = solveset(sin(x) - S.Half)
990
+ assert any(got.dummy_eq(i) for i in (
991
+ Union(ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers),
992
+ ImageSet(Lambda(n, 2*n*pi + pi*R(5, 6)), S.Integers)),
993
+ Union(ImageSet(Lambda(n, 2*n*pi + pi*R(5, 6)), S.Integers),
994
+ ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers))))
995
+
996
+
997
+ @XFAIL
998
+ def test_M16():
999
+ n = Dummy('n')
1000
+ assert solveset(sin(x) - tan(x), x) == ImageSet(Lambda(n, n*pi), S.Integers)
1001
+
1002
+
1003
+ @XFAIL
1004
+ def test_M17():
1005
+ assert solveset_real(asin(x) - atan(x), x) == FiniteSet(0)
1006
+
1007
+
1008
+ @XFAIL
1009
+ def test_M18():
1010
+ assert solveset_real(acos(x) - atan(x), x) == FiniteSet(sqrt((sqrt(5) - 1)/2))
1011
+
1012
+
1013
+ def test_M19():
1014
+ # TODO: Replace solve with solveset, as of now test fails for solveset
1015
+ assert solve((x - 2)/x**R(1, 3), x) == [2]
1016
+
1017
+
1018
+ def test_M20():
1019
+ assert solveset(sqrt(x**2 + 1) - x + 2, x) == EmptySet
1020
+
1021
+
1022
+ def test_M21():
1023
+ assert solveset(x + sqrt(x) - 2) == FiniteSet(1)
1024
+
1025
+
1026
+ def test_M22():
1027
+ assert solveset(2*sqrt(x) + 3*x**R(1, 4) - 2) == FiniteSet(R(1, 16))
1028
+
1029
+
1030
+ def test_M23():
1031
+ x = symbols('x', complex=True)
1032
+ # TODO: Replace solve with solveset, as of now test fails for solveset
1033
+ assert solve(x - 1/sqrt(1 + x**2)) == [
1034
+ -I*sqrt(S.Half + sqrt(5)/2), sqrt(Rational(-1, 2) + sqrt(5)/2)]
1035
+
1036
+
1037
+ def test_M24():
1038
+ # TODO: Replace solve with solveset, as of now test fails for solveset
1039
+ solution = solve(1 - binomial(m, 2)*2**k, k)
1040
+ answer = log(2/(m*(m - 1)), 2)
1041
+ assert solution[0].expand() == answer.expand()
1042
+
1043
+
1044
+ def test_M25():
1045
+ a, b, c, d = symbols(':d', positive=True)
1046
+ x = symbols('x')
1047
+ # TODO: Replace solve with solveset, as of now test fails for solveset
1048
+ assert solve(a*b**x - c*d**x, x)[0].expand() == (log(c/a)/log(b/d)).expand()
1049
+
1050
+
1051
+ def test_M26():
1052
+ # TODO: Replace solve with solveset, as of now test fails for solveset
1053
+ assert solve(sqrt(log(x)) - log(sqrt(x))) == [1, exp(4)]
1054
+
1055
+
1056
+ def test_M27():
1057
+ x = symbols('x', real=True)
1058
+ b = symbols('b', real=True)
1059
+ # TODO: Replace solve with solveset which gives both [+/- current answer]
1060
+ # note that there is a typo in this test in the wester.pdf; there is no
1061
+ # real solution for the equation as it appears in wester.pdf
1062
+ assert solve(log(acos(asin(x**R(2, 3) - b)) - 1) + 2, x
1063
+ ) == [(b + sin(cos(exp(-2) + 1)))**R(3, 2)]
1064
+
1065
+
1066
+ @XFAIL
1067
+ def test_M28():
1068
+ assert solveset_real(5*x + exp((x - 5)/2) - 8*x**3, x, assume=Q.real(x)) == [-0.784966, -0.016291, 0.802557]
1069
+
1070
+
1071
+ def test_M29():
1072
+ x = symbols('x')
1073
+ assert solveset(abs(x - 1) - 2, domain=S.Reals) == FiniteSet(-1, 3)
1074
+
1075
+
1076
+ def test_M30():
1077
+ # TODO: Replace solve with solveset, as of now
1078
+ # solveset doesn't supports assumptions
1079
+ # assert solve(abs(2*x + 5) - abs(x - 2),x, assume=Q.real(x)) == [-1, -7]
1080
+ assert solveset_real(abs(2*x + 5) - abs(x - 2), x) == FiniteSet(-1, -7)
1081
+
1082
+
1083
+ def test_M31():
1084
+ # TODO: Replace solve with solveset, as of now
1085
+ # solveset doesn't supports assumptions
1086
+ # assert solve(1 - abs(x) - max(-x - 2, x - 2),x, assume=Q.real(x)) == [-3/2, 3/2]
1087
+ assert solveset_real(1 - abs(x) - Max(-x - 2, x - 2), x) == FiniteSet(R(-3, 2), R(3, 2))
1088
+
1089
+
1090
+ @XFAIL
1091
+ def test_M32():
1092
+ # TODO: Replace solve with solveset, as of now
1093
+ # solveset doesn't supports assumptions
1094
+ assert solveset_real(Max(2 - x**2, x)- Max(-x, (x**3)/9), x) == FiniteSet(-1, 3)
1095
+
1096
+
1097
+ @XFAIL
1098
+ def test_M33():
1099
+ # TODO: Replace solve with solveset, as of now
1100
+ # solveset doesn't supports assumptions
1101
+
1102
+ # Second answer can be written in another form. The second answer is the root of x**3 + 9*x**2 - 18 = 0 in the interval (-2, -1).
1103
+ assert solveset_real(Max(2 - x**2, x) - x**3/9, x) == FiniteSet(-3, -1.554894, 3)
1104
+
1105
+
1106
+ @XFAIL
1107
+ def test_M34():
1108
+ z = symbols('z', complex=True)
1109
+ assert solveset((1 + I) * z + (2 - I) * conjugate(z) + 3*I, z) == FiniteSet(2 + 3*I)
1110
+
1111
+
1112
+ def test_M35():
1113
+ x, y = symbols('x y', real=True)
1114
+ assert linsolve((3*x - 2*y - I*y + 3*I).as_real_imag(), y, x) == FiniteSet((3, 2))
1115
+
1116
+
1117
+ def test_M36():
1118
+ # TODO: Replace solve with solveset, as of now
1119
+ # solveset doesn't supports solving for function
1120
+ # assert solve(f**2 + f - 2, x) == [Eq(f(x), 1), Eq(f(x), -2)]
1121
+ assert solveset(f(x)**2 + f(x) - 2, f(x)) == FiniteSet(-2, 1)
1122
+
1123
+
1124
+ def test_M37():
1125
+ assert linsolve([x + y + z - 6, 2*x + y + 2*z - 10, x + 3*y + z - 10 ], x, y, z) == \
1126
+ FiniteSet((-z + 4, 2, z))
1127
+
1128
+
1129
+ def test_M38():
1130
+ a, b, c = symbols('a, b, c')
1131
+ domain = FracField([a, b, c], ZZ).to_domain()
1132
+ ring = PolyRing('k1:50', domain)
1133
+ (k1, k2, k3, k4, k5, k6, k7, k8, k9, k10,
1134
+ k11, k12, k13, k14, k15, k16, k17, k18, k19, k20,
1135
+ k21, k22, k23, k24, k25, k26, k27, k28, k29, k30,
1136
+ k31, k32, k33, k34, k35, k36, k37, k38, k39, k40,
1137
+ k41, k42, k43, k44, k45, k46, k47, k48, k49) = ring.gens
1138
+
1139
+ system = [
1140
+ -b*k8/a + c*k8/a, -b*k11/a + c*k11/a, -b*k10/a + c*k10/a + k2, -k3 - b*k9/a + c*k9/a,
1141
+ -b*k14/a + c*k14/a, -b*k15/a + c*k15/a, -b*k18/a + c*k18/a - k2, -b*k17/a + c*k17/a,
1142
+ -b*k16/a + c*k16/a + k4, -b*k13/a + c*k13/a - b*k21/a + c*k21/a + b*k5/a - c*k5/a,
1143
+ b*k44/a - c*k44/a, -b*k45/a + c*k45/a, -b*k20/a + c*k20/a, -b*k44/a + c*k44/a,
1144
+ b*k46/a - c*k46/a, b**2*k47/a**2 - 2*b*c*k47/a**2 + c**2*k47/a**2, k3, -k4,
1145
+ -b*k12/a + c*k12/a - a*k6/b + c*k6/b, -b*k19/a + c*k19/a + a*k7/c - b*k7/c,
1146
+ b*k45/a - c*k45/a, -b*k46/a + c*k46/a, -k48 + c*k48/a + c*k48/b - c**2*k48/(a*b),
1147
+ -k49 + b*k49/a + b*k49/c - b**2*k49/(a*c), a*k1/b - c*k1/b, a*k4/b - c*k4/b,
1148
+ a*k3/b - c*k3/b + k9, -k10 + a*k2/b - c*k2/b, a*k7/b - c*k7/b, -k9, k11,
1149
+ b*k12/a - c*k12/a + a*k6/b - c*k6/b, a*k15/b - c*k15/b, k10 + a*k18/b - c*k18/b,
1150
+ -k11 + a*k17/b - c*k17/b, a*k16/b - c*k16/b, -a*k13/b + c*k13/b + a*k21/b - c*k21/b + a*k5/b - c*k5/b,
1151
+ -a*k44/b + c*k44/b, a*k45/b - c*k45/b, a*k14/c - b*k14/c + a*k20/b - c*k20/b,
1152
+ a*k44/b - c*k44/b, -a*k46/b + c*k46/b, -k47 + c*k47/a + c*k47/b - c**2*k47/(a*b),
1153
+ a*k19/b - c*k19/b, -a*k45/b + c*k45/b, a*k46/b - c*k46/b, a**2*k48/b**2 - 2*a*c*k48/b**2 + c**2*k48/b**2,
1154
+ -k49 + a*k49/b + a*k49/c - a**2*k49/(b*c), k16, -k17, -a*k1/c + b*k1/c,
1155
+ -k16 - a*k4/c + b*k4/c, -a*k3/c + b*k3/c, k18 - a*k2/c + b*k2/c, b*k19/a - c*k19/a - a*k7/c + b*k7/c,
1156
+ -a*k6/c + b*k6/c, -a*k8/c + b*k8/c, -a*k11/c + b*k11/c + k17, -a*k10/c + b*k10/c - k18,
1157
+ -a*k9/c + b*k9/c, -a*k14/c + b*k14/c - a*k20/b + c*k20/b, -a*k13/c + b*k13/c + a*k21/c - b*k21/c - a*k5/c + b*k5/c,
1158
+ a*k44/c - b*k44/c, -a*k45/c + b*k45/c, -a*k44/c + b*k44/c, a*k46/c - b*k46/c,
1159
+ -k47 + b*k47/a + b*k47/c - b**2*k47/(a*c), -a*k12/c + b*k12/c, a*k45/c - b*k45/c,
1160
+ -a*k46/c + b*k46/c, -k48 + a*k48/b + a*k48/c - a**2*k48/(b*c),
1161
+ a**2*k49/c**2 - 2*a*b*k49/c**2 + b**2*k49/c**2, k8, k11, -k15, k10 - k18,
1162
+ -k17, k9, -k16, -k29, k14 - k32, -k21 + k23 - k31, -k24 - k30, -k35, k44,
1163
+ -k45, k36, k13 - k23 + k39, -k20 + k38, k25 + k37, b*k26/a - c*k26/a - k34 + k42,
1164
+ -2*k44, k45, k46, b*k47/a - c*k47/a, k41, k44, -k46, -b*k47/a + c*k47/a,
1165
+ k12 + k24, -k19 - k25, -a*k27/b + c*k27/b - k33, k45, -k46, -a*k48/b + c*k48/b,
1166
+ a*k28/c - b*k28/c + k40, -k45, k46, a*k48/b - c*k48/b, a*k49/c - b*k49/c,
1167
+ -a*k49/c + b*k49/c, -k1, -k4, -k3, k15, k18 - k2, k17, k16, k22, k25 - k7,
1168
+ k24 + k30, k21 + k23 - k31, k28, -k44, k45, -k30 - k6, k20 + k32, k27 + b*k33/a - c*k33/a,
1169
+ k44, -k46, -b*k47/a + c*k47/a, -k36, k31 - k39 - k5, -k32 - k38, k19 - k37,
1170
+ k26 - a*k34/b + c*k34/b - k42, k44, -2*k45, k46, a*k48/b - c*k48/b,
1171
+ a*k35/c - b*k35/c - k41, -k44, k46, b*k47/a - c*k47/a, -a*k49/c + b*k49/c,
1172
+ -k40, k45, -k46, -a*k48/b + c*k48/b, a*k49/c - b*k49/c, k1, k4, k3, -k8,
1173
+ -k11, -k10 + k2, -k9, k37 + k7, -k14 - k38, -k22, -k25 - k37, -k24 + k6,
1174
+ -k13 - k23 + k39, -k28 + b*k40/a - c*k40/a, k44, -k45, -k27, -k44, k46,
1175
+ b*k47/a - c*k47/a, k29, k32 + k38, k31 - k39 + k5, -k12 + k30, k35 - a*k41/b + c*k41/b,
1176
+ -k44, k45, -k26 + k34 + a*k42/c - b*k42/c, k44, k45, -2*k46, -b*k47/a + c*k47/a,
1177
+ -a*k48/b + c*k48/b, a*k49/c - b*k49/c, k33, -k45, k46, a*k48/b - c*k48/b,
1178
+ -a*k49/c + b*k49/c
1179
+ ]
1180
+ solution = {
1181
+ k49: 0, k48: 0, k47: 0, k46: 0, k45: 0, k44: 0, k41: 0, k40: 0,
1182
+ k38: 0, k37: 0, k36: 0, k35: 0, k33: 0, k32: 0, k30: 0, k29: 0,
1183
+ k28: 0, k27: 0, k25: 0, k24: 0, k22: 0, k21: 0, k20: 0, k19: 0,
1184
+ k18: 0, k17: 0, k16: 0, k15: 0, k14: 0, k13: 0, k12: 0, k11: 0,
1185
+ k10: 0, k9: 0, k8: 0, k7: 0, k6: 0, k5: 0, k4: 0, k3: 0,
1186
+ k2: 0, k1: 0,
1187
+ k34: b/c*k42, k31: k39, k26: a/c*k42, k23: k39
1188
+ }
1189
+ assert solve_lin_sys(system, ring) == solution
1190
+
1191
+
1192
+ def test_M39():
1193
+ x, y, z = symbols('x y z', complex=True)
1194
+ # TODO: Replace solve with solveset, as of now
1195
+ # solveset doesn't supports non-linear multivariate
1196
+ assert solve([x**2*y + 3*y*z - 4, -3*x**2*z + 2*y**2 + 1, 2*y*z**2 - z**2 - 1 ]) ==\
1197
+ [{y: 1, z: 1, x: -1}, {y: 1, z: 1, x: 1},\
1198
+ {y: sqrt(2)*I, z: R(1,3) - sqrt(2)*I/3, x: -sqrt(-1 - sqrt(2)*I)},\
1199
+ {y: sqrt(2)*I, z: R(1,3) - sqrt(2)*I/3, x: sqrt(-1 - sqrt(2)*I)},\
1200
+ {y: -sqrt(2)*I, z: R(1,3) + sqrt(2)*I/3, x: -sqrt(-1 + sqrt(2)*I)},\
1201
+ {y: -sqrt(2)*I, z: R(1,3) + sqrt(2)*I/3, x: sqrt(-1 + sqrt(2)*I)}]
1202
+
1203
+ # N. Inequalities
1204
+
1205
+
1206
+ def test_N1():
1207
+ assert ask(E**pi > pi**E)
1208
+
1209
+
1210
+ @XFAIL
1211
+ def test_N2():
1212
+ x = symbols('x', real=True)
1213
+ assert ask(x**4 - x + 1 > 0) is True
1214
+ assert ask(x**4 - x + 1 > 1) is False
1215
+
1216
+
1217
+ @XFAIL
1218
+ def test_N3():
1219
+ x = symbols('x', real=True)
1220
+ assert ask(And(Lt(-1, x), Lt(x, 1)), abs(x) < 1 )
1221
+
1222
+ @XFAIL
1223
+ def test_N4():
1224
+ x, y = symbols('x y', real=True)
1225
+ assert ask(2*x**2 > 2*y**2, (x > y) & (y > 0)) is True
1226
+
1227
+
1228
+ @XFAIL
1229
+ def test_N5():
1230
+ x, y, k = symbols('x y k', real=True)
1231
+ assert ask(k*x**2 > k*y**2, (x > y) & (y > 0) & (k > 0)) is True
1232
+
1233
+
1234
+ @slow
1235
+ @XFAIL
1236
+ def test_N6():
1237
+ x, y, k, n = symbols('x y k n', real=True)
1238
+ assert ask(k*x**n > k*y**n, (x > y) & (y > 0) & (k > 0) & (n > 0)) is True
1239
+
1240
+
1241
+ @XFAIL
1242
+ def test_N7():
1243
+ x, y = symbols('x y', real=True)
1244
+ assert ask(y > 0, (x > 1) & (y >= x - 1)) is True
1245
+
1246
+
1247
+ @XFAIL
1248
+ @slow
1249
+ def test_N8():
1250
+ x, y, z = symbols('x y z', real=True)
1251
+ assert ask(Eq(x, y) & Eq(y, z),
1252
+ (x >= y) & (y >= z) & (z >= x))
1253
+
1254
+
1255
+ def test_N9():
1256
+ x = Symbol('x')
1257
+ assert solveset(abs(x - 1) > 2, domain=S.Reals) == Union(Interval(-oo, -1, False, True),
1258
+ Interval(3, oo, True))
1259
+
1260
+
1261
+ def test_N10():
1262
+ x = Symbol('x')
1263
+ p = (x - 1)*(x - 2)*(x - 3)*(x - 4)*(x - 5)
1264
+ assert solveset(expand(p) < 0, domain=S.Reals) == Union(Interval(-oo, 1, True, True),
1265
+ Interval(2, 3, True, True),
1266
+ Interval(4, 5, True, True))
1267
+
1268
+
1269
+ def test_N11():
1270
+ x = Symbol('x')
1271
+ assert solveset(6/(x - 3) <= 3, domain=S.Reals) == Union(Interval(-oo, 3, True, True), Interval(5, oo))
1272
+
1273
+
1274
+ def test_N12():
1275
+ x = Symbol('x')
1276
+ assert solveset(sqrt(x) < 2, domain=S.Reals) == Interval(0, 4, False, True)
1277
+
1278
+
1279
+ def test_N13():
1280
+ x = Symbol('x')
1281
+ assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals
1282
+
1283
+
1284
+ @XFAIL
1285
+ def test_N14():
1286
+ x = Symbol('x')
1287
+ # Gives 'Union(Interval(Integer(0), Mul(Rational(1, 2), pi), false, true),
1288
+ # Interval(Mul(Rational(1, 2), pi), Mul(Integer(2), pi), true, false))'
1289
+ # which is not the correct answer, but the provided also seems wrong.
1290
+ assert solveset(sin(x) < 1, x, domain=S.Reals) == Union(Interval(-oo, pi/2, True, True),
1291
+ Interval(pi/2, oo, True, True))
1292
+
1293
+
1294
+ def test_N15():
1295
+ r, t = symbols('r t')
1296
+ # raises NotImplementedError: only univariate inequalities are supported
1297
+ solveset(abs(2*r*(cos(t) - 1) + 1) <= 1, r, S.Reals)
1298
+
1299
+
1300
+ def test_N16():
1301
+ r, t = symbols('r t')
1302
+ solveset((r**2)*((cos(t) - 4)**2)*sin(t)**2 < 9, r, S.Reals)
1303
+
1304
+
1305
+ @XFAIL
1306
+ def test_N17():
1307
+ # currently only univariate inequalities are supported
1308
+ assert solveset((x + y > 0, x - y < 0), (x, y)) == (abs(x) < y)
1309
+
1310
+
1311
+ def test_O1():
1312
+ M = Matrix((1 + I, -2, 3*I))
1313
+ assert sqrt(expand(M.dot(M.H))) == sqrt(15)
1314
+
1315
+
1316
+ def test_O2():
1317
+ assert Matrix((2, 2, -3)).cross(Matrix((1, 3, 1))) == Matrix([[11],
1318
+ [-5],
1319
+ [4]])
1320
+
1321
+ # The vector module has no way of representing vectors symbolically (without
1322
+ # respect to a basis)
1323
+ @XFAIL
1324
+ def test_O3():
1325
+ # assert (va ^ vb) | (vc ^ vd) == -(va | vc)*(vb | vd) + (va | vd)*(vb | vc)
1326
+ raise NotImplementedError("""The vector module has no way of representing
1327
+ vectors symbolically (without respect to a basis)""")
1328
+
1329
+ def test_O4():
1330
+ from sympy.vector import CoordSys3D, Del
1331
+ N = CoordSys3D("N")
1332
+ delop = Del()
1333
+ i, j, k = N.base_vectors()
1334
+ x, y, z = N.base_scalars()
1335
+ F = i*(x*y*z) + j*((x*y*z)**2) + k*((y**2)*(z**3))
1336
+ assert delop.cross(F).doit() == (-2*x**2*y**2*z + 2*y*z**3)*i + x*y*j + (2*x*y**2*z**2 - x*z)*k
1337
+
1338
+ @XFAIL
1339
+ def test_O5():
1340
+ #assert grad|(f^g)-g|(grad^f)+f|(grad^g) == 0
1341
+ raise NotImplementedError("""The vector module has no way of representing
1342
+ vectors symbolically (without respect to a basis)""")
1343
+
1344
+ #testO8-O9 MISSING!!
1345
+
1346
+
1347
+ def test_O10():
1348
+ L = [Matrix([2, 3, 5]), Matrix([3, 6, 2]), Matrix([8, 3, 6])]
1349
+ assert GramSchmidt(L) == [Matrix([
1350
+ [2],
1351
+ [3],
1352
+ [5]]),
1353
+ Matrix([
1354
+ [R(23, 19)],
1355
+ [R(63, 19)],
1356
+ [R(-47, 19)]]),
1357
+ Matrix([
1358
+ [R(1692, 353)],
1359
+ [R(-1551, 706)],
1360
+ [R(-423, 706)]])]
1361
+
1362
+
1363
+ def test_P1():
1364
+ assert Matrix(3, 3, lambda i, j: j - i).diagonal(-1) == Matrix(
1365
+ 1, 2, [-1, -1])
1366
+
1367
+
1368
+ def test_P2():
1369
+ M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1370
+ M.row_del(1)
1371
+ M.col_del(2)
1372
+ assert M == Matrix([[1, 2],
1373
+ [7, 8]])
1374
+
1375
+
1376
+ def test_P3():
1377
+ A = Matrix([
1378
+ [11, 12, 13, 14],
1379
+ [21, 22, 23, 24],
1380
+ [31, 32, 33, 34],
1381
+ [41, 42, 43, 44]])
1382
+
1383
+ A11 = A[0:3, 1:4]
1384
+ A12 = A[(0, 1, 3), (2, 0, 3)]
1385
+ A21 = A
1386
+ A221 = -A[0:2, 2:4]
1387
+ A222 = -A[(3, 0), (2, 1)]
1388
+ A22 = BlockMatrix([[A221, A222]]).T
1389
+ rows = [[-A11, A12], [A21, A22]]
1390
+ raises(ValueError, lambda: BlockMatrix(rows))
1391
+ B = Matrix(rows)
1392
+ assert B == Matrix([
1393
+ [-12, -13, -14, 13, 11, 14],
1394
+ [-22, -23, -24, 23, 21, 24],
1395
+ [-32, -33, -34, 43, 41, 44],
1396
+ [11, 12, 13, 14, -13, -23],
1397
+ [21, 22, 23, 24, -14, -24],
1398
+ [31, 32, 33, 34, -43, -13],
1399
+ [41, 42, 43, 44, -42, -12]])
1400
+
1401
+
1402
+ @XFAIL
1403
+ def test_P4():
1404
+ raise NotImplementedError("Block matrix diagonalization not supported")
1405
+
1406
+
1407
+ def test_P5():
1408
+ M = Matrix([[7, 11],
1409
+ [3, 8]])
1410
+ assert M % 2 == Matrix([[1, 1],
1411
+ [1, 0]])
1412
+
1413
+
1414
+ def test_P6():
1415
+ M = Matrix([[cos(x), sin(x)],
1416
+ [-sin(x), cos(x)]])
1417
+ assert M.diff(x, 2) == Matrix([[-cos(x), -sin(x)],
1418
+ [sin(x), -cos(x)]])
1419
+
1420
+
1421
+ def test_P7():
1422
+ M = Matrix([[x, y]])*(
1423
+ z*Matrix([[1, 3, 5],
1424
+ [2, 4, 6]]) + Matrix([[7, -9, 11],
1425
+ [-8, 10, -12]]))
1426
+ assert M == Matrix([[x*(z + 7) + y*(2*z - 8), x*(3*z - 9) + y*(4*z + 10),
1427
+ x*(5*z + 11) + y*(6*z - 12)]])
1428
+
1429
+
1430
+ def test_P8():
1431
+ M = Matrix([[1, -2*I],
1432
+ [-3*I, 4]])
1433
+ assert M.norm(ord=S.Infinity) == 7
1434
+
1435
+
1436
+ def test_P9():
1437
+ a, b, c = symbols('a b c', nonzero=True)
1438
+ M = Matrix([[a/(b*c), 1/c, 1/b],
1439
+ [1/c, b/(a*c), 1/a],
1440
+ [1/b, 1/a, c/(a*b)]])
1441
+ assert factor(M.norm('fro')) == (a**2 + b**2 + c**2)/(abs(a)*abs(b)*abs(c))
1442
+
1443
+
1444
+ @XFAIL
1445
+ def test_P10():
1446
+ M = Matrix([[1, 2 + 3*I],
1447
+ [f(4 - 5*I), 6]])
1448
+ # conjugate(f(4 - 5*i)) is not simplified to f(4+5*I)
1449
+ assert M.H == Matrix([[1, f(4 + 5*I)],
1450
+ [2 + 3*I, 6]])
1451
+
1452
+
1453
+ @XFAIL
1454
+ def test_P11():
1455
+ # raises NotImplementedError("Matrix([[x,y],[1,x*y]]).inv()
1456
+ # not simplifying to extract common factor")
1457
+ assert Matrix([[x, y],
1458
+ [1, x*y]]).inv() == (1/(x**2 - 1))*Matrix([[x, -1],
1459
+ [-1/y, x/y]])
1460
+
1461
+
1462
+ def test_P11_workaround():
1463
+ # This test was changed to inverse method ADJ because it depended on the
1464
+ # specific form of inverse returned from the 'GE' method which has changed.
1465
+ M = Matrix([[x, y], [1, x*y]]).inv('ADJ')
1466
+ c = gcd(tuple(M))
1467
+ assert MatMul(c, M/c, evaluate=False) == MatMul(c, Matrix([
1468
+ [x*y, -y],
1469
+ [ -1, x]]), evaluate=False)
1470
+
1471
+
1472
+ def test_P12():
1473
+ A11 = MatrixSymbol('A11', n, n)
1474
+ A12 = MatrixSymbol('A12', n, n)
1475
+ A22 = MatrixSymbol('A22', n, n)
1476
+ B = BlockMatrix([[A11, A12],
1477
+ [ZeroMatrix(n, n), A22]])
1478
+ assert block_collapse(B.I) == BlockMatrix([[A11.I, (-1)*A11.I*A12*A22.I],
1479
+ [ZeroMatrix(n, n), A22.I]])
1480
+
1481
+
1482
+ def test_P13():
1483
+ M = Matrix([[1, x - 2, x - 3],
1484
+ [x - 1, x**2 - 3*x + 6, x**2 - 3*x - 2],
1485
+ [x - 2, x**2 - 8, 2*(x**2) - 12*x + 14]])
1486
+ L, U, _ = M.LUdecomposition()
1487
+ assert simplify(L) == Matrix([[1, 0, 0],
1488
+ [x - 1, 1, 0],
1489
+ [x - 2, x - 3, 1]])
1490
+ assert simplify(U) == Matrix([[1, x - 2, x - 3],
1491
+ [0, 4, x - 5],
1492
+ [0, 0, x - 7]])
1493
+
1494
+
1495
+ def test_P14():
1496
+ M = Matrix([[1, 2, 3, 1, 3],
1497
+ [3, 2, 1, 1, 7],
1498
+ [0, 2, 4, 1, 1],
1499
+ [1, 1, 1, 1, 4]])
1500
+ R, _ = M.rref()
1501
+ assert R == Matrix([[1, 0, -1, 0, 2],
1502
+ [0, 1, 2, 0, -1],
1503
+ [0, 0, 0, 1, 3],
1504
+ [0, 0, 0, 0, 0]])
1505
+
1506
+
1507
+ def test_P15():
1508
+ M = Matrix([[-1, 3, 7, -5],
1509
+ [4, -2, 1, 3],
1510
+ [2, 4, 15, -7]])
1511
+ assert M.rank() == 2
1512
+
1513
+
1514
+ def test_P16():
1515
+ M = Matrix([[2*sqrt(2), 8],
1516
+ [6*sqrt(6), 24*sqrt(3)]])
1517
+ assert M.rank() == 1
1518
+
1519
+
1520
+ def test_P17():
1521
+ t = symbols('t', real=True)
1522
+ M=Matrix([
1523
+ [sin(2*t), cos(2*t)],
1524
+ [2*(1 - (cos(t)**2))*cos(t), (1 - 2*(sin(t)**2))*sin(t)]])
1525
+ assert M.rank() == 1
1526
+
1527
+
1528
+ def test_P18():
1529
+ M = Matrix([[1, 0, -2, 0],
1530
+ [-2, 1, 0, 3],
1531
+ [-1, 2, -6, 6]])
1532
+ assert M.nullspace() == [Matrix([[2],
1533
+ [4],
1534
+ [1],
1535
+ [0]]),
1536
+ Matrix([[0],
1537
+ [-3],
1538
+ [0],
1539
+ [1]])]
1540
+
1541
+
1542
+ def test_P19():
1543
+ w = symbols('w')
1544
+ M = Matrix([[1, 1, 1, 1],
1545
+ [w, x, y, z],
1546
+ [w**2, x**2, y**2, z**2],
1547
+ [w**3, x**3, y**3, z**3]])
1548
+ assert M.det() == (w**3*x**2*y - w**3*x**2*z - w**3*x*y**2 + w**3*x*z**2
1549
+ + w**3*y**2*z - w**3*y*z**2 - w**2*x**3*y + w**2*x**3*z
1550
+ + w**2*x*y**3 - w**2*x*z**3 - w**2*y**3*z + w**2*y*z**3
1551
+ + w*x**3*y**2 - w*x**3*z**2 - w*x**2*y**3 + w*x**2*z**3
1552
+ + w*y**3*z**2 - w*y**2*z**3 - x**3*y**2*z + x**3*y*z**2
1553
+ + x**2*y**3*z - x**2*y*z**3 - x*y**3*z**2 + x*y**2*z**3
1554
+ )
1555
+
1556
+
1557
+ @XFAIL
1558
+ def test_P20():
1559
+ raise NotImplementedError("Matrix minimal polynomial not supported")
1560
+
1561
+
1562
+ def test_P21():
1563
+ M = Matrix([[5, -3, -7],
1564
+ [-2, 1, 2],
1565
+ [2, -3, -4]])
1566
+ assert M.charpoly(x).as_expr() == x**3 - 2*x**2 - 5*x + 6
1567
+
1568
+
1569
+ def test_P22():
1570
+ d = 100
1571
+ M = (2 - x)*eye(d)
1572
+ assert M.eigenvals() == {-x + 2: d}
1573
+
1574
+
1575
+ def test_P23():
1576
+ M = Matrix([
1577
+ [2, 1, 0, 0, 0],
1578
+ [1, 2, 1, 0, 0],
1579
+ [0, 1, 2, 1, 0],
1580
+ [0, 0, 1, 2, 1],
1581
+ [0, 0, 0, 1, 2]])
1582
+ assert M.eigenvals() == {
1583
+ S('1'): 1,
1584
+ S('2'): 1,
1585
+ S('3'): 1,
1586
+ S('sqrt(3) + 2'): 1,
1587
+ S('-sqrt(3) + 2'): 1}
1588
+
1589
+
1590
+ def test_P24():
1591
+ M = Matrix([[611, 196, -192, 407, -8, -52, -49, 29],
1592
+ [196, 899, 113, -192, -71, -43, -8, -44],
1593
+ [-192, 113, 899, 196, 61, 49, 8, 52],
1594
+ [ 407, -192, 196, 611, 8, 44, 59, -23],
1595
+ [ -8, -71, 61, 8, 411, -599, 208, 208],
1596
+ [ -52, -43, 49, 44, -599, 411, 208, 208],
1597
+ [ -49, -8, 8, 59, 208, 208, 99, -911],
1598
+ [ 29, -44, 52, -23, 208, 208, -911, 99]])
1599
+ assert M.eigenvals() == {
1600
+ S('0'): 1,
1601
+ S('10*sqrt(10405)'): 1,
1602
+ S('100*sqrt(26) + 510'): 1,
1603
+ S('1000'): 2,
1604
+ S('-100*sqrt(26) + 510'): 1,
1605
+ S('-10*sqrt(10405)'): 1,
1606
+ S('1020'): 1}
1607
+
1608
+
1609
+ def test_P25():
1610
+ MF = N(Matrix([[ 611, 196, -192, 407, -8, -52, -49, 29],
1611
+ [ 196, 899, 113, -192, -71, -43, -8, -44],
1612
+ [-192, 113, 899, 196, 61, 49, 8, 52],
1613
+ [ 407, -192, 196, 611, 8, 44, 59, -23],
1614
+ [ -8, -71, 61, 8, 411, -599, 208, 208],
1615
+ [ -52, -43, 49, 44, -599, 411, 208, 208],
1616
+ [ -49, -8, 8, 59, 208, 208, 99, -911],
1617
+ [ 29, -44, 52, -23, 208, 208, -911, 99]]))
1618
+
1619
+ ev_1 = sorted(MF.eigenvals(multiple=True))
1620
+ ev_2 = sorted(
1621
+ [-1020.0490184299969, 0.0, 0.09804864072151699, 1000.0, 1000.0,
1622
+ 1019.9019513592784, 1020.0, 1020.0490184299969])
1623
+
1624
+ for x, y in zip(ev_1, ev_2):
1625
+ assert abs(x - y) < 1e-12
1626
+
1627
+
1628
+ def test_P26():
1629
+ a0, a1, a2, a3, a4 = symbols('a0 a1 a2 a3 a4')
1630
+ M = Matrix([[-a4, -a3, -a2, -a1, -a0, 0, 0, 0, 0],
1631
+ [ 1, 0, 0, 0, 0, 0, 0, 0, 0],
1632
+ [ 0, 1, 0, 0, 0, 0, 0, 0, 0],
1633
+ [ 0, 0, 1, 0, 0, 0, 0, 0, 0],
1634
+ [ 0, 0, 0, 1, 0, 0, 0, 0, 0],
1635
+ [ 0, 0, 0, 0, 0, -1, -1, 0, 0],
1636
+ [ 0, 0, 0, 0, 0, 1, 0, 0, 0],
1637
+ [ 0, 0, 0, 0, 0, 0, 1, -1, -1],
1638
+ [ 0, 0, 0, 0, 0, 0, 0, 1, 0]])
1639
+ assert M.eigenvals(error_when_incomplete=False) == {
1640
+ S('-1/2 - sqrt(3)*I/2'): 2,
1641
+ S('-1/2 + sqrt(3)*I/2'): 2}
1642
+
1643
+
1644
+ def test_P27():
1645
+ a = symbols('a')
1646
+ M = Matrix([[a, 0, 0, 0, 0],
1647
+ [0, 0, 0, 0, 1],
1648
+ [0, 0, a, 0, 0],
1649
+ [0, 0, 0, a, 0],
1650
+ [0, -2, 0, 0, 2]])
1651
+
1652
+ assert M.eigenvects() == [
1653
+ (a, 3, [
1654
+ Matrix([1, 0, 0, 0, 0]),
1655
+ Matrix([0, 0, 1, 0, 0]),
1656
+ Matrix([0, 0, 0, 1, 0])
1657
+ ]),
1658
+ (1 - I, 1, [
1659
+ Matrix([0, (1 + I)/2, 0, 0, 1])
1660
+ ]),
1661
+ (1 + I, 1, [
1662
+ Matrix([0, (1 - I)/2, 0, 0, 1])
1663
+ ]),
1664
+ ]
1665
+
1666
+
1667
+ @XFAIL
1668
+ def test_P28():
1669
+ raise NotImplementedError("Generalized eigenvectors not supported \
1670
+ https://github.com/sympy/sympy/issues/5293")
1671
+
1672
+
1673
+ @XFAIL
1674
+ def test_P29():
1675
+ raise NotImplementedError("Generalized eigenvectors not supported \
1676
+ https://github.com/sympy/sympy/issues/5293")
1677
+
1678
+
1679
+ def test_P30():
1680
+ M = Matrix([[1, 0, 0, 1, -1],
1681
+ [0, 1, -2, 3, -3],
1682
+ [0, 0, -1, 2, -2],
1683
+ [1, -1, 1, 0, 1],
1684
+ [1, -1, 1, -1, 2]])
1685
+ _, J = M.jordan_form()
1686
+ assert J == Matrix([[-1, 0, 0, 0, 0],
1687
+ [0, 1, 1, 0, 0],
1688
+ [0, 0, 1, 0, 0],
1689
+ [0, 0, 0, 1, 1],
1690
+ [0, 0, 0, 0, 1]])
1691
+
1692
+
1693
+ @XFAIL
1694
+ def test_P31():
1695
+ raise NotImplementedError("Smith normal form not implemented")
1696
+
1697
+
1698
+ def test_P32():
1699
+ M = Matrix([[1, -2],
1700
+ [2, 1]])
1701
+ assert exp(M).rewrite(cos).simplify() == Matrix([[E*cos(2), -E*sin(2)],
1702
+ [E*sin(2), E*cos(2)]])
1703
+
1704
+
1705
+ def test_P33():
1706
+ w, t = symbols('w t')
1707
+ M = Matrix([[0, 1, 0, 0],
1708
+ [0, 0, 0, 2*w],
1709
+ [0, 0, 0, 1],
1710
+ [0, -2*w, 3*w**2, 0]])
1711
+ assert exp(M*t).rewrite(cos).expand() == Matrix([
1712
+ [1, -3*t + 4*sin(t*w)/w, 6*t*w - 6*sin(t*w), -2*cos(t*w)/w + 2/w],
1713
+ [0, 4*cos(t*w) - 3, -6*w*cos(t*w) + 6*w, 2*sin(t*w)],
1714
+ [0, 2*cos(t*w)/w - 2/w, -3*cos(t*w) + 4, sin(t*w)/w],
1715
+ [0, -2*sin(t*w), 3*w*sin(t*w), cos(t*w)]])
1716
+
1717
+
1718
+ @XFAIL
1719
+ def test_P34():
1720
+ a, b, c = symbols('a b c', real=True)
1721
+ M = Matrix([[a, 1, 0, 0, 0, 0],
1722
+ [0, a, 0, 0, 0, 0],
1723
+ [0, 0, b, 0, 0, 0],
1724
+ [0, 0, 0, c, 1, 0],
1725
+ [0, 0, 0, 0, c, 1],
1726
+ [0, 0, 0, 0, 0, c]])
1727
+ # raises exception, sin(M) not supported. exp(M*I) also not supported
1728
+ # https://github.com/sympy/sympy/issues/6218
1729
+ assert sin(M) == Matrix([[sin(a), cos(a), 0, 0, 0, 0],
1730
+ [0, sin(a), 0, 0, 0, 0],
1731
+ [0, 0, sin(b), 0, 0, 0],
1732
+ [0, 0, 0, sin(c), cos(c), -sin(c)/2],
1733
+ [0, 0, 0, 0, sin(c), cos(c)],
1734
+ [0, 0, 0, 0, 0, sin(c)]])
1735
+
1736
+
1737
+ @XFAIL
1738
+ def test_P35():
1739
+ M = pi/2*Matrix([[2, 1, 1],
1740
+ [2, 3, 2],
1741
+ [1, 1, 2]])
1742
+ # raises exception, sin(M) not supported. exp(M*I) also not supported
1743
+ # https://github.com/sympy/sympy/issues/6218
1744
+ assert sin(M) == eye(3)
1745
+
1746
+
1747
+ @XFAIL
1748
+ def test_P36():
1749
+ M = Matrix([[10, 7],
1750
+ [7, 17]])
1751
+ assert sqrt(M) == Matrix([[3, 1],
1752
+ [1, 4]])
1753
+
1754
+
1755
+ def test_P37():
1756
+ M = Matrix([[1, 1, 0],
1757
+ [0, 1, 0],
1758
+ [0, 0, 1]])
1759
+ assert M**S.Half == Matrix([[1, R(1, 2), 0],
1760
+ [0, 1, 0],
1761
+ [0, 0, 1]])
1762
+
1763
+
1764
+ @XFAIL
1765
+ def test_P38():
1766
+ M=Matrix([[0, 1, 0],
1767
+ [0, 0, 0],
1768
+ [0, 0, 0]])
1769
+
1770
+ with raises(AssertionError):
1771
+ # raises ValueError: Matrix det == 0; not invertible
1772
+ M**S.Half
1773
+ # if it doesn't raise then this assertion will be
1774
+ # raised and the test will be flagged as not XFAILing
1775
+ assert None
1776
+
1777
+ @XFAIL
1778
+ def test_P39():
1779
+ """
1780
+ M=Matrix([
1781
+ [1, 1],
1782
+ [2, 2],
1783
+ [3, 3]])
1784
+ M.SVD()
1785
+ """
1786
+ raise NotImplementedError("Singular value decomposition not implemented")
1787
+
1788
+
1789
+ def test_P40():
1790
+ r, t = symbols('r t', real=True)
1791
+ M = Matrix([r*cos(t), r*sin(t)])
1792
+ assert M.jacobian(Matrix([r, t])) == Matrix([[cos(t), -r*sin(t)],
1793
+ [sin(t), r*cos(t)]])
1794
+
1795
+
1796
+ def test_P41():
1797
+ r, t = symbols('r t', real=True)
1798
+ assert hessian(r**2*sin(t),(r,t)) == Matrix([[ 2*sin(t), 2*r*cos(t)],
1799
+ [2*r*cos(t), -r**2*sin(t)]])
1800
+
1801
+
1802
+ def test_P42():
1803
+ assert wronskian([cos(x), sin(x)], x).simplify() == 1
1804
+
1805
+
1806
+ def test_P43():
1807
+ def __my_jacobian(M, Y):
1808
+ return Matrix([M.diff(v).T for v in Y]).T
1809
+ r, t = symbols('r t', real=True)
1810
+ M = Matrix([r*cos(t), r*sin(t)])
1811
+ assert __my_jacobian(M,[r,t]) == Matrix([[cos(t), -r*sin(t)],
1812
+ [sin(t), r*cos(t)]])
1813
+
1814
+
1815
+ def test_P44():
1816
+ def __my_hessian(f, Y):
1817
+ V = Matrix([diff(f, v) for v in Y])
1818
+ return Matrix([V.T.diff(v) for v in Y])
1819
+ r, t = symbols('r t', real=True)
1820
+ assert __my_hessian(r**2*sin(t), (r, t)) == Matrix([
1821
+ [ 2*sin(t), 2*r*cos(t)],
1822
+ [2*r*cos(t), -r**2*sin(t)]])
1823
+
1824
+
1825
+ def test_P45():
1826
+ def __my_wronskian(Y, v):
1827
+ M = Matrix([Matrix(Y).T.diff(x, n) for n in range(0, len(Y))])
1828
+ return M.det()
1829
+ assert __my_wronskian([cos(x), sin(x)], x).simplify() == 1
1830
+
1831
+ # Q1-Q6 Tensor tests missing
1832
+
1833
+
1834
+ @XFAIL
1835
+ def test_R1():
1836
+ i, j, n = symbols('i j n', integer=True, positive=True)
1837
+ xn = MatrixSymbol('xn', n, 1)
1838
+ Sm = Sum((xn[i, 0] - Sum(xn[j, 0], (j, 0, n - 1))/n)**2, (i, 0, n - 1))
1839
+ # sum does not calculate
1840
+ # Unknown result
1841
+ Sm.doit()
1842
+ raise NotImplementedError('Unknown result')
1843
+
1844
+ @XFAIL
1845
+ def test_R2():
1846
+ m, b = symbols('m b')
1847
+ i, n = symbols('i n', integer=True, positive=True)
1848
+ xn = MatrixSymbol('xn', n, 1)
1849
+ yn = MatrixSymbol('yn', n, 1)
1850
+ f = Sum((yn[i, 0] - m*xn[i, 0] - b)**2, (i, 0, n - 1))
1851
+ f1 = diff(f, m)
1852
+ f2 = diff(f, b)
1853
+ # raises TypeError: solveset() takes at most 2 arguments (3 given)
1854
+ solveset((f1, f2), (m, b), domain=S.Reals)
1855
+
1856
+
1857
+ @XFAIL
1858
+ def test_R3():
1859
+ n, k = symbols('n k', integer=True, positive=True)
1860
+ sk = ((-1)**k) * (binomial(2*n, k))**2
1861
+ Sm = Sum(sk, (k, 1, oo))
1862
+ T = Sm.doit()
1863
+ T2 = T.combsimp()
1864
+ # returns -((-1)**n*factorial(2*n)
1865
+ # - (factorial(n))**2)*exp_polar(-I*pi)/(factorial(n))**2
1866
+ assert T2 == (-1)**n*binomial(2*n, n)
1867
+
1868
+
1869
+ @XFAIL
1870
+ def test_R4():
1871
+ # Macsyma indefinite sum test case:
1872
+ #(c15) /* Check whether the full Gosper algorithm is implemented
1873
+ # => 1/2^(n + 1) binomial(n, k - 1) */
1874
+ #closedform(indefsum(binomial(n, k)/2^n - binomial(n + 1, k)/2^(n + 1), k));
1875
+ #Time= 2690 msecs
1876
+ # (- n + k - 1) binomial(n + 1, k)
1877
+ #(d15) - --------------------------------
1878
+ # n
1879
+ # 2 2 (n + 1)
1880
+ #
1881
+ #(c16) factcomb(makefact(%));
1882
+ #Time= 220 msecs
1883
+ # n!
1884
+ #(d16) ----------------
1885
+ # n
1886
+ # 2 k! 2 (n - k)!
1887
+ # Might be possible after fixing https://github.com/sympy/sympy/pull/1879
1888
+ raise NotImplementedError("Indefinite sum not supported")
1889
+
1890
+
1891
+ @XFAIL
1892
+ def test_R5():
1893
+ a, b, c, n, k = symbols('a b c n k', integer=True, positive=True)
1894
+ sk = ((-1)**k)*(binomial(a + b, a + k)
1895
+ *binomial(b + c, b + k)*binomial(c + a, c + k))
1896
+ Sm = Sum(sk, (k, 1, oo))
1897
+ T = Sm.doit() # hypergeometric series not calculated
1898
+ assert T == factorial(a+b+c)/(factorial(a)*factorial(b)*factorial(c))
1899
+
1900
+
1901
+ def test_R6():
1902
+ n, k = symbols('n k', integer=True, positive=True)
1903
+ gn = MatrixSymbol('gn', n + 2, 1)
1904
+ Sm = Sum(gn[k, 0] - gn[k - 1, 0], (k, 1, n + 1))
1905
+ assert Sm.doit() == -gn[0, 0] + gn[n + 1, 0]
1906
+
1907
+
1908
+ def test_R7():
1909
+ n, k = symbols('n k', integer=True, positive=True)
1910
+ T = Sum(k**3,(k,1,n)).doit()
1911
+ assert T.factor() == n**2*(n + 1)**2/4
1912
+
1913
+ @XFAIL
1914
+ def test_R8():
1915
+ n, k = symbols('n k', integer=True, positive=True)
1916
+ Sm = Sum(k**2*binomial(n, k), (k, 1, n))
1917
+ T = Sm.doit() #returns Piecewise function
1918
+ assert T.combsimp() == n*(n + 1)*2**(n - 2)
1919
+
1920
+
1921
+ def test_R9():
1922
+ n, k = symbols('n k', integer=True, positive=True)
1923
+ Sm = Sum(binomial(n, k - 1)/k, (k, 1, n + 1))
1924
+ assert Sm.doit().simplify() == (2**(n + 1) - 1)/(n + 1)
1925
+
1926
+
1927
+ @XFAIL
1928
+ def test_R10():
1929
+ n, m, r, k = symbols('n m r k', integer=True, positive=True)
1930
+ Sm = Sum(binomial(n, k)*binomial(m, r - k), (k, 0, r))
1931
+ T = Sm.doit()
1932
+ T2 = T.combsimp().rewrite(factorial)
1933
+ assert T2 == factorial(m + n)/(factorial(r)*factorial(m + n - r))
1934
+ assert T2 == binomial(m + n, r).rewrite(factorial)
1935
+ # rewrite(binomial) is not working.
1936
+ # https://github.com/sympy/sympy/issues/7135
1937
+ T3 = T2.rewrite(binomial)
1938
+ assert T3 == binomial(m + n, r)
1939
+
1940
+
1941
+ @XFAIL
1942
+ def test_R11():
1943
+ n, k = symbols('n k', integer=True, positive=True)
1944
+ sk = binomial(n, k)*fibonacci(k)
1945
+ Sm = Sum(sk, (k, 0, n))
1946
+ T = Sm.doit()
1947
+ # Fibonacci simplification not implemented
1948
+ # https://github.com/sympy/sympy/issues/7134
1949
+ assert T == fibonacci(2*n)
1950
+
1951
+
1952
+ @XFAIL
1953
+ def test_R12():
1954
+ n, k = symbols('n k', integer=True, positive=True)
1955
+ Sm = Sum(fibonacci(k)**2, (k, 0, n))
1956
+ T = Sm.doit()
1957
+ assert T == fibonacci(n)*fibonacci(n + 1)
1958
+
1959
+
1960
+ @XFAIL
1961
+ def test_R13():
1962
+ n, k = symbols('n k', integer=True, positive=True)
1963
+ Sm = Sum(sin(k*x), (k, 1, n))
1964
+ T = Sm.doit() # Sum is not calculated
1965
+ assert T.simplify() == cot(x/2)/2 - cos(x*(2*n + 1)/2)/(2*sin(x/2))
1966
+
1967
+
1968
+ @XFAIL
1969
+ def test_R14():
1970
+ n, k = symbols('n k', integer=True, positive=True)
1971
+ Sm = Sum(sin((2*k - 1)*x), (k, 1, n))
1972
+ T = Sm.doit() # Sum is not calculated
1973
+ assert T.simplify() == sin(n*x)**2/sin(x)
1974
+
1975
+
1976
+ @XFAIL
1977
+ def test_R15():
1978
+ n, k = symbols('n k', integer=True, positive=True)
1979
+ Sm = Sum(binomial(n - k, k), (k, 0, floor(n/2)))
1980
+ T = Sm.doit() # Sum is not calculated
1981
+ assert T.simplify() == fibonacci(n + 1)
1982
+
1983
+
1984
+ def test_R16():
1985
+ k = symbols('k', integer=True, positive=True)
1986
+ Sm = Sum(1/k**2 + 1/k**3, (k, 1, oo))
1987
+ assert Sm.doit() == zeta(3) + pi**2/6
1988
+
1989
+
1990
+ def test_R17():
1991
+ k = symbols('k', integer=True, positive=True)
1992
+ assert abs(float(Sum(1/k**2 + 1/k**3, (k, 1, oo)))
1993
+ - 2.8469909700078206) < 1e-15
1994
+
1995
+
1996
+ def test_R18():
1997
+ k = symbols('k', integer=True, positive=True)
1998
+ Sm = Sum(1/(2**k*k**2), (k, 1, oo))
1999
+ T = Sm.doit()
2000
+ assert T.simplify() == -log(2)**2/2 + pi**2/12
2001
+
2002
+
2003
+ @slow
2004
+ @XFAIL
2005
+ def test_R19():
2006
+ k = symbols('k', integer=True, positive=True)
2007
+ Sm = Sum(1/((3*k + 1)*(3*k + 2)*(3*k + 3)), (k, 0, oo))
2008
+ T = Sm.doit()
2009
+ # assert fails, T not simplified
2010
+ assert T.simplify() == -log(3)/4 + sqrt(3)*pi/12
2011
+
2012
+
2013
+ @XFAIL
2014
+ def test_R20():
2015
+ n, k = symbols('n k', integer=True, positive=True)
2016
+ Sm = Sum(binomial(n, 4*k), (k, 0, oo))
2017
+ T = Sm.doit()
2018
+ # assert fails, T not simplified
2019
+ assert T.simplify() == 2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2
2020
+
2021
+
2022
+ @XFAIL
2023
+ def test_R21():
2024
+ k = symbols('k', integer=True, positive=True)
2025
+ Sm = Sum(1/(sqrt(k*(k + 1)) * (sqrt(k) + sqrt(k + 1))), (k, 1, oo))
2026
+ T = Sm.doit() # Sum not calculated
2027
+ assert T.simplify() == 1
2028
+
2029
+
2030
+ # test_R22 answer not available in Wester samples
2031
+ # Sum(Sum(binomial(n, k)*binomial(n - k, n - 2*k)*x**n*y**(n - 2*k),
2032
+ # (k, 0, floor(n/2))), (n, 0, oo)) with abs(x*y)<1?
2033
+
2034
+
2035
+ @XFAIL
2036
+ def test_R23():
2037
+ n, k = symbols('n k', integer=True, positive=True)
2038
+ Sm = Sum(Sum((factorial(n)/(factorial(k)**2*factorial(n - 2*k)))*
2039
+ (x/y)**k*(x*y)**(n - k), (n, 2*k, oo)), (k, 0, oo))
2040
+ # Missing how to express constraint abs(x*y)<1?
2041
+ T = Sm.doit() # Sum not calculated
2042
+ assert T == -1/sqrt(x**2*y**2 - 4*x**2 - 2*x*y + 1)
2043
+
2044
+
2045
+ def test_R24():
2046
+ m, k = symbols('m k', integer=True, positive=True)
2047
+ Sm = Sum(Product(k/(2*k - 1), (k, 1, m)), (m, 2, oo))
2048
+ assert Sm.doit() == pi/2
2049
+
2050
+
2051
+ def test_S1():
2052
+ k = symbols('k', integer=True, positive=True)
2053
+ Pr = Product(gamma(k/3), (k, 1, 8))
2054
+ assert Pr.doit().simplify() == 640*sqrt(3)*pi**3/6561
2055
+
2056
+
2057
+ def test_S2():
2058
+ n, k = symbols('n k', integer=True, positive=True)
2059
+ assert Product(k, (k, 1, n)).doit() == factorial(n)
2060
+
2061
+
2062
+ def test_S3():
2063
+ n, k = symbols('n k', integer=True, positive=True)
2064
+ assert Product(x**k, (k, 1, n)).doit().simplify() == x**(n*(n + 1)/2)
2065
+
2066
+
2067
+ def test_S4():
2068
+ n, k = symbols('n k', integer=True, positive=True)
2069
+ assert Product(1 + 1/k, (k, 1, n -1)).doit().simplify() == n
2070
+
2071
+
2072
+ def test_S5():
2073
+ n, k = symbols('n k', integer=True, positive=True)
2074
+ assert (Product((2*k - 1)/(2*k), (k, 1, n)).doit().gammasimp() ==
2075
+ gamma(n + S.Half)/(sqrt(pi)*gamma(n + 1)))
2076
+
2077
+
2078
+ @XFAIL
2079
+ def test_S6():
2080
+ n, k = symbols('n k', integer=True, positive=True)
2081
+ # Product does not evaluate
2082
+ assert (Product(x**2 -2*x*cos(k*pi/n) + 1, (k, 1, n - 1)).doit().simplify()
2083
+ == (x**(2*n) - 1)/(x**2 - 1))
2084
+
2085
+
2086
+ @XFAIL
2087
+ def test_S7():
2088
+ k = symbols('k', integer=True, positive=True)
2089
+ Pr = Product((k**3 - 1)/(k**3 + 1), (k, 2, oo))
2090
+ T = Pr.doit() # Product does not evaluate
2091
+ assert T.simplify() == R(2, 3)
2092
+
2093
+
2094
+ @XFAIL
2095
+ def test_S8():
2096
+ k = symbols('k', integer=True, positive=True)
2097
+ Pr = Product(1 - 1/(2*k)**2, (k, 1, oo))
2098
+ T = Pr.doit()
2099
+ # Product does not evaluate
2100
+ assert T.simplify() == 2/pi
2101
+
2102
+
2103
+ @XFAIL
2104
+ def test_S9():
2105
+ k = symbols('k', integer=True, positive=True)
2106
+ Pr = Product(1 + (-1)**(k + 1)/(2*k - 1), (k, 1, oo))
2107
+ T = Pr.doit()
2108
+ # Product produces 0
2109
+ # https://github.com/sympy/sympy/issues/7133
2110
+ assert T.simplify() == sqrt(2)
2111
+
2112
+
2113
+ @XFAIL
2114
+ def test_S10():
2115
+ k = symbols('k', integer=True, positive=True)
2116
+ Pr = Product((k*(k + 1) + 1 + I)/(k*(k + 1) + 1 - I), (k, 0, oo))
2117
+ T = Pr.doit()
2118
+ # Product does not evaluate
2119
+ assert T.simplify() == -1
2120
+
2121
+
2122
+ def test_T1():
2123
+ assert limit((1 + 1/n)**n, n, oo) == E
2124
+ assert limit((1 - cos(x))/x**2, x, 0) == S.Half
2125
+
2126
+
2127
+ def test_T2():
2128
+ assert limit((3**x + 5**x)**(1/x), x, oo) == 5
2129
+
2130
+
2131
+ def test_T3():
2132
+ assert limit(log(x)/(log(x) + sin(x)), x, oo) == 1
2133
+
2134
+
2135
+ def test_T4():
2136
+ assert limit((exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1))))
2137
+ - exp(x))/x, x, oo) == -exp(2)
2138
+
2139
+
2140
+ def test_T5():
2141
+ assert limit(x*log(x)*log(x*exp(x) - x**2)**2/log(log(x**2
2142
+ + 2*exp(exp(3*x**3*log(x))))), x, oo) == R(1, 3)
2143
+
2144
+
2145
+ def test_T6():
2146
+ assert limit(1/n * factorial(n)**(1/n), n, oo) == exp(-1)
2147
+
2148
+
2149
+ def test_T7():
2150
+ limit(1/n * gamma(n + 1)**(1/n), n, oo)
2151
+
2152
+
2153
+ def test_T8():
2154
+ a, z = symbols('a z', positive=True)
2155
+ assert limit(gamma(z + a)/gamma(z)*exp(-a*log(z)), z, oo) == 1
2156
+
2157
+
2158
+ @XFAIL
2159
+ def test_T9():
2160
+ z, k = symbols('z k', positive=True)
2161
+ # raises NotImplementedError:
2162
+ # Don't know how to calculate the mrv of '(1, k)'
2163
+ assert limit(hyper((1, k), (1,), z/k), k, oo) == exp(z)
2164
+
2165
+
2166
+ @XFAIL
2167
+ def test_T10():
2168
+ # No longer raises PoleError, but should return euler-mascheroni constant
2169
+ assert limit(zeta(x) - 1/(x - 1), x, 1) == integrate(-1/x + 1/floor(x), (x, 1, oo))
2170
+
2171
+ @XFAIL
2172
+ def test_T11():
2173
+ n, k = symbols('n k', integer=True, positive=True)
2174
+ # evaluates to 0
2175
+ assert limit(n**x/(x*product((1 + x/k), (k, 1, n))), n, oo) == gamma(x)
2176
+
2177
+
2178
+ def test_T12():
2179
+ x, t = symbols('x t', real=True)
2180
+ # Does not evaluate the limit but returns an expression with erf
2181
+ assert limit(x * integrate(exp(-t**2), (t, 0, x))/(1 - exp(-x**2)),
2182
+ x, 0) == 1
2183
+
2184
+
2185
+ def test_T13():
2186
+ x = symbols('x', real=True)
2187
+ assert [limit(x/abs(x), x, 0, dir='-'),
2188
+ limit(x/abs(x), x, 0, dir='+')] == [-1, 1]
2189
+
2190
+
2191
+ def test_T14():
2192
+ x = symbols('x', real=True)
2193
+ assert limit(atan(-log(x)), x, 0, dir='+') == pi/2
2194
+
2195
+
2196
+ def test_U1():
2197
+ x = symbols('x', real=True)
2198
+ assert diff(abs(x), x) == sign(x)
2199
+
2200
+
2201
+ def test_U2():
2202
+ f = Lambda(x, Piecewise((-x, x < 0), (x, x >= 0)))
2203
+ assert diff(f(x), x) == Piecewise((-1, x < 0), (1, x >= 0))
2204
+
2205
+
2206
+ def test_U3():
2207
+ f = Lambda(x, Piecewise((x**2 - 1, x == 1), (x**3, x != 1)))
2208
+ f1 = Lambda(x, diff(f(x), x))
2209
+ assert f1(x) == 3*x**2
2210
+ assert f1(1) == 3
2211
+
2212
+
2213
+ @XFAIL
2214
+ def test_U4():
2215
+ n = symbols('n', integer=True, positive=True)
2216
+ x = symbols('x', real=True)
2217
+ d = diff(x**n, x, n)
2218
+ assert d.rewrite(factorial) == factorial(n)
2219
+
2220
+
2221
+ def test_U5():
2222
+ # issue 6681
2223
+ t = symbols('t')
2224
+ ans = (
2225
+ Derivative(f(g(t)), g(t))*Derivative(g(t), (t, 2)) +
2226
+ Derivative(f(g(t)), (g(t), 2))*Derivative(g(t), t)**2)
2227
+ assert f(g(t)).diff(t, 2) == ans
2228
+ assert ans.doit() == ans
2229
+
2230
+
2231
+ def test_U6():
2232
+ h = Function('h')
2233
+ T = integrate(f(y), (y, h(x), g(x)))
2234
+ assert T.diff(x) == (
2235
+ f(g(x))*Derivative(g(x), x) - f(h(x))*Derivative(h(x), x))
2236
+
2237
+
2238
+ @XFAIL
2239
+ def test_U7():
2240
+ p, t = symbols('p t', real=True)
2241
+ # Exact differential => d(V(P, T)) => dV/dP DP + dV/dT DT
2242
+ # raises ValueError: Since there is more than one variable in the
2243
+ # expression, the variable(s) of differentiation must be supplied to
2244
+ # differentiate f(p,t)
2245
+ diff(f(p, t))
2246
+
2247
+
2248
+ def test_U8():
2249
+ x, y = symbols('x y', real=True)
2250
+ eq = cos(x*y) + x
2251
+ # If SymPy had implicit_diff() function this hack could be avoided
2252
+ # TODO: Replace solve with solveset, current test fails for solveset
2253
+ assert idiff(y - eq, y, x) == (-y*sin(x*y) + 1)/(x*sin(x*y) + 1)
2254
+
2255
+
2256
+ def test_U9():
2257
+ # Wester sample case for Maple:
2258
+ # O29 := diff(f(x, y), x) + diff(f(x, y), y);
2259
+ # /d \ /d \
2260
+ # |-- f(x, y)| + |-- f(x, y)|
2261
+ # \dx / \dy /
2262
+ #
2263
+ # O30 := factor(subs(f(x, y) = g(x^2 + y^2), %));
2264
+ # 2 2
2265
+ # 2 D(g)(x + y ) (x + y)
2266
+ x, y = symbols('x y', real=True)
2267
+ su = diff(f(x, y), x) + diff(f(x, y), y)
2268
+ s2 = su.subs(f(x, y), g(x**2 + y**2))
2269
+ s3 = s2.doit().factor()
2270
+ # Subs not performed, s3 = 2*(x + y)*Subs(Derivative(
2271
+ # g(_xi_1), _xi_1), _xi_1, x**2 + y**2)
2272
+ # Derivative(g(x*2 + y**2), x**2 + y**2) is not valid in SymPy,
2273
+ # and probably will remain that way. You can take derivatives with respect
2274
+ # to other expressions only if they are atomic, like a symbol or a
2275
+ # function.
2276
+ # D operator should be added to SymPy
2277
+ # See https://github.com/sympy/sympy/issues/4719.
2278
+ assert s3 == (x + y)*Subs(Derivative(g(x), x), x, x**2 + y**2)*2
2279
+
2280
+
2281
+ def test_U10():
2282
+ # see issue 2519:
2283
+ assert residue((z**3 + 5)/((z**4 - 1)*(z + 1)), z, -1) == R(-9, 4)
2284
+
2285
+ @XFAIL
2286
+ def test_U11():
2287
+ # assert (2*dx + dz) ^ (3*dx + dy + dz) ^ (dx + dy + 4*dz) == 8*dx ^ dy ^dz
2288
+ raise NotImplementedError
2289
+
2290
+
2291
+ @XFAIL
2292
+ def test_U12():
2293
+ # Wester sample case:
2294
+ # (c41) /* d(3 x^5 dy /\ dz + 5 x y^2 dz /\ dx + 8 z dx /\ dy)
2295
+ # => (15 x^4 + 10 x y + 8) dx /\ dy /\ dz */
2296
+ # factor(ext_diff(3*x^5 * dy ~ dz + 5*x*y^2 * dz ~ dx + 8*z * dx ~ dy));
2297
+ # 4
2298
+ # (d41) (10 x y + 15 x + 8) dx dy dz
2299
+ raise NotImplementedError(
2300
+ "External diff of differential form not supported")
2301
+
2302
+
2303
+ def test_U13():
2304
+ assert minimum(x**4 - x + 1, x) == -3*2**R(1,3)/8 + 1
2305
+
2306
+
2307
+ @XFAIL
2308
+ def test_U14():
2309
+ #f = 1/(x**2 + y**2 + 1)
2310
+ #assert [minimize(f), maximize(f)] == [0,1]
2311
+ raise NotImplementedError("minimize(), maximize() not supported")
2312
+
2313
+
2314
+ @XFAIL
2315
+ def test_U15():
2316
+ raise NotImplementedError("minimize() not supported and also solve does \
2317
+ not support multivariate inequalities")
2318
+
2319
+
2320
+ @XFAIL
2321
+ def test_U16():
2322
+ raise NotImplementedError("minimize() not supported in SymPy and also \
2323
+ solve does not support multivariate inequalities")
2324
+
2325
+
2326
+ @XFAIL
2327
+ def test_U17():
2328
+ raise NotImplementedError("Linear programming, symbolic simplex not \
2329
+ supported in SymPy")
2330
+
2331
+
2332
+ def test_V1():
2333
+ x = symbols('x', real=True)
2334
+ assert integrate(abs(x), x) == Piecewise((-x**2/2, x <= 0), (x**2/2, True))
2335
+
2336
+
2337
+ def test_V2():
2338
+ assert integrate(Piecewise((-x, x < 0), (x, x >= 0)), x
2339
+ ) == Piecewise((-x**2/2, x < 0), (x**2/2, True))
2340
+
2341
+
2342
+ def test_V3():
2343
+ assert integrate(1/(x**3 + 2),x).diff().simplify() == 1/(x**3 + 2)
2344
+
2345
+
2346
+ def test_V4():
2347
+ assert integrate(2**x/sqrt(1 + 4**x), x) == asinh(2**x)/log(2)
2348
+
2349
+
2350
+ @XFAIL
2351
+ def test_V5():
2352
+ # Returns (-45*x**2 + 80*x - 41)/(5*sqrt(2*x - 1)*(4*x**2 - 4*x + 1))
2353
+ assert (integrate((3*x - 5)**2/(2*x - 1)**R(7, 2), x).simplify() ==
2354
+ (-41 + 80*x - 45*x**2)/(5*(2*x - 1)**R(5, 2)))
2355
+
2356
+
2357
+ @XFAIL
2358
+ def test_V6():
2359
+ # returns RootSum(40*_z**2 - 1, Lambda(_i, _i*log(-4*_i + exp(-m*x))))/m
2360
+ assert (integrate(1/(2*exp(m*x) - 5*exp(-m*x)), x) == sqrt(10)*(
2361
+ log(2*exp(m*x) - sqrt(10)) - log(2*exp(m*x) + sqrt(10)))/(20*m))
2362
+
2363
+
2364
+ def test_V7():
2365
+ r1 = integrate(sinh(x)**4/cosh(x)**2)
2366
+ assert r1.simplify() == x*R(-3, 2) + sinh(x)**3/(2*cosh(x)) + 3*tanh(x)/2
2367
+
2368
+
2369
+ @XFAIL
2370
+ def test_V8_V9():
2371
+ #Macsyma test case:
2372
+ #(c27) /* This example involves several symbolic parameters
2373
+ # => 1/sqrt(b^2 - a^2) log([sqrt(b^2 - a^2) tan(x/2) + a + b]/
2374
+ # [sqrt(b^2 - a^2) tan(x/2) - a - b]) (a^2 < b^2)
2375
+ # [Gradshteyn and Ryzhik 2.553(3)] */
2376
+ #assume(b^2 > a^2)$
2377
+ #(c28) integrate(1/(a + b*cos(x)), x);
2378
+ #(c29) trigsimp(ratsimp(diff(%, x)));
2379
+ # 1
2380
+ #(d29) ------------
2381
+ # b cos(x) + a
2382
+ raise NotImplementedError(
2383
+ "Integrate with assumption not supported")
2384
+
2385
+
2386
+ def test_V10():
2387
+ assert integrate(1/(3 + 3*cos(x) + 4*sin(x)), x) == log(4*tan(x/2) + 3)/4
2388
+
2389
+
2390
+ def test_V11():
2391
+ r1 = integrate(1/(4 + 3*cos(x) + 4*sin(x)), x)
2392
+ r2 = factor(r1)
2393
+ assert (logcombine(r2, force=True) ==
2394
+ log(((tan(x/2) + 1)/(tan(x/2) + 7))**R(1, 3)))
2395
+
2396
+
2397
+ def test_V12():
2398
+ r1 = integrate(1/(5 + 3*cos(x) + 4*sin(x)), x)
2399
+ assert r1 == -1/(tan(x/2) + 2)
2400
+
2401
+
2402
+ @XFAIL
2403
+ def test_V13():
2404
+ r1 = integrate(1/(6 + 3*cos(x) + 4*sin(x)), x)
2405
+ # expression not simplified, returns: -sqrt(11)*I*log(tan(x/2) + 4/3
2406
+ # - sqrt(11)*I/3)/11 + sqrt(11)*I*log(tan(x/2) + 4/3 + sqrt(11)*I/3)/11
2407
+ assert r1.simplify() == 2*sqrt(11)*atan(sqrt(11)*(3*tan(x/2) + 4)/11)/11
2408
+
2409
+
2410
+ @slow
2411
+ @XFAIL
2412
+ def test_V14():
2413
+ r1 = integrate(log(abs(x**2 - y**2)), x)
2414
+ # Piecewise result does not simplify to the desired result.
2415
+ assert (r1.simplify() == x*log(abs(x**2 - y**2))
2416
+ + y*log(x + y) - y*log(x - y) - 2*x)
2417
+
2418
+
2419
+ def test_V15():
2420
+ r1 = integrate(x*acot(x/y), x)
2421
+ assert simplify(r1 - (x*y + (x**2 + y**2)*acot(x/y))/2) == 0
2422
+
2423
+
2424
+ @XFAIL
2425
+ def test_V16():
2426
+ # Integral not calculated
2427
+ assert integrate(cos(5*x)*Ci(2*x), x) == Ci(2*x)*sin(5*x)/5 - (Si(3*x) + Si(7*x))/10
2428
+
2429
+ @XFAIL
2430
+ def test_V17():
2431
+ r1 = integrate((diff(f(x), x)*g(x)
2432
+ - f(x)*diff(g(x), x))/(f(x)**2 - g(x)**2), x)
2433
+ # integral not calculated
2434
+ assert simplify(r1 - (f(x) - g(x))/(f(x) + g(x))/2) == 0
2435
+
2436
+
2437
+ @XFAIL
2438
+ def test_W1():
2439
+ # The function has a pole at y.
2440
+ # The integral has a Cauchy principal value of zero but SymPy returns -I*pi
2441
+ # https://github.com/sympy/sympy/issues/7159
2442
+ assert integrate(1/(x - y), (x, y - 1, y + 1)) == 0
2443
+
2444
+
2445
+ @XFAIL
2446
+ def test_W2():
2447
+ # The function has a pole at y.
2448
+ # The integral is divergent but SymPy returns -2
2449
+ # https://github.com/sympy/sympy/issues/7160
2450
+ # Test case in Macsyma:
2451
+ # (c6) errcatch(integrate(1/(x - a)^2, x, a - 1, a + 1));
2452
+ # Integral is divergent
2453
+ assert integrate(1/(x - y)**2, (x, y - 1, y + 1)) is zoo
2454
+
2455
+
2456
+ @XFAIL
2457
+ @slow
2458
+ def test_W3():
2459
+ # integral is not calculated
2460
+ # https://github.com/sympy/sympy/issues/7161
2461
+ assert integrate(sqrt(x + 1/x - 2), (x, 0, 1)) == R(4, 3)
2462
+
2463
+
2464
+ @XFAIL
2465
+ @slow
2466
+ def test_W4():
2467
+ # integral is not calculated
2468
+ assert integrate(sqrt(x + 1/x - 2), (x, 1, 2)) == -2*sqrt(2)/3 + R(4, 3)
2469
+
2470
+
2471
+ @XFAIL
2472
+ @slow
2473
+ def test_W5():
2474
+ # integral is not calculated
2475
+ assert integrate(sqrt(x + 1/x - 2), (x, 0, 2)) == -2*sqrt(2)/3 + R(8, 3)
2476
+
2477
+
2478
+ @XFAIL
2479
+ @slow
2480
+ def test_W6():
2481
+ # integral is not calculated
2482
+ assert integrate(sqrt(2 - 2*cos(2*x))/2, (x, pi*R(-3, 4), -pi/4)) == sqrt(2)
2483
+
2484
+
2485
+ def test_W7():
2486
+ a = symbols('a', positive=True)
2487
+ r1 = integrate(cos(x)/(x**2 + a**2), (x, -oo, oo))
2488
+ assert r1.simplify() == pi*exp(-a)/a
2489
+
2490
+
2491
+ @XFAIL
2492
+ def test_W8():
2493
+ # Test case in Mathematica:
2494
+ # In[19]:= Integrate[t^(a - 1)/(1 + t), {t, 0, Infinity},
2495
+ # Assumptions -> 0 < a < 1]
2496
+ # Out[19]= Pi Csc[a Pi]
2497
+ raise NotImplementedError(
2498
+ "Integrate with assumption 0 < a < 1 not supported")
2499
+
2500
+
2501
+ @XFAIL
2502
+ @slow
2503
+ def test_W9():
2504
+ # Integrand with a residue at infinity => -2 pi [sin(pi/5) + sin(2pi/5)]
2505
+ # (principal value) [Levinson and Redheffer, p. 234] *)
2506
+ r1 = integrate(5*x**3/(1 + x + x**2 + x**3 + x**4), (x, -oo, oo))
2507
+ r2 = r1.doit()
2508
+ assert r2 == -2*pi*(sqrt(-sqrt(5)/8 + 5/8) + sqrt(sqrt(5)/8 + 5/8))
2509
+
2510
+
2511
+ @XFAIL
2512
+ def test_W10():
2513
+ # integrate(1/[1 + x + x^2 + ... + x^(2 n)], x = -infinity..infinity) =
2514
+ # 2 pi/(2 n + 1) [1 + cos(pi/[2 n + 1])] csc(2 pi/[2 n + 1])
2515
+ # [Levinson and Redheffer, p. 255] => 2 pi/5 [1 + cos(pi/5)] csc(2 pi/5) */
2516
+ r1 = integrate(x/(1 + x + x**2 + x**4), (x, -oo, oo))
2517
+ r2 = r1.doit()
2518
+ assert r2 == 2*pi*(sqrt(5)/4 + 5/4)*csc(pi*R(2, 5))/5
2519
+
2520
+
2521
+ @XFAIL
2522
+ def test_W11():
2523
+ # integral not calculated
2524
+ assert (integrate(sqrt(1 - x**2)/(1 + x**2), (x, -1, 1)) ==
2525
+ pi*(-1 + sqrt(2)))
2526
+
2527
+
2528
+ def test_W12():
2529
+ p = symbols('p', positive=True)
2530
+ q = symbols('q', real=True)
2531
+ r1 = integrate(x*exp(-p*x**2 + 2*q*x), (x, -oo, oo))
2532
+ assert r1.simplify() == sqrt(pi)*q*exp(q**2/p)/p**R(3, 2)
2533
+
2534
+
2535
+ @XFAIL
2536
+ def test_W13():
2537
+ # Integral not calculated. Expected result is 2*(Euler_mascheroni_constant)
2538
+ r1 = integrate(1/log(x) + 1/(1 - x) - log(log(1/x)), (x, 0, 1))
2539
+ assert r1 == 2*EulerGamma
2540
+
2541
+
2542
+ def test_W14():
2543
+ assert integrate(sin(x)/x*exp(2*I*x), (x, -oo, oo)) == 0
2544
+
2545
+
2546
+ @XFAIL
2547
+ def test_W15():
2548
+ # integral not calculated
2549
+ assert integrate(log(gamma(x))*cos(6*pi*x), (x, 0, 1)) == R(1, 12)
2550
+
2551
+
2552
+ def test_W16():
2553
+ assert integrate((1 + x)**3*legendre_poly(1, x)*legendre_poly(2, x),
2554
+ (x, -1, 1)) == R(36, 35)
2555
+
2556
+
2557
+ def test_W17():
2558
+ a, b = symbols('a b', positive=True)
2559
+ assert integrate(exp(-a*x)*besselj(0, b*x),
2560
+ (x, 0, oo)) == 1/(b*sqrt(a**2/b**2 + 1))
2561
+
2562
+
2563
+ def test_W18():
2564
+ assert integrate((besselj(1, x)/x)**2, (x, 0, oo)) == 4/(3*pi)
2565
+
2566
+
2567
+ @XFAIL
2568
+ def test_W19():
2569
+ # Integral not calculated
2570
+ # Expected result is (cos 7 - 1)/7 [Gradshteyn and Ryzhik 6.782(3)]
2571
+ assert integrate(Ci(x)*besselj(0, 2*sqrt(7*x)), (x, 0, oo)) == (cos(7) - 1)/7
2572
+
2573
+
2574
+ @XFAIL
2575
+ def test_W20():
2576
+ # integral not calculated
2577
+ assert (integrate(x**2*polylog(3, 1/(x + 1)), (x, 0, 1)) ==
2578
+ -pi**2/36 - R(17, 108) + zeta(3)/4 +
2579
+ (-pi**2/2 - 4*log(2) + log(2)**2 + 35/3)*log(2)/9)
2580
+
2581
+
2582
+ def test_W21():
2583
+ assert abs(N(integrate(x**2*polylog(3, 1/(x + 1)), (x, 0, 1)))
2584
+ - 0.210882859565594) < 1e-15
2585
+
2586
+
2587
+ def test_W22():
2588
+ t, u = symbols('t u', real=True)
2589
+ s = Lambda(x, Piecewise((1, And(x >= 1, x <= 2)), (0, True)))
2590
+ assert integrate(s(t)*cos(t), (t, 0, u)) == Piecewise(
2591
+ (0, u < 0),
2592
+ (-sin(Min(1, u)) + sin(Min(2, u)), True))
2593
+
2594
+
2595
+ @slow
2596
+ def test_W23():
2597
+ a, b = symbols('a b', positive=True)
2598
+ r1 = integrate(integrate(x/(x**2 + y**2), (x, a, b)), (y, -oo, oo))
2599
+ assert r1.collect(pi).cancel() == -pi*a + pi*b
2600
+
2601
+
2602
+ def test_W23b():
2603
+ # like W23 but limits are reversed
2604
+ a, b = symbols('a b', positive=True)
2605
+ r2 = integrate(integrate(x/(x**2 + y**2), (y, -oo, oo)), (x, a, b))
2606
+ assert r2.collect(pi) == pi*(-a + b)
2607
+
2608
+
2609
+ @XFAIL
2610
+ @tooslow
2611
+ def test_W24():
2612
+ # Not that slow, but does not fully evaluate so simplify is slow.
2613
+ # Maybe also require doit()
2614
+ x, y = symbols('x y', real=True)
2615
+ r1 = integrate(integrate(sqrt(x**2 + y**2), (x, 0, 1)), (y, 0, 1))
2616
+ assert (r1 - (sqrt(2) + asinh(1))/3).simplify() == 0
2617
+
2618
+
2619
+ @XFAIL
2620
+ @tooslow
2621
+ def test_W25():
2622
+ a, x, y = symbols('a x y', real=True)
2623
+ i1 = integrate(
2624
+ sin(a)*sin(y)/sqrt(1 - sin(a)**2*sin(x)**2*sin(y)**2),
2625
+ (x, 0, pi/2))
2626
+ i2 = integrate(i1, (y, 0, pi/2))
2627
+ assert (i2 - pi*a/2).simplify() == 0
2628
+
2629
+
2630
+ def test_W26():
2631
+ x, y = symbols('x y', real=True)
2632
+ assert integrate(integrate(abs(y - x**2), (y, 0, 2)),
2633
+ (x, -1, 1)) == R(46, 15)
2634
+
2635
+
2636
+ def test_W27():
2637
+ a, b, c = symbols('a b c')
2638
+ assert integrate(integrate(integrate(1, (z, 0, c*(1 - x/a - y/b))),
2639
+ (y, 0, b*(1 - x/a))),
2640
+ (x, 0, a)) == a*b*c/6
2641
+
2642
+
2643
+ def test_X1():
2644
+ v, c = symbols('v c', real=True)
2645
+ assert (series(1/sqrt(1 - (v/c)**2), v, x0=0, n=8) ==
2646
+ 5*v**6/(16*c**6) + 3*v**4/(8*c**4) + v**2/(2*c**2) + 1 + O(v**8))
2647
+
2648
+
2649
+ def test_X2():
2650
+ v, c = symbols('v c', real=True)
2651
+ s1 = series(1/sqrt(1 - (v/c)**2), v, x0=0, n=8)
2652
+ assert (1/s1**2).series(v, x0=0, n=8) == -v**2/c**2 + 1 + O(v**8)
2653
+
2654
+
2655
+ def test_X3():
2656
+ s1 = (sin(x).series()/cos(x).series()).series()
2657
+ s2 = tan(x).series()
2658
+ assert s2 == x + x**3/3 + 2*x**5/15 + O(x**6)
2659
+ assert s1 == s2
2660
+
2661
+
2662
+ def test_X4():
2663
+ s1 = log(sin(x)/x).series()
2664
+ assert s1 == -x**2/6 - x**4/180 + O(x**6)
2665
+ assert log(series(sin(x)/x)).series() == s1
2666
+
2667
+
2668
+ @XFAIL
2669
+ def test_X5():
2670
+ # test case in Mathematica syntax:
2671
+ # In[21]:= (* => [a f'(a d) + g(b d) + integrate(h(c y), y = 0..d)]
2672
+ # + [a^2 f''(a d) + b g'(b d) + h(c d)] (x - d) *)
2673
+ # In[22]:= D[f[a*x], x] + g[b*x] + Integrate[h[c*y], {y, 0, x}]
2674
+ # Out[22]= g[b x] + Integrate[h[c y], {y, 0, x}] + a f'[a x]
2675
+ # In[23]:= Series[%, {x, d, 1}]
2676
+ # Out[23]= (g[b d] + Integrate[h[c y], {y, 0, d}] + a f'[a d]) +
2677
+ # 2 2
2678
+ # (h[c d] + b g'[b d] + a f''[a d]) (-d + x) + O[-d + x]
2679
+ h = Function('h')
2680
+ a, b, c, d = symbols('a b c d', real=True)
2681
+ # series() raises NotImplementedError:
2682
+ # The _eval_nseries method should be added to <class
2683
+ # 'sympy.core.function.Subs'> to give terms up to O(x**n) at x=0
2684
+ series(diff(f(a*x), x) + g(b*x) + integrate(h(c*y), (y, 0, x)),
2685
+ x, x0=d, n=2)
2686
+ # assert missing, until exception is removed
2687
+
2688
+
2689
+ def test_X6():
2690
+ # Taylor series of nonscalar objects (noncommutative multiplication)
2691
+ # expected result => (B A - A B) t^2/2 + O(t^3) [Stanly Steinberg]
2692
+ a, b = symbols('a b', commutative=False, scalar=False)
2693
+ assert (series(exp((a + b)*x) - exp(a*x) * exp(b*x), x, x0=0, n=3) ==
2694
+ x**2*(-a*b/2 + b*a/2) + O(x**3))
2695
+
2696
+
2697
+ def test_X7():
2698
+ # => sum( Bernoulli[k]/k! x^(k - 2), k = 1..infinity )
2699
+ # = 1/x^2 - 1/(2 x) + 1/12 - x^2/720 + x^4/30240 + O(x^6)
2700
+ # [Levinson and Redheffer, p. 173]
2701
+ assert (series(1/(x*(exp(x) - 1)), x, 0, 7) == x**(-2) - 1/(2*x) +
2702
+ R(1, 12) - x**2/720 + x**4/30240 - x**6/1209600 + O(x**7))
2703
+
2704
+
2705
+ def test_X8():
2706
+ # Puiseux series (terms with fractional degree):
2707
+ # => 1/sqrt(x - 3/2 pi) + (x - 3/2 pi)^(3/2) / 12 + O([x - 3/2 pi]^(7/2))
2708
+
2709
+ # see issue 7167:
2710
+ x = symbols('x', real=True)
2711
+ assert (series(sqrt(sec(x)), x, x0=pi*3/2, n=4) ==
2712
+ 1/sqrt(x - pi*R(3, 2)) + (x - pi*R(3, 2))**R(3, 2)/12 +
2713
+ (x - pi*R(3, 2))**R(7, 2)/160 + O((x - pi*R(3, 2))**4, (x, pi*R(3, 2))))
2714
+
2715
+
2716
+ def test_X9():
2717
+ assert (series(x**x, x, x0=0, n=4) == 1 + x*log(x) + x**2*log(x)**2/2 +
2718
+ x**3*log(x)**3/6 + O(x**4*log(x)**4))
2719
+
2720
+
2721
+ def test_X10():
2722
+ z, w = symbols('z w')
2723
+ assert (series(log(sinh(z)) + log(cosh(z + w)), z, x0=0, n=2) ==
2724
+ log(cosh(w)) + log(z) + z*sinh(w)/cosh(w) + O(z**2))
2725
+
2726
+
2727
+ def test_X11():
2728
+ z, w = symbols('z w')
2729
+ assert (series(log(sinh(z) * cosh(z + w)), z, x0=0, n=2) ==
2730
+ log(cosh(w)) + log(z) + z*sinh(w)/cosh(w) + O(z**2))
2731
+
2732
+
2733
+ @XFAIL
2734
+ def test_X12():
2735
+ # Look at the generalized Taylor series around x = 1
2736
+ # Result => (x - 1)^a/e^b [1 - (a + 2 b) (x - 1) / 2 + O((x - 1)^2)]
2737
+ a, b, x = symbols('a b x', real=True)
2738
+ # series returns O(log(x-1)**2)
2739
+ # https://github.com/sympy/sympy/issues/7168
2740
+ assert (series(log(x)**a*exp(-b*x), x, x0=1, n=2) ==
2741
+ (x - 1)**a/exp(b)*(1 - (a + 2*b)*(x - 1)/2 + O((x - 1)**2)))
2742
+
2743
+
2744
+ def test_X13():
2745
+ assert series(sqrt(2*x**2 + 1), x, x0=oo, n=1) == sqrt(2)*x + O(1/x, (x, oo))
2746
+
2747
+
2748
+ @XFAIL
2749
+ def test_X14():
2750
+ # Wallis' product => 1/sqrt(pi n) + ... [Knopp, p. 385]
2751
+ assert series(1/2**(2*n)*binomial(2*n, n),
2752
+ n, x==oo, n=1) == 1/(sqrt(pi)*sqrt(n)) + O(1/x, (x, oo))
2753
+
2754
+
2755
+ @SKIP("https://github.com/sympy/sympy/issues/7164")
2756
+ def test_X15():
2757
+ # => 0!/x - 1!/x^2 + 2!/x^3 - 3!/x^4 + O(1/x^5) [Knopp, p. 544]
2758
+ x, t = symbols('x t', real=True)
2759
+ # raises RuntimeError: maximum recursion depth exceeded
2760
+ # https://github.com/sympy/sympy/issues/7164
2761
+ # 2019-02-17: Raises
2762
+ # PoleError:
2763
+ # Asymptotic expansion of Ei around [-oo] is not implemented.
2764
+ e1 = integrate(exp(-t)/t, (t, x, oo))
2765
+ assert (series(e1, x, x0=oo, n=5) ==
2766
+ 6/x**4 + 2/x**3 - 1/x**2 + 1/x + O(x**(-5), (x, oo)))
2767
+
2768
+
2769
+ def test_X16():
2770
+ # Multivariate Taylor series expansion => 1 - (x^2 + 2 x y + y^2)/2 + O(x^4)
2771
+ assert (series(cos(x + y), x + y, x0=0, n=4) == 1 - (x + y)**2/2 +
2772
+ O(x**4 + x**3*y + x**2*y**2 + x*y**3 + y**4, x, y))
2773
+
2774
+
2775
+ @XFAIL
2776
+ def test_X17():
2777
+ # Power series (compute the general formula)
2778
+ # (c41) powerseries(log(sin(x)/x), x, 0);
2779
+ # /aquarius/data2/opt/local/macsyma_422/library1/trgred.so being loaded.
2780
+ # inf
2781
+ # ==== i1 2 i1 2 i1
2782
+ # \ (- 1) 2 bern(2 i1) x
2783
+ # (d41) > ------------------------------
2784
+ # / 2 i1 (2 i1)!
2785
+ # ====
2786
+ # i1 = 1
2787
+ # fps does not calculate
2788
+ assert fps(log(sin(x)/x)) == \
2789
+ Sum((-1)**k*2**(2*k - 1)*bernoulli(2*k)*x**(2*k)/(k*factorial(2*k)), (k, 1, oo))
2790
+
2791
+
2792
+ @XFAIL
2793
+ def test_X18():
2794
+ # Power series (compute the general formula). Maple FPS:
2795
+ # > FormalPowerSeries(exp(-x)*sin(x), x = 0);
2796
+ # infinity
2797
+ # ----- (1/2 k) k
2798
+ # \ 2 sin(3/4 k Pi) x
2799
+ # ) -------------------------
2800
+ # / k!
2801
+ # -----
2802
+ #
2803
+ # Now, SymPy returns
2804
+ # oo
2805
+ # _____
2806
+ # \ `
2807
+ # \ / k k\
2808
+ # \ k |I*(-1 - I) I*(-1 + I) |
2809
+ # \ x *|----------- - -----------|
2810
+ # / \ 2 2 /
2811
+ # / ------------------------------
2812
+ # / k!
2813
+ # /____,
2814
+ # k = 0
2815
+ k = Dummy('k')
2816
+ assert fps(exp(-x)*sin(x)) == \
2817
+ Sum(2**(S.Half*k)*sin(R(3, 4)*k*pi)*x**k/factorial(k), (k, 0, oo))
2818
+
2819
+
2820
+ @XFAIL
2821
+ def test_X19():
2822
+ # (c45) /* Derive an explicit Taylor series solution of y as a function of
2823
+ # x from the following implicit relation:
2824
+ # y = x - 1 + (x - 1)^2/2 + 2/3 (x - 1)^3 + (x - 1)^4 +
2825
+ # 17/10 (x - 1)^5 + ...
2826
+ # */
2827
+ # x = sin(y) + cos(y);
2828
+ # Time= 0 msecs
2829
+ # (d45) x = sin(y) + cos(y)
2830
+ #
2831
+ # (c46) taylor_revert(%, y, 7);
2832
+ raise NotImplementedError("Solve using series not supported. \
2833
+ Inverse Taylor series expansion also not supported")
2834
+
2835
+
2836
+ @XFAIL
2837
+ def test_X20():
2838
+ # Pade (rational function) approximation => (2 - x)/(2 + x)
2839
+ # > numapprox[pade](exp(-x), x = 0, [1, 1]);
2840
+ # bytes used=9019816, alloc=3669344, time=13.12
2841
+ # 1 - 1/2 x
2842
+ # ---------
2843
+ # 1 + 1/2 x
2844
+ # mpmath support numeric Pade approximant but there is
2845
+ # no symbolic implementation in SymPy
2846
+ # https://en.wikipedia.org/wiki/Pad%C3%A9_approximant
2847
+ raise NotImplementedError("Symbolic Pade approximant not supported")
2848
+
2849
+
2850
+ def test_X21():
2851
+ """
2852
+ Test whether `fourier_series` of x periodical on the [-p, p] interval equals
2853
+ `- (2 p / pi) sum( (-1)^n / n sin(n pi x / p), n = 1..infinity )`.
2854
+ """
2855
+ p = symbols('p', positive=True)
2856
+ n = symbols('n', positive=True, integer=True)
2857
+ s = fourier_series(x, (x, -p, p))
2858
+
2859
+ # All cosine coefficients are equal to 0
2860
+ assert s.an.formula == 0
2861
+
2862
+ # Check for sine coefficients
2863
+ assert s.bn.formula.subs(s.bn.variables[0], 0) == 0
2864
+ assert s.bn.formula.subs(s.bn.variables[0], n) == \
2865
+ -2*p/pi * (-1)**n / n * sin(n*pi*x/p)
2866
+
2867
+
2868
+ @XFAIL
2869
+ def test_X22():
2870
+ # (c52) /* => p / 2
2871
+ # - (2 p / pi^2) sum( [1 - (-1)^n] cos(n pi x / p) / n^2,
2872
+ # n = 1..infinity ) */
2873
+ # fourier_series(abs(x), x, p);
2874
+ # p
2875
+ # (e52) a = -
2876
+ # 0 2
2877
+ #
2878
+ # %nn
2879
+ # (2 (- 1) - 2) p
2880
+ # (e53) a = ------------------
2881
+ # %nn 2 2
2882
+ # %pi %nn
2883
+ #
2884
+ # (e54) b = 0
2885
+ # %nn
2886
+ #
2887
+ # Time= 5290 msecs
2888
+ # inf %nn %pi %nn x
2889
+ # ==== (2 (- 1) - 2) cos(---------)
2890
+ # \ p
2891
+ # p > -------------------------------
2892
+ # / 2
2893
+ # ==== %nn
2894
+ # %nn = 1 p
2895
+ # (d54) ----------------------------------------- + -
2896
+ # 2 2
2897
+ # %pi
2898
+ raise NotImplementedError("Fourier series not supported")
2899
+
2900
+
2901
+ def test_Y1():
2902
+ t = symbols('t', positive=True)
2903
+ w = symbols('w', real=True)
2904
+ s = symbols('s')
2905
+ F, _, _ = laplace_transform(cos((w - 1)*t), t, s)
2906
+ assert F == s/(s**2 + (w - 1)**2)
2907
+
2908
+
2909
+ def test_Y2():
2910
+ t = symbols('t', positive=True)
2911
+ w = symbols('w', real=True)
2912
+ s = symbols('s')
2913
+ f = inverse_laplace_transform(s/(s**2 + (w - 1)**2), s, t, simplify=True)
2914
+ assert f == cos(t*(w - 1))
2915
+
2916
+
2917
+ def test_Y3():
2918
+ t = symbols('t', positive=True)
2919
+ w = symbols('w', real=True)
2920
+ s = symbols('s')
2921
+ F, _, _ = laplace_transform(sinh(w*t)*cosh(w*t), t, s, simplify=True)
2922
+ assert F == w/(s**2 - 4*w**2)
2923
+
2924
+
2925
+ def test_Y4():
2926
+ t = symbols('t', positive=True)
2927
+ s = symbols('s')
2928
+ F, _, _ = laplace_transform(erf(3/sqrt(t)), t, s, simplify=True)
2929
+ assert F == 1/s - exp(-6*sqrt(s))/s
2930
+
2931
+
2932
+ def test_Y5_Y6():
2933
+ # Solve y'' + y = 4 [H(t - 1) - H(t - 2)], y(0) = 1, y'(0) = 0 where H is the
2934
+ # Heaviside (unit step) function (the RHS describes a pulse of magnitude 4 and
2935
+ # duration 1). See David A. Sanchez, Richard C. Allen, Jr. and Walter T.
2936
+ # Kyner, _Differential Equations: An Introduction_, Addison-Wesley Publishing
2937
+ # Company, 1983, p. 211. First, take the Laplace transform of the ODE
2938
+ # => s^2 Y(s) - s + Y(s) = 4/s [e^(-s) - e^(-2 s)]
2939
+ # where Y(s) is the Laplace transform of y(t)
2940
+ t = symbols('t', real=True)
2941
+ s = symbols('s')
2942
+ y = Function('y')
2943
+ Y = Function('Y')
2944
+ F = laplace_correspondence(laplace_transform(diff(y(t), t, 2) + y(t)
2945
+ - 4*(Heaviside(t - 1) - Heaviside(t - 2)),
2946
+ t, s, noconds=True), {y: Y})
2947
+ D = (
2948
+ -F + s**2*Y(s) - s*y(0) + Y(s) - Subs(Derivative(y(t), t), t, 0) -
2949
+ 4*exp(-s)/s + 4*exp(-2*s)/s)
2950
+ assert D == 0
2951
+ # Now, solve for Y(s) and then take the inverse Laplace transform
2952
+ # => Y(s) = s/(s^2 + 1) + 4 [1/s - s/(s^2 + 1)] [e^(-s) - e^(-2 s)]
2953
+ # => y(t) = cos t + 4 {[1 - cos(t - 1)] H(t - 1) - [1 - cos(t - 2)] H(t - 2)}
2954
+ Yf = solve(F, Y(s))[0]
2955
+ Yf = laplace_initial_conds(Yf, t, {y: [1, 0]})
2956
+ assert Yf == (s**2*exp(2*s) + 4*exp(s) - 4)*exp(-2*s)/(s*(s**2 + 1))
2957
+ yf = inverse_laplace_transform(Yf, s, t)
2958
+ yf = yf.collect(Heaviside(t-1)).collect(Heaviside(t-2))
2959
+ assert yf == (
2960
+ (4 - 4*cos(t - 1))*Heaviside(t - 1) +
2961
+ (4*cos(t - 2) - 4)*Heaviside(t - 2) +
2962
+ cos(t)*Heaviside(t))
2963
+
2964
+
2965
+ @XFAIL
2966
+ def test_Y7():
2967
+ # What is the Laplace transform of an infinite square wave?
2968
+ # => 1/s + 2 sum( (-1)^n e^(- s n a)/s, n = 1..infinity )
2969
+ # [Sanchez, Allen and Kyner, p. 213]
2970
+ t = symbols('t', positive=True)
2971
+ a = symbols('a', real=True)
2972
+ s = symbols('s')
2973
+ F, _, _ = laplace_transform(1 + 2*Sum((-1)**n*Heaviside(t - n*a),
2974
+ (n, 1, oo)), t, s)
2975
+ # returns 2*LaplaceTransform(Sum((-1)**n*Heaviside(-a*n + t),
2976
+ # (n, 1, oo)), t, s) + 1/s
2977
+ # https://github.com/sympy/sympy/issues/7177
2978
+ assert F == 2*Sum((-1)**n*exp(-a*n*s)/s, (n, 1, oo)) + 1/s
2979
+
2980
+
2981
+ @XFAIL
2982
+ def test_Y8():
2983
+ assert fourier_transform(1, x, z) == DiracDelta(z)
2984
+
2985
+
2986
+ def test_Y9():
2987
+ assert (fourier_transform(exp(-9*x**2), x, z) ==
2988
+ sqrt(pi)*exp(-pi**2*z**2/9)/3)
2989
+
2990
+
2991
+ def test_Y10():
2992
+ assert (fourier_transform(abs(x)*exp(-3*abs(x)), x, z).cancel() ==
2993
+ (-8*pi**2*z**2 + 18)/(16*pi**4*z**4 + 72*pi**2*z**2 + 81))
2994
+
2995
+
2996
+ @SKIP("https://github.com/sympy/sympy/issues/7181")
2997
+ @slow
2998
+ def test_Y11():
2999
+ # => pi cot(pi s) (0 < Re s < 1) [Gradshteyn and Ryzhik 17.43(5)]
3000
+ x, s = symbols('x s')
3001
+ # raises RuntimeError: maximum recursion depth exceeded
3002
+ # https://github.com/sympy/sympy/issues/7181
3003
+ # Update 2019-02-17 raises:
3004
+ # TypeError: cannot unpack non-iterable MellinTransform object
3005
+ F, _, _ = mellin_transform(1/(1 - x), x, s)
3006
+ assert F == pi*cot(pi*s)
3007
+
3008
+
3009
+ @XFAIL
3010
+ def test_Y12():
3011
+ # => 2^(s - 4) gamma(s/2)/gamma(4 - s/2) (0 < Re s < 1)
3012
+ # [Gradshteyn and Ryzhik 17.43(16)]
3013
+ x, s = symbols('x s')
3014
+ # returns Wrong value -2**(s - 4)*gamma(s/2 - 3)/gamma(-s/2 + 1)
3015
+ # https://github.com/sympy/sympy/issues/7182
3016
+ F, _, _ = mellin_transform(besselj(3, x)/x**3, x, s)
3017
+ assert F == -2**(s - 4)*gamma(s/2)/gamma(-s/2 + 4)
3018
+
3019
+
3020
+ @XFAIL
3021
+ def test_Y13():
3022
+ # Z[H(t - m T)] => z/[z^m (z - 1)] (H is the Heaviside (unit step) function) z
3023
+ raise NotImplementedError("z-transform not supported")
3024
+
3025
+
3026
+ @XFAIL
3027
+ def test_Y14():
3028
+ # Z[H(t - m T)] => z/[z^m (z - 1)] (H is the Heaviside (unit step) function)
3029
+ raise NotImplementedError("z-transform not supported")
3030
+
3031
+
3032
+ def test_Z1():
3033
+ r = Function('r')
3034
+ assert (rsolve(r(n + 2) - 2*r(n + 1) + r(n) - 2, r(n),
3035
+ {r(0): 1, r(1): m}).simplify() == n**2 + n*(m - 2) + 1)
3036
+
3037
+
3038
+ def test_Z2():
3039
+ r = Function('r')
3040
+ assert (rsolve(r(n) - (5*r(n - 1) - 6*r(n - 2)), r(n), {r(0): 0, r(1): 1})
3041
+ == -2**n + 3**n)
3042
+
3043
+
3044
+ def test_Z3():
3045
+ # => r(n) = Fibonacci[n + 1] [Cohen, p. 83]
3046
+ r = Function('r')
3047
+ # recurrence solution is correct, Wester expects it to be simplified to
3048
+ # fibonacci(n+1), but that is quite hard
3049
+ expected = ((S(1)/2 - sqrt(5)/2)**n*(S(1)/2 - sqrt(5)/10)
3050
+ + (S(1)/2 + sqrt(5)/2)**n*(sqrt(5)/10 + S(1)/2))
3051
+ sol = rsolve(r(n) - (r(n - 1) + r(n - 2)), r(n), {r(1): 1, r(2): 2})
3052
+ assert sol == expected
3053
+
3054
+
3055
+ @XFAIL
3056
+ def test_Z4():
3057
+ # => [c^(n+1) [c^(n+1) - 2 c - 2] + (n+1) c^2 + 2 c - n] / [(c-1)^3 (c+1)]
3058
+ # [Joan Z. Yu and Robert Israel in sci.math.symbolic]
3059
+ r = Function('r')
3060
+ c = symbols('c')
3061
+ # raises ValueError: Polynomial or rational function expected,
3062
+ # got '(c**2 - c**n)/(c - c**n)
3063
+ s = rsolve(r(n) - ((1 + c - c**(n-1) - c**(n+1))/(1 - c**n)*r(n - 1)
3064
+ - c*(1 - c**(n-2))/(1 - c**(n-1))*r(n - 2) + 1),
3065
+ r(n), {r(1): 1, r(2): (2 + 2*c + c**2)/(1 + c)})
3066
+ assert (s - (c*(n + 1)*(c*(n + 1) - 2*c - 2) +
3067
+ (n + 1)*c**2 + 2*c - n)/((c-1)**3*(c+1)) == 0)
3068
+
3069
+
3070
+ @XFAIL
3071
+ def test_Z5():
3072
+ # Second order ODE with initial conditions---solve directly
3073
+ # transform: f(t) = sin(2 t)/8 - t cos(2 t)/4
3074
+ C1, C2 = symbols('C1 C2')
3075
+ # initial conditions not supported, this is a manual workaround
3076
+ # https://github.com/sympy/sympy/issues/4720
3077
+ eq = Derivative(f(x), x, 2) + 4*f(x) - sin(2*x)
3078
+ sol = dsolve(eq, f(x))
3079
+ f0 = Lambda(x, sol.rhs)
3080
+ assert f0(x) == C2*sin(2*x) + (C1 - x/4)*cos(2*x)
3081
+ f1 = Lambda(x, diff(f0(x), x))
3082
+ # TODO: Replace solve with solveset, when it works for solveset
3083
+ const_dict = solve((f0(0), f1(0)))
3084
+ result = f0(x).subs(C1, const_dict[C1]).subs(C2, const_dict[C2])
3085
+ assert result == -x*cos(2*x)/4 + sin(2*x)/8
3086
+ # Result is OK, but ODE solving with initial conditions should be
3087
+ # supported without all this manual work
3088
+ raise NotImplementedError('ODE solving with initial conditions \
3089
+ not supported')
3090
+
3091
+
3092
+ @XFAIL
3093
+ def test_Z6():
3094
+ # Second order ODE with initial conditions---solve using Laplace
3095
+ # transform: f(t) = sin(2 t)/8 - t cos(2 t)/4
3096
+ t = symbols('t', positive=True)
3097
+ s = symbols('s')
3098
+ eq = Derivative(f(t), t, 2) + 4*f(t) - sin(2*t)
3099
+ F, _, _ = laplace_transform(eq, t, s)
3100
+ # Laplace transform for diff() not calculated
3101
+ # https://github.com/sympy/sympy/issues/7176
3102
+ assert (F == s**2*LaplaceTransform(f(t), t, s) +
3103
+ 4*LaplaceTransform(f(t), t, s) - 2/(s**2 + 4))
3104
+ # rest of test case not implemented
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_xxe.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # A test file for XXE injection
2
+ # Username: Test
3
+ # Password: Test