MilesCranmer commited on
Commit
9fa2182
1 Parent(s): 7f0b93d

Refactor GUI to multiple files

Browse files
Files changed (4) hide show
  1. gui/app.py +4 -250
  2. gui/data.py +22 -0
  3. gui/plots.py +84 -0
  4. gui/processing.py +150 -0
gui/app.py CHANGED
@@ -1,184 +1,8 @@
1
- import multiprocessing as mp
2
- import os
3
- import tempfile
4
- import time
5
- from pathlib import Path
6
-
7
  import gradio as gr
8
- import numpy as np
9
- import pandas as pd
10
- from matplotlib import pyplot as plt
11
-
12
- plt.ioff()
13
- plt.rcParams["font.family"] = [
14
- "IBM Plex Mono",
15
- # Fallback fonts:
16
- "DejaVu Sans Mono",
17
- "Courier New",
18
- "monospace",
19
- ]
20
-
21
- empty_df = lambda: pd.DataFrame(
22
- {
23
- "equation": [],
24
- "loss": [],
25
- "complexity": [],
26
- }
27
- )
28
-
29
- test_equations = ["sin(2*x)/x + 0.1*x"]
30
-
31
-
32
- def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
33
- rstate = np.random.RandomState(data_seed)
34
- x = rstate.uniform(-10, 10, num_points)
35
- for k, v in {
36
- "sin": "np.sin",
37
- "cos": "np.cos",
38
- "exp": "np.exp",
39
- "log": "np.log",
40
- "tan": "np.tan",
41
- "^": "**",
42
- }.items():
43
- s = s.replace(k, v)
44
- y = eval(s)
45
- noise = rstate.normal(0, noise_level, y.shape)
46
- y_noisy = y + noise
47
- return pd.DataFrame({"x": x}), y_noisy
48
-
49
-
50
- def _greet_dispatch(
51
- file_input,
52
- force_run,
53
- test_equation,
54
- num_points,
55
- noise_level,
56
- data_seed,
57
- niterations,
58
- maxsize,
59
- binary_operators,
60
- unary_operators,
61
- plot_update_delay,
62
- parsimony,
63
- populations,
64
- population_size,
65
- ncycles_per_iteration,
66
- elementwise_loss,
67
- adaptive_parsimony_scaling,
68
- optimizer_algorithm,
69
- optimizer_iterations,
70
- batching,
71
- batch_size,
72
- ):
73
- """Load data, then spawn a process to run the greet function."""
74
- if file_input is not None:
75
- # Look at some statistics of the file:
76
- df = pd.read_csv(file_input)
77
- if len(df) == 0:
78
- return (
79
- empty_df(),
80
- "The file is empty!",
81
- )
82
- if len(df.columns) == 1:
83
- return (
84
- empty_df(),
85
- "The file has only one column!",
86
- )
87
- if len(df) > 10_000 and not force_run:
88
- return (
89
- empty_df(),
90
- "You have uploaded a file with more than 10,000 rows. "
91
- "This will take very long to run. "
92
- "Please upload a subsample of the data, "
93
- "or check the box 'Ignore Warnings'.",
94
- )
95
-
96
- col_to_fit = df.columns[-1]
97
- y = np.array(df[col_to_fit])
98
- X = df.drop([col_to_fit], axis=1)
99
- else:
100
- X, y = generate_data(test_equation, num_points, noise_level, data_seed)
101
-
102
- with tempfile.TemporaryDirectory() as tmpdirname:
103
- base = Path(tmpdirname)
104
- equation_file = base / "hall_of_fame.csv"
105
- equation_file_bkup = base / "hall_of_fame.csv.bkup"
106
- process = mp.Process(
107
- target=greet,
108
- kwargs=dict(
109
- X=X,
110
- y=y,
111
- niterations=niterations,
112
- maxsize=maxsize,
113
- binary_operators=binary_operators,
114
- unary_operators=unary_operators,
115
- equation_file=equation_file,
116
- parsimony=parsimony,
117
- populations=populations,
118
- population_size=population_size,
119
- ncycles_per_iteration=ncycles_per_iteration,
120
- elementwise_loss=elementwise_loss,
121
- adaptive_parsimony_scaling=adaptive_parsimony_scaling,
122
- optimizer_algorithm=optimizer_algorithm,
123
- optimizer_iterations=optimizer_iterations,
124
- batching=batching,
125
- batch_size=batch_size,
126
- ),
127
- )
128
- process.start()
129
- last_yield_time = None
130
- while process.is_alive():
131
- if equation_file_bkup.exists():
132
- try:
133
- # First, copy the file to a the copy file
134
- equation_file_copy = base / "hall_of_fame_copy.csv"
135
- os.system(f"cp {equation_file_bkup} {equation_file_copy}")
136
- equations = pd.read_csv(equation_file_copy)
137
- # Ensure it is pareto dominated, with more complex expressions
138
- # having higher loss. Otherwise remove those rows.
139
- # TODO: Not sure why this occurs; could be the result of a late copy?
140
- equations.sort_values("Complexity", ascending=True, inplace=True)
141
- equations.reset_index(inplace=True)
142
- bad_idx = []
143
- min_loss = None
144
- for i in equations.index:
145
- if min_loss is None or equations.loc[i, "Loss"] < min_loss:
146
- min_loss = float(equations.loc[i, "Loss"])
147
- else:
148
- bad_idx.append(i)
149
- equations.drop(index=bad_idx, inplace=True)
150
-
151
- while (
152
- last_yield_time is not None
153
- and time.time() - last_yield_time < plot_update_delay
154
- ):
155
- time.sleep(0.1)
156
-
157
- yield equations[["Complexity", "Loss", "Equation"]]
158
-
159
- last_yield_time = time.time()
160
- except pd.errors.EmptyDataError:
161
- pass
162
-
163
- process.join()
164
 
