MilesCranmer commited on
Commit
43bc86a
1 Parent(s): 406ae3e

Process params in __init__ instead of fit

Browse files
Files changed (1) hide show
  1. pysr/sr.py +6 -13
pysr/sr.py CHANGED
@@ -754,6 +754,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
754
  f"{k} is not a valid keyword argument for PySRRegressor"
755
  )
756
 
 
 
757
  def __repr__(self):
758
  """
759
  Prints all current equations fitted by the model.
@@ -858,20 +860,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
858
  f"{self.model_selection} is not a valid model selection strategy."
859
  )
860
 
861
- def _validate_params(self, n_samples):
862
  """
863
  Perform validation on the parameters defined in init for the
864
- dataset specified in :term`fit`.
865
-
866
- Parameters
867
- ----------
868
- n_samples : int
869
- Number of samples in the dataset to be fitted.
870
-
871
- Returns
872
- -------
873
- self : object
874
- Reference to `self` with validated parameters.
875
 
876
  Raises
877
  ------
@@ -1406,7 +1400,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1406
  self.raw_julia_state_ = None
1407
 
1408
  # Parameter input validation (for parameters defined in __init__)
1409
- self._validate_params(n_samples=X.shape[0])
1410
  X, y, Xresampled, variable_names = self._validate_fit_params(
1411
  X, y, Xresampled, variable_names
1412
  )
 
754
  f"{k} is not a valid keyword argument for PySRRegressor"
755
  )
756
 
757
+ self._process_params()
758
+
759
  def __repr__(self):
760
  """
761
  Prints all current equations fitted by the model.
 
860
  f"{self.model_selection} is not a valid model selection strategy."
861
  )
862
 
863
+ def _process_params(self):
864
  """
865
  Perform validation on the parameters defined in init for the
866
+ dataset specified in :term`fit`, and update them if necessary.
867
+ For example, this will change :param`binary_operators`
868
+ into `["+", "-", "*", "/"]` if `binary_operators` is `None`.
 
 
 
 
 
 
 
 
869
 
870
  Raises
871
  ------
 
1400
  self.raw_julia_state_ = None
1401
 
1402
  # Parameter input validation (for parameters defined in __init__)
 
1403
  X, y, Xresampled, variable_names = self._validate_fit_params(
1404
  X, y, Xresampled, variable_names
1405
  )