MilesCranmer commited on
Commit
aef1f27
1 Parent(s): e63cf2d

Add warning if training on pandas dataframe then torch

Browse files
Files changed (1) hide show
  1. pysr/sr.py +15 -0
pysr/sr.py CHANGED
@@ -796,6 +796,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
796
  return sympy.latex(sympy_representation)
797
 
798
  def jax(self):
 
 
 
 
 
 
799
  self.set_params(output_jax_format=True)
800
  self.refresh()
801
  best = self.get_best()
@@ -804,6 +810,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
804
  return best["jax_format"]
805
 
806
  def pytorch(self):
 
 
 
 
 
 
807
  self.set_params(output_torch_format=True)
808
  self.refresh()
809
  best = self.get_best()
@@ -854,6 +866,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
854
 
855
  variable_names = list(X.columns)
856
  X = np.array(X)
 
 
 
857
 
858
  if len(X.shape) == 1:
859
  X = X[:, None]
 
796
  return sympy.latex(sympy_representation)
797
 
798
  def jax(self):
799
+ if self.using_pandas:
800
+ warnings.warn(
801
+ "PySR's JAX modules are not set up to work with a "
802
+ "model that was trained on pandas dataframes. "
803
+ "Train on an array instead to ensure everything works as planned."
804
+ )
805
  self.set_params(output_jax_format=True)
806
  self.refresh()
807
  best = self.get_best()
 
810
  return best["jax_format"]
811
 
812
  def pytorch(self):
813
+ if self.using_pandas:
814
+ warnings.warn(
815
+ "PySR's PyTorch modules are not set up to work with a "
816
+ "model that was trained on pandas dataframes. "
817
+ "Train on an array instead to ensure everything works as planned."
818
+ )
819
  self.set_params(output_torch_format=True)
820
  self.refresh()
821
  best = self.get_best()
 
866
 
867
  variable_names = list(X.columns)
868
  X = np.array(X)
869
+ self.using_pandas = True
870
+ else:
871
+ self.using_pandas = False
872
 
873
  if len(X.shape) == 1:
874
  X = X[:, None]