Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Merge pull request #658 from MilesCranmer/fix-number-symbol
Browse files- Dockerfile +2 -2
- pysr/export_jax.py +3 -1
- pysr/export_sympy.py +6 -1
- pysr/export_torch.py +5 -0
- pysr/test/test.py +3 -3
- pysr/test/test_jax.py +36 -8
- pysr/test/test_torch.py +35 -1
Dockerfile
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# This builds a dockerfile containing a working copy of PySR
|
2 |
# with all pre-requisites installed.
|
3 |
|
4 |
-
ARG JLVERSION=1.
|
5 |
-
ARG PYVERSION=3.
|
6 |
ARG BASE_IMAGE=bullseye
|
7 |
|
8 |
FROM julia:${JLVERSION}-${BASE_IMAGE} AS jl
|
|
|
1 |
# This builds a dockerfile containing a working copy of PySR
|
2 |
# with all pre-requisites installed.
|
3 |
|
4 |
+
ARG JLVERSION=1.10.4
|
5 |
+
ARG PYVERSION=3.12.2
|
6 |
ARG BASE_IMAGE=bullseye
|
7 |
|
8 |
FROM julia:${JLVERSION}-${BASE_IMAGE} AS jl
|
pysr/export_jax.py
CHANGED
@@ -55,7 +55,9 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
-
elif issubclass(expr.func, sympy.Rational)
|
|
|
|
|
59 |
return f"{float(expr)}"
|
60 |
elif issubclass(expr.func, sympy.Integer):
|
61 |
return f"{int(expr)}"
|
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
+
elif issubclass(expr.func, sympy.Rational) or issubclass(
|
59 |
+
expr.func, sympy.NumberSymbol
|
60 |
+
):
|
61 |
return f"{float(expr)}"
|
62 |
elif issubclass(expr.func, sympy.Integer):
|
63 |
return f"{int(expr)}"
|
pysr/export_sympy.py
CHANGED
@@ -87,7 +87,12 @@ def pysr2sympy(
|
|
87 |
**sympy_mappings,
|
88 |
}
|
89 |
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
|
93 |
def assert_valid_sympy_symbol(var_name: str) -> None:
|
|
|
87 |
**sympy_mappings,
|
88 |
}
|
89 |
|
90 |
+
try:
|
91 |
+
return sympify(equation, locals=local_sympy_mappings, evaluate=False)
|
92 |
+
except TypeError as e:
|
93 |
+
if "got an unexpected keyword argument 'evaluate'" in str(e):
|
94 |
+
return sympify(equation, locals=local_sympy_mappings)
|
95 |
+
raise TypeError(f"Error processing equation '{equation}'") from e
|
96 |
|
97 |
|
98 |
def assert_valid_sympy_symbol(var_name: str) -> None:
|
pysr/export_torch.py
CHANGED
@@ -116,6 +116,11 @@ def _initialize_torch():
|
|
116 |
self._value = int(expr)
|
117 |
self._torch_func = lambda: self._value
|
118 |
self._args = ()
|
|
|
|
|
|
|
|
|
|
|
119 |
elif issubclass(expr.func, sympy.Symbol):
|
120 |
self._name = expr.name
|
121 |
self._torch_func = lambda value: value
|
|
|
116 |
self._value = int(expr)
|
117 |
self._torch_func = lambda: self._value
|
118 |
self._args = ()
|
119 |
+
elif issubclass(expr.func, sympy.NumberSymbol):
|
120 |
+
# Can get here from exp(1) or exact pi
|
121 |
+
self._value = float(expr)
|
122 |
+
self._torch_func = lambda: self._value
|
123 |
+
self._args = ()
|
124 |
elif issubclass(expr.func, sympy.Symbol):
|
125 |
self._name = expr.name
|
126 |
self._torch_func = lambda value: value
|
pysr/test/test.py
CHANGED
@@ -674,7 +674,7 @@ class TestMiscellaneous(unittest.TestCase):
|
|
674 |
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
675 |
|
676 |
y_predictions2 = model2.predict(X)
|
677 |
-
np.testing.
|
678 |
|
679 |
def test_scikit_learn_compatibility(self):
|
680 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
@@ -1039,7 +1039,7 @@ class TestLaTeXTable(unittest.TestCase):
|
|
1039 |
middle_part_2 = r"""
|
1040 |
$y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
|
1041 |
$y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
|
1042 |
-
$y_{1} = x_{0}
|
1043 |
"""
|
1044 |
true_latex_table_str = "\n\n".join(
|
1045 |
self.create_true_latex(part, include_score=True)
|
@@ -1092,7 +1092,7 @@ class TestLaTeXTable(unittest.TestCase):
|
|
1092 |
middle_part = r"""
|
1093 |
$y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
|
1094 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
1095 |
-
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}
|
1096 |
"""
|
1097 |
true_latex_table_str = (
|
1098 |
TRUE_PREAMBLE
|
|
|
674 |
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
675 |
|
676 |
y_predictions2 = model2.predict(X)
|
677 |
+
np.testing.assert_array_almost_equal(y_predictions, y_predictions2)
|
678 |
|
679 |
def test_scikit_learn_compatibility(self):
|
680 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
|
|
1039 |
middle_part_2 = r"""
|
1040 |
$y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
|
1041 |
$y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
|
1042 |
+
$y_{1} = x_{0} x_{0} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
|
1043 |
"""
|
1044 |
true_latex_table_str = "\n\n".join(
|
1045 |
self.create_true_latex(part, include_score=True)
|
|
|
1092 |
middle_part = r"""
|
1093 |
$y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
|
1094 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
1095 |
+
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0} x_{0} x_{0} + x_{0} x_{0} x_{0} x_{0} x_{0} + 3.20 x_{0} - 1.20 x_{1} + x_{1} x_{1} x_{1} + 5.20 \sin{\left(- 2.60 x_{0} + 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0} x_{0} x_{0} + 3.20 x_{0} - 1.20 x_{1} + x_{1} x_{1} x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
|
1096 |
"""
|
1097 |
true_latex_table_str = (
|
1098 |
TRUE_PREAMBLE
|
pysr/test/test_jax.py
CHANGED
@@ -5,27 +5,29 @@ import numpy as np
|
|
5 |
import pandas as pd
|
6 |
import sympy
|
7 |
|
|
|
8 |
from pysr import PySRRegressor, sympy2jax
|
9 |
|
10 |
|
11 |
class TestJAX(unittest.TestCase):
|
12 |
def setUp(self):
|
13 |
np.random.seed(0)
|
|
|
|
|
|
|
14 |
|
15 |
def test_sympy2jax(self):
|
16 |
-
from jax import numpy as jnp
|
17 |
from jax import random
|
18 |
|
19 |
x, y, z = sympy.symbols("x y z")
|
20 |
cosx = 1.0 * sympy.cos(x) + y
|
21 |
key = random.PRNGKey(0)
|
22 |
X = random.normal(key, (1000, 2))
|
23 |
-
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
24 |
f, params = sympy2jax(cosx, [x, y, z])
|
25 |
-
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
|
26 |
|
27 |
def test_pipeline_pandas(self):
|
28 |
-
from jax import numpy as jnp
|
29 |
|
30 |
X = pd.DataFrame(np.random.randn(100, 10))
|
31 |
y = np.ones(X.shape[0])
|
@@ -52,14 +54,12 @@ class TestJAX(unittest.TestCase):
|
|
52 |
jformat = model.jax()
|
53 |
|
54 |
np.testing.assert_almost_equal(
|
55 |
-
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
56 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
57 |
decimal=3,
|
58 |
)
|
59 |
|
60 |
def test_pipeline(self):
|
61 |
-
from jax import numpy as jnp
|
62 |
-
|
63 |
X = np.random.randn(100, 10)
|
64 |
y = np.ones(X.shape[0])
|
65 |
model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
|
@@ -81,11 +81,39 @@ class TestJAX(unittest.TestCase):
|
|
81 |
jformat = model.jax()
|
82 |
|
83 |
np.testing.assert_almost_equal(
|
84 |
-
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
85 |
np.square(np.cos(X[:, 1])), # Select feature 1
|
86 |
decimal=3,
|
87 |
)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def test_feature_selection_custom_operators(self):
|
90 |
rstate = np.random.RandomState(0)
|
91 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
|
|
5 |
import pandas as pd
|
6 |
import sympy
|
7 |
|
8 |
+
import pysr
|
9 |
from pysr import PySRRegressor, sympy2jax
|
10 |
|
11 |
|
12 |
class TestJAX(unittest.TestCase):
|
13 |
def setUp(self):
|
14 |
np.random.seed(0)
|
15 |
+
from jax import numpy as jnp
|
16 |
+
|
17 |
+
self.jnp = jnp
|
18 |
|
19 |
def test_sympy2jax(self):
|
|
|
20 |
from jax import random
|
21 |
|
22 |
x, y, z = sympy.symbols("x y z")
|
23 |
cosx = 1.0 * sympy.cos(x) + y
|
24 |
key = random.PRNGKey(0)
|
25 |
X = random.normal(key, (1000, 2))
|
26 |
+
true = 1.0 * self.jnp.cos(X[:, 0]) + X[:, 1]
|
27 |
f, params = sympy2jax(cosx, [x, y, z])
|
28 |
+
self.assertTrue(self.jnp.all(self.jnp.isclose(f(X, params), true)).item())
|
29 |
|
30 |
def test_pipeline_pandas(self):
|
|
|
31 |
|
32 |
X = pd.DataFrame(np.random.randn(100, 10))
|
33 |
y = np.ones(X.shape[0])
|
|
|
54 |
jformat = model.jax()
|
55 |
|
56 |
np.testing.assert_almost_equal(
|
57 |
+
np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
|
58 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
59 |
decimal=3,
|
60 |
)
|
61 |
|
62 |
def test_pipeline(self):
|
|
|
|
|
63 |
X = np.random.randn(100, 10)
|
64 |
y = np.ones(X.shape[0])
|
65 |
model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
|
|
|
81 |
jformat = model.jax()
|
82 |
|
83 |
np.testing.assert_almost_equal(
|
84 |
+
np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
|
85 |
np.square(np.cos(X[:, 1])), # Select feature 1
|
86 |
decimal=3,
|
87 |
)
|
88 |
|
89 |
+
def test_avoid_simplification(self):
|
90 |
+
ex = pysr.export_sympy.pysr2sympy(
|
91 |
+
"square(exp(sign(0.44796443))) + 1.5 * x1",
|
92 |
+
feature_names_in=["x1"],
|
93 |
+
extra_sympy_mappings={"square": lambda x: x**2},
|
94 |
+
)
|
95 |
+
f, params = pysr.export_jax.sympy2jax(ex, [sympy.symbols("x1")])
|
96 |
+
key = np.random.RandomState(0)
|
97 |
+
X = key.randn(10, 1)
|
98 |
+
np.testing.assert_almost_equal(
|
99 |
+
np.array(f(self.jnp.array(X), params)),
|
100 |
+
np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
|
101 |
+
decimal=3,
|
102 |
+
)
|
103 |
+
|
104 |
+
def test_issue_656(self):
|
105 |
+
import sympy
|
106 |
+
|
107 |
+
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
|
108 |
+
f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
|
109 |
+
key = np.random.RandomState(0)
|
110 |
+
X = key.randn(10, 1)
|
111 |
+
np.testing.assert_almost_equal(
|
112 |
+
np.array(f(self.jnp.array(X), params)),
|
113 |
+
np.exp(1) + X[:, 0],
|
114 |
+
decimal=3,
|
115 |
+
)
|
116 |
+
|
117 |
def test_feature_selection_custom_operators(self):
|
118 |
rstate = np.random.RandomState(0)
|
119 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
pysr/test/test_torch.py
CHANGED
@@ -4,6 +4,7 @@ import numpy as np
|
|
4 |
import pandas as pd
|
5 |
import sympy
|
6 |
|
|
|
7 |
from pysr import PySRRegressor, sympy2torch
|
8 |
|
9 |
|
@@ -153,10 +154,43 @@ class TestTorch(unittest.TestCase):
|
|
153 |
decimal=3,
|
154 |
)
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def test_feature_selection_custom_operators(self):
|
157 |
rstate = np.random.RandomState(0)
|
158 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
159 |
-
|
|
|
|
|
|
|
160 |
y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
|
161 |
|
162 |
model = PySRRegressor(
|
|
|
4 |
import pandas as pd
|
5 |
import sympy
|
6 |
|
7 |
+
import pysr
|
8 |
from pysr import PySRRegressor, sympy2torch
|
9 |
|
10 |
|
|
|
154 |
decimal=3,
|
155 |
)
|
156 |
|
157 |
+
def test_avoid_simplification(self):
|
158 |
+
# SymPy should not simplify without permission
|
159 |
+
torch = self.torch
|
160 |
+
ex = pysr.export_sympy.pysr2sympy(
|
161 |
+
"square(exp(sign(0.44796443))) + 1.5 * x1",
|
162 |
+
# ^ Normally this would become exp1 and require
|
163 |
+
# its own mapping
|
164 |
+
feature_names_in=["x1"],
|
165 |
+
extra_sympy_mappings={"square": lambda x: x**2},
|
166 |
+
)
|
167 |
+
m = pysr.export_torch.sympy2torch(ex, ["x1"])
|
168 |
+
rng = np.random.RandomState(0)
|
169 |
+
X = rng.randn(10, 1)
|
170 |
+
np.testing.assert_almost_equal(
|
171 |
+
m(torch.tensor(X)).detach().numpy(),
|
172 |
+
np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
|
173 |
+
decimal=3,
|
174 |
+
)
|
175 |
+
|
176 |
+
def test_issue_656(self):
|
177 |
+
# Should correctly map numeric symbols to floats
|
178 |
+
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
|
179 |
+
m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"])
|
180 |
+
X = np.random.randn(10, 1)
|
181 |
+
np.testing.assert_almost_equal(
|
182 |
+
m(self.torch.tensor(X)).detach().numpy(),
|
183 |
+
np.exp(1) + X[:, 0],
|
184 |
+
decimal=3,
|
185 |
+
)
|
186 |
+
|
187 |
def test_feature_selection_custom_operators(self):
|
188 |
rstate = np.random.RandomState(0)
|
189 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
190 |
+
|
191 |
+
def cos_approx(x):
|
192 |
+
return 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
|
193 |
+
|
194 |
y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
|
195 |
|
196 |
model = PySRRegressor(
|