Spaces:
Running
Running
deepsource-autofix[bot]
commited on
Commit
•
5bb2875
1
Parent(s):
b5d0afb
Refactor unnecessary `else` / `elif` when `if` block has a `return` statement
Browse files- pysr/export_jax.py +9 -11
- pysr/sr.py +8 -15
pysr/export_jax.py
CHANGED
@@ -55,21 +55,19 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
-
|
59 |
return f"{int(expr)}"
|
60 |
-
|
61 |
return (
|
62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
63 |
)
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
else:
|
72 |
-
return f'{_func}({", ".join(args)})'
|
73 |
|
74 |
|
75 |
jax_initialized = False
|
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
+
if issubclass(expr.func, sympy.Integer):
|
59 |
return f"{int(expr)}"
|
60 |
+
if issubclass(expr.func, sympy.Symbol):
|
61 |
return (
|
62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
63 |
)
|
64 |
+
_func = _jnp_func_lookup[expr.func]
|
65 |
+
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
|
66 |
+
if _func == MUL:
|
67 |
+
return " * ".join(["(" + arg + ")" for arg in args])
|
68 |
+
if _func == ADD:
|
69 |
+
return " + ".join(["(" + arg + ")" for arg in args])
|
70 |
+
return f'{_func}({", ".join(args)})'
|
|
|
|
|
71 |
|
72 |
|
73 |
jax_initialized = False
|
pysr/sr.py
CHANGED
@@ -643,10 +643,9 @@ def _make_hyperparams_julia_str(
|
|
643 |
def tuple_fix(ops):
|
644 |
if len(ops) > 1:
|
645 |
return ", ".join(ops)
|
646 |
-
|
647 |
return ""
|
648 |
-
|
649 |
-
return ops[0] + ","
|
650 |
|
651 |
def_hyperparams += f"""\n
|
652 |
plus=(+)
|
@@ -1025,8 +1024,7 @@ def get_hof(
|
|
1025 |
|
1026 |
if multioutput:
|
1027 |
return ret_outputs
|
1028 |
-
|
1029 |
-
return ret_outputs[0]
|
1030 |
|
1031 |
|
1032 |
def best_row(equations=None):
|
@@ -1037,8 +1035,7 @@ def best_row(equations=None):
|
|
1037 |
equations = get_hof()
|
1038 |
if isinstance(equations, list):
|
1039 |
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
|
1040 |
-
|
1041 |
-
return equations.iloc[np.argmax(equations["score"])]
|
1042 |
|
1043 |
|
1044 |
def best_tex(equations=None):
|
@@ -1051,8 +1048,7 @@ def best_tex(equations=None):
|
|
1051 |
return [
|
1052 |
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
|
1053 |
]
|
1054 |
-
|
1055 |
-
return sympy.latex(best_row(equations)["sympy_format"].simplify())
|
1056 |
|
1057 |
|
1058 |
def best(equations=None):
|
@@ -1063,8 +1059,7 @@ def best(equations=None):
|
|
1063 |
equations = get_hof()
|
1064 |
if isinstance(equations, list):
|
1065 |
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
|
1066 |
-
|
1067 |
-
return best_row(equations)["sympy_format"].simplify()
|
1068 |
|
1069 |
|
1070 |
def best_callable(equations=None):
|
@@ -1075,8 +1070,7 @@ def best_callable(equations=None):
|
|
1075 |
equations = get_hof()
|
1076 |
if isinstance(equations, list):
|
1077 |
return [best_row(eq)["lambda_format"] for eq in equations]
|
1078 |
-
|
1079 |
-
return best_row(equations)["lambda_format"]
|
1080 |
|
1081 |
|
1082 |
def _escape_filename(filename):
|
@@ -1114,5 +1108,4 @@ class CallableEquation(object):
|
|
1114 |
def __call__(self, X):
|
1115 |
if self._selection is not None:
|
1116 |
return self._lambda(*X[:, self._selection].T)
|
1117 |
-
|
1118 |
-
return self._lambda(*X.T)
|
|
|
643 |
def tuple_fix(ops):
|
644 |
if len(ops) > 1:
|
645 |
return ", ".join(ops)
|
646 |
+
if len(ops) == 0:
|
647 |
return ""
|
648 |
+
return ops[0] + ","
|
|
|
649 |
|
650 |
def_hyperparams += f"""\n
|
651 |
plus=(+)
|
|
|
1024 |
|
1025 |
if multioutput:
|
1026 |
return ret_outputs
|
1027 |
+
return ret_outputs[0]
|
|
|
1028 |
|
1029 |
|
1030 |
def best_row(equations=None):
|
|
|
1035 |
equations = get_hof()
|
1036 |
if isinstance(equations, list):
|
1037 |
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
|
1038 |
+
return equations.iloc[np.argmax(equations["score"])]
|
|
|
1039 |
|
1040 |
|
1041 |
def best_tex(equations=None):
|
|
|
1048 |
return [
|
1049 |
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
|
1050 |
]
|
1051 |
+
return sympy.latex(best_row(equations)["sympy_format"].simplify())
|
|
|
1052 |
|
1053 |
|
1054 |
def best(equations=None):
|
|
|
1059 |
equations = get_hof()
|
1060 |
if isinstance(equations, list):
|
1061 |
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
|
1062 |
+
return best_row(equations)["sympy_format"].simplify()
|
|
|
1063 |
|
1064 |
|
1065 |
def best_callable(equations=None):
|
|
|
1070 |
equations = get_hof()
|
1071 |
if isinstance(equations, list):
|
1072 |
return [best_row(eq)["lambda_format"] for eq in equations]
|
1073 |
+
return best_row(equations)["lambda_format"]
|
|
|
1074 |
|
1075 |
|
1076 |
def _escape_filename(filename):
|
|
|
1108 |
def __call__(self, X):
|
1109 |
if self._selection is not None:
|
1110 |
return self._lambda(*X[:, self._selection].T)
|
1111 |
+
return self._lambda(*X.T)
|
|