Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
aef1f27
1
Parent(s):
e63cf2d
Add warning if training on pandas dataframe then torch
Browse files- pysr/sr.py +15 -0
pysr/sr.py
CHANGED
@@ -796,6 +796,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
796 |
return sympy.latex(sympy_representation)
|
797 |
|
798 |
def jax(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
self.set_params(output_jax_format=True)
|
800 |
self.refresh()
|
801 |
best = self.get_best()
|
@@ -804,6 +810,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
804 |
return best["jax_format"]
|
805 |
|
806 |
def pytorch(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
807 |
self.set_params(output_torch_format=True)
|
808 |
self.refresh()
|
809 |
best = self.get_best()
|
@@ -854,6 +866,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
854 |
|
855 |
variable_names = list(X.columns)
|
856 |
X = np.array(X)
|
|
|
|
|
|
|
857 |
|
858 |
if len(X.shape) == 1:
|
859 |
X = X[:, None]
|
|
|
796 |
return sympy.latex(sympy_representation)
|
797 |
|
798 |
def jax(self):
|
799 |
+
if self.using_pandas:
|
800 |
+
warnings.warn(
|
801 |
+
"PySR's JAX modules are not set up to work with a "
|
802 |
+
"model that was trained on pandas dataframes. "
|
803 |
+
"Train on an array instead to ensure everything works as planned."
|
804 |
+
)
|
805 |
self.set_params(output_jax_format=True)
|
806 |
self.refresh()
|
807 |
best = self.get_best()
|
|
|
810 |
return best["jax_format"]
|
811 |
|
812 |
def pytorch(self):
|
813 |
+
if self.using_pandas:
|
814 |
+
warnings.warn(
|
815 |
+
"PySR's PyTorch modules are not set up to work with a "
|
816 |
+
"model that was trained on pandas dataframes. "
|
817 |
+
"Train on an array instead to ensure everything works as planned."
|
818 |
+
)
|
819 |
self.set_params(output_torch_format=True)
|
820 |
self.refresh()
|
821 |
best = self.get_best()
|
|
|
866 |
|
867 |
variable_names = list(X.columns)
|
868 |
X = np.array(X)
|
869 |
+
self.using_pandas = True
|
870 |
+
else:
|
871 |
+
self.using_pandas = False
|
872 |
|
873 |
if len(X.shape) == 1:
|
874 |
X = X[:, None]
|