MilesCranmer commited on
Commit
86dd9ce
1 Parent(s): 9068541

Fix output_torch_format option for pysr

Browse files
Files changed (1) hide show
  1. pysr/sr.py +2 -2
pysr/sr.py CHANGED
@@ -800,8 +800,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
800
  =======
801
  if output_torch_format:
802
  from .export_torch import sympy2torch
803
- func, params = sympy2torch(eqn, sympy_symbols)
804
- torch_format.append({'callable': func, 'parameters': params})
805
  lambda_format.append(lambdify(sympy_symbols, eqn))
806
  >>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
807
  curMSE = output.loc[i, 'MSE']
 
800
  =======
801
  if output_torch_format:
802
  from .export_torch import sympy2torch
803
+ module = sympy2torch(eqn, sympy_symbols)
804
+ torch_format.append(module)
805
  lambda_format.append(lambdify(sympy_symbols, eqn))
806
  >>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
807
  curMSE = output.loc[i, 'MSE']