MilesCranmer commited on
Commit
b955f86
1 Parent(s): b58f1db

feat(gui): add stop button

Browse files
Files changed (3) hide show
  1. gui/app.py +3 -1
  2. gui/plots.py +3 -0
  3. gui/processing.py +42 -25
gui/app.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  from data import TEST_EQUATIONS
6
  from gradio.components.base import Component
7
  from plots import plot_example_data, plot_pareto_curve
8
- from processing import processing
9
 
10
 
11
  class ExampleData:
@@ -239,6 +239,7 @@ class AppInterface:
239
  with gr.Column():
240
  self.results = Results()
241
  self.run = gr.Button()
 
242
 
243
  # Update plot when dataframe is updated:
244
  self.results.df.change(
@@ -263,6 +264,7 @@ class AppInterface:
263
  ],
264
  show_progress=True,
265
  )
 
266
 
267
 
268
  def last_part(k: str) -> str:
 
5
  from data import TEST_EQUATIONS
6
  from gradio.components.base import Component
7
  from plots import plot_example_data, plot_pareto_curve
8
+ from processing import processing, stop
9
 
10
 
11
  class ExampleData:
 
239
  with gr.Column():
240
  self.results = Results()
241
  self.run = gr.Button()
242
+ self.stop = gr.Button(value="Stop")
243
 
244
  # Update plot when dataframe is updated:
245
  self.results.df.change(
 
264
  ],
265
  show_progress=True,
266
  )
267
+ self.stop.click(stop)
268
 
269
 
270
  def last_part(k: str) -> str:
gui/plots.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  import pandas as pd
3
  from matplotlib import pyplot as plt
@@ -10,6 +12,7 @@ plt.rcParams["font.family"] = [
10
  "Courier New",
11
  "monospace",
12
  ]
 
13
 
14
  from data import generate_data
15
 
 
1
+ import logging
2
+
3
  import numpy as np
4
  import pandas as pd
5
  from matplotlib import pyplot as plt
 
12
  "Courier New",
13
  "monospace",
14
  ]
15
+ logging.getLogger("matplotlib.font_manager").disabled = True
16
 
17
  from data import generate_data
18
 
gui/processing.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  from pathlib import Path
6
  from typing import Callable
7
 
 
8
  import pandas as pd
9
  from data import generate_data, read_csv
10
  from plots import plot_predictions
@@ -89,8 +90,11 @@ class ProcessWrapper:
89
  self.process.start()
90
 
91
 
92
- PERSISTENT_WRITER = None
93
- PERSISTENT_READER = None
 
 
 
94
 
95
 
