MilesCranmer commited on
Commit
42cd6af
1 Parent(s): 4c39e04

Add jax, pytorch, sympy output from Regressor

Browse files
Files changed (2) hide show
  1. docs/options.md +15 -11
  2. pysr/sklearn.py +13 -5
docs/options.md CHANGED
@@ -198,17 +198,18 @@ over `X` (as a PyTorch tensor). This is differentiable, and the
198
  parameters of this PyTorch module correspond to the learned parameters
199
  in the equation, and are trainable.
200
  ```python
201
- output = model.pytorch()
202
- output['callable'](X)
203
  ```
 
204
 
205
- For JAX, you can equivalently set the argument `output_jax_format=True`.
206
  This will return a dictionary containing a `'callable'` (a JAX function),
207
  and `'parameters'` (a list of parameters in the equation).
208
  You can execute this function with:
209
  ```python
210
- output = model.jax()
211
- output['callable'](X, output['parameters'])
212
  ```
213
  Since the parameter list is a jax array, this therefore lets you also
214
  train the parameters within JAX (and is differentiable).
@@ -226,26 +227,29 @@ Here are some additional examples:
226
 
227
  abs(x-y) loss
228
  ```python
229
- pysr(..., loss="f(x, y) = abs(x - y)^1.5")
230
  ```
231
  Note that the function name doesn't matter:
232
  ```python
233
- pysr(..., loss="loss(x, y) = abs(x * y)")
234
  ```
235
  With weights:
236
  ```python
237
- pysr(..., weights=weights, loss="myloss(x, y, w) = w * abs(x - y)")
 
238
  ```
239
  Weights can be used in arbitrary ways:
240
  ```python
241
- pysr(..., weights=weights, loss="myloss(x, y, w) = abs(x - y)^2/w^2")
 
242
  ```
243
  Built-in loss (faster) (see [losses](https://astroautomata.com/SymbolicRegression.jl/dev/losses/)).
244
  This one computes the L3 norm:
245
  ```python
246
- pysr(..., loss="LPDistLoss{3}()")
247
  ```
248
  Can also uses these losses for weighted (weighted-average):
249
  ```python
250
- pysr(..., weights=weights, loss="LPDistLoss{3}()")
 
251
  ```
 
198
  parameters of this PyTorch module correspond to the learned parameters
199
  in the equation, and are trainable.
200
  ```python
201
+ torch_model = model.pytorch()
202
+ torch_model(X)
203
  ```
204
+ **Warning: If you are using custom operators, you must define `extra_torch_mappings` or `extra_jax_mappings` (both are `dict` of callables) to provide an equivalent definition of the functions.** (At any time you can set these parameters or any others with `model.set_params`.)
205
 
206
+ For JAX, you can equivalently call `model.jax()`
207
  This will return a dictionary containing a `'callable'` (a JAX function),
208
  and `'parameters'` (a list of parameters in the equation).
209
  You can execute this function with:
210
  ```python
211
+ jax_model = model.jax()
212
+ jax_model['callable'](X, jax_model['parameters'])
213
  ```
214
  Since the parameter list is a jax array, this therefore lets you also
215
  train the parameters within JAX (and is differentiable).
 
227
 
228
  abs(x-y) loss
229
  ```python
230
+ PySRRegressor(..., loss="f(x, y) = abs(x - y)^1.5")
231
  ```
232
  Note that the function name doesn't matter:
233
  ```python
234
+ PySRRegressor(..., loss="loss(x, y) = abs(x * y)")
235
  ```
236
  With weights:
237
  ```python
238
+ model = PySRRegressor(..., loss="myloss(x, y, w) = w * abs(x - y)")
239
+ model.fit(..., weights=weights)
240
  ```
241
  Weights can be used in arbitrary ways:
242
  ```python
243
+ model = PySRRegressor(..., weights=weights, loss="myloss(x, y, w) = abs(x - y)^2/w^2")
244
+ model.fit(..., weights=weights)
245
  ```
246
  Built-in loss (faster) (see [losses](https://astroautomata.com/SymbolicRegression.jl/dev/losses/)).
247
  This one computes the L3 norm:
248
  ```python
249
+ PySRRegressor(..., loss="LPDistLoss{3}()")
250
  ```
251
  Can also uses these losses for weighted (weighted-average):
252
  ```python
253
+ model = PySRRegressor(..., weights=weights, loss="LPDistLoss{3}()")
254
+ model.fit(..., weights=weights)
255
  ```
pysr/sklearn.py CHANGED
@@ -1,4 +1,4 @@
1
- from pysr import pysr, best_row
2
  from sklearn.base import BaseEstimator, RegressorMixin
3
  import inspect
4
  import pandas as pd
@@ -94,14 +94,22 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
94
  return self
95
 
96
  def predict(self, X):
97
- equation_row = self.get_best()
98
- np_format = equation_row["lambda_format"]
99
-
100
  return np_format(X)
101
 
 
 
102
 
103
- # Add the docs from pysr() to PySRRegressor():
 
 
 
 
 
 
104
 
 
 
105
  _pysr_docstring_split = []
106
  _start_recording = False
107
  for line in inspect.getdoc(pysr).split("\n"):
 
1
+ from pysr import pysr, best_row, get_hof
2
  from sklearn.base import BaseEstimator, RegressorMixin
3
  import inspect
4
  import pandas as pd
 
94
  return self
95
 
96
  def predict(self, X):
97
+ np_format = self.get_best()["lambda_format"]
 
 
98
  return np_format(X)
99
 
100
+ def sympy(self):
101
+ return self.get_best()["sympy_format"]
102
 
103
+ def jax(self):
104
+ self.equations = get_hof(output_jax_format=True)
105
+ return self.get_best()["jax_format"]
106
+
107
+ def pytorch(self):
108
+ self.equations = get_hof(output_torch_format=True)
109
+ return self.get_best()["torch_format"]
110
 
111
+
112
+ # Add the docs from pysr() to PySRRegressor():
113
  _pysr_docstring_split = []
114
  _start_recording = False
115
  for line in inspect.getdoc(pysr).split("\n"):