MilesCranmer commited on
Commit
ec8124e
·
1 Parent(s): af14165

Get PySRRegressor working with multi-output

Browse files
Files changed (2) hide show
  1. pysr/sr.py +74 -30
  2. test/test.py +5 -9
pysr/sr.py CHANGED
@@ -665,27 +665,46 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
665
  if self.equations is None:
666
  return "PySRRegressor.equations = None"
667
 
 
 
668
  equations = self.equations
669
- selected = ["" for _ in range(len(equations))]
670
- if self.model_selection == "accuracy":
671
- chosen_row = -1
672
- elif self.model_selection == "best":
673
- chosen_row = equations["score"].idxmax()
674
  else:
675
- raise NotImplementedError
676
- selected[chosen_row] = ">>>>"
677
- output = "PySRRegressor.equations = [\n"
678
- repr_equations = pd.DataFrame(
679
- dict(
680
- pick=selected,
681
- score=equations["score"],
682
- equation=equations["equation"],
683
- loss=equations["loss"],
684
- complexity=equations["complexity"],
 
 
 
 
 
 
 
 
 
685
  )
686
- )
687
- output += repr_equations.__repr__()
688
- output += "\n]"
 
 
 
 
 
 
 
 
 
 
 
689
  return output
690
 
691
  def set_params(self, **params):
@@ -708,13 +727,19 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
708
 
709
  def get_best(self):
710
  if self.equations is None:
711
- return 0.0
712
  if self.model_selection == "accuracy":
 
 
713
  return self.equations.iloc[-1]
714
  elif self.model_selection == "best":
715
- return best_row(self.equations)
 
 
716
  else:
717
- raise NotImplementedError
 
 
718
 
719
  def fit(self, X, y, weights=None, variable_names=None):
720
  """Search for equations to fit the dataset.
@@ -747,26 +772,40 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
747
 
748
  def predict(self, X):
749
  self.refresh()
750
- np_format = self.get_best()["lambda_format"]
751
- return np_format(X)
 
 
752
 
753
  def sympy(self):
754
  self.refresh()
755
- return self.get_best()["sympy_format"]
 
 
 
756
 
757
  def latex(self):
758
  self.refresh()
759
- return self.sympy().simplify()
 
 
 
760
 
761
  def jax(self):
762
  self.set_params(output_jax_format=True)
763
  self.refresh()
764
- return self.get_best()["jax_format"]
765
-
 
 
 
766
  def pytorch(self):
767
  self.set_params(output_torch_format=True)
768
  self.refresh()
769
- return self.get_best()["torch_format"]
 
 
 
770
 
771
  def _run(self, X, y, weights, variable_names):
772
  global already_ran
@@ -846,11 +885,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
846
 
847
  if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
848
  self.multioutput = False
849
- nout = 1
850
  y = y.reshape(-1)
851
  elif len(y.shape) == 2:
852
  self.multioutput = True
853
- nout = y.shape[1]
854
  else:
855
  raise NotImplementedError("y shape not supported!")
856
 
@@ -1182,3 +1221,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1182
  if self.multioutput:
1183
  return ret_outputs
1184
  return ret_outputs[0]
 
 
 
 
 
 
665
  if self.equations is None:
666
  return "PySRRegressor.equations = None"
667
 
668
+ output = "PySRRegressor.equations = [\n"
669
+
670
  equations = self.equations
671
+ if not isinstance(equations, list):
672
+ all_equations = [equations]
 
 
 
673
  else:
674
+ all_equations = equations
675
+
676
+ for i, equations in enumerate(all_equations):
677
+ selected = ["" for _ in range(len(equations))]
678
+ if self.model_selection == "accuracy":
679
+ chosen_row = -1
680
+ elif self.model_selection == "best":
681
+ chosen_row = equations["score"].idxmax()
682
+ else:
683
+ raise NotImplementedError
684
+ selected[chosen_row] = ">>>>"
685
+ repr_equations = pd.DataFrame(
686
+ dict(
687
+ pick=selected,
688
+ score=equations["score"],
689
+ equation=equations["equation"],
690
+ loss=equations["loss"],
691
+ complexity=equations["complexity"],
692
+ )
693
  )
694
+
695
+ if len(all_equations) > 1:
696
+ output += "[\n"
697
+
698
+ for line in repr_equations.__repr__().split("\n"):
699
+ output += "\t" + line + "\n"
700
+
701
+ if len(all_equations) > 1:
702
+ output += "]"
703
+
704
+ if i < len(all_equations) - 1:
705
+ output += ", "
706
+
707
+ output += "]"
708
  return output
709
 
710
  def set_params(self, **params):
 
727
 
728
  def get_best(self):
729
  if self.equations is None:
730
+ raise ValueError("No equations have been generated yet.")
731
  if self.model_selection == "accuracy":
732
+ if isinstance(self.equations, list):
733
+ return [eq.iloc[-1] for eq in self.equations]
734
  return self.equations.iloc[-1]
735
  elif self.model_selection == "best":
736
+ if isinstance(self.equations, list):
737
+ return [eq.iloc[eq["score"].idxmax()] for eq in self.equations]
738
+ return self.equations.iloc[self.equations["score"].idxmax()]
739
  else:
740
+ raise NotImplementedError(
741
+ f"{self.model_selection} is not a valid model selection strategy."
742
+ )
743
 
744
  def fit(self, X, y, weights=None, variable_names=None):
745
  """Search for equations to fit the dataset.
 
