MTerryJack commited on
Commit
b2fa06a
·
verified ·
1 Parent(s): 5a67a10

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_cartan_matrix.py +10 -0
  2. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_B.py +17 -0
  3. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_C.py +22 -0
  4. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_F.py +24 -0
  5. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_G.py +16 -0
  6. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/__init__.py +22 -0
  7. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/availability.py +77 -0
  8. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/compilation.py +657 -0
  9. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/runners.py +301 -0
  10. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/__init__.py +0 -0
  11. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/test_compilation.py +104 -0
  12. .venv/lib/python3.13/site-packages/sympy/utilities/_compilation/util.py +312 -0
  13. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/__init__.py +0 -0
  14. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmlctop.xsl +0 -0
  15. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmltex.xsl +0 -0
  16. .venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl +0 -0
  17. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen.py +1632 -0
  18. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_julia.py +620 -0
  19. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_decorator.py +129 -0
  20. .venv/lib/python3.13/site-packages/sympy/utilities/tests/test_enumerative.py +179 -0
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_cartan_matrix.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_matrix import CartanMatrix
2
+ from sympy.matrices import Matrix
3
+
4
+ def test_CartanMatrix():
5
+ c = CartanMatrix("A3")
6
+ m = Matrix(3, 3, [2, -1, 0, -1, 2, -1, 0, -1, 2])
7
+ assert c == m
8
+ a = CartanMatrix(["G",2])
9
+ mt = Matrix(2, 2, [2, -1, -3, 2])
10
+ assert a == mt
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_B.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_type import CartanType
2
+ from sympy.matrices import Matrix
3
+
4
+ def test_type_B():
5
+ c = CartanType("B3")
6
+ m = Matrix(3, 3, [2, -1, 0, -1, 2, -2, 0, -1, 2])
7
+ assert m == c.cartan_matrix()
8
+ assert c.dimension() == 3
9
+ assert c.roots() == 18
10
+ assert c.simple_root(3) == [0, 0, 1]
11
+ assert c.basis() == 3
12
+ assert c.lie_algebra() == "so(6)"
13
+ diag = "0---0=>=0\n1 2 3"
14
+ assert c.dynkin_diagram() == diag
15
+ assert c.positive_roots() == {1: [1, -1, 0], 2: [1, 1, 0], 3: [1, 0, -1],
16
+ 4: [1, 0, 1], 5: [0, 1, -1], 6: [0, 1, 1], 7: [1, 0, 0],
17
+ 8: [0, 1, 0], 9: [0, 0, 1]}
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_C.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_type import CartanType
2
+ from sympy.matrices import Matrix
3
+
4
+ def test_type_C():
5
+ c = CartanType("C4")
6
+ m = Matrix(4, 4, [2, -1, 0, 0, -1, 2, -1, 0, 0, -1, 2, -1, 0, 0, -2, 2])
7
+ assert c.cartan_matrix() == m
8
+ assert c.dimension() == 4
9
+ assert c.simple_root(4) == [0, 0, 0, 2]
10
+ assert c.roots() == 32
11
+ assert c.basis() == 36
12
+ assert c.lie_algebra() == "sp(8)"
13
+ t = CartanType(['C', 3])
14
+ assert t.dimension() == 3
15
+ diag = "0---0---0=<=0\n1 2 3 4"
16
+ assert c.dynkin_diagram() == diag
17
+ assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 1, 0, 0],
18
+ 3: [1, 0, -1, 0], 4: [1, 0, 1, 0], 5: [1, 0, 0, -1],
19
+ 6: [1, 0, 0, 1], 7: [0, 1, -1, 0], 8: [0, 1, 1, 0],
20
+ 9: [0, 1, 0, -1], 10: [0, 1, 0, 1], 11: [0, 0, 1, -1],
21
+ 12: [0, 0, 1, 1], 13: [2, 0, 0, 0], 14: [0, 2, 0, 0], 15: [0, 0, 2, 0],
22
+ 16: [0, 0, 0, 2]}
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_F.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_type import CartanType
2
+ from sympy.matrices import Matrix
3
+ from sympy.core.backend import S
4
+
5
+ def test_type_F():
6
+ c = CartanType("F4")
7
+ m = Matrix(4, 4, [2, -1, 0, 0, -1, 2, -2, 0, 0, -1, 2, -1, 0, 0, -1, 2])
8
+ assert c.cartan_matrix() == m
9
+ assert c.dimension() == 4
10
+ assert c.simple_root(1) == [1, -1, 0, 0]
11
+ assert c.simple_root(2) == [0, 1, -1, 0]
12
+ assert c.simple_root(3) == [0, 0, 0, 1]
13
+ assert c.simple_root(4) == [-S.Half, -S.Half, -S.Half, -S.Half]
14
+ assert c.roots() == 48
15
+ assert c.basis() == 52
16
+ diag = "0---0=>=0---0\n" + " ".join(str(i) for i in range(1, 5))
17
+ assert c.dynkin_diagram() == diag
18
+ assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 1, 0, 0], 3: [1, 0, -1, 0],
19
+ 4: [1, 0, 1, 0], 5: [1, 0, 0, -1], 6: [1, 0, 0, 1], 7: [0, 1, -1, 0],
20
+ 8: [0, 1, 1, 0], 9: [0, 1, 0, -1], 10: [0, 1, 0, 1], 11: [0, 0, 1, -1],
21
+ 12: [0, 0, 1, 1], 13: [1, 0, 0, 0], 14: [0, 1, 0, 0], 15: [0, 0, 1, 0],
22
+ 16: [0, 0, 0, 1], 17: [S.Half, S.Half, S.Half, S.Half], 18: [S.Half, -S.Half, S.Half, S.Half],
23
+ 19: [S.Half, S.Half, -S.Half, S.Half], 20: [S.Half, S.Half, S.Half, -S.Half], 21: [S.Half, S.Half, -S.Half, -S.Half],
24
+ 22: [S.Half, -S.Half, S.Half, -S.Half], 23: [S.Half, -S.Half, -S.Half, S.Half], 24: [S.Half, -S.Half, -S.Half, -S.Half]}
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_G.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from sympy.liealgebras.cartan_type import CartanType
3
+ from sympy.matrices import Matrix
4
+
5
+ def test_type_G():
6
+ c = CartanType("G2")
7
+ m = Matrix(2, 2, [2, -1, -3, 2])
8
+ assert c.cartan_matrix() == m
9
+ assert c.simple_root(2) == [1, -2, 1]
10
+ assert c.basis() == 14
11
+ assert c.roots() == 12
12
+ assert c.dimension() == 3
13
+ diag = "0≡<≡0\n1 2"
14
+ assert diag == c.dynkin_diagram()
15
+ assert c.positive_roots() == {1: [0, 1, -1], 2: [1, -2, 1], 3: [1, -1, 0],
16
+ 4: [1, 0, 1], 5: [1, 1, -2], 6: [2, -1, -1]}
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ This sub-module is private, i.e. external code should not depend on it.
2
+
3
+ These functions are used by tests run as part of continuous integration.
4
+ Once the implementation is mature (it should support the major
5
+ platforms: Windows, OS X & Linux) it may become official API which
6
+ may be relied upon by downstream libraries. Until then API may break
7
+ without prior notice.
8
+
9
+ TODO:
10
+ - (optionally) clean up after tempfile.mkdtemp()
11
+ - cross-platform testing
12
+ - caching of compiler choice and intermediate files
13
+
14
+ """
15
+
16
+ from .compilation import compile_link_import_strings, compile_run_strings
17
+ from .availability import has_fortran, has_c, has_cxx
18
+
19
+ __all__ = [
20
+ 'compile_link_import_strings', 'compile_run_strings',
21
+ 'has_fortran', 'has_c', 'has_cxx',
22
+ ]
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/availability.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .compilation import compile_run_strings
3
+ from .util import CompilerNotFoundError
4
+
5
+ def has_fortran():
6
+ if not hasattr(has_fortran, 'result'):
7
+ try:
8
+ (stdout, stderr), info = compile_run_strings(
9
+ [('main.f90', (
10
+ 'program foo\n'
11
+ 'print *, "hello world"\n'
12
+ 'end program'
13
+ ))], clean=True
14
+ )
15
+ except CompilerNotFoundError:
16
+ has_fortran.result = False
17
+ if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
18
+ raise
19
+ else:
20
+ if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
21
+ if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
22
+ raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
23
+ has_fortran.result = False
24
+ else:
25
+ has_fortran.result = True
26
+ return has_fortran.result
27
+
28
+
29
+ def has_c():
30
+ if not hasattr(has_c, 'result'):
31
+ try:
32
+ (stdout, stderr), info = compile_run_strings(
33
+ [('main.c', (
34
+ '#include <stdio.h>\n'
35
+ 'int main(){\n'
36
+ 'printf("hello world\\n");\n'
37
+ 'return 0;\n'
38
+ '}'
39
+ ))], clean=True
40
+ )
41
+ except CompilerNotFoundError:
42
+ has_c.result = False
43
+ if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
44
+ raise
45
+ else:
46
+ if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
47
+ if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
48
+ raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
49
+ has_c.result = False
50
+ else:
51
+ has_c.result = True
52
+ return has_c.result
53
+
54
+
55
+ def has_cxx():
56
+ if not hasattr(has_cxx, 'result'):
57
+ try:
58
+ (stdout, stderr), info = compile_run_strings(
59
+ [('main.cxx', (
60
+ '#include <iostream>\n'
61
+ 'int main(){\n'
62
+ 'std::cout << "hello world" << std::endl;\n'
63
+ '}'
64
+ ))], clean=True
65
+ )
66
+ except CompilerNotFoundError:
67
+ has_cxx.result = False
68
+ if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
69
+ raise
70
+ else:
71
+ if info['exit_status'] != os.EX_OK or 'hello world' not in stdout:
72
+ if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1':
73
+ raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr))
74
+ has_cxx.result = False
75
+ else:
76
+ has_cxx.result = True
77
+ return has_cxx.result
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/compilation.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ import tempfile
7
+ import warnings
8
+ from pathlib import Path
9
+ from sysconfig import get_config_var, get_config_vars, get_path
10
+
11
+ from .runners import (
12
+ CCompilerRunner,
13
+ CppCompilerRunner,
14
+ FortranCompilerRunner
15
+ )
16
+ from .util import (
17
+ get_abspath, make_dirs, copy, Glob, ArbitraryDepthGlob,
18
+ glob_at_depth, import_module_from_file, pyx_is_cplus,
19
+ sha256_of_string, sha256_of_file, CompileError
20
+ )
21
+
22
+ if os.name == 'posix':
23
+ objext = '.o'
24
+ elif os.name == 'nt':
25
+ objext = '.obj'
26
+ else:
27
+ warnings.warn("Unknown os.name: {}".format(os.name))
28
+ objext = '.o'
29
+
30
+
31
+ def compile_sources(files, Runner=None, destdir=None, cwd=None, keep_dir_struct=False,
32
+ per_file_kwargs=None, **kwargs):
33
+ """ Compile source code files to object files.
34
+
35
+ Parameters
36
+ ==========
37
+
38
+ files : iterable of str
39
+ Paths to source files, if ``cwd`` is given, the paths are taken as relative.
40
+ Runner: CompilerRunner subclass (optional)
41
+ Could be e.g. ``FortranCompilerRunner``. Will be inferred from filename
42
+ extensions if missing.
43
+ destdir: str
44
+ Output directory, if cwd is given, the path is taken as relative.
45
+ cwd: str
46
+ Working directory. Specify to have compiler run in other directory.
47
+ also used as root of relative paths.
48
+ keep_dir_struct: bool
49
+ Reproduce directory structure in `destdir`. default: ``False``
50
+ per_file_kwargs: dict
51
+ Dict mapping instances in ``files`` to keyword arguments.
52
+ \\*\\*kwargs: dict
53
+ Default keyword arguments to pass to ``Runner``.
54
+
55
+ Returns
56
+ =======
57
+ List of strings (paths of object files).
58
+ """
59
+ _per_file_kwargs = {}
60
+
61
+ if per_file_kwargs is not None:
62
+ for k, v in per_file_kwargs.items():
63
+ if isinstance(k, Glob):
64
+ for path in glob.glob(k.pathname):
65
+ _per_file_kwargs[path] = v
66
+ elif isinstance(k, ArbitraryDepthGlob):
67
+ for path in glob_at_depth(k.filename, cwd):
68
+ _per_file_kwargs[path] = v
69
+ else:
70
+ _per_file_kwargs[k] = v
71
+
72
+ # Set up destination directory
73
+ destdir = destdir or '.'
74
+ if not os.path.isdir(destdir):
75
+ if os.path.exists(destdir):
76
+ raise OSError("{} is not a directory".format(destdir))
77
+ else:
78
+ make_dirs(destdir)
79
+ if cwd is None:
80
+ cwd = '.'
81
+ for f in files:
82
+ copy(f, destdir, only_update=True, dest_is_dir=True)
83
+
84
+ # Compile files and return list of paths to the objects
85
+ dstpaths = []
86
+ for f in files:
87
+ if keep_dir_struct:
88
+ name, ext = os.path.splitext(f)
89
+ else:
90
+ name, ext = os.path.splitext(os.path.basename(f))
91
+ file_kwargs = kwargs.copy()
92
+ file_kwargs.update(_per_file_kwargs.get(f, {}))
93
+ dstpaths.append(src2obj(f, Runner, cwd=cwd, **file_kwargs))
94
+ return dstpaths
95
+
96
+
97
+ def get_mixed_fort_c_linker(vendor=None, cplus=False, cwd=None):
98
+ vendor = vendor or os.environ.get('SYMPY_COMPILER_VENDOR', 'gnu')
99
+
100
+ if vendor.lower() == 'intel':
101
+ if cplus:
102
+ return (FortranCompilerRunner,
103
+ {'flags': ['-nofor_main', '-cxxlib']}, vendor)
104
+ else:
105
+ return (FortranCompilerRunner,
106
+ {'flags': ['-nofor_main']}, vendor)
107
+ elif vendor.lower() == 'gnu' or 'llvm':
108
+ if cplus:
109
+ return (CppCompilerRunner,
110
+ {'lib_options': ['fortran']}, vendor)
111
+ else:
112
+ return (FortranCompilerRunner,
113
+ {}, vendor)
114
+ else:
115
+ raise ValueError("No vendor found.")
116
+
117
+
118
+ def link(obj_files, out_file=None, shared=False, Runner=None,
119
+ cwd=None, cplus=False, fort=False, extra_objs=None, **kwargs):
120
+ """ Link object files.
121
+
122
+ Parameters
123
+ ==========
124
+
125
+ obj_files: iterable of str
126
+ Paths to object files.
127
+ out_file: str (optional)
128
+ Path to executable/shared library, if ``None`` it will be
129
+ deduced from the last item in obj_files.
130
+ shared: bool
131
+ Generate a shared library?
132
+ Runner: CompilerRunner subclass (optional)
133
+ If not given the ``cplus`` and ``fort`` flags will be inspected
134
+ (fallback is the C compiler).
135
+ cwd: str
136
+ Path to the root of relative paths and working directory for compiler.
137
+ cplus: bool
138
+ C++ objects? default: ``False``.
139
+ fort: bool
140
+ Fortran objects? default: ``False``.
141
+ extra_objs: list
142
+ List of paths to extra object files / static libraries.
143
+ \\*\\*kwargs: dict
144
+ Keyword arguments passed to ``Runner``.
145
+
146
+ Returns
147
+ =======
148
+
149
+ The absolute path to the generated shared object / executable.
150
+
151
+ """
152
+ if out_file is None:
153
+ out_file, ext = os.path.splitext(os.path.basename(obj_files[-1]))
154
+ if shared:
155
+ out_file += get_config_var('EXT_SUFFIX')
156
+
157
+ if not Runner:
158
+ if fort:
159
+ Runner, extra_kwargs, vendor = \
160
+ get_mixed_fort_c_linker(
161
+ vendor=kwargs.get('vendor', None),
162
+ cplus=cplus,
163
+ cwd=cwd,
164
+ )
165
+ for k, v in extra_kwargs.items():
166
+ if k in kwargs:
167
+ kwargs[k].expand(v)
168
+ else:
169
+ kwargs[k] = v
170
+ else:
171
+ if cplus:
172
+ Runner = CppCompilerRunner
173
+ else:
174
+ Runner = CCompilerRunner
175
+
176
+ flags = kwargs.pop('flags', [])
177
+ if shared:
178
+ if '-shared' not in flags:
179
+ flags.append('-shared')
180
+ run_linker = kwargs.pop('run_linker', True)
181
+ if not run_linker:
182
+ raise ValueError("run_linker was set to False (nonsensical).")
183
+
184
+ out_file = get_abspath(out_file, cwd=cwd)
185
+ runner = Runner(obj_files+(extra_objs or []), out_file, flags, cwd=cwd, **kwargs)
186
+ runner.run()
187
+ return out_file
188
+
189
+
190
+ def link_py_so(obj_files, so_file=None, cwd=None, libraries=None,
191
+ cplus=False, fort=False, extra_objs=None, **kwargs):
192
+ """ Link Python extension module (shared object) for importing
193
+
194
+ Parameters
195
+ ==========
196
+
197
+ obj_files: iterable of str
198
+ Paths to object files to be linked.
199
+ so_file: str
200
+ Name (path) of shared object file to create. If not specified it will
201
+ have the basname of the last object file in `obj_files` but with the
202
+ extension '.so' (Unix).
203
+ cwd: path string
204
+ Root of relative paths and working directory of linker.
205
+ libraries: iterable of strings
206
+ Libraries to link against, e.g. ['m'].
207
+ cplus: bool
208
+ Any C++ objects? default: ``False``.
209
+ fort: bool
210
+ Any Fortran objects? default: ``False``.
211
+ extra_objs: list
212
+ List of paths of extra object files / static libraries to link against.
213
+ kwargs**: dict
214
+ Keyword arguments passed to ``link(...)``.
215
+
216
+ Returns
217
+ =======
218
+
219
+ Absolute path to the generate shared object.
220
+ """
221
+ libraries = libraries or []
222
+
223
+ include_dirs = kwargs.pop('include_dirs', [])
224
+ library_dirs = kwargs.pop('library_dirs', [])
225
+
226
+ # Add Python include and library directories
227
+ # PY_LDFLAGS does not available on all python implementations
228
+ # e.g. when with pypy, so it's LDFLAGS we need to use
229
+ if sys.platform == "win32":
230
+ warnings.warn("Windows not yet supported.")
231
+ elif sys.platform == 'darwin':
232
+ cfgDict = get_config_vars()
233
+ kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']]
234
+ library_dirs += [cfgDict['LIBDIR']]
235
+
236
+ # In macOS, linker needs to compile frameworks
237
+ # e.g. "-framework CoreFoundation"
238
+ is_framework = False
239
+ for opt in cfgDict['LIBS'].split():
240
+ if is_framework:
241
+ kwargs['linkline'] = kwargs.get('linkline', []) + ['-framework', opt]
242
+ is_framework = False
243
+ elif opt.startswith('-l'):
244
+ libraries.append(opt[2:])
245
+ elif opt.startswith('-framework'):
246
+ is_framework = True
247
+ # The python library is not included in LIBS
248
+ libfile = cfgDict['LIBRARY']
249
+ libname = ".".join(libfile.split('.')[:-1])[3:]
250
+ libraries.append(libname)
251
+
252
+ elif sys.platform[:3] == 'aix':
253
+ # Don't use the default code below
254
+ pass
255
+ else:
256
+ if get_config_var('Py_ENABLE_SHARED'):
257
+ cfgDict = get_config_vars()
258
+ kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']]
259
+ library_dirs += [cfgDict['LIBDIR']]
260
+ for opt in cfgDict['BLDLIBRARY'].split():
261
+ if opt.startswith('-l'):
262
+ libraries += [opt[2:]]
263
+ else:
264
+ pass
265
+
266
+ flags = kwargs.pop('flags', [])
267
+ needed_flags = ('-pthread',)
268
+ for flag in needed_flags:
269
+ if flag not in flags:
270
+ flags.append(flag)
271
+
272
+ return link(obj_files, shared=True, flags=flags, cwd=cwd, cplus=cplus, fort=fort,
273
+ include_dirs=include_dirs, libraries=libraries,
274
+ library_dirs=library_dirs, extra_objs=extra_objs, **kwargs)
275
+
276
+
277
+ def simple_cythonize(src, destdir=None, cwd=None, **cy_kwargs):
278
+ """ Generates a C file from a Cython source file.
279
+
280
+ Parameters
281
+ ==========
282
+
283
+ src: str
284
+ Path to Cython source.
285
+ destdir: str (optional)
286
+ Path to output directory (default: '.').
287
+ cwd: path string (optional)
288
+ Root of relative paths (default: '.').
289
+ **cy_kwargs:
290
+ Second argument passed to cy_compile. Generates a .cpp file if ``cplus=True`` in ``cy_kwargs``,
291
+ else a .c file.
292
+ """
293
+ from Cython.Compiler.Main import (
294
+ default_options, CompilationOptions
295
+ )
296
+ from Cython.Compiler.Main import compile as cy_compile
297
+
298
+ assert src.lower().endswith('.pyx') or src.lower().endswith('.py')
299
+ cwd = cwd or '.'
300
+ destdir = destdir or '.'
301
+
302
+ ext = '.cpp' if cy_kwargs.get('cplus', False) else '.c'
303
+ c_name = os.path.splitext(os.path.basename(src))[0] + ext
304
+
305
+ dstfile = os.path.join(destdir, c_name)
306
+
307
+ if cwd:
308
+ ori_dir = os.getcwd()
309
+ else:
310
+ ori_dir = '.'
311
+ os.chdir(cwd)
312
+ try:
313
+ cy_options = CompilationOptions(default_options)
314
+ cy_options.__dict__.update(cy_kwargs)
315
+ # Set language_level if not set by cy_kwargs
316
+ # as not setting it is deprecated
317
+ if 'language_level' not in cy_kwargs:
318
+ cy_options.__dict__['language_level'] = 3
319
+ cy_result = cy_compile([src], cy_options)
320
+ if cy_result.num_errors > 0:
321
+ raise ValueError("Cython compilation failed.")
322
+
323
+ # Move generated C file to destination
324
+ # In macOS, the generated C file is in the same directory as the source
325
+ # but the /var is a symlink to /private/var, so we need to use realpath
326
+ if os.path.realpath(os.path.dirname(src)) != os.path.realpath(destdir):
327
+ if os.path.exists(dstfile):
328
+ os.unlink(dstfile)
329
+ shutil.move(os.path.join(os.path.dirname(src), c_name), destdir)
330
+ finally:
331
+ os.chdir(ori_dir)
332
+ return dstfile
333
+
334
+
335
+ extension_mapping = {
336
+ '.c': (CCompilerRunner, None),
337
+ '.cpp': (CppCompilerRunner, None),
338
+ '.cxx': (CppCompilerRunner, None),
339
+ '.f': (FortranCompilerRunner, None),
340
+ '.for': (FortranCompilerRunner, None),
341
+ '.ftn': (FortranCompilerRunner, None),
342
+ '.f90': (FortranCompilerRunner, None), # ifort only knows about .f90
343
+ '.f95': (FortranCompilerRunner, 'f95'),
344
+ '.f03': (FortranCompilerRunner, 'f2003'),
345
+ '.f08': (FortranCompilerRunner, 'f2008'),
346
+ }
347
+
348
+
349
+ def src2obj(srcpath, Runner=None, objpath=None, cwd=None, inc_py=False, **kwargs):
350
+ """ Compiles a source code file to an object file.
351
+
352
+ Files ending with '.pyx' assumed to be cython files and
353
+ are dispatched to pyx2obj.
354
+
355
+ Parameters
356
+ ==========
357
+
358
+ srcpath: str
359
+ Path to source file.
360
+ Runner: CompilerRunner subclass (optional)
361
+ If ``None``: deduced from extension of srcpath.
362
+ objpath : str (optional)
363
+ Path to generated object. If ``None``: deduced from ``srcpath``.
364
+ cwd: str (optional)
365
+ Working directory and root of relative paths. If ``None``: current dir.
366
+ inc_py: bool
367
+ Add Python include path to kwarg "include_dirs". Default: False
368
+ \\*\\*kwargs: dict
369
+ keyword arguments passed to Runner or pyx2obj
370
+
371
+ """
372
+ name, ext = os.path.splitext(os.path.basename(srcpath))
373
+ if objpath is None:
374
+ if os.path.isabs(srcpath):
375
+ objpath = '.'
376
+ else:
377
+ objpath = os.path.dirname(srcpath)
378
+ objpath = objpath or '.' # avoid objpath == ''
379
+
380
+ if os.path.isdir(objpath):
381
+ objpath = os.path.join(objpath, name + objext)
382
+
383
+ include_dirs = kwargs.pop('include_dirs', [])
384
+ if inc_py:
385
+ py_inc_dir = get_path('include')
386
+ if py_inc_dir not in include_dirs:
387
+ include_dirs.append(py_inc_dir)
388
+
389
+ if ext.lower() == '.pyx':
390
+ return pyx2obj(srcpath, objpath=objpath, include_dirs=include_dirs, cwd=cwd,
391
+ **kwargs)
392
+
393
+ if Runner is None:
394
+ Runner, std = extension_mapping[ext.lower()]
395
+ if 'std' not in kwargs:
396
+ kwargs['std'] = std
397
+
398
+ flags = kwargs.pop('flags', [])
399
+ needed_flags = ('-fPIC',)
400
+ for flag in needed_flags:
401
+ if flag not in flags:
402
+ flags.append(flag)
403
+
404
+ # src2obj implies not running the linker...
405
+ run_linker = kwargs.pop('run_linker', False)
406
+ if run_linker:
407
+ raise CompileError("src2obj called with run_linker=True")
408
+
409
+ runner = Runner([srcpath], objpath, include_dirs=include_dirs,
410
+ run_linker=run_linker, cwd=cwd, flags=flags, **kwargs)
411
+ runner.run()
412
+ return objpath
413
+
414
+
415
+ def pyx2obj(pyxpath, objpath=None, destdir=None, cwd=None,
416
+ include_dirs=None, cy_kwargs=None, cplus=None, **kwargs):
417
+ """
418
+ Convenience function
419
+
420
+ If cwd is specified, pyxpath and dst are taken to be relative
421
+ If only_update is set to `True` the modification time is checked
422
+ and compilation is only run if the source is newer than the
423
+ destination
424
+
425
+ Parameters
426
+ ==========
427
+
428
+ pyxpath: str
429
+ Path to Cython source file.
430
+ objpath: str (optional)
431
+ Path to object file to generate.
432
+ destdir: str (optional)
433
+ Directory to put generated C file. When ``None``: directory of ``objpath``.
434
+ cwd: str (optional)
435
+ Working directory and root of relative paths.
436
+ include_dirs: iterable of path strings (optional)
437
+ Passed onto src2obj and via cy_kwargs['include_path']
438
+ to simple_cythonize.
439
+ cy_kwargs: dict (optional)
440
+ Keyword arguments passed onto `simple_cythonize`
441
+ cplus: bool (optional)
442
+ Indicate whether C++ is used. default: auto-detect using ``.util.pyx_is_cplus``.
443
+ compile_kwargs: dict
444
+ keyword arguments passed onto src2obj
445
+
446
+ Returns
447
+ =======
448
+
449
+ Absolute path of generated object file.
450
+
451
+ """
452
+ assert pyxpath.endswith('.pyx')
453
+ cwd = cwd or '.'
454
+ objpath = objpath or '.'
455
+ destdir = destdir or os.path.dirname(objpath)
456
+
457
+ abs_objpath = get_abspath(objpath, cwd=cwd)
458
+
459
+ if os.path.isdir(abs_objpath):
460
+ pyx_fname = os.path.basename(pyxpath)
461
+ name, ext = os.path.splitext(pyx_fname)
462
+ objpath = os.path.join(objpath, name + objext)
463
+
464
+ cy_kwargs = cy_kwargs or {}
465
+ cy_kwargs['output_dir'] = cwd
466
+ if cplus is None:
467
+ cplus = pyx_is_cplus(pyxpath)
468
+ cy_kwargs['cplus'] = cplus
469
+
470
+ interm_c_file = simple_cythonize(pyxpath, destdir=destdir, cwd=cwd, **cy_kwargs)
471
+
472
+ include_dirs = include_dirs or []
473
+ flags = kwargs.pop('flags', [])
474
+ needed_flags = ('-fwrapv', '-pthread', '-fPIC')
475
+ for flag in needed_flags:
476
+ if flag not in flags:
477
+ flags.append(flag)
478
+
479
+ options = kwargs.pop('options', [])
480
+
481
+ if kwargs.pop('strict_aliasing', False):
482
+ raise CompileError("Cython requires strict aliasing to be disabled.")
483
+
484
+ # Let's be explicit about standard
485
+ if cplus:
486
+ std = kwargs.pop('std', 'c++98')
487
+ else:
488
+ std = kwargs.pop('std', 'c99')
489
+
490
+ return src2obj(interm_c_file, objpath=objpath, cwd=cwd,
491
+ include_dirs=include_dirs, flags=flags, std=std,
492
+ options=options, inc_py=True, strict_aliasing=False,
493
+ **kwargs)
494
+
495
+
496
+ def _any_X(srcs, cls):
497
+ for src in srcs:
498
+ name, ext = os.path.splitext(src)
499
+ key = ext.lower()
500
+ if key in extension_mapping:
501
+ if extension_mapping[key][0] == cls:
502
+ return True
503
+ return False
504
+
505
+
506
+ def any_fortran_src(srcs):
507
+ return _any_X(srcs, FortranCompilerRunner)
508
+
509
+
510
+ def any_cplus_src(srcs):
511
+ return _any_X(srcs, CppCompilerRunner)
512
+
513
+
514
+ def compile_link_import_py_ext(sources, extname=None, build_dir='.', compile_kwargs=None,
515
+ link_kwargs=None, extra_objs=None):
516
+ """ Compiles sources to a shared object (Python extension) and imports it
517
+
518
+ Sources in ``sources`` which is imported. If shared object is newer than the sources, they
519
+ are not recompiled but instead it is imported.
520
+
521
+ Parameters
522
+ ==========
523
+
524
+ sources : list of strings
525
+ List of paths to sources.
526
+ extname : string
527
+ Name of extension (default: ``None``).
528
+ If ``None``: taken from the last file in ``sources`` without extension.
529
+ build_dir: str
530
+ Path to directory in which objects files etc. are generated.
531
+ compile_kwargs: dict
532
+ keyword arguments passed to ``compile_sources``
533
+ link_kwargs: dict
534
+ keyword arguments passed to ``link_py_so``
535
+ extra_objs: list
536
+ List of paths to (prebuilt) object files / static libraries to link against.
537
+
538
+ Returns
539
+ =======
540
+
541
+ The imported module from of the Python extension.
542
+ """
543
+ if extname is None:
544
+ extname = os.path.splitext(os.path.basename(sources[-1]))[0]
545
+
546
+ compile_kwargs = compile_kwargs or {}
547
+ link_kwargs = link_kwargs or {}
548
+
549
+ try:
550
+ mod = import_module_from_file(os.path.join(build_dir, extname), sources)
551
+ except ImportError:
552
+ objs = compile_sources(list(map(get_abspath, sources)), destdir=build_dir,
553
+ cwd=build_dir, **compile_kwargs)
554
+ so = link_py_so(objs, cwd=build_dir, fort=any_fortran_src(sources),
555
+ cplus=any_cplus_src(sources), extra_objs=extra_objs, **link_kwargs)
556
+ mod = import_module_from_file(so)
557
+ return mod
558
+
559
+
560
+ def _write_sources_to_build_dir(sources, build_dir):
561
+ build_dir = build_dir or tempfile.mkdtemp()
562
+ if not os.path.isdir(build_dir):
563
+ raise OSError("Non-existent directory: ", build_dir)
564
+
565
+ source_files = []
566
+ for name, src in sources:
567
+ dest = os.path.join(build_dir, name)
568
+ differs = True
569
+ sha256_in_mem = sha256_of_string(src.encode('utf-8')).hexdigest()
570
+ if os.path.exists(dest):
571
+ if os.path.exists(dest + '.sha256'):
572
+ sha256_on_disk = Path(dest + '.sha256').read_text()
573
+ else:
574
+ sha256_on_disk = sha256_of_file(dest).hexdigest()
575
+
576
+ differs = sha256_on_disk != sha256_in_mem
577
+ if differs:
578
+ with open(dest, 'wt') as fh:
579
+ fh.write(src)
580
+ with open(dest + '.sha256', 'wt') as fh:
581
+ fh.write(sha256_in_mem)
582
+ source_files.append(dest)
583
+ return source_files, build_dir
584
+
585
+
586
+ def compile_link_import_strings(sources, build_dir=None, **kwargs):
587
+ """ Compiles, links and imports extension module from source.
588
+
589
+ Parameters
590
+ ==========
591
+
592
+ sources : iterable of name/source pair tuples
593
+ build_dir : string (default: None)
594
+ Path. ``None`` implies use a temporary directory.
595
+ **kwargs:
596
+ Keyword arguments passed onto `compile_link_import_py_ext`.
597
+
598
+ Returns
599
+ =======
600
+
601
+ mod : module
602
+ The compiled and imported extension module.
603
+ info : dict
604
+ Containing ``build_dir`` as 'build_dir'.
605
+
606
+ """
607
+ source_files, build_dir = _write_sources_to_build_dir(sources, build_dir)
608
+ mod = compile_link_import_py_ext(source_files, build_dir=build_dir, **kwargs)
609
+ info = {"build_dir": build_dir}
610
+ return mod, info
611
+
612
+
613
+ def compile_run_strings(sources, build_dir=None, clean=False, compile_kwargs=None, link_kwargs=None):
614
+ """ Compiles, links and runs a program built from sources.
615
+
616
+ Parameters
617
+ ==========
618
+
619
+ sources : iterable of name/source pair tuples
620
+ build_dir : string (default: None)
621
+ Path. ``None`` implies use a temporary directory.
622
+ clean : bool
623
+ Whether to remove build_dir after use. This will only have an
624
+ effect if ``build_dir`` is ``None`` (which creates a temporary directory).
625
+ Passing ``clean == True`` and ``build_dir != None`` raises a ``ValueError``.
626
+ This will also set ``build_dir`` in returned info dictionary to ``None``.
627
+ compile_kwargs: dict
628
+ Keyword arguments passed onto ``compile_sources``
629
+ link_kwargs: dict
630
+ Keyword arguments passed onto ``link``
631
+
632
+ Returns
633
+ =======
634
+
635
+ (stdout, stderr): pair of strings
636
+ info: dict
637
+ Containing exit status as 'exit_status' and ``build_dir`` as 'build_dir'
638
+
639
+ """
640
+ if clean and build_dir is not None:
641
+ raise ValueError("Automatic removal of build_dir is only available for temporary directory.")
642
+ try:
643
+ source_files, build_dir = _write_sources_to_build_dir(sources, build_dir)
644
+ objs = compile_sources(list(map(get_abspath, source_files)), destdir=build_dir,
645
+ cwd=build_dir, **(compile_kwargs or {}))
646
+ prog = link(objs, cwd=build_dir,
647
+ fort=any_fortran_src(source_files),
648
+ cplus=any_cplus_src(source_files), **(link_kwargs or {}))
649
+ p = subprocess.Popen([prog], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
650
+ exit_status = p.wait()
651
+ stdout, stderr = [txt.decode('utf-8') for txt in p.communicate()]
652
+ finally:
653
+ if clean and os.path.isdir(build_dir):
654
+ shutil.rmtree(build_dir)
655
+ build_dir = None
656
+ info = {"exit_status": exit_status, "build_dir": build_dir}
657
+ return (stdout, stderr), info
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/runners.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Callable, Optional
3
+
4
+ from collections import OrderedDict
5
+ import os
6
+ import re
7
+ import subprocess
8
+ import warnings
9
+
10
+ from .util import (
11
+ find_binary_of_command, unique_list, CompileError
12
+ )
13
+
14
+
15
+ class CompilerRunner:
16
+ """ CompilerRunner base class.
17
+
18
+ Parameters
19
+ ==========
20
+
21
+ sources : list of str
22
+ Paths to sources.
23
+ out : str
24
+ flags : iterable of str
25
+ Compiler flags.
26
+ run_linker : bool
27
+ compiler_name_exe : (str, str) tuple
28
+ Tuple of compiler name & command to call.
29
+ cwd : str
30
+ Path of root of relative paths.
31
+ include_dirs : list of str
32
+ Include directories.
33
+ libraries : list of str
34
+ Libraries to link against.
35
+ library_dirs : list of str
36
+ Paths to search for shared libraries.
37
+ std : str
38
+ Standard string, e.g. ``'c++11'``, ``'c99'``, ``'f2003'``.
39
+ define: iterable of strings
40
+ macros to define
41
+ undef : iterable of strings
42
+ macros to undefine
43
+ preferred_vendor : string
44
+ name of preferred vendor e.g. 'gnu' or 'intel'
45
+
46
+ Methods
47
+ =======
48
+
49
+ run():
50
+ Invoke compilation as a subprocess.
51
+
52
+ """
53
+
54
+ environ_key_compiler: str # e.g. 'CC', 'CXX', ...
55
+ environ_key_flags: str # e.g. 'CFLAGS', 'CXXFLAGS', ...
56
+ environ_key_ldflags: str = "LDFLAGS" # typically 'LDFLAGS'
57
+
58
+ # Subclass to vendor/binary dict
59
+ compiler_dict: dict[str, str]
60
+
61
+ # Standards should be a tuple of supported standards
62
+ # (first one will be the default)
63
+ standards: tuple[None | str, ...]
64
+
65
+ # Subclass to dict of binary/formater-callback
66
+ std_formater: dict[str, Callable[[Optional[str]], str]]
67
+
68
+ # subclass to be e.g. {'gcc': 'gnu', ...}
69
+ compiler_name_vendor_mapping: dict[str, str]
70
+
71
+ def __init__(self, sources, out, flags=None, run_linker=True, compiler=None, cwd='.',
72
+ include_dirs=None, libraries=None, library_dirs=None, std=None, define=None,
73
+ undef=None, strict_aliasing=None, preferred_vendor=None, linkline=None, **kwargs):
74
+ if isinstance(sources, str):
75
+ raise ValueError("Expected argument sources to be a list of strings.")
76
+ self.sources = list(sources)
77
+ self.out = out
78
+ self.flags = flags or []
79
+ if os.environ.get(self.environ_key_flags):
80
+ self.flags += os.environ[self.environ_key_flags].split()
81
+ self.cwd = cwd
82
+ if compiler:
83
+ self.compiler_name, self.compiler_binary = compiler
84
+ elif os.environ.get(self.environ_key_compiler):
85
+ self.compiler_binary = os.environ[self.environ_key_compiler]
86
+ for k, v in self.compiler_dict.items():
87
+ if k in self.compiler_binary:
88
+ self.compiler_vendor = k
89
+ self.compiler_name = v
90
+ break
91
+ else:
92
+ self.compiler_vendor, self.compiler_name = list(self.compiler_dict.items())[0]
93
+ warnings.warn("failed to determine what kind of compiler %s is, assuming %s" %
94
+ (self.compiler_binary, self.compiler_name))
95
+ else:
96
+ # Find a compiler
97
+ if preferred_vendor is None:
98
+ preferred_vendor = os.environ.get('SYMPY_COMPILER_VENDOR', None)
99
+ self.compiler_name, self.compiler_binary, self.compiler_vendor = self.find_compiler(preferred_vendor)
100
+ if self.compiler_binary is None:
101
+ raise ValueError("No compiler found (searched: {})".format(', '.join(self.compiler_dict.values())))
102
+ self.define = define or []
103
+ self.undef = undef or []
104
+ self.include_dirs = include_dirs or []
105
+ self.libraries = libraries or []
106
+ self.library_dirs = library_dirs or []
107
+ self.std = std or self.standards[0]
108
+ self.run_linker = run_linker
109
+ if self.run_linker:
110
+ # both gnu and intel compilers use '-c' for disabling linker
111
+ self.flags = list(filter(lambda x: x != '-c', self.flags))
112
+ else:
113
+ if '-c' not in self.flags:
114
+ self.flags.append('-c')
115
+
116
+ if self.std:
117
+ self.flags.append(self.std_formater[
118
+ self.compiler_name](self.std))
119
+
120
+ self.linkline = (linkline or []) + [lf for lf in map(
121
+ str.strip, os.environ.get(self.environ_key_ldflags, "").split()
122
+ ) if lf != ""]
123
+
124
+ if strict_aliasing is not None:
125
+ nsa_re = re.compile("no-strict-aliasing$")
126
+ sa_re = re.compile("strict-aliasing$")
127
+ if strict_aliasing is True:
128
+ if any(map(nsa_re.match, flags)):
129
+ raise CompileError("Strict aliasing cannot be both enforced and disabled")
130
+ elif any(map(sa_re.match, flags)):
131
+ pass # already enforced
132
+ else:
133
+ flags.append('-fstrict-aliasing')
134
+ elif strict_aliasing is False:
135
+ if any(map(nsa_re.match, flags)):
136
+ pass # already disabled
137
+ else:
138
+ if any(map(sa_re.match, flags)):
139
+ raise CompileError("Strict aliasing cannot be both enforced and disabled")
140
+ else:
141
+ flags.append('-fno-strict-aliasing')
142
+ else:
143
+ msg = "Expected argument strict_aliasing to be True/False, got {}"
144
+ raise ValueError(msg.format(strict_aliasing))
145
+
146
+ @classmethod
147
+ def find_compiler(cls, preferred_vendor=None):
148
+ """ Identify a suitable C/fortran/other compiler. """
149
+ candidates = list(cls.compiler_dict.keys())
150
+ if preferred_vendor:
151
+ if preferred_vendor in candidates:
152
+ candidates = [preferred_vendor]+candidates
153
+ else:
154
+ raise ValueError("Unknown vendor {}".format(preferred_vendor))
155
+ name, path = find_binary_of_command([cls.compiler_dict[x] for x in candidates])
156
+ return name, path, cls.compiler_name_vendor_mapping[name]
157
+
158
+ def cmd(self):
159
+ """ List of arguments (str) to be passed to e.g. ``subprocess.Popen``. """
160
+ cmd = (
161
+ [self.compiler_binary] +
162
+ self.flags +
163
+ ['-U'+x for x in self.undef] +
164
+ ['-D'+x for x in self.define] +
165
+ ['-I'+x for x in self.include_dirs] +
166
+ self.sources
167
+ )
168
+ if self.run_linker:
169
+ cmd += (['-L'+x for x in self.library_dirs] +
170
+ ['-l'+x for x in self.libraries] +
171
+ self.linkline)
172
+ counted = []
173
+ for envvar in re.findall(r'\$\{(\w+)\}', ' '.join(cmd)):
174
+ if os.getenv(envvar) is None:
175
+ if envvar not in counted:
176
+ counted.append(envvar)
177
+ msg = "Environment variable '{}' undefined.".format(envvar)
178
+ raise CompileError(msg)
179
+ return cmd
180
+
181
+ def run(self):
182
+ self.flags = unique_list(self.flags)
183
+
184
+ # Append output flag and name to tail of flags
185
+ self.flags.extend(['-o', self.out])
186
+ env = os.environ.copy()
187
+ env['PWD'] = self.cwd
188
+
189
+ # NOTE: intel compilers seems to need shell=True
190
+ p = subprocess.Popen(' '.join(self.cmd()),
191
+ shell=True,
192
+ cwd=self.cwd,
193
+ stdin=subprocess.PIPE,
194
+ stdout=subprocess.PIPE,
195
+ stderr=subprocess.STDOUT,
196
+ env=env)
197
+ comm = p.communicate()
198
+ try:
199
+ self.cmd_outerr = comm[0].decode('utf-8')
200
+ except UnicodeDecodeError:
201
+ self.cmd_outerr = comm[0].decode('iso-8859-1') # win32
202
+ self.cmd_returncode = p.returncode
203
+
204
+ # Error handling
205
+ if self.cmd_returncode != 0:
206
+ msg = "Error executing '{}' in {} (exited status {}):\n {}\n".format(
207
+ ' '.join(self.cmd()), self.cwd, str(self.cmd_returncode), self.cmd_outerr
208
+ )
209
+ raise CompileError(msg)
210
+
211
+ return self.cmd_outerr, self.cmd_returncode
212
+
213
+
214
+ class CCompilerRunner(CompilerRunner):
215
+
216
+ environ_key_compiler = 'CC'
217
+ environ_key_flags = 'CFLAGS'
218
+
219
+ compiler_dict = OrderedDict([
220
+ ('gnu', 'gcc'),
221
+ ('intel', 'icc'),
222
+ ('llvm', 'clang'),
223
+ ])
224
+
225
+ standards = ('c89', 'c90', 'c99', 'c11') # First is default
226
+
227
+ std_formater = {
228
+ 'gcc': '-std={}'.format,
229
+ 'icc': '-std={}'.format,
230
+ 'clang': '-std={}'.format,
231
+ }
232
+
233
+ compiler_name_vendor_mapping = {
234
+ 'gcc': 'gnu',
235
+ 'icc': 'intel',
236
+ 'clang': 'llvm'
237
+ }
238
+
239
+
240
+ def _mk_flag_filter(cmplr_name): # helper for class initialization
241
+ not_welcome = {'g++': ("Wimplicit-interface",)} # "Wstrict-prototypes",)}
242
+ if cmplr_name in not_welcome:
243
+ def fltr(x):
244
+ for nw in not_welcome[cmplr_name]:
245
+ if nw in x:
246
+ return False
247
+ return True
248
+ else:
249
+ def fltr(x):
250
+ return True
251
+ return fltr
252
+
253
+
254
+ class CppCompilerRunner(CompilerRunner):
255
+
256
+ environ_key_compiler = 'CXX'
257
+ environ_key_flags = 'CXXFLAGS'
258
+
259
+ compiler_dict = OrderedDict([
260
+ ('gnu', 'g++'),
261
+ ('intel', 'icpc'),
262
+ ('llvm', 'clang++'),
263
+ ])
264
+
265
+ # First is the default, c++0x == c++11
266
+ standards = ('c++98', 'c++0x')
267
+
268
+ std_formater = {
269
+ 'g++': '-std={}'.format,
270
+ 'icpc': '-std={}'.format,
271
+ 'clang++': '-std={}'.format,
272
+ }
273
+
274
+ compiler_name_vendor_mapping = {
275
+ 'g++': 'gnu',
276
+ 'icpc': 'intel',
277
+ 'clang++': 'llvm'
278
+ }
279
+
280
+
281
+ class FortranCompilerRunner(CompilerRunner):
282
+
283
+ environ_key_compiler = 'FC'
284
+ environ_key_flags = 'FFLAGS'
285
+
286
+ standards = (None, 'f77', 'f95', 'f2003', 'f2008')
287
+
288
+ std_formater = {
289
+ 'gfortran': lambda x: '-std=gnu' if x is None else '-std=legacy' if x == 'f77' else '-std={}'.format(x),
290
+ 'ifort': lambda x: '-stand f08' if x is None else '-stand f{}'.format(x[-2:]), # f2008 => f08
291
+ }
292
+
293
+ compiler_dict = OrderedDict([
294
+ ('gnu', 'gfortran'),
295
+ ('intel', 'ifort'),
296
+ ])
297
+
298
+ compiler_name_vendor_mapping = {
299
+ 'gfortran': 'gnu',
300
+ 'ifort': 'intel',
301
+ }
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/tests/test_compilation.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os
3
+ import subprocess
4
+ import tempfile
5
+ from sympy.external import import_module
6
+ from sympy.testing.pytest import skip, skip_under_pyodide
7
+
8
+ from sympy.utilities._compilation.compilation import compile_link_import_py_ext, compile_link_import_strings, compile_sources, get_abspath
9
+
10
+ numpy = import_module('numpy')
11
+ cython = import_module('cython')
12
+
13
+ _sources1 = [
14
+ ('sigmoid.c', r"""
15
+ #include <math.h>
16
+
17
+ void sigmoid(int n, const double * const restrict in,
18
+ double * const restrict out, double lim){
19
+ for (int i=0; i<n; ++i){
20
+ const double x = in[i];
21
+ out[i] = x*pow(pow(x/lim, 8)+1, -1./8.);
22
+ }
23
+ }
24
+ """),
25
+ ('_sigmoid.pyx', r"""
26
+ import numpy as np
27
+ cimport numpy as cnp
28
+
29
+ cdef extern void c_sigmoid "sigmoid" (int, const double * const,
30
+ double * const, double)
31
+
32
+ def sigmoid(double [:] inp, double lim=350.0):
33
+ cdef cnp.ndarray[cnp.float64_t, ndim=1] out = np.empty(
34
+ inp.size, dtype=np.float64)
35
+ c_sigmoid(inp.size, &inp[0], &out[0], lim)
36
+ return out
37
+ """)
38
+ ]
39
+
40
+
41
+ def npy(data, lim=350.0):
42
+ return data/((data/lim)**8+1)**(1/8.)
43
+
44
+
45
+ def test_compile_link_import_strings():
46
+ if not numpy:
47
+ skip("numpy not installed.")
48
+ if not cython:
49
+ skip("cython not installed.")
50
+
51
+ from sympy.utilities._compilation import has_c
52
+ if not has_c():
53
+ skip("No C compiler found.")
54
+
55
+ compile_kw = {"std": 'c99', "include_dirs": [numpy.get_include()]}
56
+ info = None
57
+ try:
58
+ mod, info = compile_link_import_strings(_sources1, compile_kwargs=compile_kw)
59
+ data = numpy.random.random(1024*1024*8) # 64 MB of RAM needed..
60
+ res_mod = mod.sigmoid(data)
61
+ res_npy = npy(data)
62
+ assert numpy.allclose(res_mod, res_npy)
63
+ finally:
64
+ if info and info['build_dir']:
65
+ shutil.rmtree(info['build_dir'])
66
+
67
+
68
+ @skip_under_pyodide("Emscripten does not support subprocesses")
69
+ def test_compile_sources():
70
+ tmpdir = tempfile.mkdtemp()
71
+
72
+ from sympy.utilities._compilation import has_c
73
+ if not has_c():
74
+ skip("No C compiler found.")
75
+
76
+ build_dir = str(tmpdir)
77
+ _handle, file_path = tempfile.mkstemp('.c', dir=build_dir)
78
+ with open(file_path, 'wt') as ofh:
79
+ ofh.write("""
80
+ int foo(int bar) {
81
+ return 2*bar;
82
+ }
83
+ """)
84
+ obj, = compile_sources([file_path], cwd=build_dir)
85
+ obj_path = get_abspath(obj, cwd=build_dir)
86
+ assert os.path.exists(obj_path)
87
+ try:
88
+ _ = subprocess.check_output(["nm", "--help"])
89
+ except subprocess.CalledProcessError:
90
+ pass # we cannot test contents of object file
91
+ else:
92
+ nm_out = subprocess.check_output(["nm", obj_path])
93
+ assert 'foo' in nm_out.decode('utf-8')
94
+
95
+ if not cython:
96
+ return # the final (optional) part of the test below requires Cython.
97
+
98
+ _handle, pyx_path = tempfile.mkstemp('.pyx', dir=build_dir)
99
+ with open(pyx_path, 'wt') as ofh:
100
+ ofh.write(("cdef extern int foo(int)\n"
101
+ "def _foo(arg):\n"
102
+ " return foo(arg)"))
103
+ mod = compile_link_import_py_ext([pyx_path], extra_objs=[obj_path], build_dir=build_dir)
104
+ assert mod._foo(21) == 42
.venv/lib/python3.13/site-packages/sympy/utilities/_compilation/util.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from hashlib import sha256
3
+ import os
4
+ import shutil
5
+ import sys
6
+ import fnmatch
7
+
8
+ from sympy.testing.pytest import XFAIL
9
+
10
+
11
+ def may_xfail(func):
12
+ if sys.platform.lower() == 'darwin' or os.name == 'nt':
13
+ # sympy.utilities._compilation needs more testing on Windows and macOS
14
+ # once those two platforms are reliably supported this xfail decorator
15
+ # may be removed.
16
+ return XFAIL(func)
17
+ else:
18
+ return func
19
+
20
+
21
+ class CompilerNotFoundError(FileNotFoundError):
22
+ pass
23
+
24
+
25
+ class CompileError (Exception):
26
+ """Failure to compile one or more C/C++ source files."""
27
+
28
+
29
+ def get_abspath(path, cwd='.'):
30
+ """ Returns the absolute path.
31
+
32
+ Parameters
33
+ ==========
34
+
35
+ path : str
36
+ (relative) path.
37
+ cwd : str
38
+ Path to root of relative path.
39
+ """
40
+ if os.path.isabs(path):
41
+ return path
42
+ else:
43
+ if not os.path.isabs(cwd):
44
+ cwd = os.path.abspath(cwd)
45
+ return os.path.abspath(
46
+ os.path.join(cwd, path)
47
+ )
48
+
49
+
50
+ def make_dirs(path):
51
+ """ Create directories (equivalent of ``mkdir -p``). """
52
+ if path[-1] == '/':
53
+ parent = os.path.dirname(path[:-1])
54
+ else:
55
+ parent = os.path.dirname(path)
56
+
57
+ if len(parent) > 0:
58
+ if not os.path.exists(parent):
59
+ make_dirs(parent)
60
+
61
+ if not os.path.exists(path):
62
+ os.mkdir(path, 0o777)
63
+ else:
64
+ assert os.path.isdir(path)
65
+
66
+ def missing_or_other_newer(path, other_path, cwd=None):
67
+ """
68
+ Investigate if path is non-existent or older than provided reference
69
+ path.
70
+
71
+ Parameters
72
+ ==========
73
+ path: string
74
+ path to path which might be missing or too old
75
+ other_path: string
76
+ reference path
77
+ cwd: string
78
+ working directory (root of relative paths)
79
+
80
+ Returns
81
+ =======
82
+ True if path is older or missing.
83
+ """
84
+ cwd = cwd or '.'
85
+ path = get_abspath(path, cwd=cwd)
86
+ other_path = get_abspath(other_path, cwd=cwd)
87
+ if not os.path.exists(path):
88
+ return True
89
+ if os.path.getmtime(other_path) - 1e-6 >= os.path.getmtime(path):
90
+ # 1e-6 is needed because http://stackoverflow.com/questions/17086426/
91
+ return True
92
+ return False
93
+
94
+ def copy(src, dst, only_update=False, copystat=True, cwd=None,
95
+ dest_is_dir=False, create_dest_dirs=False):
96
+ """ Variation of ``shutil.copy`` with extra options.
97
+
98
+ Parameters
99
+ ==========
100
+
101
+ src : str
102
+ Path to source file.
103
+ dst : str
104
+ Path to destination.
105
+ only_update : bool
106
+ Only copy if source is newer than destination
107
+ (returns None if it was newer), default: ``False``.
108
+ copystat : bool
109
+ See ``shutil.copystat``. default: ``True``.
110
+ cwd : str
111
+ Path to working directory (root of relative paths).
112
+ dest_is_dir : bool
113
+ Ensures that dst is treated as a directory. default: ``False``
114
+ create_dest_dirs : bool
115
+ Creates directories if needed.
116
+
117
+ Returns
118
+ =======
119
+
120
+ Path to the copied file.
121
+
122
+ """
123
+ if cwd: # Handle working directory
124
+ if not os.path.isabs(src):
125
+ src = os.path.join(cwd, src)
126
+ if not os.path.isabs(dst):
127
+ dst = os.path.join(cwd, dst)
128
+
129
+ if not os.path.exists(src): # Make sure source file exists
130
+ raise FileNotFoundError("Source: `{}` does not exist".format(src))
131
+
132
+ # We accept both (re)naming destination file _or_
133
+ # passing a (possible non-existent) destination directory
134
+ if dest_is_dir:
135
+ if not dst[-1] == '/':
136
+ dst = dst+'/'
137
+ else:
138
+ if os.path.exists(dst) and os.path.isdir(dst):
139
+ dest_is_dir = True
140
+
141
+ if dest_is_dir:
142
+ dest_dir = dst
143
+ dest_fname = os.path.basename(src)
144
+ dst = os.path.join(dest_dir, dest_fname)
145
+ else:
146
+ dest_dir = os.path.dirname(dst)
147
+
148
+ if not os.path.exists(dest_dir):
149
+ if create_dest_dirs:
150
+ make_dirs(dest_dir)
151
+ else:
152
+ raise FileNotFoundError("You must create directory first.")
153
+
154
+ if only_update:
155
+ if not missing_or_other_newer(dst, src):
156
+ return
157
+
158
+ if os.path.islink(dst):
159
+ dst = os.path.abspath(os.path.realpath(dst), cwd=cwd)
160
+
161
+ shutil.copy(src, dst)
162
+ if copystat:
163
+ shutil.copystat(src, dst)
164
+
165
+ return dst
166
+
167
+ Glob = namedtuple('Glob', 'pathname')
168
+ ArbitraryDepthGlob = namedtuple('ArbitraryDepthGlob', 'filename')
169
+
170
+ def glob_at_depth(filename_glob, cwd=None):
171
+ if cwd is not None:
172
+ cwd = '.'
173
+ globbed = []
174
+ for root, dirs, filenames in os.walk(cwd):
175
+ for fn in filenames:
176
+ # This is not tested:
177
+ if fnmatch.fnmatch(fn, filename_glob):
178
+ globbed.append(os.path.join(root, fn))
179
+ return globbed
180
+
181
+ def sha256_of_file(path, nblocks=128):
182
+ """ Computes the SHA256 hash of a file.
183
+
184
+ Parameters
185
+ ==========
186
+
187
+ path : string
188
+ Path to file to compute hash of.
189
+ nblocks : int
190
+ Number of blocks to read per iteration.
191
+
192
+ Returns
193
+ =======
194
+
195
+ hashlib sha256 hash object. Use ``.digest()`` or ``.hexdigest()``
196
+ on returned object to get binary or hex encoded string.
197
+ """
198
+ sh = sha256()
199
+ with open(path, 'rb') as f:
200
+ for chunk in iter(lambda: f.read(nblocks*sh.block_size), b''):
201
+ sh.update(chunk)
202
+ return sh
203
+
204
+
205
+ def sha256_of_string(string):
206
+ """ Computes the SHA256 hash of a string. """
207
+ sh = sha256()
208
+ sh.update(string)
209
+ return sh
210
+
211
+
212
+ def pyx_is_cplus(path):
213
+ """
214
+ Inspect a Cython source file (.pyx) and look for comment line like:
215
+
216
+ # distutils: language = c++
217
+
218
+ Returns True if such a file is present in the file, else False.
219
+ """
220
+ with open(path) as fh:
221
+ for line in fh:
222
+ if line.startswith('#') and '=' in line:
223
+ splitted = line.split('=')
224
+ if len(splitted) != 2:
225
+ continue
226
+ lhs, rhs = splitted
227
+ if lhs.strip().split()[-1].lower() == 'language' and \
228
+ rhs.strip().split()[0].lower() == 'c++':
229
+ return True
230
+ return False
231
+
232
+ def import_module_from_file(filename, only_if_newer_than=None):
233
+ """ Imports Python extension (from shared object file)
234
+
235
+ Provide a list of paths in `only_if_newer_than` to check
236
+ timestamps of dependencies. import_ raises an ImportError
237
+ if any is newer.
238
+
239
+ Word of warning: The OS may cache shared objects which makes
240
+ reimporting same path of an shared object file very problematic.
241
+
242
+ It will not detect the new time stamp, nor new checksum, but will
243
+ instead silently use old module. Use unique names for this reason.
244
+
245
+ Parameters
246
+ ==========
247
+
248
+ filename : str
249
+ Path to shared object.
250
+ only_if_newer_than : iterable of strings
251
+ Paths to dependencies of the shared object.
252
+
253
+ Raises
254
+ ======
255
+
256
+ ``ImportError`` if any of the files specified in ``only_if_newer_than`` are newer
257
+ than the file given by filename.
258
+ """
259
+ path, name = os.path.split(filename)
260
+ name, ext = os.path.splitext(name)
261
+ name = name.split('.')[0]
262
+ if sys.version_info[0] == 2:
263
+ from imp import find_module, load_module
264
+ fobj, filename, data = find_module(name, [path])
265
+ if only_if_newer_than:
266
+ for dep in only_if_newer_than:
267
+ if os.path.getmtime(filename) < os.path.getmtime(dep):
268
+ raise ImportError("{} is newer than {}".format(dep, filename))
269
+ mod = load_module(name, fobj, filename, data)
270
+ else:
271
+ import importlib.util
272
+ spec = importlib.util.spec_from_file_location(name, filename)
273
+ if spec is None:
274
+ raise ImportError("Failed to import: '%s'" % filename)
275
+ mod = importlib.util.module_from_spec(spec)
276
+ spec.loader.exec_module(mod)
277
+ return mod
278
+
279
+
280
+ def find_binary_of_command(candidates):
281
+ """ Finds binary first matching name among candidates.
282
+
283
+ Calls ``which`` from shutils for provided candidates and returns
284
+ first hit.
285
+
286
+ Parameters
287
+ ==========
288
+
289
+ candidates : iterable of str
290
+ Names of candidate commands
291
+
292
+ Raises
293
+ ======
294
+
295
+ CompilerNotFoundError if no candidates match.
296
+ """
297
+ from shutil import which
298
+ for c in candidates:
299
+ binary_path = which(c)
300
+ if c and binary_path:
301
+ return c, binary_path
302
+
303
+ raise CompilerNotFoundError('No binary located for candidates: {}'.format(candidates))
304
+
305
+
306
+ def unique_list(l):
307
+ """ Uniquify a list (skip duplicate items). """
308
+ result = []
309
+ for x in l:
310
+ if x not in result:
311
+ result.append(x)
312
+ return result
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmlctop.xsl ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/mmltex.xsl ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen.py ADDED
@@ -0,0 +1,1632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+
3
+ from sympy.core import symbols, Eq, pi, Catalan, Lambda, Dummy
4
+ from sympy.core.relational import Equality
5
+ from sympy.core.symbol import Symbol
6
+ from sympy.functions.special.error_functions import erf
7
+ from sympy.integrals.integrals import Integral
8
+ from sympy.matrices import Matrix, MatrixSymbol
9
+ from sympy.utilities.codegen import (
10
+ codegen, make_routine, CCodeGen, C89CodeGen, C99CodeGen, InputArgument,
11
+ CodeGenError, FCodeGen, CodeGenArgumentListError, OutputArgument,
12
+ InOutArgument)
13
+ from sympy.testing.pytest import raises
14
+ from sympy.utilities.lambdify import implemented_function
15
+
16
+ #FIXME: Fails due to circular import in with core
17
+ # from sympy import codegen
18
+
19
+
20
+ def get_string(dump_fn, routines, prefix="file", header=False, empty=False):
21
+ """Wrapper for dump_fn. dump_fn writes its results to a stream object and
22
+ this wrapper returns the contents of that stream as a string. This
23
+ auxiliary function is used by many tests below.
24
+
25
+ The header and the empty lines are not generated to facilitate the
26
+ testing of the output.
27
+ """
28
+ output = StringIO()
29
+ dump_fn(routines, output, prefix, header, empty)
30
+ source = output.getvalue()
31
+ output.close()
32
+ return source
33
+
34
+
35
+ def test_Routine_argument_order():
36
+ a, x, y, z = symbols('a x y z')
37
+ expr = (x + y)*z
38
+ raises(CodeGenArgumentListError, lambda: make_routine("test", expr,
39
+ argument_sequence=[z, x]))
40
+ raises(CodeGenArgumentListError, lambda: make_routine("test", Eq(a,
41
+ expr), argument_sequence=[z, x, y]))
42
+ r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
43
+ assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
44
+ assert [ type(arg) for arg in r.arguments ] == [
45
+ InputArgument, InputArgument, OutputArgument, InputArgument ]
46
+ r = make_routine('test', Eq(z, expr), argument_sequence=[z, x, y])
47
+ assert [ type(arg) for arg in r.arguments ] == [
48
+ InOutArgument, InputArgument, InputArgument ]
49
+
50
+ from sympy.tensor import IndexedBase, Idx
51
+ A, B = map(IndexedBase, ['A', 'B'])
52
+ m = symbols('m', integer=True)
53
+ i = Idx('i', m)
54
+ r = make_routine('test', Eq(A[i], B[i]), argument_sequence=[B, A, m])
55
+ assert [ arg.name for arg in r.arguments ] == [B.label, A.label, m]
56
+
57
+ expr = Integral(x*y*z, (x, 1, 2), (y, 1, 3))
58
+ r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
59
+ assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
60
+
61
+
62
+ def test_empty_c_code():
63
+ code_gen = C89CodeGen()
64
+ source = get_string(code_gen.dump_c, [])
65
+ assert source == "#include \"file.h\"\n#include <math.h>\n"
66
+
67
+
68
+ def test_empty_c_code_with_comment():
69
+ code_gen = C89CodeGen()
70
+ source = get_string(code_gen.dump_c, [], header=True)
71
+ assert source[:82] == (
72
+ "/******************************************************************************\n *"
73
+ )
74
+ # " Code generated with SymPy 0.7.2-git "
75
+ assert source[158:] == ( "*\n"
76
+ " * *\n"
77
+ " * See http://www.sympy.org/ for more information. *\n"
78
+ " * *\n"
79
+ " * This file is part of 'project' *\n"
80
+ " ******************************************************************************/\n"
81
+ "#include \"file.h\"\n"
82
+ "#include <math.h>\n"
83
+ )
84
+
85
+
86
+ def test_empty_c_header():
87
+ code_gen = C99CodeGen()
88
+ source = get_string(code_gen.dump_h, [])
89
+ assert source == "#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n#endif\n"
90
+
91
+
92
+ def test_simple_c_code():
93
+ x, y, z = symbols('x,y,z')
94
+ expr = (x + y)*z
95
+ routine = make_routine("test", expr)
96
+ code_gen = C89CodeGen()
97
+ source = get_string(code_gen.dump_c, [routine])
98
+ expected = (
99
+ "#include \"file.h\"\n"
100
+ "#include <math.h>\n"
101
+ "double test(double x, double y, double z) {\n"
102
+ " double test_result;\n"
103
+ " test_result = z*(x + y);\n"
104
+ " return test_result;\n"
105
+ "}\n"
106
+ )
107
+ assert source == expected
108
+
109
+
110
+ def test_c_code_reserved_words():
111
+ x, y, z = symbols('if, typedef, while')
112
+ expr = (x + y) * z
113
+ routine = make_routine("test", expr)
114
+ code_gen = C99CodeGen()
115
+ source = get_string(code_gen.dump_c, [routine])
116
+ expected = (
117
+ "#include \"file.h\"\n"
118
+ "#include <math.h>\n"
119
+ "double test(double if_, double typedef_, double while_) {\n"
120
+ " double test_result;\n"
121
+ " test_result = while_*(if_ + typedef_);\n"
122
+ " return test_result;\n"
123
+ "}\n"
124
+ )
125
+ assert source == expected
126
+
127
+
128
+ def test_numbersymbol_c_code():
129
+ routine = make_routine("test", pi**Catalan)
130
+ code_gen = C89CodeGen()
131
+ source = get_string(code_gen.dump_c, [routine])
132
+ expected = (
133
+ "#include \"file.h\"\n"
134
+ "#include <math.h>\n"
135
+ "double test() {\n"
136
+ " double test_result;\n"
137
+ " double const Catalan = %s;\n"
138
+ " test_result = pow(M_PI, Catalan);\n"
139
+ " return test_result;\n"
140
+ "}\n"
141
+ ) % Catalan.evalf(17)
142
+ assert source == expected
143
+
144
+
145
+ def test_c_code_argument_order():
146
+ x, y, z = symbols('x,y,z')
147
+ expr = x + y
148
+ routine = make_routine("test", expr, argument_sequence=[z, x, y])
149
+ code_gen = C89CodeGen()
150
+ source = get_string(code_gen.dump_c, [routine])
151
+ expected = (
152
+ "#include \"file.h\"\n"
153
+ "#include <math.h>\n"
154
+ "double test(double z, double x, double y) {\n"
155
+ " double test_result;\n"
156
+ " test_result = x + y;\n"
157
+ " return test_result;\n"
158
+ "}\n"
159
+ )
160
+ assert source == expected
161
+
162
+
163
+ def test_simple_c_header():
164
+ x, y, z = symbols('x,y,z')
165
+ expr = (x + y)*z
166
+ routine = make_routine("test", expr)
167
+ code_gen = C89CodeGen()
168
+ source = get_string(code_gen.dump_h, [routine])
169
+ expected = (
170
+ "#ifndef PROJECT__FILE__H\n"
171
+ "#define PROJECT__FILE__H\n"
172
+ "double test(double x, double y, double z);\n"
173
+ "#endif\n"
174
+ )
175
+ assert source == expected
176
+
177
+
178
+ def test_simple_c_codegen():
179
+ x, y, z = symbols('x,y,z')
180
+ expr = (x + y)*z
181
+ expected = [
182
+ ("file.c",
183
+ "#include \"file.h\"\n"
184
+ "#include <math.h>\n"
185
+ "double test(double x, double y, double z) {\n"
186
+ " double test_result;\n"
187
+ " test_result = z*(x + y);\n"
188
+ " return test_result;\n"
189
+ "}\n"),
190
+ ("file.h",
191
+ "#ifndef PROJECT__FILE__H\n"
192
+ "#define PROJECT__FILE__H\n"
193
+ "double test(double x, double y, double z);\n"
194
+ "#endif\n")
195
+ ]
196
+ result = codegen(("test", expr), "C", "file", header=False, empty=False)
197
+ assert result == expected
198
+
199
+
200
+ def test_multiple_results_c():
201
+ x, y, z = symbols('x,y,z')
202
+ expr1 = (x + y)*z
203
+ expr2 = (x - y)*z
204
+ routine = make_routine(
205
+ "test",
206
+ [expr1, expr2]
207
+ )
208
+ code_gen = C99CodeGen()
209
+ raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
210
+
211
+
212
+ def test_no_results_c():
213
+ raises(ValueError, lambda: make_routine("test", []))
214
+
215
+
216
+ def test_ansi_math1_codegen():
217
+ # not included: log10
218
+ from sympy.functions.elementary.complexes import Abs
219
+ from sympy.functions.elementary.exponential import log
220
+ from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
221
+ from sympy.functions.elementary.integers import (ceiling, floor)
222
+ from sympy.functions.elementary.miscellaneous import sqrt
223
+ from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
224
+ x = symbols('x')
225
+ name_expr = [
226
+ ("test_fabs", Abs(x)),
227
+ ("test_acos", acos(x)),
228
+ ("test_asin", asin(x)),
229
+ ("test_atan", atan(x)),
230
+ ("test_ceil", ceiling(x)),
231
+ ("test_cos", cos(x)),
232
+ ("test_cosh", cosh(x)),
233
+ ("test_floor", floor(x)),
234
+ ("test_log", log(x)),
235
+ ("test_ln", log(x)),
236
+ ("test_sin", sin(x)),
237
+ ("test_sinh", sinh(x)),
238
+ ("test_sqrt", sqrt(x)),
239
+ ("test_tan", tan(x)),
240
+ ("test_tanh", tanh(x)),
241
+ ]
242
+ result = codegen(name_expr, "C89", "file", header=False, empty=False)
243
+ assert result[0][0] == "file.c"
244
+ assert result[0][1] == (
245
+ '#include "file.h"\n#include <math.h>\n'
246
+ 'double test_fabs(double x) {\n double test_fabs_result;\n test_fabs_result = fabs(x);\n return test_fabs_result;\n}\n'
247
+ 'double test_acos(double x) {\n double test_acos_result;\n test_acos_result = acos(x);\n return test_acos_result;\n}\n'
248
+ 'double test_asin(double x) {\n double test_asin_result;\n test_asin_result = asin(x);\n return test_asin_result;\n}\n'
249
+ 'double test_atan(double x) {\n double test_atan_result;\n test_atan_result = atan(x);\n return test_atan_result;\n}\n'
250
+ 'double test_ceil(double x) {\n double test_ceil_result;\n test_ceil_result = ceil(x);\n return test_ceil_result;\n}\n'
251
+ 'double test_cos(double x) {\n double test_cos_result;\n test_cos_result = cos(x);\n return test_cos_result;\n}\n'
252
+ 'double test_cosh(double x) {\n double test_cosh_result;\n test_cosh_result = cosh(x);\n return test_cosh_result;\n}\n'
253
+ 'double test_floor(double x) {\n double test_floor_result;\n test_floor_result = floor(x);\n return test_floor_result;\n}\n'
254
+ 'double test_log(double x) {\n double test_log_result;\n test_log_result = log(x);\n return test_log_result;\n}\n'
255
+ 'double test_ln(double x) {\n double test_ln_result;\n test_ln_result = log(x);\n return test_ln_result;\n}\n'
256
+ 'double test_sin(double x) {\n double test_sin_result;\n test_sin_result = sin(x);\n return test_sin_result;\n}\n'
257
+ 'double test_sinh(double x) {\n double test_sinh_result;\n test_sinh_result = sinh(x);\n return test_sinh_result;\n}\n'
258
+ 'double test_sqrt(double x) {\n double test_sqrt_result;\n test_sqrt_result = sqrt(x);\n return test_sqrt_result;\n}\n'
259
+ 'double test_tan(double x) {\n double test_tan_result;\n test_tan_result = tan(x);\n return test_tan_result;\n}\n'
260
+ 'double test_tanh(double x) {\n double test_tanh_result;\n test_tanh_result = tanh(x);\n return test_tanh_result;\n}\n'
261
+ )
262
+ assert result[1][0] == "file.h"
263
+ assert result[1][1] == (
264
+ '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
265
+ 'double test_fabs(double x);\ndouble test_acos(double x);\n'
266
+ 'double test_asin(double x);\ndouble test_atan(double x);\n'
267
+ 'double test_ceil(double x);\ndouble test_cos(double x);\n'
268
+ 'double test_cosh(double x);\ndouble test_floor(double x);\n'
269
+ 'double test_log(double x);\ndouble test_ln(double x);\n'
270
+ 'double test_sin(double x);\ndouble test_sinh(double x);\n'
271
+ 'double test_sqrt(double x);\ndouble test_tan(double x);\n'
272
+ 'double test_tanh(double x);\n#endif\n'
273
+ )
274
+
275
+
276
+ def test_ansi_math2_codegen():
277
+ # not included: frexp, ldexp, modf, fmod
278
+ from sympy.functions.elementary.trigonometric import atan2
279
+ x, y = symbols('x,y')
280
+ name_expr = [
281
+ ("test_atan2", atan2(x, y)),
282
+ ("test_pow", x**y),
283
+ ]
284
+ result = codegen(name_expr, "C89", "file", header=False, empty=False)
285
+ assert result[0][0] == "file.c"
286
+ assert result[0][1] == (
287
+ '#include "file.h"\n#include <math.h>\n'
288
+ 'double test_atan2(double x, double y) {\n double test_atan2_result;\n test_atan2_result = atan2(x, y);\n return test_atan2_result;\n}\n'
289
+ 'double test_pow(double x, double y) {\n double test_pow_result;\n test_pow_result = pow(x, y);\n return test_pow_result;\n}\n'
290
+ )
291
+ assert result[1][0] == "file.h"
292
+ assert result[1][1] == (
293
+ '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
294
+ 'double test_atan2(double x, double y);\n'
295
+ 'double test_pow(double x, double y);\n'
296
+ '#endif\n'
297
+ )
298
+
299
+
300
+ def test_complicated_codegen():
301
+ from sympy.functions.elementary.trigonometric import (cos, sin, tan)
302
+ x, y, z = symbols('x,y,z')
303
+ name_expr = [
304
+ ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
305
+ ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
306
+ ]
307
+ result = codegen(name_expr, "C89", "file", header=False, empty=False)
308
+ assert result[0][0] == "file.c"
309
+ assert result[0][1] == (
310
+ '#include "file.h"\n#include <math.h>\n'
311
+ 'double test1(double x, double y, double z) {\n'
312
+ ' double test1_result;\n'
313
+ ' test1_result = '
314
+ 'pow(sin(x), 7) + '
315
+ '7*pow(sin(x), 6)*cos(y) + '
316
+ '7*pow(sin(x), 6)*tan(z) + '
317
+ '21*pow(sin(x), 5)*pow(cos(y), 2) + '
318
+ '42*pow(sin(x), 5)*cos(y)*tan(z) + '
319
+ '21*pow(sin(x), 5)*pow(tan(z), 2) + '
320
+ '35*pow(sin(x), 4)*pow(cos(y), 3) + '
321
+ '105*pow(sin(x), 4)*pow(cos(y), 2)*tan(z) + '
322
+ '105*pow(sin(x), 4)*cos(y)*pow(tan(z), 2) + '
323
+ '35*pow(sin(x), 4)*pow(tan(z), 3) + '
324
+ '35*pow(sin(x), 3)*pow(cos(y), 4) + '
325
+ '140*pow(sin(x), 3)*pow(cos(y), 3)*tan(z) + '
326
+ '210*pow(sin(x), 3)*pow(cos(y), 2)*pow(tan(z), 2) + '
327
+ '140*pow(sin(x), 3)*cos(y)*pow(tan(z), 3) + '
328
+ '35*pow(sin(x), 3)*pow(tan(z), 4) + '
329
+ '21*pow(sin(x), 2)*pow(cos(y), 5) + '
330
+ '105*pow(sin(x), 2)*pow(cos(y), 4)*tan(z) + '
331
+ '210*pow(sin(x), 2)*pow(cos(y), 3)*pow(tan(z), 2) + '
332
+ '210*pow(sin(x), 2)*pow(cos(y), 2)*pow(tan(z), 3) + '
333
+ '105*pow(sin(x), 2)*cos(y)*pow(tan(z), 4) + '
334
+ '21*pow(sin(x), 2)*pow(tan(z), 5) + '
335
+ '7*sin(x)*pow(cos(y), 6) + '
336
+ '42*sin(x)*pow(cos(y), 5)*tan(z) + '
337
+ '105*sin(x)*pow(cos(y), 4)*pow(tan(z), 2) + '
338
+ '140*sin(x)*pow(cos(y), 3)*pow(tan(z), 3) + '
339
+ '105*sin(x)*pow(cos(y), 2)*pow(tan(z), 4) + '
340
+ '42*sin(x)*cos(y)*pow(tan(z), 5) + '
341
+ '7*sin(x)*pow(tan(z), 6) + '
342
+ 'pow(cos(y), 7) + '
343
+ '7*pow(cos(y), 6)*tan(z) + '
344
+ '21*pow(cos(y), 5)*pow(tan(z), 2) + '
345
+ '35*pow(cos(y), 4)*pow(tan(z), 3) + '
346
+ '35*pow(cos(y), 3)*pow(tan(z), 4) + '
347
+ '21*pow(cos(y), 2)*pow(tan(z), 5) + '
348
+ '7*cos(y)*pow(tan(z), 6) + '
349
+ 'pow(tan(z), 7);\n'
350
+ ' return test1_result;\n'
351
+ '}\n'
352
+ 'double test2(double x, double y, double z) {\n'
353
+ ' double test2_result;\n'
354
+ ' test2_result = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n'
355
+ ' return test2_result;\n'
356
+ '}\n'
357
+ )
358
+ assert result[1][0] == "file.h"
359
+ assert result[1][1] == (
360
+ '#ifndef PROJECT__FILE__H\n'
361
+ '#define PROJECT__FILE__H\n'
362
+ 'double test1(double x, double y, double z);\n'
363
+ 'double test2(double x, double y, double z);\n'
364
+ '#endif\n'
365
+ )
366
+
367
+
368
+ def test_loops_c():
369
+ from sympy.tensor import IndexedBase, Idx
370
+ from sympy.core.symbol import symbols
371
+ n, m = symbols('n m', integer=True)
372
+ A = IndexedBase('A')
373
+ x = IndexedBase('x')
374
+ y = IndexedBase('y')
375
+ i = Idx('i', m)
376
+ j = Idx('j', n)
377
+
378
+ (f1, code), (f2, interface) = codegen(
379
+ ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
380
+
381
+ assert f1 == 'file.c'
382
+ expected = (
383
+ '#include "file.h"\n'
384
+ '#include <math.h>\n'
385
+ 'void matrix_vector(double *A, int m, int n, double *x, double *y) {\n'
386
+ ' for (int i=0; i<m; i++){\n'
387
+ ' y[i] = 0;\n'
388
+ ' }\n'
389
+ ' for (int i=0; i<m; i++){\n'
390
+ ' for (int j=0; j<n; j++){\n'
391
+ ' y[i] = %(rhs)s + y[i];\n'
392
+ ' }\n'
393
+ ' }\n'
394
+ '}\n'
395
+ )
396
+
397
+ assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*n + j)} or
398
+ code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*n)} or
399
+ code == expected % {'rhs': 'x[j]*A[%s]' % (i*n + j)} or
400
+ code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*n)})
401
+ assert f2 == 'file.h'
402
+ assert interface == (
403
+ '#ifndef PROJECT__FILE__H\n'
404
+ '#define PROJECT__FILE__H\n'
405
+ 'void matrix_vector(double *A, int m, int n, double *x, double *y);\n'
406
+ '#endif\n'
407
+ )
408
+
409
+
410
+ def test_dummy_loops_c():
411
+ from sympy.tensor import IndexedBase, Idx
412
+ i, m = symbols('i m', integer=True, cls=Dummy)
413
+ x = IndexedBase('x')
414
+ y = IndexedBase('y')
415
+ i = Idx(i, m)
416
+ expected = (
417
+ '#include "file.h"\n'
418
+ '#include <math.h>\n'
419
+ 'void test_dummies(int m_%(mno)i, double *x, double *y) {\n'
420
+ ' for (int i_%(ino)i=0; i_%(ino)i<m_%(mno)i; i_%(ino)i++){\n'
421
+ ' y[i_%(ino)i] = x[i_%(ino)i];\n'
422
+ ' }\n'
423
+ '}\n'
424
+ ) % {'ino': i.label.dummy_index, 'mno': m.dummy_index}
425
+ r = make_routine('test_dummies', Eq(y[i], x[i]))
426
+ c89 = C89CodeGen()
427
+ c99 = C99CodeGen()
428
+ code = get_string(c99.dump_c, [r])
429
+ assert code == expected
430
+ with raises(NotImplementedError):
431
+ get_string(c89.dump_c, [r])
432
+
433
+ def test_partial_loops_c():
434
+ # check that loop boundaries are determined by Idx, and array strides
435
+ # determined by shape of IndexedBase object.
436
+ from sympy.tensor import IndexedBase, Idx
437
+ from sympy.core.symbol import symbols
438
+ n, m, o, p = symbols('n m o p', integer=True)
439
+ A = IndexedBase('A', shape=(m, p))
440
+ x = IndexedBase('x')
441
+ y = IndexedBase('y')
442
+ i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
443
+ j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
444
+
445
+ (f1, code), (f2, interface) = codegen(
446
+ ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
447
+
448
+ assert f1 == 'file.c'
449
+ expected = (
450
+ '#include "file.h"\n'
451
+ '#include <math.h>\n'
452
+ 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y) {\n'
453
+ ' for (int i=o; i<%(upperi)s; i++){\n'
454
+ ' y[i] = 0;\n'
455
+ ' }\n'
456
+ ' for (int i=o; i<%(upperi)s; i++){\n'
457
+ ' for (int j=0; j<n; j++){\n'
458
+ ' y[i] = %(rhs)s + y[i];\n'
459
+ ' }\n'
460
+ ' }\n'
461
+ '}\n'
462
+ ) % {'upperi': m - 4, 'rhs': '%(rhs)s'}
463
+
464
+ assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*p + j)} or
465
+ code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*p)} or
466
+ code == expected % {'rhs': 'x[j]*A[%s]' % (i*p + j)} or
467
+ code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*p)})
468
+ assert f2 == 'file.h'
469
+ assert interface == (
470
+ '#ifndef PROJECT__FILE__H\n'
471
+ '#define PROJECT__FILE__H\n'
472
+ 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y);\n'
473
+ '#endif\n'
474
+ )
475
+
476
+
477
+ def test_output_arg_c():
478
+ from sympy.core.relational import Equality
479
+ from sympy.functions.elementary.trigonometric import (cos, sin)
480
+ x, y, z = symbols("x,y,z")
481
+ r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
482
+ c = C89CodeGen()
483
+ result = c.write([r], "test", header=False, empty=False)
484
+ assert result[0][0] == "test.c"
485
+ expected = (
486
+ '#include "test.h"\n'
487
+ '#include <math.h>\n'
488
+ 'double foo(double x, double *y) {\n'
489
+ ' (*y) = sin(x);\n'
490
+ ' double foo_result;\n'
491
+ ' foo_result = cos(x);\n'
492
+ ' return foo_result;\n'
493
+ '}\n'
494
+ )
495
+ assert result[0][1] == expected
496
+
497
+
498
+ def test_output_arg_c_reserved_words():
499
+ from sympy.core.relational import Equality
500
+ from sympy.functions.elementary.trigonometric import (cos, sin)
501
+ x, y, z = symbols("if, while, z")
502
+ r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
503
+ c = C89CodeGen()
504
+ result = c.write([r], "test", header=False, empty=False)
505
+ assert result[0][0] == "test.c"
506
+ expected = (
507
+ '#include "test.h"\n'
508
+ '#include <math.h>\n'
509
+ 'double foo(double if_, double *while_) {\n'
510
+ ' (*while_) = sin(if_);\n'
511
+ ' double foo_result;\n'
512
+ ' foo_result = cos(if_);\n'
513
+ ' return foo_result;\n'
514
+ '}\n'
515
+ )
516
+ assert result[0][1] == expected
517
+
518
+
519
+ def test_multidim_c_argument_cse():
520
+ A_sym = MatrixSymbol('A', 3, 3)
521
+ b_sym = MatrixSymbol('b', 3, 1)
522
+ A = Matrix(A_sym)
523
+ b = Matrix(b_sym)
524
+ c = A*b
525
+ cgen = CCodeGen(project="test", cse=True)
526
+ r = cgen.routine("c", c)
527
+ r.arguments[-1].result_var = "out"
528
+ r.arguments[-1]._name = "out"
529
+ code = get_string(cgen.dump_c, [r], prefix="test")
530
+ expected = (
531
+ '#include "test.h"\n'
532
+ "#include <math.h>\n"
533
+ "void c(double *A, double *b, double *out) {\n"
534
+ " out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n"
535
+ " out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n"
536
+ " out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n"
537
+ "}\n"
538
+ )
539
+ assert code == expected
540
+
541
+
542
+ def test_ccode_results_named_ordered():
543
+ x, y, z = symbols('x,y,z')
544
+ B, C = symbols('B,C')
545
+ A = MatrixSymbol('A', 1, 3)
546
+ expr1 = Equality(A, Matrix([[1, 2, x]]))
547
+ expr2 = Equality(C, (x + y)*z)
548
+ expr3 = Equality(B, 2*x)
549
+ name_expr = ("test", [expr1, expr2, expr3])
550
+ expected = (
551
+ '#include "test.h"\n'
552
+ '#include <math.h>\n'
553
+ 'void test(double x, double *C, double z, double y, double *A, double *B) {\n'
554
+ ' (*C) = z*(x + y);\n'
555
+ ' A[0] = 1;\n'
556
+ ' A[1] = 2;\n'
557
+ ' A[2] = x;\n'
558
+ ' (*B) = 2*x;\n'
559
+ '}\n'
560
+ )
561
+
562
+ result = codegen(name_expr, "c", "test", header=False, empty=False,
563
+ argument_sequence=(x, C, z, y, A, B))
564
+ source = result[0][1]
565
+ assert source == expected
566
+
567
+
568
+ def test_ccode_matrixsymbol_slice():
569
+ A = MatrixSymbol('A', 5, 3)
570
+ B = MatrixSymbol('B', 1, 3)
571
+ C = MatrixSymbol('C', 1, 3)
572
+ D = MatrixSymbol('D', 5, 1)
573
+ name_expr = ("test", [Equality(B, A[0, :]),
574
+ Equality(C, A[1, :]),
575
+ Equality(D, A[:, 2])])
576
+ result = codegen(name_expr, "c99", "test", header=False, empty=False)
577
+ source = result[0][1]
578
+ expected = (
579
+ '#include "test.h"\n'
580
+ '#include <math.h>\n'
581
+ 'void test(double *A, double *B, double *C, double *D) {\n'
582
+ ' B[0] = A[0];\n'
583
+ ' B[1] = A[1];\n'
584
+ ' B[2] = A[2];\n'
585
+ ' C[0] = A[3];\n'
586
+ ' C[1] = A[4];\n'
587
+ ' C[2] = A[5];\n'
588
+ ' D[0] = A[2];\n'
589
+ ' D[1] = A[5];\n'
590
+ ' D[2] = A[8];\n'
591
+ ' D[3] = A[11];\n'
592
+ ' D[4] = A[14];\n'
593
+ '}\n'
594
+ )
595
+ assert source == expected
596
+
597
+ def test_ccode_cse():
598
+ a, b, c, d = symbols('a b c d')
599
+ e = MatrixSymbol('e', 3, 1)
600
+ name_expr = ("test", [Equality(e, Matrix([[a*b], [a*b + c*d], [a*b*c*d]]))])
601
+ generator = CCodeGen(cse=True)
602
+ result = codegen(name_expr, code_gen=generator, header=False, empty=False)
603
+ source = result[0][1]
604
+ expected = (
605
+ '#include "test.h"\n'
606
+ '#include <math.h>\n'
607
+ 'void test(double a, double b, double c, double d, double *e) {\n'
608
+ ' const double x0 = a*b;\n'
609
+ ' const double x1 = c*d;\n'
610
+ ' e[0] = x0;\n'
611
+ ' e[1] = x0 + x1;\n'
612
+ ' e[2] = x0*x1;\n'
613
+ '}\n'
614
+ )
615
+ assert source == expected
616
+
617
+ def test_ccode_unused_array_arg():
618
+ x = MatrixSymbol('x', 2, 1)
619
+ # x does not appear in output
620
+ name_expr = ("test", 1.0)
621
+ generator = CCodeGen()
622
+ result = codegen(name_expr, code_gen=generator, header=False, empty=False, argument_sequence=(x,))
623
+ source = result[0][1]
624
+ # note: x should appear as (double *)
625
+ expected = (
626
+ '#include "test.h"\n'
627
+ '#include <math.h>\n'
628
+ 'double test(double *x) {\n'
629
+ ' double test_result;\n'
630
+ ' test_result = 1.0;\n'
631
+ ' return test_result;\n'
632
+ '}\n'
633
+ )
634
+ assert source == expected
635
+
636
+ def test_ccode_unused_array_arg_func():
637
+ # issue 16689
638
+ X = MatrixSymbol('X',3,1)
639
+ Y = MatrixSymbol('Y',3,1)
640
+ z = symbols('z',integer = True)
641
+ name_expr = ('testBug', X[0] + X[1])
642
+ result = codegen(name_expr, language='C', header=False, empty=False, argument_sequence=(X, Y, z))
643
+ source = result[0][1]
644
+ expected = (
645
+ '#include "testBug.h"\n'
646
+ '#include <math.h>\n'
647
+ 'double testBug(double *X, double *Y, int z) {\n'
648
+ ' double testBug_result;\n'
649
+ ' testBug_result = X[0] + X[1];\n'
650
+ ' return testBug_result;\n'
651
+ '}\n'
652
+ )
653
+ assert source == expected
654
+
655
+ def test_empty_f_code():
656
+ code_gen = FCodeGen()
657
+ source = get_string(code_gen.dump_f95, [])
658
+ assert source == ""
659
+
660
+
661
+ def test_empty_f_code_with_header():
662
+ code_gen = FCodeGen()
663
+ source = get_string(code_gen.dump_f95, [], header=True)
664
+ assert source[:82] == (
665
+ "!******************************************************************************\n!*"
666
+ )
667
+ # " Code generated with SymPy 0.7.2-git "
668
+ assert source[158:] == ( "*\n"
669
+ "!* *\n"
670
+ "!* See http://www.sympy.org/ for more information. *\n"
671
+ "!* *\n"
672
+ "!* This file is part of 'project' *\n"
673
+ "!******************************************************************************\n"
674
+ )
675
+
676
+
677
+ def test_empty_f_header():
678
+ code_gen = FCodeGen()
679
+ source = get_string(code_gen.dump_h, [])
680
+ assert source == ""
681
+
682
+
683
+ def test_simple_f_code():
684
+ x, y, z = symbols('x,y,z')
685
+ expr = (x + y)*z
686
+ routine = make_routine("test", expr)
687
+ code_gen = FCodeGen()
688
+ source = get_string(code_gen.dump_f95, [routine])
689
+ expected = (
690
+ "REAL*8 function test(x, y, z)\n"
691
+ "implicit none\n"
692
+ "REAL*8, intent(in) :: x\n"
693
+ "REAL*8, intent(in) :: y\n"
694
+ "REAL*8, intent(in) :: z\n"
695
+ "test = z*(x + y)\n"
696
+ "end function\n"
697
+ )
698
+ assert source == expected
699
+
700
+
701
+ def test_numbersymbol_f_code():
702
+ routine = make_routine("test", pi**Catalan)
703
+ code_gen = FCodeGen()
704
+ source = get_string(code_gen.dump_f95, [routine])
705
+ expected = (
706
+ "REAL*8 function test()\n"
707
+ "implicit none\n"
708
+ "REAL*8, parameter :: Catalan = %sd0\n"
709
+ "REAL*8, parameter :: pi = %sd0\n"
710
+ "test = pi**Catalan\n"
711
+ "end function\n"
712
+ ) % (Catalan.evalf(17), pi.evalf(17))
713
+ assert source == expected
714
+
715
+ def test_erf_f_code():
716
+ x = symbols('x')
717
+ routine = make_routine("test", erf(x) - erf(-2 * x))
718
+ code_gen = FCodeGen()
719
+ source = get_string(code_gen.dump_f95, [routine])
720
+ expected = (
721
+ "REAL*8 function test(x)\n"
722
+ "implicit none\n"
723
+ "REAL*8, intent(in) :: x\n"
724
+ "test = erf(x) + erf(2.0d0*x)\n"
725
+ "end function\n"
726
+ )
727
+ assert source == expected, source
728
+
729
+ def test_f_code_argument_order():
730
+ x, y, z = symbols('x,y,z')
731
+ expr = x + y
732
+ routine = make_routine("test", expr, argument_sequence=[z, x, y])
733
+ code_gen = FCodeGen()
734
+ source = get_string(code_gen.dump_f95, [routine])
735
+ expected = (
736
+ "REAL*8 function test(z, x, y)\n"
737
+ "implicit none\n"
738
+ "REAL*8, intent(in) :: z\n"
739
+ "REAL*8, intent(in) :: x\n"
740
+ "REAL*8, intent(in) :: y\n"
741
+ "test = x + y\n"
742
+ "end function\n"
743
+ )
744
+ assert source == expected
745
+
746
+
747
+ def test_simple_f_header():
748
+ x, y, z = symbols('x,y,z')
749
+ expr = (x + y)*z
750
+ routine = make_routine("test", expr)
751
+ code_gen = FCodeGen()
752
+ source = get_string(code_gen.dump_h, [routine])
753
+ expected = (
754
+ "interface\n"
755
+ "REAL*8 function test(x, y, z)\n"
756
+ "implicit none\n"
757
+ "REAL*8, intent(in) :: x\n"
758
+ "REAL*8, intent(in) :: y\n"
759
+ "REAL*8, intent(in) :: z\n"
760
+ "end function\n"
761
+ "end interface\n"
762
+ )
763
+ assert source == expected
764
+
765
+
766
+ def test_simple_f_codegen():
767
+ x, y, z = symbols('x,y,z')
768
+ expr = (x + y)*z
769
+ result = codegen(
770
+ ("test", expr), "F95", "file", header=False, empty=False)
771
+ expected = [
772
+ ("file.f90",
773
+ "REAL*8 function test(x, y, z)\n"
774
+ "implicit none\n"
775
+ "REAL*8, intent(in) :: x\n"
776
+ "REAL*8, intent(in) :: y\n"
777
+ "REAL*8, intent(in) :: z\n"
778
+ "test = z*(x + y)\n"
779
+ "end function\n"),
780
+ ("file.h",
781
+ "interface\n"
782
+ "REAL*8 function test(x, y, z)\n"
783
+ "implicit none\n"
784
+ "REAL*8, intent(in) :: x\n"
785
+ "REAL*8, intent(in) :: y\n"
786
+ "REAL*8, intent(in) :: z\n"
787
+ "end function\n"
788
+ "end interface\n")
789
+ ]
790
+ assert result == expected
791
+
792
+
793
+ def test_multiple_results_f():
794
+ x, y, z = symbols('x,y,z')
795
+ expr1 = (x + y)*z
796
+ expr2 = (x - y)*z
797
+ routine = make_routine(
798
+ "test",
799
+ [expr1, expr2]
800
+ )
801
+ code_gen = FCodeGen()
802
+ raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
803
+
804
+
805
+ def test_no_results_f():
806
+ raises(ValueError, lambda: make_routine("test", []))
807
+
808
+
809
+ def test_intrinsic_math_codegen():
810
+ # not included: log10
811
+ from sympy.functions.elementary.complexes import Abs
812
+ from sympy.functions.elementary.exponential import log
813
+ from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
814
+ from sympy.functions.elementary.miscellaneous import sqrt
815
+ from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
816
+ x = symbols('x')
817
+ name_expr = [
818
+ ("test_abs", Abs(x)),
819
+ ("test_acos", acos(x)),
820
+ ("test_asin", asin(x)),
821
+ ("test_atan", atan(x)),
822
+ ("test_cos", cos(x)),
823
+ ("test_cosh", cosh(x)),
824
+ ("test_log", log(x)),
825
+ ("test_ln", log(x)),
826
+ ("test_sin", sin(x)),
827
+ ("test_sinh", sinh(x)),
828
+ ("test_sqrt", sqrt(x)),
829
+ ("test_tan", tan(x)),
830
+ ("test_tanh", tanh(x)),
831
+ ]
832
+ result = codegen(name_expr, "F95", "file", header=False, empty=False)
833
+ assert result[0][0] == "file.f90"
834
+ expected = (
835
+ 'REAL*8 function test_abs(x)\n'
836
+ 'implicit none\n'
837
+ 'REAL*8, intent(in) :: x\n'
838
+ 'test_abs = abs(x)\n'
839
+ 'end function\n'
840
+ 'REAL*8 function test_acos(x)\n'
841
+ 'implicit none\n'
842
+ 'REAL*8, intent(in) :: x\n'
843
+ 'test_acos = acos(x)\n'
844
+ 'end function\n'
845
+ 'REAL*8 function test_asin(x)\n'
846
+ 'implicit none\n'
847
+ 'REAL*8, intent(in) :: x\n'
848
+ 'test_asin = asin(x)\n'
849
+ 'end function\n'
850
+ 'REAL*8 function test_atan(x)\n'
851
+ 'implicit none\n'
852
+ 'REAL*8, intent(in) :: x\n'
853
+ 'test_atan = atan(x)\n'
854
+ 'end function\n'
855
+ 'REAL*8 function test_cos(x)\n'
856
+ 'implicit none\n'
857
+ 'REAL*8, intent(in) :: x\n'
858
+ 'test_cos = cos(x)\n'
859
+ 'end function\n'
860
+ 'REAL*8 function test_cosh(x)\n'
861
+ 'implicit none\n'
862
+ 'REAL*8, intent(in) :: x\n'
863
+ 'test_cosh = cosh(x)\n'
864
+ 'end function\n'
865
+ 'REAL*8 function test_log(x)\n'
866
+ 'implicit none\n'
867
+ 'REAL*8, intent(in) :: x\n'
868
+ 'test_log = log(x)\n'
869
+ 'end function\n'
870
+ 'REAL*8 function test_ln(x)\n'
871
+ 'implicit none\n'
872
+ 'REAL*8, intent(in) :: x\n'
873
+ 'test_ln = log(x)\n'
874
+ 'end function\n'
875
+ 'REAL*8 function test_sin(x)\n'
876
+ 'implicit none\n'
877
+ 'REAL*8, intent(in) :: x\n'
878
+ 'test_sin = sin(x)\n'
879
+ 'end function\n'
880
+ 'REAL*8 function test_sinh(x)\n'
881
+ 'implicit none\n'
882
+ 'REAL*8, intent(in) :: x\n'
883
+ 'test_sinh = sinh(x)\n'
884
+ 'end function\n'
885
+ 'REAL*8 function test_sqrt(x)\n'
886
+ 'implicit none\n'
887
+ 'REAL*8, intent(in) :: x\n'
888
+ 'test_sqrt = sqrt(x)\n'
889
+ 'end function\n'
890
+ 'REAL*8 function test_tan(x)\n'
891
+ 'implicit none\n'
892
+ 'REAL*8, intent(in) :: x\n'
893
+ 'test_tan = tan(x)\n'
894
+ 'end function\n'
895
+ 'REAL*8 function test_tanh(x)\n'
896
+ 'implicit none\n'
897
+ 'REAL*8, intent(in) :: x\n'
898
+ 'test_tanh = tanh(x)\n'
899
+ 'end function\n'
900
+ )
901
+ assert result[0][1] == expected
902
+
903
+ assert result[1][0] == "file.h"
904
+ expected = (
905
+ 'interface\n'
906
+ 'REAL*8 function test_abs(x)\n'
907
+ 'implicit none\n'
908
+ 'REAL*8, intent(in) :: x\n'
909
+ 'end function\n'
910
+ 'end interface\n'
911
+ 'interface\n'
912
+ 'REAL*8 function test_acos(x)\n'
913
+ 'implicit none\n'
914
+ 'REAL*8, intent(in) :: x\n'
915
+ 'end function\n'
916
+ 'end interface\n'
917
+ 'interface\n'
918
+ 'REAL*8 function test_asin(x)\n'
919
+ 'implicit none\n'
920
+ 'REAL*8, intent(in) :: x\n'
921
+ 'end function\n'
922
+ 'end interface\n'
923
+ 'interface\n'
924
+ 'REAL*8 function test_atan(x)\n'
925
+ 'implicit none\n'
926
+ 'REAL*8, intent(in) :: x\n'
927
+ 'end function\n'
928
+ 'end interface\n'
929
+ 'interface\n'
930
+ 'REAL*8 function test_cos(x)\n'
931
+ 'implicit none\n'
932
+ 'REAL*8, intent(in) :: x\n'
933
+ 'end function\n'
934
+ 'end interface\n'
935
+ 'interface\n'
936
+ 'REAL*8 function test_cosh(x)\n'
937
+ 'implicit none\n'
938
+ 'REAL*8, intent(in) :: x\n'
939
+ 'end function\n'
940
+ 'end interface\n'
941
+ 'interface\n'
942
+ 'REAL*8 function test_log(x)\n'
943
+ 'implicit none\n'
944
+ 'REAL*8, intent(in) :: x\n'
945
+ 'end function\n'
946
+ 'end interface\n'
947
+ 'interface\n'
948
+ 'REAL*8 function test_ln(x)\n'
949
+ 'implicit none\n'
950
+ 'REAL*8, intent(in) :: x\n'
951
+ 'end function\n'
952
+ 'end interface\n'
953
+ 'interface\n'
954
+ 'REAL*8 function test_sin(x)\n'
955
+ 'implicit none\n'
956
+ 'REAL*8, intent(in) :: x\n'
957
+ 'end function\n'
958
+ 'end interface\n'
959
+ 'interface\n'
960
+ 'REAL*8 function test_sinh(x)\n'
961
+ 'implicit none\n'
962
+ 'REAL*8, intent(in) :: x\n'
963
+ 'end function\n'
964
+ 'end interface\n'
965
+ 'interface\n'
966
+ 'REAL*8 function test_sqrt(x)\n'
967
+ 'implicit none\n'
968
+ 'REAL*8, intent(in) :: x\n'
969
+ 'end function\n'
970
+ 'end interface\n'
971
+ 'interface\n'
972
+ 'REAL*8 function test_tan(x)\n'
973
+ 'implicit none\n'
974
+ 'REAL*8, intent(in) :: x\n'
975
+ 'end function\n'
976
+ 'end interface\n'
977
+ 'interface\n'
978
+ 'REAL*8 function test_tanh(x)\n'
979
+ 'implicit none\n'
980
+ 'REAL*8, intent(in) :: x\n'
981
+ 'end function\n'
982
+ 'end interface\n'
983
+ )
984
+ assert result[1][1] == expected
985
+
986
+
987
+ def test_intrinsic_math2_codegen():
988
+ # not included: frexp, ldexp, modf, fmod
989
+ from sympy.functions.elementary.trigonometric import atan2
990
+ x, y = symbols('x,y')
991
+ name_expr = [
992
+ ("test_atan2", atan2(x, y)),
993
+ ("test_pow", x**y),
994
+ ]
995
+ result = codegen(name_expr, "F95", "file", header=False, empty=False)
996
+ assert result[0][0] == "file.f90"
997
+ expected = (
998
+ 'REAL*8 function test_atan2(x, y)\n'
999
+ 'implicit none\n'
1000
+ 'REAL*8, intent(in) :: x\n'
1001
+ 'REAL*8, intent(in) :: y\n'
1002
+ 'test_atan2 = atan2(x, y)\n'
1003
+ 'end function\n'
1004
+ 'REAL*8 function test_pow(x, y)\n'
1005
+ 'implicit none\n'
1006
+ 'REAL*8, intent(in) :: x\n'
1007
+ 'REAL*8, intent(in) :: y\n'
1008
+ 'test_pow = x**y\n'
1009
+ 'end function\n'
1010
+ )
1011
+ assert result[0][1] == expected
1012
+
1013
+ assert result[1][0] == "file.h"
1014
+ expected = (
1015
+ 'interface\n'
1016
+ 'REAL*8 function test_atan2(x, y)\n'
1017
+ 'implicit none\n'
1018
+ 'REAL*8, intent(in) :: x\n'
1019
+ 'REAL*8, intent(in) :: y\n'
1020
+ 'end function\n'
1021
+ 'end interface\n'
1022
+ 'interface\n'
1023
+ 'REAL*8 function test_pow(x, y)\n'
1024
+ 'implicit none\n'
1025
+ 'REAL*8, intent(in) :: x\n'
1026
+ 'REAL*8, intent(in) :: y\n'
1027
+ 'end function\n'
1028
+ 'end interface\n'
1029
+ )
1030
+ assert result[1][1] == expected
1031
+
1032
+
1033
+ def test_complicated_codegen_f95():
1034
+ from sympy.functions.elementary.trigonometric import (cos, sin, tan)
1035
+ x, y, z = symbols('x,y,z')
1036
+ name_expr = [
1037
+ ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
1038
+ ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
1039
+ ]
1040
+ result = codegen(name_expr, "F95", "file", header=False, empty=False)
1041
+ assert result[0][0] == "file.f90"
1042
+ expected = (
1043
+ 'REAL*8 function test1(x, y, z)\n'
1044
+ 'implicit none\n'
1045
+ 'REAL*8, intent(in) :: x\n'
1046
+ 'REAL*8, intent(in) :: y\n'
1047
+ 'REAL*8, intent(in) :: z\n'
1048
+ 'test1 = sin(x)**7 + 7*sin(x)**6*cos(y) + 7*sin(x)**6*tan(z) + 21*sin(x) &\n'
1049
+ ' **5*cos(y)**2 + 42*sin(x)**5*cos(y)*tan(z) + 21*sin(x)**5*tan(z) &\n'
1050
+ ' **2 + 35*sin(x)**4*cos(y)**3 + 105*sin(x)**4*cos(y)**2*tan(z) + &\n'
1051
+ ' 105*sin(x)**4*cos(y)*tan(z)**2 + 35*sin(x)**4*tan(z)**3 + 35*sin( &\n'
1052
+ ' x)**3*cos(y)**4 + 140*sin(x)**3*cos(y)**3*tan(z) + 210*sin(x)**3* &\n'
1053
+ ' cos(y)**2*tan(z)**2 + 140*sin(x)**3*cos(y)*tan(z)**3 + 35*sin(x) &\n'
1054
+ ' **3*tan(z)**4 + 21*sin(x)**2*cos(y)**5 + 105*sin(x)**2*cos(y)**4* &\n'
1055
+ ' tan(z) + 210*sin(x)**2*cos(y)**3*tan(z)**2 + 210*sin(x)**2*cos(y) &\n'
1056
+ ' **2*tan(z)**3 + 105*sin(x)**2*cos(y)*tan(z)**4 + 21*sin(x)**2*tan &\n'
1057
+ ' (z)**5 + 7*sin(x)*cos(y)**6 + 42*sin(x)*cos(y)**5*tan(z) + 105* &\n'
1058
+ ' sin(x)*cos(y)**4*tan(z)**2 + 140*sin(x)*cos(y)**3*tan(z)**3 + 105 &\n'
1059
+ ' *sin(x)*cos(y)**2*tan(z)**4 + 42*sin(x)*cos(y)*tan(z)**5 + 7*sin( &\n'
1060
+ ' x)*tan(z)**6 + cos(y)**7 + 7*cos(y)**6*tan(z) + 21*cos(y)**5*tan( &\n'
1061
+ ' z)**2 + 35*cos(y)**4*tan(z)**3 + 35*cos(y)**3*tan(z)**4 + 21*cos( &\n'
1062
+ ' y)**2*tan(z)**5 + 7*cos(y)*tan(z)**6 + tan(z)**7\n'
1063
+ 'end function\n'
1064
+ 'REAL*8 function test2(x, y, z)\n'
1065
+ 'implicit none\n'
1066
+ 'REAL*8, intent(in) :: x\n'
1067
+ 'REAL*8, intent(in) :: y\n'
1068
+ 'REAL*8, intent(in) :: z\n'
1069
+ 'test2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n'
1070
+ 'end function\n'
1071
+ )
1072
+ assert result[0][1] == expected
1073
+ assert result[1][0] == "file.h"
1074
+ expected = (
1075
+ 'interface\n'
1076
+ 'REAL*8 function test1(x, y, z)\n'
1077
+ 'implicit none\n'
1078
+ 'REAL*8, intent(in) :: x\n'
1079
+ 'REAL*8, intent(in) :: y\n'
1080
+ 'REAL*8, intent(in) :: z\n'
1081
+ 'end function\n'
1082
+ 'end interface\n'
1083
+ 'interface\n'
1084
+ 'REAL*8 function test2(x, y, z)\n'
1085
+ 'implicit none\n'
1086
+ 'REAL*8, intent(in) :: x\n'
1087
+ 'REAL*8, intent(in) :: y\n'
1088
+ 'REAL*8, intent(in) :: z\n'
1089
+ 'end function\n'
1090
+ 'end interface\n'
1091
+ )
1092
+ assert result[1][1] == expected
1093
+
1094
+
1095
+ def test_loops():
1096
+ from sympy.tensor import IndexedBase, Idx
1097
+ from sympy.core.symbol import symbols
1098
+
1099
+ n, m = symbols('n,m', integer=True)
1100
+ A, x, y = map(IndexedBase, 'Axy')
1101
+ i = Idx('i', m)
1102
+ j = Idx('j', n)
1103
+
1104
+ (f1, code), (f2, interface) = codegen(
1105
+ ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
1106
+
1107
+ assert f1 == 'file.f90'
1108
+ expected = (
1109
+ 'subroutine matrix_vector(A, m, n, x, y)\n'
1110
+ 'implicit none\n'
1111
+ 'INTEGER*4, intent(in) :: m\n'
1112
+ 'INTEGER*4, intent(in) :: n\n'
1113
+ 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1114
+ 'REAL*8, intent(in), dimension(1:n) :: x\n'
1115
+ 'REAL*8, intent(out), dimension(1:m) :: y\n'
1116
+ 'INTEGER*4 :: i\n'
1117
+ 'INTEGER*4 :: j\n'
1118
+ 'do i = 1, m\n'
1119
+ ' y(i) = 0\n'
1120
+ 'end do\n'
1121
+ 'do i = 1, m\n'
1122
+ ' do j = 1, n\n'
1123
+ ' y(i) = %(rhs)s + y(i)\n'
1124
+ ' end do\n'
1125
+ 'end do\n'
1126
+ 'end subroutine\n'
1127
+ )
1128
+
1129
+ assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
1130
+ code == expected % {'rhs': 'x(j)*A(i, j)'}
1131
+ assert f2 == 'file.h'
1132
+ assert interface == (
1133
+ 'interface\n'
1134
+ 'subroutine matrix_vector(A, m, n, x, y)\n'
1135
+ 'implicit none\n'
1136
+ 'INTEGER*4, intent(in) :: m\n'
1137
+ 'INTEGER*4, intent(in) :: n\n'
1138
+ 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1139
+ 'REAL*8, intent(in), dimension(1:n) :: x\n'
1140
+ 'REAL*8, intent(out), dimension(1:m) :: y\n'
1141
+ 'end subroutine\n'
1142
+ 'end interface\n'
1143
+ )
1144
+
1145
+
1146
+ def test_dummy_loops_f95():
1147
+ from sympy.tensor import IndexedBase, Idx
1148
+ i, m = symbols('i m', integer=True, cls=Dummy)
1149
+ x = IndexedBase('x')
1150
+ y = IndexedBase('y')
1151
+ i = Idx(i, m)
1152
+ expected = (
1153
+ 'subroutine test_dummies(m_%(mcount)i, x, y)\n'
1154
+ 'implicit none\n'
1155
+ 'INTEGER*4, intent(in) :: m_%(mcount)i\n'
1156
+ 'REAL*8, intent(in), dimension(1:m_%(mcount)i) :: x\n'
1157
+ 'REAL*8, intent(out), dimension(1:m_%(mcount)i) :: y\n'
1158
+ 'INTEGER*4 :: i_%(icount)i\n'
1159
+ 'do i_%(icount)i = 1, m_%(mcount)i\n'
1160
+ ' y(i_%(icount)i) = x(i_%(icount)i)\n'
1161
+ 'end do\n'
1162
+ 'end subroutine\n'
1163
+ ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
1164
+ r = make_routine('test_dummies', Eq(y[i], x[i]))
1165
+ c = FCodeGen()
1166
+ code = get_string(c.dump_f95, [r])
1167
+ assert code == expected
1168
+
1169
+
1170
+ def test_loops_InOut():
1171
+ from sympy.tensor import IndexedBase, Idx
1172
+ from sympy.core.symbol import symbols
1173
+
1174
+ i, j, n, m = symbols('i,j,n,m', integer=True)
1175
+ A, x, y = symbols('A,x,y')
1176
+ A = IndexedBase(A)[Idx(i, m), Idx(j, n)]
1177
+ x = IndexedBase(x)[Idx(j, n)]
1178
+ y = IndexedBase(y)[Idx(i, m)]
1179
+
1180
+ (f1, code), (f2, interface) = codegen(
1181
+ ('matrix_vector', Eq(y, y + A*x)), "F95", "file", header=False, empty=False)
1182
+
1183
+ assert f1 == 'file.f90'
1184
+ expected = (
1185
+ 'subroutine matrix_vector(A, m, n, x, y)\n'
1186
+ 'implicit none\n'
1187
+ 'INTEGER*4, intent(in) :: m\n'
1188
+ 'INTEGER*4, intent(in) :: n\n'
1189
+ 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1190
+ 'REAL*8, intent(in), dimension(1:n) :: x\n'
1191
+ 'REAL*8, intent(inout), dimension(1:m) :: y\n'
1192
+ 'INTEGER*4 :: i\n'
1193
+ 'INTEGER*4 :: j\n'
1194
+ 'do i = 1, m\n'
1195
+ ' do j = 1, n\n'
1196
+ ' y(i) = %(rhs)s + y(i)\n'
1197
+ ' end do\n'
1198
+ 'end do\n'
1199
+ 'end subroutine\n'
1200
+ )
1201
+
1202
+ assert (code == expected % {'rhs': 'A(i, j)*x(j)'} or
1203
+ code == expected % {'rhs': 'x(j)*A(i, j)'})
1204
+ assert f2 == 'file.h'
1205
+ assert interface == (
1206
+ 'interface\n'
1207
+ 'subroutine matrix_vector(A, m, n, x, y)\n'
1208
+ 'implicit none\n'
1209
+ 'INTEGER*4, intent(in) :: m\n'
1210
+ 'INTEGER*4, intent(in) :: n\n'
1211
+ 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
1212
+ 'REAL*8, intent(in), dimension(1:n) :: x\n'
1213
+ 'REAL*8, intent(inout), dimension(1:m) :: y\n'
1214
+ 'end subroutine\n'
1215
+ 'end interface\n'
1216
+ )
1217
+
1218
+
1219
+ def test_partial_loops_f():
1220
+ # check that loop boundaries are determined by Idx, and array strides
1221
+ # determined by shape of IndexedBase object.
1222
+ from sympy.tensor import IndexedBase, Idx
1223
+ from sympy.core.symbol import symbols
1224
+ n, m, o, p = symbols('n m o p', integer=True)
1225
+ A = IndexedBase('A', shape=(m, p))
1226
+ x = IndexedBase('x')
1227
+ y = IndexedBase('y')
1228
+ i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
1229
+ j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
1230
+
1231
+ (f1, code), (f2, interface) = codegen(
1232
+ ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
1233
+
1234
+ expected = (
1235
+ 'subroutine matrix_vector(A, m, n, o, p, x, y)\n'
1236
+ 'implicit none\n'
1237
+ 'INTEGER*4, intent(in) :: m\n'
1238
+ 'INTEGER*4, intent(in) :: n\n'
1239
+ 'INTEGER*4, intent(in) :: o\n'
1240
+ 'INTEGER*4, intent(in) :: p\n'
1241
+ 'REAL*8, intent(in), dimension(1:m, 1:p) :: A\n'
1242
+ 'REAL*8, intent(in), dimension(1:n) :: x\n'
1243
+ 'REAL*8, intent(out), dimension(1:%(iup-ilow)s) :: y\n'
1244
+ 'INTEGER*4 :: i\n'
1245
+ 'INTEGER*4 :: j\n'
1246
+ 'do i = %(ilow)s, %(iup)s\n'
1247
+ ' y(i) = 0\n'
1248
+ 'end do\n'
1249
+ 'do i = %(ilow)s, %(iup)s\n'
1250
+ ' do j = 1, n\n'
1251
+ ' y(i) = %(rhs)s + y(i)\n'
1252
+ ' end do\n'
1253
+ 'end do\n'
1254
+ 'end subroutine\n'
1255
+ ) % {
1256
+ 'rhs': '%(rhs)s',
1257
+ 'iup': str(m - 4),
1258
+ 'ilow': str(1 + o),
1259
+ 'iup-ilow': str(m - 4 - o)
1260
+ }
1261
+
1262
+ assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
1263
+ code == expected % {'rhs': 'x(j)*A(i, j)'}
1264
+
1265
+
1266
+ def test_output_arg_f():
1267
+ from sympy.core.relational import Equality
1268
+ from sympy.functions.elementary.trigonometric import (cos, sin)
1269
+ x, y, z = symbols("x,y,z")
1270
+ r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
1271
+ c = FCodeGen()
1272
+ result = c.write([r], "test", header=False, empty=False)
1273
+ assert result[0][0] == "test.f90"
1274
+ assert result[0][1] == (
1275
+ 'REAL*8 function foo(x, y)\n'
1276
+ 'implicit none\n'
1277
+ 'REAL*8, intent(in) :: x\n'
1278
+ 'REAL*8, intent(out) :: y\n'
1279
+ 'y = sin(x)\n'
1280
+ 'foo = cos(x)\n'
1281
+ 'end function\n'
1282
+ )
1283
+
1284
+
1285
+ def test_inline_function():
1286
+ from sympy.tensor import IndexedBase, Idx
1287
+ from sympy.core.symbol import symbols
1288
+ n, m = symbols('n m', integer=True)
1289
+ A, x, y = map(IndexedBase, 'Axy')
1290
+ i = Idx('i', m)
1291
+ p = FCodeGen()
1292
+ func = implemented_function('func', Lambda(n, n*(n + 1)))
1293
+ routine = make_routine('test_inline', Eq(y[i], func(x[i])))
1294
+ code = get_string(p.dump_f95, [routine])
1295
+ expected = (
1296
+ 'subroutine test_inline(m, x, y)\n'
1297
+ 'implicit none\n'
1298
+ 'INTEGER*4, intent(in) :: m\n'
1299
+ 'REAL*8, intent(in), dimension(1:m) :: x\n'
1300
+ 'REAL*8, intent(out), dimension(1:m) :: y\n'
1301
+ 'INTEGER*4 :: i\n'
1302
+ 'do i = 1, m\n'
1303
+ ' y(i) = %s*%s\n'
1304
+ 'end do\n'
1305
+ 'end subroutine\n'
1306
+ )
1307
+ args = ('x(i)', '(x(i) + 1)')
1308
+ assert code == expected % args or\
1309
+ code == expected % args[::-1]
1310
+
1311
+
1312
+ def test_f_code_call_signature_wrap():
1313
+ # Issue #7934
1314
+ x = symbols('x:20')
1315
+ expr = 0
1316
+ for sym in x:
1317
+ expr += sym
1318
+ routine = make_routine("test", expr)
1319
+ code_gen = FCodeGen()
1320
+ source = get_string(code_gen.dump_f95, [routine])
1321
+ expected = """\
1322
+ REAL*8 function test(x0, x1, x10, x11, x12, x13, x14, x15, x16, x17, x18, &
1323
+ x19, x2, x3, x4, x5, x6, x7, x8, x9)
1324
+ implicit none
1325
+ REAL*8, intent(in) :: x0
1326
+ REAL*8, intent(in) :: x1
1327
+ REAL*8, intent(in) :: x10
1328
+ REAL*8, intent(in) :: x11
1329
+ REAL*8, intent(in) :: x12
1330
+ REAL*8, intent(in) :: x13
1331
+ REAL*8, intent(in) :: x14
1332
+ REAL*8, intent(in) :: x15
1333
+ REAL*8, intent(in) :: x16
1334
+ REAL*8, intent(in) :: x17
1335
+ REAL*8, intent(in) :: x18
1336
+ REAL*8, intent(in) :: x19
1337
+ REAL*8, intent(in) :: x2
1338
+ REAL*8, intent(in) :: x3
1339
+ REAL*8, intent(in) :: x4
1340
+ REAL*8, intent(in) :: x5
1341
+ REAL*8, intent(in) :: x6
1342
+ REAL*8, intent(in) :: x7
1343
+ REAL*8, intent(in) :: x8
1344
+ REAL*8, intent(in) :: x9
1345
+ test = x0 + x1 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + &
1346
+ x19 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
1347
+ end function
1348
+ """
1349
+ assert source == expected
1350
+
1351
+
1352
+ def test_check_case():
1353
+ x, X = symbols('x,X')
1354
+ raises(CodeGenError, lambda: codegen(('test', x*X), 'f95', 'prefix'))
1355
+
1356
+
1357
+ def test_check_case_false_positive():
1358
+ # The upper case/lower case exception should not be triggered by SymPy
1359
+ # objects that differ only because of assumptions. (It may be useful to
1360
+ # have a check for that as well, but here we only want to test against
1361
+ # false positives with respect to case checking.)
1362
+ x1 = symbols('x')
1363
+ x2 = symbols('x', my_assumption=True)
1364
+ try:
1365
+ codegen(('test', x1*x2), 'f95', 'prefix')
1366
+ except CodeGenError as e:
1367
+ if e.args[0].startswith("Fortran ignores case."):
1368
+ raise AssertionError("This exception should not be raised!")
1369
+
1370
+
1371
+ def test_c_fortran_omit_routine_name():
1372
+ x, y = symbols("x,y")
1373
+ name_expr = [("foo", 2*x)]
1374
+ result = codegen(name_expr, "F95", header=False, empty=False)
1375
+ expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
1376
+ assert result[0][1] == expresult[0][1]
1377
+
1378
+ name_expr = ("foo", x*y)
1379
+ result = codegen(name_expr, "F95", header=False, empty=False)
1380
+ expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
1381
+ assert result[0][1] == expresult[0][1]
1382
+
1383
+ name_expr = ("foo", Matrix([[x, y], [x+y, x-y]]))
1384
+ result = codegen(name_expr, "C89", header=False, empty=False)
1385
+ expresult = codegen(name_expr, "C89", "foo", header=False, empty=False)
1386
+ assert result[0][1] == expresult[0][1]
1387
+
1388
+
1389
+ def test_fcode_matrix_output():
1390
+ x, y, z = symbols('x,y,z')
1391
+ e1 = x + y
1392
+ e2 = Matrix([[x, y], [z, 16]])
1393
+ name_expr = ("test", (e1, e2))
1394
+ result = codegen(name_expr, "f95", "test", header=False, empty=False)
1395
+ source = result[0][1]
1396
+ expected = (
1397
+ "REAL*8 function test(x, y, z, out_%(hash)s)\n"
1398
+ "implicit none\n"
1399
+ "REAL*8, intent(in) :: x\n"
1400
+ "REAL*8, intent(in) :: y\n"
1401
+ "REAL*8, intent(in) :: z\n"
1402
+ "REAL*8, intent(out), dimension(1:2, 1:2) :: out_%(hash)s\n"
1403
+ "out_%(hash)s(1, 1) = x\n"
1404
+ "out_%(hash)s(2, 1) = z\n"
1405
+ "out_%(hash)s(1, 2) = y\n"
1406
+ "out_%(hash)s(2, 2) = 16\n"
1407
+ "test = x + y\n"
1408
+ "end function\n"
1409
+ )
1410
+ # look for the magic number
1411
+ a = source.splitlines()[5]
1412
+ b = a.split('_')
1413
+ out = b[1]
1414
+ expected = expected % {'hash': out}
1415
+ assert source == expected
1416
+
1417
+
1418
+ def test_fcode_results_named_ordered():
1419
+ x, y, z = symbols('x,y,z')
1420
+ B, C = symbols('B,C')
1421
+ A = MatrixSymbol('A', 1, 3)
1422
+ expr1 = Equality(A, Matrix([[1, 2, x]]))
1423
+ expr2 = Equality(C, (x + y)*z)
1424
+ expr3 = Equality(B, 2*x)
1425
+ name_expr = ("test", [expr1, expr2, expr3])
1426
+ result = codegen(name_expr, "f95", "test", header=False, empty=False,
1427
+ argument_sequence=(x, z, y, C, A, B))
1428
+ source = result[0][1]
1429
+ expected = (
1430
+ "subroutine test(x, z, y, C, A, B)\n"
1431
+ "implicit none\n"
1432
+ "REAL*8, intent(in) :: x\n"
1433
+ "REAL*8, intent(in) :: z\n"
1434
+ "REAL*8, intent(in) :: y\n"
1435
+ "REAL*8, intent(out) :: C\n"
1436
+ "REAL*8, intent(out) :: B\n"
1437
+ "REAL*8, intent(out), dimension(1:1, 1:3) :: A\n"
1438
+ "C = z*(x + y)\n"
1439
+ "A(1, 1) = 1\n"
1440
+ "A(1, 2) = 2\n"
1441
+ "A(1, 3) = x\n"
1442
+ "B = 2*x\n"
1443
+ "end subroutine\n"
1444
+ )
1445
+ assert source == expected
1446
+
1447
+
1448
+ def test_fcode_matrixsymbol_slice():
1449
+ A = MatrixSymbol('A', 2, 3)
1450
+ B = MatrixSymbol('B', 1, 3)
1451
+ C = MatrixSymbol('C', 1, 3)
1452
+ D = MatrixSymbol('D', 2, 1)
1453
+ name_expr = ("test", [Equality(B, A[0, :]),
1454
+ Equality(C, A[1, :]),
1455
+ Equality(D, A[:, 2])])
1456
+ result = codegen(name_expr, "f95", "test", header=False, empty=False)
1457
+ source = result[0][1]
1458
+ expected = (
1459
+ "subroutine test(A, B, C, D)\n"
1460
+ "implicit none\n"
1461
+ "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
1462
+ "REAL*8, intent(out), dimension(1:1, 1:3) :: B\n"
1463
+ "REAL*8, intent(out), dimension(1:1, 1:3) :: C\n"
1464
+ "REAL*8, intent(out), dimension(1:2, 1:1) :: D\n"
1465
+ "B(1, 1) = A(1, 1)\n"
1466
+ "B(1, 2) = A(1, 2)\n"
1467
+ "B(1, 3) = A(1, 3)\n"
1468
+ "C(1, 1) = A(2, 1)\n"
1469
+ "C(1, 2) = A(2, 2)\n"
1470
+ "C(1, 3) = A(2, 3)\n"
1471
+ "D(1, 1) = A(1, 3)\n"
1472
+ "D(2, 1) = A(2, 3)\n"
1473
+ "end subroutine\n"
1474
+ )
1475
+ assert source == expected
1476
+
1477
+
1478
+ def test_fcode_matrixsymbol_slice_autoname():
1479
+ # see issue #8093
1480
+ A = MatrixSymbol('A', 2, 3)
1481
+ name_expr = ("test", A[:, 1])
1482
+ result = codegen(name_expr, "f95", "test", header=False, empty=False)
1483
+ source = result[0][1]
1484
+ expected = (
1485
+ "subroutine test(A, out_%(hash)s)\n"
1486
+ "implicit none\n"
1487
+ "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
1488
+ "REAL*8, intent(out), dimension(1:2, 1:1) :: out_%(hash)s\n"
1489
+ "out_%(hash)s(1, 1) = A(1, 2)\n"
1490
+ "out_%(hash)s(2, 1) = A(2, 2)\n"
1491
+ "end subroutine\n"
1492
+ )
1493
+ # look for the magic number
1494
+ a = source.splitlines()[3]
1495
+ b = a.split('_')
1496
+ out = b[1]
1497
+ expected = expected % {'hash': out}
1498
+ assert source == expected
1499
+
1500
+
1501
+ def test_global_vars():
1502
+ x, y, z, t = symbols("x y z t")
1503
+ result = codegen(('f', x*y), "F95", header=False, empty=False,
1504
+ global_vars=(y,))
1505
+ source = result[0][1]
1506
+ expected = (
1507
+ "REAL*8 function f(x)\n"
1508
+ "implicit none\n"
1509
+ "REAL*8, intent(in) :: x\n"
1510
+ "f = x*y\n"
1511
+ "end function\n"
1512
+ )
1513
+ assert source == expected
1514
+
1515
+ expected = (
1516
+ '#include "f.h"\n'
1517
+ '#include <math.h>\n'
1518
+ 'double f(double x, double y) {\n'
1519
+ ' double f_result;\n'
1520
+ ' f_result = x*y + z;\n'
1521
+ ' return f_result;\n'
1522
+ '}\n'
1523
+ )
1524
+ result = codegen(('f', x*y+z), "C", header=False, empty=False,
1525
+ global_vars=(z, t))
1526
+ source = result[0][1]
1527
+ assert source == expected
1528
+
1529
+ def test_custom_codegen():
1530
+ from sympy.printing.c import C99CodePrinter
1531
+ from sympy.functions.elementary.exponential import exp
1532
+
1533
+ printer = C99CodePrinter(settings={'user_functions': {'exp': 'fastexp'}})
1534
+
1535
+ x, y = symbols('x y')
1536
+ expr = exp(x + y)
1537
+
1538
+ # replace math.h with a different header
1539
+ gen = C99CodeGen(printer=printer,
1540
+ preprocessor_statements=['#include "fastexp.h"'])
1541
+
1542
+ expected = (
1543
+ '#include "expr.h"\n'
1544
+ '#include "fastexp.h"\n'
1545
+ 'double expr(double x, double y) {\n'
1546
+ ' double expr_result;\n'
1547
+ ' expr_result = fastexp(x + y);\n'
1548
+ ' return expr_result;\n'
1549
+ '}\n'
1550
+ )
1551
+
1552
+ result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
1553
+ source = result[0][1]
1554
+ assert source == expected
1555
+
1556
+ # use both math.h and an external header
1557
+ gen = C99CodeGen(printer=printer)
1558
+ gen.preprocessor_statements.append('#include "fastexp.h"')
1559
+
1560
+ expected = (
1561
+ '#include "expr.h"\n'
1562
+ '#include <math.h>\n'
1563
+ '#include "fastexp.h"\n'
1564
+ 'double expr(double x, double y) {\n'
1565
+ ' double expr_result;\n'
1566
+ ' expr_result = fastexp(x + y);\n'
1567
+ ' return expr_result;\n'
1568
+ '}\n'
1569
+ )
1570
+
1571
+ result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
1572
+ source = result[0][1]
1573
+ assert source == expected
1574
+
1575
+ def test_c_with_printer():
1576
+ # issue 13586
1577
+ from sympy.printing.c import C99CodePrinter
1578
+ class CustomPrinter(C99CodePrinter):
1579
+ def _print_Pow(self, expr):
1580
+ return "fastpow({}, {})".format(self._print(expr.base),
1581
+ self._print(expr.exp))
1582
+
1583
+ x = symbols('x')
1584
+ expr = x**3
1585
+ expected =[
1586
+ ("file.c",
1587
+ "#include \"file.h\"\n"
1588
+ "#include <math.h>\n"
1589
+ "double test(double x) {\n"
1590
+ " double test_result;\n"
1591
+ " test_result = fastpow(x, 3);\n"
1592
+ " return test_result;\n"
1593
+ "}\n"),
1594
+ ("file.h",
1595
+ "#ifndef PROJECT__FILE__H\n"
1596
+ "#define PROJECT__FILE__H\n"
1597
+ "double test(double x);\n"
1598
+ "#endif\n")
1599
+ ]
1600
+ result = codegen(("test", expr), "C","file", header=False, empty=False, printer = CustomPrinter())
1601
+ assert result == expected
1602
+
1603
+
1604
+ def test_fcode_complex():
1605
+ import sympy.utilities.codegen
1606
+ sympy.utilities.codegen.COMPLEX_ALLOWED = True
1607
+ x = Symbol('x', real=True)
1608
+ y = Symbol('y',real=True)
1609
+ result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
1610
+ source = (result[0][1])
1611
+ expected = (
1612
+ "REAL*8 function test(x, y)\n"
1613
+ "implicit none\n"
1614
+ "REAL*8, intent(in) :: x\n"
1615
+ "REAL*8, intent(in) :: y\n"
1616
+ "test = x + y\n"
1617
+ "end function\n")
1618
+ assert source == expected
1619
+ x = Symbol('x')
1620
+ y = Symbol('y',real=True)
1621
+ result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
1622
+ source = (result[0][1])
1623
+ expected = (
1624
+ "COMPLEX*16 function test(x, y)\n"
1625
+ "implicit none\n"
1626
+ "COMPLEX*16, intent(in) :: x\n"
1627
+ "REAL*8, intent(in) :: y\n"
1628
+ "test = x + y\n"
1629
+ "end function\n"
1630
+ )
1631
+ assert source==expected
1632
+ sympy.utilities.codegen.COMPLEX_ALLOWED = False
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_codegen_julia.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+
3
+ from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
4
+ from sympy.core.relational import Equality
5
+ from sympy.functions.elementary.piecewise import Piecewise
6
+ from sympy.matrices import Matrix, MatrixSymbol
7
+ from sympy.utilities.codegen import JuliaCodeGen, codegen, make_routine
8
+ from sympy.testing.pytest import XFAIL
9
+ import sympy
10
+
11
+
12
+ x, y, z = symbols('x,y,z')
13
+
14
+
15
+ def test_empty_jl_code():
16
+ code_gen = JuliaCodeGen()
17
+ output = StringIO()
18
+ code_gen.dump_jl([], output, "file", header=False, empty=False)
19
+ source = output.getvalue()
20
+ assert source == ""
21
+
22
+
23
+ def test_jl_simple_code():
24
+ name_expr = ("test", (x + y)*z)
25
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
26
+ assert result[0] == "test.jl"
27
+ source = result[1]
28
+ expected = (
29
+ "function test(x, y, z)\n"
30
+ " out1 = z .* (x + y)\n"
31
+ " return out1\n"
32
+ "end\n"
33
+ )
34
+ assert source == expected
35
+
36
+
37
+ def test_jl_simple_code_with_header():
38
+ name_expr = ("test", (x + y)*z)
39
+ result, = codegen(name_expr, "Julia", header=True, empty=False)
40
+ assert result[0] == "test.jl"
41
+ source = result[1]
42
+ expected = (
43
+ "# Code generated with SymPy " + sympy.__version__ + "\n"
44
+ "#\n"
45
+ "# See http://www.sympy.org/ for more information.\n"
46
+ "#\n"
47
+ "# This file is part of 'project'\n"
48
+ "function test(x, y, z)\n"
49
+ " out1 = z .* (x + y)\n"
50
+ " return out1\n"
51
+ "end\n"
52
+ )
53
+ assert source == expected
54
+
55
+
56
+ def test_jl_simple_code_nameout():
57
+ expr = Equality(z, (x + y))
58
+ name_expr = ("test", expr)
59
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
60
+ source = result[1]
61
+ expected = (
62
+ "function test(x, y)\n"
63
+ " z = x + y\n"
64
+ " return z\n"
65
+ "end\n"
66
+ )
67
+ assert source == expected
68
+
69
+
70
+ def test_jl_numbersymbol():
71
+ name_expr = ("test", pi**Catalan)
72
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
73
+ source = result[1]
74
+ expected = (
75
+ "function test()\n"
76
+ " out1 = pi ^ catalan\n"
77
+ " return out1\n"
78
+ "end\n"
79
+ )
80
+ assert source == expected
81
+
82
+
83
+ @XFAIL
84
+ def test_jl_numbersymbol_no_inline():
85
+ # FIXME: how to pass inline=False to the JuliaCodePrinter?
86
+ name_expr = ("test", [pi**Catalan, EulerGamma])
87
+ result, = codegen(name_expr, "Julia", header=False,
88
+ empty=False, inline=False)
89
+ source = result[1]
90
+ expected = (
91
+ "function test()\n"
92
+ " Catalan = 0.915965594177219\n"
93
+ " EulerGamma = 0.5772156649015329\n"
94
+ " out1 = pi ^ Catalan\n"
95
+ " out2 = EulerGamma\n"
96
+ " return out1, out2\n"
97
+ "end\n"
98
+ )
99
+ assert source == expected
100
+
101
+
102
+ def test_jl_code_argument_order():
103
+ expr = x + y
104
+ routine = make_routine("test", expr, argument_sequence=[z, x, y], language="julia")
105
+ code_gen = JuliaCodeGen()
106
+ output = StringIO()
107
+ code_gen.dump_jl([routine], output, "test", header=False, empty=False)
108
+ source = output.getvalue()
109
+ expected = (
110
+ "function test(z, x, y)\n"
111
+ " out1 = x + y\n"
112
+ " return out1\n"
113
+ "end\n"
114
+ )
115
+ assert source == expected
116
+
117
+
118
+ def test_multiple_results_m():
119
+ # Here the output order is the input order
120
+ expr1 = (x + y)*z
121
+ expr2 = (x - y)*z
122
+ name_expr = ("test", [expr1, expr2])
123
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
124
+ source = result[1]
125
+ expected = (
126
+ "function test(x, y, z)\n"
127
+ " out1 = z .* (x + y)\n"
128
+ " out2 = z .* (x - y)\n"
129
+ " return out1, out2\n"
130
+ "end\n"
131
+ )
132
+ assert source == expected
133
+
134
+
135
+ def test_results_named_unordered():
136
+ # Here output order is based on name_expr
137
+ A, B, C = symbols('A,B,C')
138
+ expr1 = Equality(C, (x + y)*z)
139
+ expr2 = Equality(A, (x - y)*z)
140
+ expr3 = Equality(B, 2*x)
141
+ name_expr = ("test", [expr1, expr2, expr3])
142
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
143
+ source = result[1]
144
+ expected = (
145
+ "function test(x, y, z)\n"
146
+ " C = z .* (x + y)\n"
147
+ " A = z .* (x - y)\n"
148
+ " B = 2 * x\n"
149
+ " return C, A, B\n"
150
+ "end\n"
151
+ )
152
+ assert source == expected
153
+
154
+
155
+ def test_results_named_ordered():
156
+ A, B, C = symbols('A,B,C')
157
+ expr1 = Equality(C, (x + y)*z)
158
+ expr2 = Equality(A, (x - y)*z)
159
+ expr3 = Equality(B, 2*x)
160
+ name_expr = ("test", [expr1, expr2, expr3])
161
+ result = codegen(name_expr, "Julia", header=False, empty=False,
162
+ argument_sequence=(x, z, y))
163
+ assert result[0][0] == "test.jl"
164
+ source = result[0][1]
165
+ expected = (
166
+ "function test(x, z, y)\n"
167
+ " C = z .* (x + y)\n"
168
+ " A = z .* (x - y)\n"
169
+ " B = 2 * x\n"
170
+ " return C, A, B\n"
171
+ "end\n"
172
+ )
173
+ assert source == expected
174
+
175
+
176
+ def test_complicated_jl_codegen():
177
+ from sympy.functions.elementary.trigonometric import (cos, sin, tan)
178
+ name_expr = ("testlong",
179
+ [ ((sin(x) + cos(y) + tan(z))**3).expand(),
180
+ cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
181
+ ])
182
+ result = codegen(name_expr, "Julia", header=False, empty=False)
183
+ assert result[0][0] == "testlong.jl"
184
+ source = result[0][1]
185
+ expected = (
186
+ "function testlong(x, y, z)\n"
187
+ " out1 = sin(x) .^ 3 + 3 * sin(x) .^ 2 .* cos(y) + 3 * sin(x) .^ 2 .* tan(z)"
188
+ " + 3 * sin(x) .* cos(y) .^ 2 + 6 * sin(x) .* cos(y) .* tan(z) + 3 * sin(x) .* tan(z) .^ 2"
189
+ " + cos(y) .^ 3 + 3 * cos(y) .^ 2 .* tan(z) + 3 * cos(y) .* tan(z) .^ 2 + tan(z) .^ 3\n"
190
+ " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n"
191
+ " return out1, out2\n"
192
+ "end\n"
193
+ )
194
+ assert source == expected
195
+
196
+
197
+ def test_jl_output_arg_mixed_unordered():
198
+ # named outputs are alphabetical, unnamed output appear in the given order
199
+ from sympy.functions.elementary.trigonometric import (cos, sin)
200
+ a = symbols("a")
201
+ name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
202
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
203
+ assert result[0] == "foo.jl"
204
+ source = result[1]
205
+ expected = (
206
+ 'function foo(x)\n'
207
+ ' out1 = cos(2 * x)\n'
208
+ ' y = sin(x)\n'
209
+ ' out3 = cos(x)\n'
210
+ ' a = sin(2 * x)\n'
211
+ ' return out1, y, out3, a\n'
212
+ 'end\n'
213
+ )
214
+ assert source == expected
215
+
216
+
217
+ def test_jl_piecewise_():
218
+ pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
219
+ name_expr = ("pwtest", pw)
220
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
221
+ source = result[1]
222
+ expected = (
223
+ "function pwtest(x)\n"
224
+ " out1 = ((x < -1) ? (0) :\n"
225
+ " (x <= 1) ? (x .^ 2) :\n"
226
+ " (x > 1) ? (2 - x) : (1))\n"
227
+ " return out1\n"
228
+ "end\n"
229
+ )
230
+ assert source == expected
231
+
232
+
233
+ @XFAIL
234
+ def test_jl_piecewise_no_inline():
235
+ # FIXME: how to pass inline=False to the JuliaCodePrinter?
236
+ pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
237
+ name_expr = ("pwtest", pw)
238
+ result, = codegen(name_expr, "Julia", header=False, empty=False,
239
+ inline=False)
240
+ source = result[1]
241
+ expected = (
242
+ "function pwtest(x)\n"
243
+ " if (x < -1)\n"
244
+ " out1 = 0\n"
245
+ " elseif (x <= 1)\n"
246
+ " out1 = x .^ 2\n"
247
+ " elseif (x > 1)\n"
248
+ " out1 = -x + 2\n"
249
+ " else\n"
250
+ " out1 = 1\n"
251
+ " end\n"
252
+ " return out1\n"
253
+ "end\n"
254
+ )
255
+ assert source == expected
256
+
257
+
258
+ def test_jl_multifcns_per_file():
259
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
260
+ result = codegen(name_expr, "Julia", header=False, empty=False)
261
+ assert result[0][0] == "foo.jl"
262
+ source = result[0][1]
263
+ expected = (
264
+ "function foo(x, y)\n"
265
+ " out1 = 2 * x\n"
266
+ " out2 = 3 * y\n"
267
+ " return out1, out2\n"
268
+ "end\n"
269
+ "function bar(y)\n"
270
+ " out1 = y .^ 2\n"
271
+ " out2 = 4 * y\n"
272
+ " return out1, out2\n"
273
+ "end\n"
274
+ )
275
+ assert source == expected
276
+
277
+
278
+ def test_jl_multifcns_per_file_w_header():
279
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
280
+ result = codegen(name_expr, "Julia", header=True, empty=False)
281
+ assert result[0][0] == "foo.jl"
282
+ source = result[0][1]
283
+ expected = (
284
+ "# Code generated with SymPy " + sympy.__version__ + "\n"
285
+ "#\n"
286
+ "# See http://www.sympy.org/ for more information.\n"
287
+ "#\n"
288
+ "# This file is part of 'project'\n"
289
+ "function foo(x, y)\n"
290
+ " out1 = 2 * x\n"
291
+ " out2 = 3 * y\n"
292
+ " return out1, out2\n"
293
+ "end\n"
294
+ "function bar(y)\n"
295
+ " out1 = y .^ 2\n"
296
+ " out2 = 4 * y\n"
297
+ " return out1, out2\n"
298
+ "end\n"
299
+ )
300
+ assert source == expected
301
+
302
+
303
+ def test_jl_filename_match_prefix():
304
+ name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
305
+ result, = codegen(name_expr, "Julia", prefix="baz", header=False,
306
+ empty=False)
307
+ assert result[0] == "baz.jl"
308
+
309
+
310
+ def test_jl_matrix_named():
311
+ e2 = Matrix([[x, 2*y, pi*z]])
312
+ name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
313
+ result = codegen(name_expr, "Julia", header=False, empty=False)
314
+ assert result[0][0] == "test.jl"
315
+ source = result[0][1]
316
+ expected = (
317
+ "function test(x, y, z)\n"
318
+ " myout1 = [x 2 * y pi * z]\n"
319
+ " return myout1\n"
320
+ "end\n"
321
+ )
322
+ assert source == expected
323
+
324
+
325
+ def test_jl_matrix_named_matsym():
326
+ myout1 = MatrixSymbol('myout1', 1, 3)
327
+ e2 = Matrix([[x, 2*y, pi*z]])
328
+ name_expr = ("test", Equality(myout1, e2, evaluate=False))
329
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
330
+ source = result[1]
331
+ expected = (
332
+ "function test(x, y, z)\n"
333
+ " myout1 = [x 2 * y pi * z]\n"
334
+ " return myout1\n"
335
+ "end\n"
336
+ )
337
+ assert source == expected
338
+
339
+
340
+ def test_jl_matrix_output_autoname():
341
+ expr = Matrix([[x, x+y, 3]])
342
+ name_expr = ("test", expr)
343
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
344
+ source = result[1]
345
+ expected = (
346
+ "function test(x, y)\n"
347
+ " out1 = [x x + y 3]\n"
348
+ " return out1\n"
349
+ "end\n"
350
+ )
351
+ assert source == expected
352
+
353
+
354
+ def test_jl_matrix_output_autoname_2():
355
+ e1 = (x + y)
356
+ e2 = Matrix([[2*x, 2*y, 2*z]])
357
+ e3 = Matrix([[x], [y], [z]])
358
+ e4 = Matrix([[x, y], [z, 16]])
359
+ name_expr = ("test", (e1, e2, e3, e4))
360
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
361
+ source = result[1]
362
+ expected = (
363
+ "function test(x, y, z)\n"
364
+ " out1 = x + y\n"
365
+ " out2 = [2 * x 2 * y 2 * z]\n"
366
+ " out3 = [x, y, z]\n"
367
+ " out4 = [x y;\n"
368
+ " z 16]\n"
369
+ " return out1, out2, out3, out4\n"
370
+ "end\n"
371
+ )
372
+ assert source == expected
373
+
374
+
375
+ def test_jl_results_matrix_named_ordered():
376
+ B, C = symbols('B,C')
377
+ A = MatrixSymbol('A', 1, 3)
378
+ expr1 = Equality(C, (x + y)*z)
379
+ expr2 = Equality(A, Matrix([[1, 2, x]]))
380
+ expr3 = Equality(B, 2*x)
381
+ name_expr = ("test", [expr1, expr2, expr3])
382
+ result, = codegen(name_expr, "Julia", header=False, empty=False,
383
+ argument_sequence=(x, z, y))
384
+ source = result[1]
385
+ expected = (
386
+ "function test(x, z, y)\n"
387
+ " C = z .* (x + y)\n"
388
+ " A = [1 2 x]\n"
389
+ " B = 2 * x\n"
390
+ " return C, A, B\n"
391
+ "end\n"
392
+ )
393
+ assert source == expected
394
+
395
+
396
+ def test_jl_matrixsymbol_slice():
397
+ A = MatrixSymbol('A', 2, 3)
398
+ B = MatrixSymbol('B', 1, 3)
399
+ C = MatrixSymbol('C', 1, 3)
400
+ D = MatrixSymbol('D', 2, 1)
401
+ name_expr = ("test", [Equality(B, A[0, :]),
402
+ Equality(C, A[1, :]),
403
+ Equality(D, A[:, 2])])
404
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
405
+ source = result[1]
406
+ expected = (
407
+ "function test(A)\n"
408
+ " B = A[1,:]\n"
409
+ " C = A[2,:]\n"
410
+ " D = A[:,3]\n"
411
+ " return B, C, D\n"
412
+ "end\n"
413
+ )
414
+ assert source == expected
415
+
416
+
417
+ def test_jl_matrixsymbol_slice2():
418
+ A = MatrixSymbol('A', 3, 4)
419
+ B = MatrixSymbol('B', 2, 2)
420
+ C = MatrixSymbol('C', 2, 2)
421
+ name_expr = ("test", [Equality(B, A[0:2, 0:2]),
422
+ Equality(C, A[0:2, 1:3])])
423
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
424
+ source = result[1]
425
+ expected = (
426
+ "function test(A)\n"
427
+ " B = A[1:2,1:2]\n"
428
+ " C = A[1:2,2:3]\n"
429
+ " return B, C\n"
430
+ "end\n"
431
+ )
432
+ assert source == expected
433
+
434
+
435
+ def test_jl_matrixsymbol_slice3():
436
+ A = MatrixSymbol('A', 8, 7)
437
+ B = MatrixSymbol('B', 2, 2)
438
+ C = MatrixSymbol('C', 4, 2)
439
+ name_expr = ("test", [Equality(B, A[6:, 1::3]),
440
+ Equality(C, A[::2, ::3])])
441
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
442
+ source = result[1]
443
+ expected = (
444
+ "function test(A)\n"
445
+ " B = A[7:end,2:3:end]\n"
446
+ " C = A[1:2:end,1:3:end]\n"
447
+ " return B, C\n"
448
+ "end\n"
449
+ )
450
+ assert source == expected
451
+
452
+
453
+ def test_jl_matrixsymbol_slice_autoname():
454
+ A = MatrixSymbol('A', 2, 3)
455
+ B = MatrixSymbol('B', 1, 3)
456
+ name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
457
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
458
+ source = result[1]
459
+ expected = (
460
+ "function test(A)\n"
461
+ " B = A[1,:]\n"
462
+ " out2 = A[2,:]\n"
463
+ " out3 = A[:,1]\n"
464
+ " out4 = A[:,2]\n"
465
+ " return B, out2, out3, out4\n"
466
+ "end\n"
467
+ )
468
+ assert source == expected
469
+
470
+
471
+ def test_jl_loops():
472
+ # Note: an Julia programmer would probably vectorize this across one or
473
+ # more dimensions. Also, size(A) would be used rather than passing in m
474
+ # and n. Perhaps users would expect us to vectorize automatically here?
475
+ # Or is it possible to represent such things using IndexedBase?
476
+ from sympy.tensor import IndexedBase, Idx
477
+ from sympy.core.symbol import symbols
478
+ n, m = symbols('n m', integer=True)
479
+ A = IndexedBase('A')
480
+ x = IndexedBase('x')
481
+ y = IndexedBase('y')
482
+ i = Idx('i', m)
483
+ j = Idx('j', n)
484
+ result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Julia",
485
+ header=False, empty=False)
486
+ source = result[1]
487
+ expected = (
488
+ 'function mat_vec_mult(y, A, m, n, x)\n'
489
+ ' for i = 1:m\n'
490
+ ' y[i] = 0\n'
491
+ ' end\n'
492
+ ' for i = 1:m\n'
493
+ ' for j = 1:n\n'
494
+ ' y[i] = %(rhs)s + y[i]\n'
495
+ ' end\n'
496
+ ' end\n'
497
+ ' return y\n'
498
+ 'end\n'
499
+ )
500
+ assert (source == expected % {'rhs': 'A[%s,%s] .* x[j]' % (i, j)} or
501
+ source == expected % {'rhs': 'x[j] .* A[%s,%s]' % (i, j)})
502
+
503
+
504
+ def test_jl_tensor_loops_multiple_contractions():
505
+ # see comments in previous test about vectorizing
506
+ from sympy.tensor import IndexedBase, Idx
507
+ from sympy.core.symbol import symbols
508
+ n, m, o, p = symbols('n m o p', integer=True)
509
+ A = IndexedBase('A')
510
+ B = IndexedBase('B')
511
+ y = IndexedBase('y')
512
+ i = Idx('i', m)
513
+ j = Idx('j', n)
514
+ k = Idx('k', o)
515
+ l = Idx('l', p)
516
+ result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
517
+ "Julia", header=False, empty=False)
518
+ source = result[1]
519
+ expected = (
520
+ 'function tensorthing(y, A, B, m, n, o, p)\n'
521
+ ' for i = 1:m\n'
522
+ ' y[i] = 0\n'
523
+ ' end\n'
524
+ ' for i = 1:m\n'
525
+ ' for j = 1:n\n'
526
+ ' for k = 1:o\n'
527
+ ' for l = 1:p\n'
528
+ ' y[i] = A[i,j,k,l] .* B[j,k,l] + y[i]\n'
529
+ ' end\n'
530
+ ' end\n'
531
+ ' end\n'
532
+ ' end\n'
533
+ ' return y\n'
534
+ 'end\n'
535
+ )
536
+ assert source == expected
537
+
538
+
539
+ def test_jl_InOutArgument():
540
+ expr = Equality(x, x**2)
541
+ name_expr = ("mysqr", expr)
542
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
543
+ source = result[1]
544
+ expected = (
545
+ "function mysqr(x)\n"
546
+ " x = x .^ 2\n"
547
+ " return x\n"
548
+ "end\n"
549
+ )
550
+ assert source == expected
551
+
552
+
553
+ def test_jl_InOutArgument_order():
554
+ # can specify the order as (x, y)
555
+ expr = Equality(x, x**2 + y)
556
+ name_expr = ("test", expr)
557
+ result, = codegen(name_expr, "Julia", header=False,
558
+ empty=False, argument_sequence=(x,y))
559
+ source = result[1]
560
+ expected = (
561
+ "function test(x, y)\n"
562
+ " x = x .^ 2 + y\n"
563
+ " return x\n"
564
+ "end\n"
565
+ )
566
+ assert source == expected
567
+ # make sure it gives (x, y) not (y, x)
568
+ expr = Equality(x, x**2 + y)
569
+ name_expr = ("test", expr)
570
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
571
+ source = result[1]
572
+ expected = (
573
+ "function test(x, y)\n"
574
+ " x = x .^ 2 + y\n"
575
+ " return x\n"
576
+ "end\n"
577
+ )
578
+ assert source == expected
579
+
580
+
581
+ def test_jl_not_supported():
582
+ f = Function('f')
583
+ name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
584
+ result, = codegen(name_expr, "Julia", header=False, empty=False)
585
+ source = result[1]
586
+ expected = (
587
+ "function test(x)\n"
588
+ " # unsupported: Derivative(f(x), x)\n"
589
+ " # unsupported: zoo\n"
590
+ " out1 = Derivative(f(x), x)\n"
591
+ " out2 = zoo\n"
592
+ " return out1, out2\n"
593
+ "end\n"
594
+ )
595
+ assert source == expected
596
+
597
+
598
+ def test_global_vars_octave():
599
+ x, y, z, t = symbols("x y z t")
600
+ result = codegen(('f', x*y), "Julia", header=False, empty=False,
601
+ global_vars=(y,))
602
+ source = result[0][1]
603
+ expected = (
604
+ "function f(x)\n"
605
+ " out1 = x .* y\n"
606
+ " return out1\n"
607
+ "end\n"
608
+ )
609
+ assert source == expected
610
+
611
+ result = codegen(('f', x*y+z), "Julia", header=False, empty=False,
612
+ argument_sequence=(x, y), global_vars=(z, t))
613
+ source = result[0][1]
614
+ expected = (
615
+ "function f(x, y)\n"
616
+ " out1 = x .* y + z\n"
617
+ " return out1\n"
618
+ "end\n"
619
+ )
620
+ assert source == expected
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_decorator.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+
3
+ from sympy.utilities.decorator import threaded, xthreaded, memoize_property, deprecated
4
+ from sympy.testing.pytest import warns_deprecated_sympy
5
+
6
+ from sympy.core.basic import Basic
7
+ from sympy.core.relational import Eq
8
+ from sympy.matrices.dense import Matrix
9
+
10
+ from sympy.abc import x, y
11
+
12
+
13
+ def test_threaded():
14
+ @threaded
15
+ def function(expr, *args):
16
+ return 2*expr + sum(args)
17
+
18
+ assert function(Matrix([[x, y], [1, x]]), 1, 2) == \
19
+ Matrix([[2*x + 3, 2*y + 3], [5, 2*x + 3]])
20
+
21
+ assert function(Eq(x, y), 1, 2) == Eq(2*x + 3, 2*y + 3)
22
+
23
+ assert function([x, y], 1, 2) == [2*x + 3, 2*y + 3]
24
+ assert function((x, y), 1, 2) == (2*x + 3, 2*y + 3)
25
+
26
+ assert function({x, y}, 1, 2) == {2*x + 3, 2*y + 3}
27
+
28
+ @threaded
29
+ def function(expr, n):
30
+ return expr**n
31
+
32
+ assert function(x + y, 2) == x**2 + y**2
33
+ assert function(x, 2) == x**2
34
+
35
+
36
+ def test_xthreaded():
37
+ @xthreaded
38
+ def function(expr, n):
39
+ return expr**n
40
+
41
+ assert function(x + y, 2) == (x + y)**2
42
+
43
+
44
+ def test_wraps():
45
+ def my_func(x):
46
+ """My function. """
47
+
48
+ my_func.is_my_func = True
49
+
50
+ new_my_func = threaded(my_func)
51
+ new_my_func = wraps(my_func)(new_my_func)
52
+
53
+ assert new_my_func.__name__ == 'my_func'
54
+ assert new_my_func.__doc__ == 'My function. '
55
+ assert hasattr(new_my_func, 'is_my_func')
56
+ assert new_my_func.is_my_func is True
57
+
58
+
59
+ def test_memoize_property():
60
+ class TestMemoize(Basic):
61
+ @memoize_property
62
+ def prop(self):
63
+ return Basic()
64
+
65
+ member = TestMemoize()
66
+ obj1 = member.prop
67
+ obj2 = member.prop
68
+ assert obj1 is obj2
69
+
70
+ def test_deprecated():
71
+ @deprecated('deprecated_function is deprecated',
72
+ deprecated_since_version='1.10',
73
+ # This is the target at the top of the file, which will never
74
+ # go away.
75
+ active_deprecations_target='active-deprecations')
76
+ def deprecated_function(x):
77
+ return x
78
+
79
+ with warns_deprecated_sympy():
80
+ assert deprecated_function(1) == 1
81
+
82
+ @deprecated('deprecated_class is deprecated',
83
+ deprecated_since_version='1.10',
84
+ active_deprecations_target='active-deprecations')
85
+ class deprecated_class:
86
+ pass
87
+
88
+ with warns_deprecated_sympy():
89
+ assert isinstance(deprecated_class(), deprecated_class)
90
+
91
+ # Ensure the class decorator works even when the class never returns
92
+ # itself
93
+ @deprecated('deprecated_class_new is deprecated',
94
+ deprecated_since_version='1.10',
95
+ active_deprecations_target='active-deprecations')
96
+ class deprecated_class_new:
97
+ def __new__(cls, arg):
98
+ return arg
99
+
100
+ with warns_deprecated_sympy():
101
+ assert deprecated_class_new(1) == 1
102
+
103
+ @deprecated('deprecated_class_init is deprecated',
104
+ deprecated_since_version='1.10',
105
+ active_deprecations_target='active-deprecations')
106
+ class deprecated_class_init:
107
+ def __init__(self, arg):
108
+ self.arg = 1
109
+
110
+ with warns_deprecated_sympy():
111
+ assert deprecated_class_init(1).arg == 1
112
+
113
+ @deprecated('deprecated_class_new_init is deprecated',
114
+ deprecated_since_version='1.10',
115
+ active_deprecations_target='active-deprecations')
116
+ class deprecated_class_new_init:
117
+ def __new__(cls, arg):
118
+ if arg == 0:
119
+ return arg
120
+ return object.__new__(cls)
121
+
122
+ def __init__(self, arg):
123
+ self.arg = 1
124
+
125
+ with warns_deprecated_sympy():
126
+ assert deprecated_class_new_init(0) == 0
127
+
128
+ with warns_deprecated_sympy():
129
+ assert deprecated_class_new_init(1).arg == 1
.venv/lib/python3.13/site-packages/sympy/utilities/tests/test_enumerative.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ from itertools import zip_longest
3
+
4
+ from sympy.utilities.enumerative import (
5
+ list_visitor,
6
+ MultisetPartitionTraverser,
7
+ multiset_partitions_taocp
8
+ )
9
+ from sympy.utilities.iterables import _set_partitions
10
+
11
+ # first some functions only useful as test scaffolding - these provide
12
+ # straightforward, but slow reference implementations against which to
13
+ # compare the real versions, and also a comparison to verify that
14
+ # different versions are giving identical results.
15
+
16
+ def part_range_filter(partition_iterator, lb, ub):
17
+ """
18
+ Filters (on the number of parts) a multiset partition enumeration
19
+
20
+ Arguments
21
+ =========
22
+
23
+ lb, and ub are a range (in the Python slice sense) on the lpart
24
+ variable returned from a multiset partition enumeration. Recall
25
+ that lpart is 0-based (it points to the topmost part on the part
26
+ stack), so if you want to return parts of sizes 2,3,4,5 you would
27
+ use lb=1 and ub=5.
28
+ """
29
+ for state in partition_iterator:
30
+ f, lpart, pstack = state
31
+ if lpart >= lb and lpart < ub:
32
+ yield state
33
+
34
+ def multiset_partitions_baseline(multiplicities, components):
35
+ """Enumerates partitions of a multiset
36
+
37
+ Parameters
38
+ ==========
39
+
40
+ multiplicities
41
+ list of integer multiplicities of the components of the multiset.
42
+
43
+ components
44
+ the components (elements) themselves
45
+
46
+ Returns
47
+ =======
48
+
49
+ Set of partitions. Each partition is tuple of parts, and each
50
+ part is a tuple of components (with repeats to indicate
51
+ multiplicity)
52
+
53
+ Notes
54
+ =====
55
+
56
+ Multiset partitions can be created as equivalence classes of set
57
+ partitions, and this function does just that. This approach is
58
+ slow and memory intensive compared to the more advanced algorithms
59
+ available, but the code is simple and easy to understand. Hence
60
+ this routine is strictly for testing -- to provide a
61
+ straightforward baseline against which to regress the production
62
+ versions. (This code is a simplified version of an earlier
63
+ production implementation.)
64
+ """
65
+
66
+ canon = [] # list of components with repeats
67
+ for ct, elem in zip(multiplicities, components):
68
+ canon.extend([elem]*ct)
69
+
70
+ # accumulate the multiset partitions in a set to eliminate dups
71
+ cache = set()
72
+ n = len(canon)
73
+ for nc, q in _set_partitions(n):
74
+ rv = [[] for i in range(nc)]
75
+ for i in range(n):
76
+ rv[q[i]].append(canon[i])
77
+ canonical = tuple(
78
+ sorted([tuple(p) for p in rv]))
79
+ cache.add(canonical)
80
+ return cache
81
+
82
+
83
+ def compare_multiset_w_baseline(multiplicities):
84
+ """
85
+ Enumerates the partitions of multiset with AOCP algorithm and
86
+ baseline implementation, and compare the results.
87
+
88
+ """
89
+ letters = string.ascii_lowercase
90
+ bl_partitions = multiset_partitions_baseline(multiplicities, letters)
91
+
92
+ # The partitions returned by the different algorithms may have
93
+ # their parts in different orders. Also, they generate partitions
94
+ # in different orders. Hence the sorting, and set comparison.
95
+
96
+ aocp_partitions = set()
97
+ for state in multiset_partitions_taocp(multiplicities):
98
+ p1 = tuple(sorted(
99
+ [tuple(p) for p in list_visitor(state, letters)]))
100
+ aocp_partitions.add(p1)
101
+
102
+ assert bl_partitions == aocp_partitions
103
+
104
+ def compare_multiset_states(s1, s2):
105
+ """compare for equality two instances of multiset partition states
106
+
107
+ This is useful for comparing different versions of the algorithm
108
+ to verify correctness."""
109
+ # Comparison is physical, the only use of semantics is to ignore
110
+ # trash off the top of the stack.
111
+ f1, lpart1, pstack1 = s1
112
+ f2, lpart2, pstack2 = s2
113
+
114
+ if (lpart1 == lpart2) and (f1[0:lpart1+1] == f2[0:lpart2+1]):
115
+ if pstack1[0:f1[lpart1+1]] == pstack2[0:f2[lpart2+1]]:
116
+ return True
117
+ return False
118
+
119
+ def test_multiset_partitions_taocp():
120
+ """Compares the output of multiset_partitions_taocp with a baseline
121
+ (set partition based) implementation."""
122
+
123
+ # Test cases should not be too large, since the baseline
124
+ # implementation is fairly slow.
125
+ multiplicities = [2,2]
126
+ compare_multiset_w_baseline(multiplicities)
127
+
128
+ multiplicities = [4,3,1]
129
+ compare_multiset_w_baseline(multiplicities)
130
+
131
+ def test_multiset_partitions_versions():
132
+ """Compares Knuth-based versions of multiset_partitions"""
133
+ multiplicities = [5,2,2,1]
134
+ m = MultisetPartitionTraverser()
135
+ for s1, s2 in zip_longest(m.enum_all(multiplicities),
136
+ multiset_partitions_taocp(multiplicities)):
137
+ assert compare_multiset_states(s1, s2)
138
+
139
+ def subrange_exercise(mult, lb, ub):
140
+ """Compare filter-based and more optimized subrange implementations
141
+
142
+ Helper for tests, called with both small and larger multisets.
143
+ """
144
+ m = MultisetPartitionTraverser()
145
+ assert m.count_partitions(mult) == \
146
+ m.count_partitions_slow(mult)
147
+
148
+ # Note - multiple traversals from the same
149
+ # MultisetPartitionTraverser object cannot execute at the same
150
+ # time, hence make several instances here.
151
+ ma = MultisetPartitionTraverser()
152
+ mc = MultisetPartitionTraverser()
153
+ md = MultisetPartitionTraverser()
154
+
155
+ # Several paths to compute just the size two partitions
156
+ a_it = ma.enum_range(mult, lb, ub)
157
+ b_it = part_range_filter(multiset_partitions_taocp(mult), lb, ub)
158
+ c_it = part_range_filter(mc.enum_small(mult, ub), lb, sum(mult))
159
+ d_it = part_range_filter(md.enum_large(mult, lb), 0, ub)
160
+
161
+ for sa, sb, sc, sd in zip_longest(a_it, b_it, c_it, d_it):
162
+ assert compare_multiset_states(sa, sb)
163
+ assert compare_multiset_states(sa, sc)
164
+ assert compare_multiset_states(sa, sd)
165
+
166
+ def test_subrange():
167
+ # Quick, but doesn't hit some of the corner cases
168
+ mult = [4,4,2,1] # mississippi
169
+ lb = 1
170
+ ub = 2
171
+ subrange_exercise(mult, lb, ub)
172
+
173
+
174
+ def test_subrange_large():
175
+ # takes a second or so, depending on cpu, Python version, etc.
176
+ mult = [6,3,2,1]
177
+ lb = 4
178
+ ub = 7
179
+ subrange_exercise(mult, lb, ub)