Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
e0c68fc
1
Parent(s):
a29e818
Propagate and check torch/jax mappings
Browse files- pysr/export_jax.py +14 -5
- pysr/sr.py +26 -2
pysr/export_jax.py
CHANGED
@@ -51,7 +51,7 @@ _jnp_func_lookup = {
|
|
51 |
}
|
52 |
|
53 |
|
54 |
-
def sympy2jaxtext(expr, parameters, symbols_in):
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
@@ -61,8 +61,15 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
61 |
return (
|
62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
63 |
)
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if _func == MUL:
|
67 |
return " * ".join(["(" + arg + ")" for arg in args])
|
68 |
if _func == ADD:
|
@@ -92,7 +99,7 @@ def _initialize_jax():
|
|
92 |
jsp = _jsp
|
93 |
|
94 |
|
95 |
-
def sympy2jax(expression, symbols_in, selection=None):
|
96 |
"""Returns a function f and its parameters;
|
97 |
the function takes an input matrix, and a list of arguments:
|
98 |
f(X, parameters)
|
@@ -170,7 +177,9 @@ def sympy2jax(expression, symbols_in, selection=None):
|
|
170 |
global jsp
|
171 |
|
172 |
parameters = []
|
173 |
-
functional_form_text = sympy2jaxtext(
|
|
|
|
|
174 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
175 |
text = f"def {hash_string}(X, parameters):\n"
|
176 |
if selection is not None:
|
|
|
51 |
}
|
52 |
|
53 |
|
54 |
+
def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
|
|
61 |
return (
|
62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
63 |
)
|
64 |
+
if extra_jax_mappings is None:
|
65 |
+
extra_jax_mappings = {}
|
66 |
+
_func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
|
67 |
+
args = [
|
68 |
+
sympy2jaxtext(
|
69 |
+
arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
|
70 |
+
)
|
71 |
+
for arg in expr.args
|
72 |
+
]
|
73 |
if _func == MUL:
|
74 |
return " * ".join(["(" + arg + ")" for arg in args])
|
75 |
if _func == ADD:
|
|
|
99 |
jsp = _jsp
|
100 |
|
101 |
|
102 |
+
def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
|
103 |
"""Returns a function f and its parameters;
|
104 |
the function takes an input matrix, and a list of arguments:
|
105 |
f(X, parameters)
|
|
|
177 |
global jsp
|
178 |
|
179 |
parameters = []
|
180 |
+
functional_form_text = sympy2jaxtext(
|
181 |
+
expression, parameters, symbols_in, extra_jax_mappings
|
182 |
+
)
|
183 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
184 |
text = f"def {hash_string}(X, parameters):\n"
|
185 |
if selection is not None:
|
pysr/sr.py
CHANGED
@@ -289,6 +289,20 @@ def pysr(
|
|
289 |
if len(variable_names) == 0:
|
290 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
use_custom_variable_names = len(variable_names) != 0
|
293 |
|
294 |
_check_assertions(
|
@@ -996,14 +1010,24 @@ def get_hof(
|
|
996 |
if output_jax_format:
|
997 |
from .export_jax import sympy2jax
|
998 |
|
999 |
-
func, params = sympy2jax(
|
|
|
|
|
|
|
|
|
|
|
1000 |
jax_format.append({"callable": func, "parameters": params})
|
1001 |
|
1002 |
# Torch:
|
1003 |
if output_torch_format:
|
1004 |
from .export_torch import sympy2torch
|
1005 |
|
1006 |
-
module = sympy2torch(
|
|
|
|
|
|
|
|
|
|
|
1007 |
torch_format.append(module)
|
1008 |
|
1009 |
curMSE = output.loc[i, "MSE"]
|
|
|
289 |
if len(variable_names) == 0:
|
290 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
291 |
|
292 |
+
if extra_jax_mappings is not None:
|
293 |
+
for key, value in extra_jax_mappings:
|
294 |
+
if not isinstance(value, str):
|
295 |
+
raise NotImplementedError(
|
296 |
+
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
297 |
+
)
|
298 |
+
|
299 |
+
if extra_torch_mappings is not None:
|
300 |
+
for key, value in extra_jax_mappings:
|
301 |
+
if not callable(value):
|
302 |
+
raise NotImplementedError(
|
303 |
+
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
304 |
+
)
|
305 |
+
|
306 |
use_custom_variable_names = len(variable_names) != 0
|
307 |
|
308 |
_check_assertions(
|
|
|
1010 |
if output_jax_format:
|
1011 |
from .export_jax import sympy2jax
|
1012 |
|
1013 |
+
func, params = sympy2jax(
|
1014 |
+
eqn,
|
1015 |
+
sympy_symbols,
|
1016 |
+
selection=selection,
|
1017 |
+
extra_jax_mappings=extra_jax_mappings,
|
1018 |
+
)
|
1019 |
jax_format.append({"callable": func, "parameters": params})
|
1020 |
|
1021 |
# Torch:
|
1022 |
if output_torch_format:
|
1023 |
from .export_torch import sympy2torch
|
1024 |
|
1025 |
+
module = sympy2torch(
|
1026 |
+
eqn,
|
1027 |
+
sympy_symbols,
|
1028 |
+
selection=selection,
|
1029 |
+
extra_torch_mappings=extra_torch_mappings,
|
1030 |
+
)
|
1031 |
torch_format.append(module)
|
1032 |
|
1033 |
curMSE = output.loc[i, "MSE"]
|