Spaces:
Running
Running
MilesCranmer
commited on
Commit
·
ec8124e
1
Parent(s):
af14165
Get PySRRegressor working with multi-output
Browse files- pysr/sr.py +74 -30
- test/test.py +5 -9
pysr/sr.py
CHANGED
@@ -665,27 +665,46 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
665 |
if self.equations is None:
|
666 |
return "PySRRegressor.equations = None"
|
667 |
|
|
|
|
|
668 |
equations = self.equations
|
669 |
-
|
670 |
-
|
671 |
-
chosen_row = -1
|
672 |
-
elif self.model_selection == "best":
|
673 |
-
chosen_row = equations["score"].idxmax()
|
674 |
else:
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
)
|
686 |
-
|
687 |
-
|
688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
return output
|
690 |
|
691 |
def set_params(self, **params):
|
@@ -708,13 +727,19 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
708 |
|
709 |
def get_best(self):
|
710 |
if self.equations is None:
|
711 |
-
|
712 |
if self.model_selection == "accuracy":
|
|
|
|
|
713 |
return self.equations.iloc[-1]
|
714 |
elif self.model_selection == "best":
|
715 |
-
|
|
|
|
|
716 |
else:
|
717 |
-
raise NotImplementedError
|
|
|
|
|
718 |
|
719 |
def fit(self, X, y, weights=None, variable_names=None):
|
720 |
"""Search for equations to fit the dataset.
|
@@ -747,26 +772,40 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
747 |
|
748 |
def predict(self, X):
|
749 |
self.refresh()
|
750 |
-
|
751 |
-
|
|
|
|
|
752 |
|
753 |
def sympy(self):
|
754 |
self.refresh()
|
755 |
-
|
|
|
|
|
|
|
756 |
|
757 |
def latex(self):
|
758 |
self.refresh()
|
759 |
-
|
|
|
|
|
|
|
760 |
|
761 |
def jax(self):
|
762 |
self.set_params(output_jax_format=True)
|
763 |
self.refresh()
|
764 |
-
|
765 |
-
|
|
|
|
|
|
|
766 |
def pytorch(self):
|
767 |
self.set_params(output_torch_format=True)
|
768 |
self.refresh()
|
769 |
-
|
|
|
|
|
|
|
770 |
|
771 |
def _run(self, X, y, weights, variable_names):
|
772 |
global already_ran
|
@@ -846,11 +885,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
846 |
|
847 |
if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
|
848 |
self.multioutput = False
|
849 |
-
nout = 1
|
850 |
y = y.reshape(-1)
|
851 |
elif len(y.shape) == 2:
|
852 |
self.multioutput = True
|
853 |
-
nout = y.shape[1]
|
854 |
else:
|
855 |
raise NotImplementedError("y shape not supported!")
|
856 |
|
@@ -1182,3 +1221,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1182 |
if self.multioutput:
|
1183 |
return ret_outputs
|
1184 |
return ret_outputs[0]
|
|
|
|
|
|
|
|
|
|
|
|
665 |
if self.equations is None:
|
666 |
return "PySRRegressor.equations = None"
|
667 |
|
668 |
+
output = "PySRRegressor.equations = [\n"
|
669 |
+
|
670 |
equations = self.equations
|
671 |
+
if not isinstance(equations, list):
|
672 |
+
all_equations = [equations]
|
|
|
|
|
|
|
673 |
else:
|
674 |
+
all_equations = equations
|
675 |
+
|
676 |
+
for i, equations in enumerate(all_equations):
|
677 |
+
selected = ["" for _ in range(len(equations))]
|
678 |
+
if self.model_selection == "accuracy":
|
679 |
+
chosen_row = -1
|
680 |
+
elif self.model_selection == "best":
|
681 |
+
chosen_row = equations["score"].idxmax()
|
682 |
+
else:
|
683 |
+
raise NotImplementedError
|
684 |
+
selected[chosen_row] = ">>>>"
|
685 |
+
repr_equations = pd.DataFrame(
|
686 |
+
dict(
|
687 |
+
pick=selected,
|
688 |
+
score=equations["score"],
|
689 |
+
equation=equations["equation"],
|
690 |
+
loss=equations["loss"],
|
691 |
+
complexity=equations["complexity"],
|
692 |
+
)
|
693 |
)
|
694 |
+
|
695 |
+
if len(all_equations) > 1:
|
696 |
+
output += "[\n"
|
697 |
+
|
698 |
+
for line in repr_equations.__repr__().split("\n"):
|
699 |
+
output += "\t" + line + "\n"
|
700 |
+
|
701 |
+
if len(all_equations) > 1:
|
702 |
+
output += "]"
|
703 |
+
|
704 |
+
if i < len(all_equations) - 1:
|
705 |
+
output += ", "
|
706 |
+
|
707 |
+
output += "]"
|
708 |
return output
|
709 |
|
710 |
def set_params(self, **params):
|
|
|
727 |
|
728 |
def get_best(self):
|
729 |
if self.equations is None:
|
730 |
+
raise ValueError("No equations have been generated yet.")
|
731 |
if self.model_selection == "accuracy":
|
732 |
+
if isinstance(self.equations, list):
|
733 |
+
return [eq.iloc[-1] for eq in self.equations]
|
734 |
return self.equations.iloc[-1]
|
735 |
elif self.model_selection == "best":
|
736 |
+
if isinstance(self.equations, list):
|
737 |
+
return [eq.iloc[eq["score"].idxmax()] for eq in self.equations]
|
738 |
+
return self.equations.iloc[self.equations["score"].idxmax()]
|
739 |
else:
|
740 |
+
raise NotImplementedError(
|
741 |
+
f"{self.model_selection} is not a valid model selection strategy."
|
742 |
+
)
|
743 |
|
744 |
def fit(self, X, y, weights=None, variable_names=None):
|
745 |
"""Search for equations to fit the dataset.
|
|
|
772 |
|
773 |
def predict(self, X):
|
774 |
self.refresh()
|
775 |
+
best = self.get_best()
|
776 |
+
if self.multioutput:
|
777 |
+
return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
|
778 |
+
return best["lambda_format"](X)
|
779 |
|
780 |
def sympy(self):
|
781 |
self.refresh()
|
782 |
+
best = self.get_best()
|
783 |
+
if self.multioutput:
|
784 |
+
return [eq["sympy_format"] for eq in best]
|
785 |
+
return best["sympy_format"]
|
786 |
|
787 |
def latex(self):
|
788 |
self.refresh()
|
789 |
+
sympy_representation = self.sympy()
|
790 |
+
if self.multioutput:
|
791 |
+
return [sympy.latex(s) for s in sympy_representation]
|
792 |
+
return sympy.latex(sympy_representation)
|
793 |
|
794 |
def jax(self):
|
795 |
self.set_params(output_jax_format=True)
|
796 |
self.refresh()
|
797 |
+
best = self.get_best()
|
798 |
+
if self.multioutput:
|
799 |
+
return [eq["jax_format"] for eq in best]
|
800 |
+
return best["jax_format"]
|
801 |
+
|
802 |
def pytorch(self):
|
803 |
self.set_params(output_torch_format=True)
|
804 |
self.refresh()
|
805 |
+
best = self.get_best()
|
806 |
+
if self.multioutput:
|
807 |
+
return [eq["torch_format"] for eq in best]
|
808 |
+
return best["torch_format"]
|
809 |
|
810 |
def _run(self, X, y, weights, variable_names):
|
811 |
global already_ran
|
|
|
885 |
|
886 |
if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
|
887 |
self.multioutput = False
|
888 |
+
self.nout = 1
|
889 |
y = y.reshape(-1)
|
890 |
elif len(y.shape) == 2:
|
891 |
self.multioutput = True
|
892 |
+
self.nout = y.shape[1]
|
893 |
else:
|
894 |
raise NotImplementedError("y shape not supported!")
|
895 |
|
|
|
1221 |
if self.multioutput:
|
1222 |
return ret_outputs
|
1223 |
return ret_outputs[0]
|
1224 |
+
|
1225 |
+
def score(self, X, y):
|
1226 |
+
del X
|
1227 |
+
del y
|
1228 |
+
raise NotImplementedError
|
test/test.py
CHANGED
@@ -171,13 +171,13 @@ class TestBest(unittest.TestCase):
|
|
171 |
def setUp(self):
|
172 |
equations = pd.DataFrame(
|
173 |
{
|
174 |
-
"
|
175 |
-
"
|
176 |
-
"
|
177 |
}
|
178 |
)
|
179 |
|
180 |
-
equations["
|
181 |
"equation_file.csv.bkup", sep="|"
|
182 |
)
|
183 |
|
@@ -195,19 +195,15 @@ class TestBest(unittest.TestCase):
|
|
195 |
self.model.equations = self.equations
|
196 |
|
197 |
def test_best(self):
|
198 |
-
self.assertEqual(best(self.equations), sympy.cos(sympy.Symbol("x0")) ** 2)
|
199 |
-
self.assertEqual(best(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
200 |
self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
201 |
|
202 |
def test_best_tex(self):
|
203 |
-
self.assertEqual(best_tex(self.equations), "\\cos^{2}{\\left(x_{0} \\right)}")
|
204 |
-
self.assertEqual(best_tex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
205 |
self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
206 |
|
207 |
def test_best_lambda(self):
|
208 |
X = np.random.randn(10, 2)
|
209 |
y = np.cos(X[:, 0]) ** 2
|
210 |
-
for f in [
|
211 |
np.testing.assert_almost_equal(f(X), y, decimal=4)
|
212 |
|
213 |
|
|
|
171 |
def setUp(self):
|
172 |
equations = pd.DataFrame(
|
173 |
{
|
174 |
+
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
175 |
+
"loss": [1.0, 0.1, 1e-5],
|
176 |
+
"complexity": [1, 2, 3],
|
177 |
}
|
178 |
)
|
179 |
|
180 |
+
equations["complexity loss equation".split(" ")].to_csv(
|
181 |
"equation_file.csv.bkup", sep="|"
|
182 |
)
|
183 |
|
|
|
195 |
self.model.equations = self.equations
|
196 |
|
197 |
def test_best(self):
|
|
|
|
|
198 |
self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
199 |
|
200 |
def test_best_tex(self):
|
|
|
|
|
201 |
self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
202 |
|
203 |
def test_best_lambda(self):
|
204 |
X = np.random.randn(10, 2)
|
205 |
y = np.cos(X[:, 0]) ** 2
|
206 |
+
for f in [self.model.predict, self.equations.iloc[-1]['lambda_format']]:
|
207 |
np.testing.assert_almost_equal(f(X), y, decimal=4)
|
208 |
|
209 |
|