MilesCranmer commited on
Commit
b4cb407
·
1 Parent(s): ce5b119

Fix feature selection for JAX export

Browse files
Files changed (2) hide show
  1. pysr/export_jax.py +4 -1
  2. pysr/sr.py +1 -0
pysr/export_jax.py CHANGED
@@ -109,7 +109,7 @@ def _initialize_jax():
109
  jsp = _jsp
110
 
111
 
112
- def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
113
  """Returns a function f and its parameters;
114
  the function takes an input matrix, and a list of arguments:
115
  f(X, parameters)
@@ -192,6 +192,9 @@ def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
192
  )
193
  hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
194
  text = f"def {hash_string}(X, parameters):\n"
 
 
 
195
  text += " return "
196
  text += functional_form_text
197
  ldict = {}
 
109
  jsp = _jsp
110
 
111
 
112
+ def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
113
  """Returns a function f and its parameters;
114
  the function takes an input matrix, and a list of arguments:
115
  f(X, parameters)
 
192
  )
193
  hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
194
  text = f"def {hash_string}(X, parameters):\n"
195
+ if selection is not None:
196
+ # Impose the feature selection:
197
+ text += f" X = X[:, {list(selection)}]\n"
198
  text += " return "
199
  text += functional_form_text
200
  ldict = {}
pysr/sr.py CHANGED
@@ -1740,6 +1740,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1740
  func, params = sympy2jax(
1741
  eqn,
1742
  sympy_symbols,
 
1743
  extra_jax_mappings=self.extra_jax_mappings,
1744
  )
1745
  jax_format.append({"callable": func, "parameters": params})
 
1740
  func, params = sympy2jax(
1741
  eqn,
1742
  sympy_symbols,
1743
+ selection=self.selection_mask_,
1744
  extra_jax_mappings=self.extra_jax_mappings,
1745
  )
1746
  jax_format.append({"callable": func, "parameters": params})