772
 
773
  def predict(self, X):
774
  self.refresh()
775
+ best = self.get_best()
776
+ if self.multioutput:
777
+ return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
778
+ return best["lambda_format"](X)
779
 
780
  def sympy(self):
781
  self.refresh()
782
+ best = self.get_best()
783
+ if self.multioutput:
784
+ return [eq["sympy_format"] for eq in best]
785
+ return best["sympy_format"]
786
 
787
  def latex(self):
788
  self.refresh()
789
+ sympy_representation = self.sympy()
790
+ if self.multioutput:
791
+ return [sympy.latex(s) for s in sympy_representation]
792
+ return sympy.latex(sympy_representation)
793
 
794
  def jax(self):
795
  self.set_params(output_jax_format=True)
796
  self.refresh()
797
+ best = self.get_best()
798
+ if self.multioutput:
799
+ return [eq["jax_format"] for eq in best]
800
+ return best["jax_format"]
801
+
802
  def pytorch(self):
803
  self.set_params(output_torch_format=True)
804
  self.refresh()
805
+ best = self.get_best()
806
+ if self.multioutput:
807
+ return [eq["torch_format"] for eq in best]
808
+ return best["torch_format"]
809
 
810
  def _run(self, X, y, weights, variable_names):
811
  global already_ran
 
885
 
886
  if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
887
  self.multioutput = False
888
+ self.nout = 1
889
  y = y.reshape(-1)
890
  elif len(y.shape) == 2:
891
  self.multioutput = True
892
+ self.nout = y.shape[1]
893
  else:
894
  raise NotImplementedError("y shape not supported!")
895
 
 
1221
  if self.multioutput:
1222
  return ret_outputs
1223
  return ret_outputs[0]
1224
+
1225
+ def score(self, X, y):
1226
+ del X
1227
+ del y
1228
+ raise NotImplementedError
test/test.py CHANGED
@@ -171,13 +171,13 @@ class TestBest(unittest.TestCase):
171
  def setUp(self):
172
  equations = pd.DataFrame(
173
  {
174
- "Equation": ["1.0", "cos(x0)", "square(cos(x0))"],
175
- "MSE": [1.0, 0.1, 1e-5],
176
- "Complexity": [1, 2, 3],
177
  }
178
  )
179
 
180
- equations["Complexity MSE Equation".split(" ")].to_csv(
181
  "equation_file.csv.bkup", sep="|"
182
  )
183
 
@@ -195,19 +195,15 @@ class TestBest(unittest.TestCase):
195
  self.model.equations = self.equations
196
 
197
  def test_best(self):
198
- self.assertEqual(best(self.equations), sympy.cos(sympy.Symbol("x0")) ** 2)
199
- self.assertEqual(best(), sympy.cos(sympy.Symbol("x0")) ** 2)
200
  self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
201
 
202
  def test_best_tex(self):
203
- self.assertEqual(best_tex(self.equations), "\\cos^{2}{\\left(x_{0} \\right)}")
204
- self.assertEqual(best_tex(), "\\cos^{2}{\\left(x_{0} \\right)}")
205
  self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
206
 
207
  def test_best_lambda(self):
208
  X = np.random.randn(10, 2)
209
  y = np.cos(X[:, 0]) ** 2
210
- for f in [best_callable(), best_callable(self.equations)]:
211
  np.testing.assert_almost_equal(f(X), y, decimal=4)
212
 
213
 
 
171
  def setUp(self):
172
  equations = pd.DataFrame(
173
  {
174
+ "equation": ["1.0", "cos(x0)", "square(cos(x0))"],
175
+ "loss": [1.0, 0.1, 1e-5],
176
+ "complexity": [1, 2, 3],
177
  }
178
  )
179
 
180
+ equations["complexity loss equation".split(" ")].to_csv(
181
  "equation_file.csv.bkup", sep="|"
182
  )
183
 
 
195
  self.model.equations = self.equations
196
 
197
  def test_best(self):
 
 
198
  self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
199
 
200
  def test_best_tex(self):
 
 
201
  self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
202
 
203
  def test_best_lambda(self):
204
  X = np.random.randn(10, 2)
205
  y = np.cos(X[:, 0]) ** 2
206
+ for f in [self.model.predict, self.equations.iloc[-1]['lambda_format']]:
207
  np.testing.assert_almost_equal(f(X), y, decimal=4)
208
 
209