96
  def processing(
@@ -118,17 +122,17 @@ def processing(
118
  batch_size,
119
  **kwargs,
120
  ):
121
- """Load data, then spawn a process to run the greet function."""
122
- global PERSISTENT_WRITER
123
- global PERSISTENT_READER
 
124
 
125
- if PERSISTENT_WRITER is None:
126
- print("Starting PySR fit process")
127
- PERSISTENT_WRITER = ProcessWrapper(pysr_fit)
128
 
129
- if PERSISTENT_READER is None:
130
- print("Starting PySR predict process")
131
- PERSISTENT_READER = ProcessWrapper(pysr_predict)
132
 
133
  if file_input is not None:
134
  try:
@@ -143,23 +147,23 @@ def processing(
143
  equation_file = base / "hall_of_fame.csv"
144
  # Check if queue is empty, if not, kill the process
145
  # and start a new one
146
- if not PERSISTENT_WRITER.queue.empty():
147
  print("Restarting PySR fit process")
148
- if PERSISTENT_WRITER.process.is_alive():
149
- PERSISTENT_WRITER.process.terminate()
150
- PERSISTENT_WRITER.process.join()
151
 
152
- PERSISTENT_WRITER = ProcessWrapper(pysr_fit)
153
 
154
- if not PERSISTENT_READER.queue.empty():
155
  print("Restarting PySR predict process")
156
- if PERSISTENT_READER.process.is_alive():
157
- PERSISTENT_READER.process.terminate()
158
- PERSISTENT_READER.process.join()
159
 
160
- PERSISTENT_READER = ProcessWrapper(pysr_predict)
161
 
162
- PERSISTENT_WRITER.queue.put(
163
  dict(
164
  X=X,
165
  y=y,
@@ -191,20 +195,20 @@ def processing(
191
 
192
  yield last_yield
193
 
194
- while PERSISTENT_WRITER.out_queue.empty():
195
  if (
196
  equation_file.exists()
197
  and Path(str(equation_file).replace(".csv", ".pkl")).exists()
198
  ):
199
  # First, copy the file to a the copy file
200
- PERSISTENT_READER.queue.put(
201
  dict(
202
  X=X,
203
  equation_file=equation_file,
204
  index=-1,
205
  )
206
  )
207
- out = PERSISTENT_READER.out_queue.get()
208
  predictions = out["ypred"]
209
  equations = out["equations"]
210
  last_yield = (
@@ -214,6 +218,19 @@ def processing(
214
  )
215
  yield last_yield
216
 
 
 
 
 
 
 
217
  time.sleep(0.1)
218
 
219
  yield (*last_yield[:-1], "Done")
 
 
 
 
 
 
 
 
5
  from pathlib import Path
6
  from typing import Callable
7
 
8
+ import numpy as np
9
  import pandas as pd
10
  from data import generate_data, read_csv
11
  from plots import plot_predictions
 
90
  self.process.start()
91
 
92
 
93
+ ACTIVE_PROCESS = None
94
+
95
+
96
+ def _random_string():
97
+ return "".join(list(np.random.choice("abcdefghijklmnopqrstuvwxyz".split(), 16)))
98
 
99
 
100
  def processing(
 
122
  batch_size,
123
  **kwargs,
124
  ):
125
+ # random string:
126
+ global ACTIVE_PROCESS
127
+ cur_process = _random_string()
128
+ ACTIVE_PROCESS = cur_process
129
 
130
+ """Load data, then spawn a process to run the greet function."""
131
+ print("Starting PySR fit process")
132
+ writer = ProcessWrapper(pysr_fit)
133
 
134
+ print("Starting PySR predict process")
135
+ reader = ProcessWrapper(pysr_predict)
 
136
 
137
  if file_input is not None:
138
  try:
 
147
  equation_file = base / "hall_of_fame.csv"
148
  # Check if queue is empty, if not, kill the process
149
  # and start a new one
150
+ if not writer.queue.empty():
151
  print("Restarting PySR fit process")
152
+ if writer.process.is_alive():
153
+ writer.process.terminate()
154
+ writer.process.join()
155
 
156
+ writer = ProcessWrapper(pysr_fit)
157
 
158
+ if not reader.queue.empty():
159
  print("Restarting PySR predict process")
160
+ if reader.process.is_alive():
161
+ reader.process.terminate()
162
+ reader.process.join()
163
 
164
+ reader = ProcessWrapper(pysr_predict)
165
 
166
+ writer.queue.put(
167
  dict(
168
  X=X,
169
  y=y,
 
195
 
196
  yield last_yield
197
 
198
+ while writer.out_queue.empty():
199
  if (
200
  equation_file.exists()
201
  and Path(str(equation_file).replace(".csv", ".pkl")).exists()
202
  ):
203
  # First, copy the file to a the copy file
204
+ reader.queue.put(
205
  dict(
206
  X=X,
207
  equation_file=equation_file,
208
  index=-1,
209
  )
210
  )
211
+ out = reader.out_queue.get()
212
  predictions = out["ypred"]
213
  equations = out["equations"]
214
  last_yield = (
 
218
  )
219
  yield last_yield
220
 
221
+ if cur_process != ACTIVE_PROCESS:
222
+ # Kill both reader and writer
223
+ writer.process.terminate()
224
+ reader.process.terminate()
225
+ return
226
+
227
  time.sleep(0.1)
228
 
229
  yield (*last_yield[:-1], "Done")
230
+ return
231
+
232
+
233
+ def stop():
234
+ global ACTIVE_PROCESS
235
+ ACTIVE_PROCESS = None
236
+ return