Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
9fa2182
1
Parent(s):
7f0b93d
Refactor GUI to multiple files
Browse files- gui/app.py +4 -250
- gui/data.py +22 -0
- gui/plots.py +84 -0
- gui/processing.py +150 -0
gui/app.py
CHANGED
@@ -1,184 +1,8 @@
|
|
1 |
-
import multiprocessing as mp
|
2 |
-
import os
|
3 |
-
import tempfile
|
4 |
-
import time
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
import gradio as gr
|
8 |
-
import numpy as np
|
9 |
-
import pandas as pd
|
10 |
-
from matplotlib import pyplot as plt
|
11 |
-
|
12 |
-
plt.ioff()
|
13 |
-
plt.rcParams["font.family"] = [
|
14 |
-
"IBM Plex Mono",
|
15 |
-
# Fallback fonts:
|
16 |
-
"DejaVu Sans Mono",
|
17 |
-
"Courier New",
|
18 |
-
"monospace",
|
19 |
-
]
|
20 |
-
|
21 |
-
empty_df = lambda: pd.DataFrame(
|
22 |
-
{
|
23 |
-
"equation": [],
|
24 |
-
"loss": [],
|
25 |
-
"complexity": [],
|
26 |
-
}
|
27 |
-
)
|
28 |
-
|
29 |
-
test_equations = ["sin(2*x)/x + 0.1*x"]
|
30 |
-
|
31 |
-
|
32 |
-
def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
|
33 |
-
rstate = np.random.RandomState(data_seed)
|
34 |
-
x = rstate.uniform(-10, 10, num_points)
|
35 |
-
for k, v in {
|
36 |
-
"sin": "np.sin",
|
37 |
-
"cos": "np.cos",
|
38 |
-
"exp": "np.exp",
|
39 |
-
"log": "np.log",
|
40 |
-
"tan": "np.tan",
|
41 |
-
"^": "**",
|
42 |
-
}.items():
|
43 |
-
s = s.replace(k, v)
|
44 |
-
y = eval(s)
|
45 |
-
noise = rstate.normal(0, noise_level, y.shape)
|
46 |
-
y_noisy = y + noise
|
47 |
-
return pd.DataFrame({"x": x}), y_noisy
|
48 |
-
|
49 |
-
|
50 |
-
def _greet_dispatch(
|
51 |
-
file_input,
|
52 |
-
force_run,
|
53 |
-
test_equation,
|
54 |
-
num_points,
|
55 |
-
noise_level,
|
56 |
-
data_seed,
|
57 |
-
niterations,
|
58 |
-
maxsize,
|
59 |
-
binary_operators,
|
60 |
-
unary_operators,
|
61 |
-
plot_update_delay,
|
62 |
-
parsimony,
|
63 |
-
populations,
|
64 |
-
population_size,
|
65 |
-
ncycles_per_iteration,
|
66 |
-
elementwise_loss,
|
67 |
-
adaptive_parsimony_scaling,
|
68 |
-
optimizer_algorithm,
|
69 |
-
optimizer_iterations,
|
70 |
-
batching,
|
71 |
-
batch_size,
|
72 |
-
):
|
73 |
-
"""Load data, then spawn a process to run the greet function."""
|
74 |
-
if file_input is not None:
|
75 |
-
# Look at some statistics of the file:
|
76 |
-
df = pd.read_csv(file_input)
|
77 |
-
if len(df) == 0:
|
78 |
-
return (
|
79 |
-
empty_df(),
|
80 |
-
"The file is empty!",
|
81 |
-
)
|
82 |
-
if len(df.columns) == 1:
|
83 |
-
return (
|
84 |
-
empty_df(),
|
85 |
-
"The file has only one column!",
|
86 |
-
)
|
87 |
-
if len(df) > 10_000 and not force_run:
|
88 |
-
return (
|
89 |
-
empty_df(),
|
90 |
-
"You have uploaded a file with more than 10,000 rows. "
|
91 |
-
"This will take very long to run. "
|
92 |
-
"Please upload a subsample of the data, "
|
93 |
-
"or check the box 'Ignore Warnings'.",
|
94 |
-
)
|
95 |
-
|
96 |
-
col_to_fit = df.columns[-1]
|
97 |
-
y = np.array(df[col_to_fit])
|
98 |
-
X = df.drop([col_to_fit], axis=1)
|
99 |
-
else:
|
100 |
-
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
101 |
-
|
102 |
-
with tempfile.TemporaryDirectory() as tmpdirname:
|
103 |
-
base = Path(tmpdirname)
|
104 |
-
equation_file = base / "hall_of_fame.csv"
|
105 |
-
equation_file_bkup = base / "hall_of_fame.csv.bkup"
|
106 |
-
process = mp.Process(
|
107 |
-
target=greet,
|
108 |
-
kwargs=dict(
|
109 |
-
X=X,
|
110 |
-
y=y,
|
111 |
-
niterations=niterations,
|
112 |
-
maxsize=maxsize,
|
113 |
-
binary_operators=binary_operators,
|
114 |
-
unary_operators=unary_operators,
|
115 |
-
equation_file=equation_file,
|
116 |
-
parsimony=parsimony,
|
117 |
-
populations=populations,
|
118 |
-
population_size=population_size,
|
119 |
-
ncycles_per_iteration=ncycles_per_iteration,
|
120 |
-
elementwise_loss=elementwise_loss,
|
121 |
-
adaptive_parsimony_scaling=adaptive_parsimony_scaling,
|
122 |
-
optimizer_algorithm=optimizer_algorithm,
|
123 |
-
optimizer_iterations=optimizer_iterations,
|
124 |
-
batching=batching,
|
125 |
-
batch_size=batch_size,
|
126 |
-
),
|
127 |
-
)
|
128 |
-
process.start()
|
129 |
-
last_yield_time = None
|
130 |
-
while process.is_alive():
|
131 |
-
if equation_file_bkup.exists():
|
132 |
-
try:
|
133 |
-
# First, copy the file to a the copy file
|
134 |
-
equation_file_copy = base / "hall_of_fame_copy.csv"
|
135 |
-
os.system(f"cp {equation_file_bkup} {equation_file_copy}")
|
136 |
-
equations = pd.read_csv(equation_file_copy)
|
137 |
-
# Ensure it is pareto dominated, with more complex expressions
|
138 |
-
# having higher loss. Otherwise remove those rows.
|
139 |
-
# TODO: Not sure why this occurs; could be the result of a late copy?
|
140 |
-
equations.sort_values("Complexity", ascending=True, inplace=True)
|
141 |
-
equations.reset_index(inplace=True)
|
142 |
-
bad_idx = []
|
143 |
-
min_loss = None
|
144 |
-
for i in equations.index:
|
145 |
-
if min_loss is None or equations.loc[i, "Loss"] < min_loss:
|
146 |
-
min_loss = float(equations.loc[i, "Loss"])
|
147 |
-
else:
|
148 |
-
bad_idx.append(i)
|
149 |
-
equations.drop(index=bad_idx, inplace=True)
|
150 |
-
|
151 |
-
while (
|
152 |
-
last_yield_time is not None
|
153 |
-
and time.time() - last_yield_time < plot_update_delay
|
154 |
-
):
|
155 |
-
time.sleep(0.1)
|
156 |
-
|
157 |
-
yield equations[["Complexity", "Loss", "Equation"]]
|
158 |
-
|
159 |
-
last_yield_time = time.time()
|
160 |
-
except pd.errors.EmptyDataError:
|
161 |
-
pass
|
162 |
-
|
163 |
-
process.join()
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
X,
|
169 |
-
y,
|
170 |
-
**pysr_kwargs,
|
171 |
-
):
|
172 |
-
import pysr
|
173 |
-
|
174 |
-
model = pysr.PySRRegressor(
|
175 |
-
progress=False,
|
176 |
-
timeout_in_seconds=1000,
|
177 |
-
**pysr_kwargs,
|
178 |
-
)
|
179 |
-
model.fit(X, y)
|
180 |
-
|
181 |
-
return 0
|
182 |
|
183 |
|
184 |
def _data_layout():
|
@@ -372,7 +196,7 @@ def main():
|
|
372 |
blocks["run"] = gr.Button()
|
373 |
|
374 |
blocks["run"].click(
|
375 |
-
|
376 |
inputs=[
|
377 |
blocks[k]
|
378 |
for k in [
|
@@ -423,75 +247,5 @@ def main():
|
|
423 |
demo.launch(debug=True)
|
424 |
|
425 |
|
426 |
-
def replot_pareto(df, maxsize):
|
427 |
-
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
428 |
-
|
429 |
-
if len(df) == 0 or "Equation" not in df.columns:
|
430 |
-
return fig
|
431 |
-
|
432 |
-
# Plotting the data
|
433 |
-
ax.loglog(
|
434 |
-
df["Complexity"],
|
435 |
-
df["Loss"],
|
436 |
-
marker="o",
|
437 |
-
linestyle="-",
|
438 |
-
color="#333f48",
|
439 |
-
linewidth=1.5,
|
440 |
-
markersize=6,
|
441 |
-
)
|
442 |
-
|
443 |
-
# Set the axis limits
|
444 |
-
ax.set_xlim(0.5, maxsize + 1)
|
445 |
-
ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
|
446 |
-
ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
|
447 |
-
ax.set_ylim(ybottom, ytop)
|
448 |
-
|
449 |
-
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
450 |
-
ax.spines["top"].set_visible(False)
|
451 |
-
ax.spines["right"].set_visible(False)
|
452 |
-
|
453 |
-
# Range-frame the plot
|
454 |
-
for direction in ["bottom", "left"]:
|
455 |
-
ax.spines[direction].set_position(("outward", 10))
|
456 |
-
|
457 |
-
# Delete far ticks
|
458 |
-
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
459 |
-
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
460 |
-
|
461 |
-
ax.set_xlabel("Complexity")
|
462 |
-
ax.set_ylabel("Loss")
|
463 |
-
fig.tight_layout(pad=2)
|
464 |
-
|
465 |
-
return fig
|
466 |
-
|
467 |
-
|
468 |
-
def replot(test_equation, num_points, noise_level, data_seed):
|
469 |
-
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
470 |
-
x = X["x"]
|
471 |
-
|
472 |
-
plt.rcParams["font.family"] = "IBM Plex Mono"
|
473 |
-
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
474 |
-
|
475 |
-
ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
|
476 |
-
|
477 |
-
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
478 |
-
ax.spines["top"].set_visible(False)
|
479 |
-
ax.spines["right"].set_visible(False)
|
480 |
-
|
481 |
-
# Range-frame the plot
|
482 |
-
for direction in ["bottom", "left"]:
|
483 |
-
ax.spines[direction].set_position(("outward", 10))
|
484 |
-
|
485 |
-
# Delete far ticks
|
486 |
-
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
487 |
-
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
488 |
-
|
489 |
-
ax.set_xlabel("x")
|
490 |
-
ax.set_ylabel("y")
|
491 |
-
fig.tight_layout(pad=2)
|
492 |
-
|
493 |
-
return fig
|
494 |
-
|
495 |
-
|
496 |
if __name__ == "__main__":
|
497 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
from .data import test_equations
|
4 |
+
from .plots import replot, replot_pareto
|
5 |
+
from .processing import process
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def _data_layout():
|
|
|
196 |
blocks["run"] = gr.Button()
|
197 |
|
198 |
blocks["run"].click(
|
199 |
+
process,
|
200 |
inputs=[
|
201 |
blocks[k]
|
202 |
for k in [
|
|
|
247 |
demo.launch(debug=True)
|
248 |
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
if __name__ == "__main__":
|
251 |
main()
|
gui/data.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
test_equations = ["sin(2*x)/x + 0.1*x"]
|
5 |
+
|
6 |
+
|
7 |
+
def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
|
8 |
+
rstate = np.random.RandomState(data_seed)
|
9 |
+
x = rstate.uniform(-10, 10, num_points)
|
10 |
+
for k, v in {
|
11 |
+
"sin": "np.sin",
|
12 |
+
"cos": "np.cos",
|
13 |
+
"exp": "np.exp",
|
14 |
+
"log": "np.log",
|
15 |
+
"tan": "np.tan",
|
16 |
+
"^": "**",
|
17 |
+
}.items():
|
18 |
+
s = s.replace(k, v)
|
19 |
+
y = eval(s)
|
20 |
+
noise = rstate.normal(0, noise_level, y.shape)
|
21 |
+
y_noisy = y + noise
|
22 |
+
return pd.DataFrame({"x": x}), y_noisy
|
gui/plots.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from matplotlib import pyplot as plt
|
4 |
+
|
5 |
+
plt.ioff()
|
6 |
+
plt.rcParams["font.family"] = [
|
7 |
+
"IBM Plex Mono",
|
8 |
+
# Fallback fonts:
|
9 |
+
"DejaVu Sans Mono",
|
10 |
+
"Courier New",
|
11 |
+
"monospace",
|
12 |
+
]
|
13 |
+
|
14 |
+
from .data import generate_data
|
15 |
+
|
16 |
+
|
17 |
+
def replot_pareto(df: pd.DataFrame, maxsize: int):
|
18 |
+
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
19 |
+
|
20 |
+
if len(df) == 0 or "Equation" not in df.columns:
|
21 |
+
return fig
|
22 |
+
|
23 |
+
# Plotting the data
|
24 |
+
ax.loglog(
|
25 |
+
df["Complexity"],
|
26 |
+
df["Loss"],
|
27 |
+
marker="o",
|
28 |
+
linestyle="-",
|
29 |
+
color="#333f48",
|
30 |
+
linewidth=1.5,
|
31 |
+
markersize=6,
|
32 |
+
)
|
33 |
+
|
34 |
+
# Set the axis limits
|
35 |
+
ax.set_xlim(0.5, maxsize + 1)
|
36 |
+
ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
|
37 |
+
ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
|
38 |
+
ax.set_ylim(ybottom, ytop)
|
39 |
+
|
40 |
+
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
41 |
+
ax.spines["top"].set_visible(False)
|
42 |
+
ax.spines["right"].set_visible(False)
|
43 |
+
|
44 |
+
# Range-frame the plot
|
45 |
+
for direction in ["bottom", "left"]:
|
46 |
+
ax.spines[direction].set_position(("outward", 10))
|
47 |
+
|
48 |
+
# Delete far ticks
|
49 |
+
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
50 |
+
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
51 |
+
|
52 |
+
ax.set_xlabel("Complexity")
|
53 |
+
ax.set_ylabel("Loss")
|
54 |
+
fig.tight_layout(pad=2)
|
55 |
+
|
56 |
+
return fig
|
57 |
+
|
58 |
+
|
59 |
+
def replot(test_equation, num_points, noise_level, data_seed):
|
60 |
+
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
61 |
+
x = X["x"]
|
62 |
+
|
63 |
+
plt.rcParams["font.family"] = "IBM Plex Mono"
|
64 |
+
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
65 |
+
|
66 |
+
ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
|
67 |
+
|
68 |
+
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
69 |
+
ax.spines["top"].set_visible(False)
|
70 |
+
ax.spines["right"].set_visible(False)
|
71 |
+
|
72 |
+
# Range-frame the plot
|
73 |
+
for direction in ["bottom", "left"]:
|
74 |
+
ax.spines[direction].set_position(("outward", 10))
|
75 |
+
|
76 |
+
# Delete far ticks
|
77 |
+
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
78 |
+
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
79 |
+
|
80 |
+
ax.set_xlabel("x")
|
81 |
+
ax.set_ylabel("y")
|
82 |
+
fig.tight_layout(pad=2)
|
83 |
+
|
84 |
+
return fig
|
gui/processing.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing as mp
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from .data import generate_data
|
11 |
+
|
12 |
+
EMPTY_DF = lambda: pd.DataFrame(
|
13 |
+
{
|
14 |
+
"Equation": [],
|
15 |
+
"Loss": [],
|
16 |
+
"Complexity": [],
|
17 |
+
}
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def process(
|
22 |
+
file_input,
|
23 |
+
force_run,
|
24 |
+
test_equation,
|
25 |
+
num_points,
|
26 |
+
noise_level,
|
27 |
+
data_seed,
|
28 |
+
niterations,
|
29 |
+
maxsize,
|
30 |
+
binary_operators,
|
31 |
+
unary_operators,
|
32 |
+
plot_update_delay,
|
33 |
+
parsimony,
|
34 |
+
populations,
|
35 |
+
population_size,
|
36 |
+
ncycles_per_iteration,
|
37 |
+
elementwise_loss,
|
38 |
+
adaptive_parsimony_scaling,
|
39 |
+
optimizer_algorithm,
|
40 |
+
optimizer_iterations,
|
41 |
+
batching,
|
42 |
+
batch_size,
|
43 |
+
):
|
44 |
+
"""Load data, then spawn a process to run the greet function."""
|
45 |
+
if file_input is not None:
|
46 |
+
# Look at some statistics of the file:
|
47 |
+
df = pd.read_csv(file_input)
|
48 |
+
if len(df) == 0:
|
49 |
+
return (
|
50 |
+
EMPTY_DF(),
|
51 |
+
"The file is empty!",
|
52 |
+
)
|
53 |
+
if len(df.columns) == 1:
|
54 |
+
return (
|
55 |
+
EMPTY_DF(),
|
56 |
+
"The file has only one column!",
|
57 |
+
)
|
58 |
+
if len(df) > 10_000 and not force_run:
|
59 |
+
return (
|
60 |
+
EMPTY_DF(),
|
61 |
+
"You have uploaded a file with more than 10,000 rows. "
|
62 |
+
"This will take very long to run. "
|
63 |
+
"Please upload a subsample of the data, "
|
64 |
+
"or check the box 'Ignore Warnings'.",
|
65 |
+
)
|
66 |
+
|
67 |
+
col_to_fit = df.columns[-1]
|
68 |
+
y = np.array(df[col_to_fit])
|
69 |
+
X = df.drop([col_to_fit], axis=1)
|
70 |
+
else:
|
71 |
+
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
72 |
+
|
73 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
74 |
+
base = Path(tmpdirname)
|
75 |
+
equation_file = base / "hall_of_fame.csv"
|
76 |
+
equation_file_bkup = base / "hall_of_fame.csv.bkup"
|
77 |
+
process = mp.Process(
|
78 |
+
target=pysr_fit,
|
79 |
+
kwargs=dict(
|
80 |
+
X=X,
|
81 |
+
y=y,
|
82 |
+
niterations=niterations,
|
83 |
+
maxsize=maxsize,
|
84 |
+
binary_operators=binary_operators,
|
85 |
+
unary_operators=unary_operators,
|
86 |
+
equation_file=equation_file,
|
87 |
+
parsimony=parsimony,
|
88 |
+
populations=populations,
|
89 |
+
population_size=population_size,
|
90 |
+
ncycles_per_iteration=ncycles_per_iteration,
|
91 |
+
elementwise_loss=elementwise_loss,
|
92 |
+
adaptive_parsimony_scaling=adaptive_parsimony_scaling,
|
93 |
+
optimizer_algorithm=optimizer_algorithm,
|
94 |
+
optimizer_iterations=optimizer_iterations,
|
95 |
+
batching=batching,
|
96 |
+
batch_size=batch_size,
|
97 |
+
),
|
98 |
+
)
|
99 |
+
process.start()
|
100 |
+
last_yield_time = None
|
101 |
+
while process.is_alive():
|
102 |
+
if equation_file_bkup.exists():
|
103 |
+
try:
|
104 |
+
# First, copy the file to a the copy file
|
105 |
+
equation_file_copy = base / "hall_of_fame_copy.csv"
|
106 |
+
os.system(f"cp {equation_file_bkup} {equation_file_copy}")
|
107 |
+
equations = pd.read_csv(equation_file_copy)
|
108 |
+
# Ensure it is pareto dominated, with more complex expressions
|
109 |
+
# having higher loss. Otherwise remove those rows.
|
110 |
+
# TODO: Not sure why this occurs; could be the result of a late copy?
|
111 |
+
equations.sort_values("Complexity", ascending=True, inplace=True)
|
112 |
+
equations.reset_index(inplace=True)
|
113 |
+
bad_idx = []
|
114 |
+
min_loss = None
|
115 |
+
for i in equations.index:
|
116 |
+
if min_loss is None or equations.loc[i, "Loss"] < min_loss:
|
117 |
+
min_loss = float(equations.loc[i, "Loss"])
|
118 |
+
else:
|
119 |
+
bad_idx.append(i)
|
120 |
+
equations.drop(index=bad_idx, inplace=True)
|
121 |
+
|
122 |
+
while (
|
123 |
+
last_yield_time is not None
|
124 |
+
and time.time() - last_yield_time < plot_update_delay
|
125 |
+
):
|
126 |
+
time.sleep(0.1)
|
127 |
+
|
128 |
+
yield equations[["Complexity", "Loss", "Equation"]]
|
129 |
+
|
130 |
+
last_yield_time = time.time()
|
131 |
+
except pd.errors.EmptyDataError:
|
132 |
+
pass
|
133 |
+
|
134 |
+
process.join()
|
135 |
+
|
136 |
+
|
137 |
+
def pysr_fit(
|
138 |
+
*,
|
139 |
+
X,
|
140 |
+
y,
|
141 |
+
**pysr_kwargs,
|
142 |
+
):
|
143 |
+
import pysr
|
144 |
+
|
145 |
+
model = pysr.PySRRegressor(
|
146 |
+
progress=False,
|
147 |
+
timeout_in_seconds=1000,
|
148 |
+
**pysr_kwargs,
|
149 |
+
)
|
150 |
+
model.fit(X, y)
|