tttc3 commited on
Commit
73c6ffd
1 Parent(s): 32a2de6

Fixed jax export compatibility with refactor

Browse files
Files changed (1) hide show
  1. pysr/export_jax.py +1 -4
pysr/export_jax.py CHANGED
@@ -109,7 +109,7 @@ def _initialize_jax():
109
  jsp = _jsp
110
 
111
 
112
- def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
113
  """Returns a function f and its parameters;
114
  the function takes an input matrix, and a list of arguments:
115
  f(X, parameters)
@@ -192,9 +192,6 @@ def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
192
  )
193
  hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
194
  text = f"def {hash_string}(X, parameters):\n"
195
- if selection is not None:
196
- # Impose the feature selection:
197
- text += f" X = X[:, {list(selection)}]\n"
198
  text += " return "
199
  text += functional_form_text
200
  ldict = {}
 
109
  jsp = _jsp
110
 
111
 
112
+ def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
113
  """Returns a function f and its parameters;
114
  the function takes an input matrix, and a list of arguments:
115
  f(X, parameters)
 
192
  )
193
  hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
194
  text = f"def {hash_string}(X, parameters):\n"
 
 
 
195
  text += " return "
196
  text += functional_form_text
197
  ldict = {}