Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
d0788ef
1
Parent(s):
44216ab
Fix syntax error in JAX converter
Browse files- pysr/export.py +1 -1
- pysr/sr.py +2 -2
pysr/export.py
CHANGED
@@ -62,7 +62,7 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
62 |
parameters.append(float(expr))
|
63 |
return f"parameters[{len(parameters) - 1}]"
|
64 |
elif issubclass(expr.func, sympy.Integer):
|
65 |
-
return "{int(expr)}"
|
66 |
elif issubclass(expr.func, sympy.Symbol):
|
67 |
return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
68 |
else:
|
|
|
62 |
parameters.append(float(expr))
|
63 |
return f"parameters[{len(parameters) - 1}]"
|
64 |
elif issubclass(expr.func, sympy.Integer):
|
65 |
+
return f"{int(expr)}"
|
66 |
elif issubclass(expr.func, sympy.Symbol):
|
67 |
return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
68 |
else:
|
pysr/sr.py
CHANGED
@@ -686,7 +686,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
686 |
sympy_format.append(eqn)
|
687 |
if output_jax_format:
|
688 |
func, params = sympy2jax(eqn, sympy_symbols)
|
689 |
-
jax_format.append({'callable': func, 'parameters':
|
690 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
691 |
curMSE = output.loc[i, 'MSE']
|
692 |
curComplexity = output.loc[i, 'Complexity']
|
@@ -705,7 +705,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
705 |
output['lambda_format'] = lambda_format
|
706 |
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
707 |
if output_jax_format:
|
708 |
-
output_cols += 'jax_format'
|
709 |
output['jax_format'] = jax_format
|
710 |
|
711 |
return output[output_cols]
|
|
|
686 |
sympy_format.append(eqn)
|
687 |
if output_jax_format:
|
688 |
func, params = sympy2jax(eqn, sympy_symbols)
|
689 |
+
jax_format.append({'callable': func, 'parameters': params})
|
690 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
691 |
curMSE = output.loc[i, 'MSE']
|
692 |
curComplexity = output.loc[i, 'Complexity']
|
|
|
705 |
output['lambda_format'] = lambda_format
|
706 |
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
707 |
if output_jax_format:
|
708 |
+
output_cols += ['jax_format']
|
709 |
output['jax_format'] = jax_format
|
710 |
|
711 |
return output[output_cols]
|