MilesCranmer commited on
Commit
205d866
1 Parent(s): d05f2e2

Assert _set_globals can work with both X and n_features

Browse files
Files changed (1) hide show
  1. pysr/sr.py +10 -2
pysr/sr.py CHANGED
@@ -593,7 +593,6 @@ Tried to activate project {julia_project} but failed."""
593
 
594
  def _set_globals(
595
  *,
596
- X,
597
  equation_file,
598
  variable_names,
599
  extra_sympy_mappings,
@@ -605,10 +604,19 @@ def _set_globals(
605
  nout,
606
  selection,
607
  raw_julia_output,
 
 
608
  ):
609
  global global_state
610
 
611
- global_state["n_features"] = X.shape[1]
 
 
 
 
 
 
 
612
  global_state["equation_file"] = equation_file
613
  global_state["variable_names"] = variable_names
614
  global_state["extra_sympy_mappings"] = extra_sympy_mappings
 
593
 
594
  def _set_globals(
595
  *,
 
596
  equation_file,
597
  variable_names,
598
  extra_sympy_mappings,
 
604
  nout,
605
  selection,
606
  raw_julia_output,
607
+ X=None,
608
+ n_features=None
609
  ):
610
  global global_state
611
 
612
+ if n_features is None and X is not None:
613
+ global_state["n_features"] = X.shape[1]
614
+ elif X is None and n_features is not None:
615
+ global_state["n_features"] = n_features
616
+ elif X is not None and n_features is not None:
617
+ assert X.shape[1] == n_features
618
+ global_state["n_features"] = n_features
619
+
620
  global_state["equation_file"] = equation_file
621
  global_state["variable_names"] = variable_names
622
  global_state["extra_sympy_mappings"] = extra_sympy_mappings