165
-
166
- def greet(
167
- *,
168
- X,
169
- y,
170
- **pysr_kwargs,
171
- ):
172
- import pysr
173
-
174
- model = pysr.PySRRegressor(
175
- progress=False,
176
- timeout_in_seconds=1000,
177
- **pysr_kwargs,
178
- )
179
- model.fit(X, y)
180
-
181
- return 0
182
 
183
 
184
  def _data_layout():
@@ -372,7 +196,7 @@ def main():
372
  blocks["run"] = gr.Button()
373
 
374
  blocks["run"].click(
375
- _greet_dispatch,
376
  inputs=[
377
  blocks[k]
378
  for k in [
@@ -423,75 +247,5 @@ def main():
423
  demo.launch(debug=True)
424
 
425
 
426
- def replot_pareto(df, maxsize):
427
- fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
428
-
429
- if len(df) == 0 or "Equation" not in df.columns:
430
- return fig
431
-
432
- # Plotting the data
433
- ax.loglog(
434
- df["Complexity"],
435
- df["Loss"],
436
- marker="o",
437
- linestyle="-",
438
- color="#333f48",
439
- linewidth=1.5,
440
- markersize=6,
441
- )
442
-
443
- # Set the axis limits
444
- ax.set_xlim(0.5, maxsize + 1)
445
- ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
446
- ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
447
- ax.set_ylim(ybottom, ytop)
448
-
449
- ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
450
- ax.spines["top"].set_visible(False)
451
- ax.spines["right"].set_visible(False)
452
-
453
- # Range-frame the plot
454
- for direction in ["bottom", "left"]:
455
- ax.spines[direction].set_position(("outward", 10))
456
-
457
- # Delete far ticks
458
- ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
459
- ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
460
-
461
- ax.set_xlabel("Complexity")
462
- ax.set_ylabel("Loss")
463
- fig.tight_layout(pad=2)
464
-
465
- return fig
466
-
467
-
468
- def replot(test_equation, num_points, noise_level, data_seed):
469
- X, y = generate_data(test_equation, num_points, noise_level, data_seed)
470
- x = X["x"]
471
-
472
- plt.rcParams["font.family"] = "IBM Plex Mono"
473
- fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
474
-
475
- ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
476
-
477
- ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
478
- ax.spines["top"].set_visible(False)
479
- ax.spines["right"].set_visible(False)
480
-
481
- # Range-frame the plot
482
- for direction in ["bottom", "left"]:
483
- ax.spines[direction].set_position(("outward", 10))
484
-
485
- # Delete far ticks
486
- ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
487
- ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
488
-
489
- ax.set_xlabel("x")
490
- ax.set_ylabel("y")
491
- fig.tight_layout(pad=2)
492
-
493
- return fig
494
-
495
-
496
  if __name__ == "__main__":
497
  main()
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from .data import test_equations
4
+ from .plots import replot, replot_pareto
5
+ from .processing import process
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def _data_layout():
 
196
  blocks["run"] = gr.Button()
197
 
198
  blocks["run"].click(
199
+ process,
200
  inputs=[
201
  blocks[k]
202
  for k in [
 
247
  demo.launch(debug=True)
248
 
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  if __name__ == "__main__":
251
  main()
gui/data.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ test_equations = ["sin(2*x)/x + 0.1*x"]
5
+
6
+
7
+ def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
8
+ rstate = np.random.RandomState(data_seed)
9
+ x = rstate.uniform(-10, 10, num_points)
10
+ for k, v in {
11
+ "sin": "np.sin",
12
+ "cos": "np.cos",
13
+ "exp": "np.exp",
14
+ "log": "np.log",
15
+ "tan": "np.tan",
16
+ "^": "**",
17
+ }.items():
18
+ s = s.replace(k, v)
19
+ y = eval(s)
20
+ noise = rstate.normal(0, noise_level, y.shape)
21
+ y_noisy = y + noise
22
+ return pd.DataFrame({"x": x}), y_noisy
gui/plots.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from matplotlib import pyplot as plt
4
+
5
+ plt.ioff()
6
+ plt.rcParams["font.family"] = [
7
+ "IBM Plex Mono",
8
+ # Fallback fonts:
9
+ "DejaVu Sans Mono",
10
+ "Courier New",
11
+ "monospace",
12
+ ]
13
+
14
+ from .data import generate_data
15
+
16
+
17
+ def replot_pareto(df: pd.DataFrame, maxsize: int):
18
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
19
+
20
+ if len(df) == 0 or "Equation" not in df.columns:
21
+ return fig
22
+
23
+ # Plotting the data
24
+ ax.loglog(
25
+ df["Complexity"],
26
+ df["Loss"],
27
+ marker="o",
28
+ linestyle="-",
29
+ color="#333f48",
30
+ linewidth=1.5,
31
+ markersize=6,
32
+ )
33
+
34
+ # Set the axis limits
35
+ ax.set_xlim(0.5, maxsize + 1)
36
+ ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
37
+ ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
38
+ ax.set_ylim(ybottom, ytop)
39
+
40
+ ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
41
+ ax.spines["top"].set_visible(False)
42
+ ax.spines["right"].set_visible(False)
43
+
44
+ # Range-frame the plot
45
+ for direction in ["bottom", "left"]:
46
+ ax.spines[direction].set_position(("outward", 10))
47
+
48
+ # Delete far ticks
49
+ ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
50
+ ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
51
+
52
+ ax.set_xlabel("Complexity")
53
+ ax.set_ylabel("Loss")
54
+ fig.tight_layout(pad=2)
55
+
56
+ return fig
57
+
58
+
59
+ def replot(test_equation, num_points, noise_level, data_seed):
60
+ X, y = generate_data(test_equation, num_points, noise_level, data_seed)
61
+ x = X["x"]
62
+
63
+ plt.rcParams["font.family"] = "IBM Plex Mono"
64
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
65
+
66
+ ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
67
+
68
+ ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
69
+ ax.spines["top"].set_visible(False)
70
+ ax.spines["right"].set_visible(False)
71
+
72
+ # Range-frame the plot
73
+ for direction in ["bottom", "left"]:
74
+ ax.spines[direction].set_position(("outward", 10))
75
+
76
+ # Delete far ticks
77
+ ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
78
+ ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
79
+
80
+ ax.set_xlabel("x")
81
+ ax.set_ylabel("y")
82
+ fig.tight_layout(pad=2)
83
+
84
+ return fig
gui/processing.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing as mp
2
+ import os
3
+ import tempfile
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from .data import generate_data
11
+
12
+ EMPTY_DF = lambda: pd.DataFrame(
13
+ {
14
+ "Equation": [],
15
+ "Loss": [],
16
+ "Complexity": [],
17
+ }
18
+ )
19
+
20
+
21
+ def process(
22
+ file_input,
23
+ force_run,
24
+ test_equation,
25
+ num_points,
26
+ noise_level,
27
+ data_seed,
28
+ niterations,
29
+ maxsize,
30
+ binary_operators,
31
+ unary_operators,
32
+ plot_update_delay,
33
+ parsimony,
34
+ populations,
35
+ population_size,
36
+ ncycles_per_iteration,
37
+ elementwise_loss,
38
+ adaptive_parsimony_scaling,
39
+ optimizer_algorithm,
40
+ optimizer_iterations,
41
+ batching,
42
+ batch_size,
43
+ ):
44
+ """Load data, then spawn a process to run the greet function."""
45
+ if file_input is not None:
46
+ # Look at some statistics of the file:
47
+ df = pd.read_csv(file_input)
48
+ if len(df) == 0:
49
+ return (
50
+ EMPTY_DF(),
51
+ "The file is empty!",
52
+ )
53
+ if len(df.columns) == 1:
54
+ return (
55
+ EMPTY_DF(),
56
+ "The file has only one column!",
57
+ )
58
+ if len(df) > 10_000 and not force_run:
59
+ return (
60
+ EMPTY_DF(),
61
+ "You have uploaded a file with more than 10,000 rows. "
62
+ "This will take very long to run. "
63
+ "Please upload a subsample of the data, "
64
+ "or check the box 'Ignore Warnings'.",
65
+ )
66
+
67
+ col_to_fit = df.columns[-1]
68
+ y = np.array(df[col_to_fit])
69
+ X = df.drop([col_to_fit], axis=1)
70
+ else:
71
+ X, y = generate_data(test_equation, num_points, noise_level, data_seed)
72
+
73
+ with tempfile.TemporaryDirectory() as tmpdirname:
74
+ base = Path(tmpdirname)
75
+ equation_file = base / "hall_of_fame.csv"
76
+ equation_file_bkup = base / "hall_of_fame.csv.bkup"
77
+ process = mp.Process(
78
+ target=pysr_fit,
79
+ kwargs=dict(
80
+ X=X,
81
+ y=y,
82
+ niterations=niterations,
83
+ maxsize=maxsize,
84
+ binary_operators=binary_operators,
85
+ unary_operators=unary_operators,
86
+ equation_file=equation_file,
87
+ parsimony=parsimony,
88
+ populations=populations,
89
+ population_size=population_size,
90
+ ncycles_per_iteration=ncycles_per_iteration,
91
+ elementwise_loss=elementwise_loss,
92
+ adaptive_parsimony_scaling=adaptive_parsimony_scaling,
93
+ optimizer_algorithm=optimizer_algorithm,
94
+ optimizer_iterations=optimizer_iterations,
95
+ batching=batching,
96
+ batch_size=batch_size,
97
+ ),
98
+ )
99
+ process.start()
100
+ last_yield_time = None
101
+ while process.is_alive():
102
+ if equation_file_bkup.exists():
103
+ try:
104
+ # First, copy the file to a the copy file
105
+ equation_file_copy = base / "hall_of_fame_copy.csv"
106
+ os.system(f"cp {equation_file_bkup} {equation_file_copy}")
107
+ equations = pd.read_csv(equation_file_copy)
108
+ # Ensure it is pareto dominated, with more complex expressions
109
+ # having higher loss. Otherwise remove those rows.
110
+ # TODO: Not sure why this occurs; could be the result of a late copy?
111
+ equations.sort_values("Complexity", ascending=True, inplace=True)
112
+ equations.reset_index(inplace=True)
113
+ bad_idx = []
114
+ min_loss = None
115
+ for i in equations.index:
116
+ if min_loss is None or equations.loc[i, "Loss"] < min_loss:
117
+ min_loss = float(equations.loc[i, "Loss"])
118
+ else:
119
+ bad_idx.append(i)
120
+ equations.drop(index=bad_idx, inplace=True)
121
+
122
+ while (
123
+ last_yield_time is not None
124
+ and time.time() - last_yield_time < plot_update_delay
125
+ ):
126
+ time.sleep(0.1)
127
+
128
+ yield equations[["Complexity", "Loss", "Equation"]]
129
+
130
+ last_yield_time = time.time()
131
+ except pd.errors.EmptyDataError:
132
+ pass
133
+
134
+ process.join()
135
+
136
+
137
+ def pysr_fit(
138
+ *,
139
+ X,
140
+ y,
141
+ **pysr_kwargs,
142
+ ):
143
+ import pysr
144
+
145
+ model = pysr.PySRRegressor(
146
+ progress=False,
147
+ timeout_in_seconds=1000,
148
+ **pysr_kwargs,
149
+ )
150
+ model.fit(X, y)