MilesCranmer commited on
Commit
84b46ac
1 Parent(s): a2492c3

Working prediction plot

Browse files
Files changed (3) hide show
  1. gui/app.py +2 -1
  2. gui/plots.py +14 -0
  3. gui/processing.py +5 -1
gui/app.py CHANGED
@@ -222,7 +222,8 @@ def main():
222
  "batch_size",
223
  ]
224
  ],
225
- outputs=blocks["df"],
 
226
  )
227
 
228
  # Any update to the equation choice will trigger a plot_example_data:
 
222
  "batch_size",
223
  ]
224
  ],
225
+ outputs=[blocks["df"], blocks["predictions_plot"]],
226
+ show_progress=True,
227
  )
228
 
229
  # Any update to the equation choice will trigger a plot_example_data:
gui/plots.py CHANGED
@@ -61,6 +61,20 @@ def plot_example_data(test_equation, num_points, noise_level, data_seed):
61
  return fig
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def stylize_axis(ax):
65
  ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
66
  ax.spines["top"].set_visible(False)
 
61
  return fig
62
 
63
 
64
+ def plot_predictions(y, ypred):
65
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
66
+
67
+ ax.scatter(y, ypred, alpha=0.7, edgecolors="w", s=50)
68
+
69
+ stylize_axis(ax)
70
+
71
+ ax.set_xlabel("true")
72
+ ax.set_ylabel("prediction")
73
+ fig.tight_layout(pad=2)
74
+
75
+ return fig
76
+
77
+
78
  def stylize_axis(ax):
79
  ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
80
  ax.spines["top"].set_visible(False)
gui/processing.py CHANGED
@@ -7,6 +7,7 @@ from typing import Callable
7
 
8
  import pandas as pd
9
  from data import generate_data, read_csv
 
10
 
11
  EMPTY_DF = lambda: pd.DataFrame(
12
  {
@@ -188,7 +189,10 @@ def processing(
188
  )
189
  )
190
  out = PERSISTENT_READER.out_queue.get()
 
191
  equations = out["equations"]
192
- yield equations[["Complexity", "Loss", "Equation"]]
 
 
193
 
194
  time.sleep(0.1)
 
7
 
8
  import pandas as pd
9
  from data import generate_data, read_csv
10
+ from plots import plot_predictions
11
 
12
  EMPTY_DF = lambda: pd.DataFrame(
13
  {
 
189
  )
190
  )
191
  out = PERSISTENT_READER.out_queue.get()
192
+ predictions = out["ypred"]
193
  equations = out["equations"]
194
+ yield equations[["Complexity", "Loss", "Equation"]], plot_predictions(
195
+ y, predictions
196
+ )
197
 
198
  time.sleep(0.1)