Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
84b46ac
1
Parent(s):
a2492c3
Working prediction plot
Browse files- gui/app.py +2 -1
- gui/plots.py +14 -0
- gui/processing.py +5 -1
gui/app.py
CHANGED
@@ -222,7 +222,8 @@ def main():
|
|
222 |
"batch_size",
|
223 |
]
|
224 |
],
|
225 |
-
outputs=blocks["df"],
|
|
|
226 |
)
|
227 |
|
228 |
# Any update to the equation choice will trigger a plot_example_data:
|
|
|
222 |
"batch_size",
|
223 |
]
|
224 |
],
|
225 |
+
outputs=[blocks["df"], blocks["predictions_plot"]],
|
226 |
+
show_progress=True,
|
227 |
)
|
228 |
|
229 |
# Any update to the equation choice will trigger a plot_example_data:
|
gui/plots.py
CHANGED
@@ -61,6 +61,20 @@ def plot_example_data(test_equation, num_points, noise_level, data_seed):
|
|
61 |
return fig
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def stylize_axis(ax):
|
65 |
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
66 |
ax.spines["top"].set_visible(False)
|
|
|
61 |
return fig
|
62 |
|
63 |
|
64 |
+
def plot_predictions(y, ypred):
|
65 |
+
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
66 |
+
|
67 |
+
ax.scatter(y, ypred, alpha=0.7, edgecolors="w", s=50)
|
68 |
+
|
69 |
+
stylize_axis(ax)
|
70 |
+
|
71 |
+
ax.set_xlabel("true")
|
72 |
+
ax.set_ylabel("prediction")
|
73 |
+
fig.tight_layout(pad=2)
|
74 |
+
|
75 |
+
return fig
|
76 |
+
|
77 |
+
|
78 |
def stylize_axis(ax):
|
79 |
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
80 |
ax.spines["top"].set_visible(False)
|
gui/processing.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Callable
|
|
7 |
|
8 |
import pandas as pd
|
9 |
from data import generate_data, read_csv
|
|
|
10 |
|
11 |
EMPTY_DF = lambda: pd.DataFrame(
|
12 |
{
|
@@ -188,7 +189,10 @@ def processing(
|
|
188 |
)
|
189 |
)
|
190 |
out = PERSISTENT_READER.out_queue.get()
|
|
|
191 |
equations = out["equations"]
|
192 |
-
yield equations[["Complexity", "Loss", "Equation"]]
|
|
|
|
|
193 |
|
194 |
time.sleep(0.1)
|
|
|
7 |
|
8 |
import pandas as pd
|
9 |
from data import generate_data, read_csv
|
10 |
+
from plots import plot_predictions
|
11 |
|
12 |
EMPTY_DF = lambda: pd.DataFrame(
|
13 |
{
|
|
|
189 |
)
|
190 |
)
|
191 |
out = PERSISTENT_READER.out_queue.get()
|
192 |
+
predictions = out["ypred"]
|
193 |
equations = out["equations"]
|
194 |
+
yield equations[["Complexity", "Loss", "Equation"]], plot_predictions(
|
195 |
+
y, predictions
|
196 |
+
)
|
197 |
|
198 |
time.sleep(0.1)
|