Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
43bc86a
1
Parent(s):
406ae3e
Process params in __init__ instead of fit
Browse files- pysr/sr.py +6 -13
pysr/sr.py
CHANGED
@@ -754,6 +754,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
754 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
755 |
)
|
756 |
|
|
|
|
|
757 |
def __repr__(self):
|
758 |
"""
|
759 |
Prints all current equations fitted by the model.
|
@@ -858,20 +860,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
858 |
f"{self.model_selection} is not a valid model selection strategy."
|
859 |
)
|
860 |
|
861 |
-
def
|
862 |
"""
|
863 |
Perform validation on the parameters defined in init for the
|
864 |
-
dataset specified in :term`fit
|
865 |
-
|
866 |
-
|
867 |
-
----------
|
868 |
-
n_samples : int
|
869 |
-
Number of samples in the dataset to be fitted.
|
870 |
-
|
871 |
-
Returns
|
872 |
-
-------
|
873 |
-
self : object
|
874 |
-
Reference to `self` with validated parameters.
|
875 |
|
876 |
Raises
|
877 |
------
|
@@ -1406,7 +1400,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1406 |
self.raw_julia_state_ = None
|
1407 |
|
1408 |
# Parameter input validation (for parameters defined in __init__)
|
1409 |
-
self._validate_params(n_samples=X.shape[0])
|
1410 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
1411 |
X, y, Xresampled, variable_names
|
1412 |
)
|
|
|
754 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
755 |
)
|
756 |
|
757 |
+
self._process_params()
|
758 |
+
|
759 |
def __repr__(self):
|
760 |
"""
|
761 |
Prints all current equations fitted by the model.
|
|
|
860 |
f"{self.model_selection} is not a valid model selection strategy."
|
861 |
)
|
862 |
|
863 |
+
def _process_params(self):
|
864 |
"""
|
865 |
Perform validation on the parameters defined in init for the
|
866 |
+
dataset specified in :term`fit`, and update them if necessary.
|
867 |
+
For example, this will change :param`binary_operators`
|
868 |
+
into `["+", "-", "*", "/"]` if `binary_operators` is `None`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
|
870 |
Raises
|
871 |
------
|
|
|
1400 |
self.raw_julia_state_ = None
|
1401 |
|
1402 |
# Parameter input validation (for parameters defined in __init__)
|
|
|
1403 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
1404 |
X, y, Xresampled, variable_names
|
1405 |
)
|