helboukkouri commited on
Commit
09447d8
1 Parent(s): 785ae13

initial commit

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +170 -0
  3. requirements.txt +210 -0
README.md CHANGED
@@ -10,4 +10,4 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: apache-2.0
11
  ---
12
 
13
+ This is a basic space showcasing how you can have an interactive matplotlib plot refresh according to various input fields, all using Gradio.
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import sympy as sp
4
+ import seaborn as sns
5
+ from matplotlib import pyplot as plt
6
+
7
+ sns.set_style(style="darkgrid")
8
+ sns.set_context(context="notebook", font_scale=1.2)
9
+
10
+ MAX_NOISE = 20
11
+ DEFAULT_NOISE = 6
12
+ SLIDE_NOISE_STEP = 2
13
+
14
+ MAX_POINTS = 100
15
+ DEFAULT_POINTS = 20
16
+ SLIDE_POINTS_STEP = 5
17
+
18
+ def generate_equation(process_params):
19
+ process_params = process_params.astype(float).values.tolist()
20
+
21
+ # Define symbols
22
+ x = sp.symbols('x')
23
+ coefficients = sp.symbols('a b c d e')
24
+
25
+ # Create the polynomial expression
26
+ polynomial_expression = None
27
+ for i, coef in enumerate(reversed(coefficients)):
28
+ polynomial_expression = polynomial_expression + coef * x**i if polynomial_expression else coef * x**i
29
+
30
+ # Parameter mapping
31
+ parameters = {coef: value for coef, value in zip(coefficients, process_params[0])}
32
+
33
+ # Substitute parameter values into the expression
34
+ polynomial_with_values = polynomial_expression.subs(parameters)
35
+ latex_representation = sp.latex(polynomial_with_values)
36
+ return fr"$${latex_representation}$$"
37
+
38
+
39
+ def true_process(x, process_params):
40
+ """The true process we want to model."""
41
+ process_params = process_params.astype(float).values.tolist()
42
+ return (
43
+ process_params[0][0] * (x ** 4)
44
+ + process_params[0][1] * (x ** 3)
45
+ + process_params[0][2] * (x ** 2)
46
+ + process_params[0][3] * x
47
+ + process_params[0][4]
48
+ )
49
+
50
+
51
+ def generate_data(num_points, noise_level, process_params):
52
+
53
+ # x is the list of input values
54
+ input_values = np.linspace(-5, 2, num_points)
55
+ input_values_dense = np.linspace(-5, 2, MAX_POINTS)
56
+
57
+ # y = f(x) is the underlying process we want to model
58
+ y = [true_process(x, process_params) for x in input_values]
59
+ y_dense = [true_process(x, process_params) for x in input_values_dense]
60
+
61
+ # however, we can only observe a noisy version of f(x)
62
+ noise = np.random.normal(0, noise_level, len(input_values))
63
+ y_noisy = y + noise
64
+
65
+ return input_values, input_values_dense, y, y_dense, y_noisy
66
+
67
+
68
+ def make_plot(
69
+ num_points, noise_level, process_params,
70
+ show_true_process, show_original_points, show_added_noise, show_noisy_points,
71
+ ):
72
+
73
+ x, x_dense, y, y_dense, y_noisy = generate_data(num_points, noise_level, process_params)
74
+
75
+ fig = plt.figure(figsize=(10, 10), dpi=300)
76
+ if show_true_process:
77
+ plt.plot(
78
+ x_dense, y_dense, "-", color="#363A4F",
79
+ label="True Process",
80
+ )
81
+ if show_added_noise:
82
+ plt.vlines(
83
+ x, y, y_noisy, color="#556D9A",
84
+ linestyles="dashed",
85
+ alpha=0.75,
86
+ lw=1.2,
87
+ label="Added Noise",
88
+ )
89
+ if show_original_points:
90
+ plt.plot(
91
+ x, y, "-o", color="none",
92
+ ms=8,
93
+ markerfacecolor="white",
94
+ markeredgecolor="#556D9A",
95
+ markeredgewidth=1.5,
96
+ label="Original Points",
97
+ )
98
+ if show_noisy_points:
99
+ plt.plot(
100
+ x, y_noisy, "-o", color="none",
101
+ ms=8.5,
102
+ markerfacecolor="#556D9A",
103
+ markeredgecolor="none",
104
+ markeredgewidth=1.5,
105
+ alpha=1,
106
+ label="Noisy Points",
107
+ )
108
+
109
+ plt.xlabel("\nX")
110
+ plt.ylabel("\nY")
111
+ plt.legend(fontsize=11)
112
+ plt.tight_layout()
113
+ plt.show()
114
+ return fig
115
+
116
+
117
+ with gr.Blocks() as demo:
118
+ with gr.Row():
119
+ with gr.Column():
120
+ with gr.Row():
121
+ process_params = gr.DataFrame(
122
+ value=[[0.5, 2, -0.5, -2, 1]],
123
+ label="Underlying Process Coefficients",
124
+ type="pandas",
125
+ column_widths=("2", "1", "1", "1", "1w"),
126
+ headers=["x ** 4", "x ** 3", "x ** 2", "x", "1"],
127
+ interactive=True
128
+ )
129
+ equation = gr.Markdown()
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ num_points = gr.Slider(
134
+ minimum=5,
135
+ maximum=MAX_POINTS,
136
+ value=DEFAULT_POINTS,
137
+ step=SLIDE_POINTS_STEP,
138
+ label="Number of Points"
139
+ )
140
+ with gr.Column():
141
+ noise_level = gr.Slider(
142
+ minimum=0,
143
+ maximum=MAX_NOISE,
144
+ value=DEFAULT_NOISE,
145
+ step=SLIDE_NOISE_STEP,
146
+ label="Noise Level"
147
+ )
148
+
149
+ show_params = []
150
+ with gr.Row():
151
+ with gr.Column():
152
+ show_params.append(gr.Checkbox(label="Show Underlying Process", value=True))
153
+ show_params.append(gr.Checkbox(label="Show Original Points", value=True))
154
+ with gr.Column():
155
+ show_params.append(gr.Checkbox(label="Show Added Noise", value=True))
156
+ show_params.append(gr.Checkbox(label="Show Noisy Points", value=True))
157
+
158
+ scatter_plot = gr.Plot(scale=1)
159
+
160
+ num_points.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
161
+ noise_level.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
162
+ process_params.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
163
+ process_params.change(fn=generate_equation, inputs=[process_params], outputs=equation)
164
+ for component in show_params:
165
+ component.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
166
+ demo.load(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
167
+ demo.load(fn=generate_equation, inputs=[process_params], outputs=equation)
168
+
169
+ if __name__ == "__main__":
170
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.10
3
+ # by the following command:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ aiofiles==23.2.1
8
+ # via gradio
9
+ altair==5.2.0
10
+ # via gradio
11
+ annotated-types==0.6.0
12
+ # via pydantic
13
+ anyio==4.3.0
14
+ # via
15
+ # httpx
16
+ # starlette
17
+ attrs==23.2.0
18
+ # via
19
+ # jsonschema
20
+ # referencing
21
+ certifi==2024.2.2
22
+ # via
23
+ # httpcore
24
+ # httpx
25
+ # requests
26
+ charset-normalizer==3.3.2
27
+ # via requests
28
+ click==8.1.7
29
+ # via
30
+ # typer
31
+ # uvicorn
32
+ colorama==0.4.6
33
+ # via typer
34
+ contourpy==1.2.0
35
+ # via matplotlib
36
+ cycler==0.12.1
37
+ # via matplotlib
38
+ exceptiongroup==1.2.0
39
+ # via anyio
40
+ fastapi==0.110.0
41
+ # via gradio
42
+ ffmpy==0.3.2
43
+ # via gradio
44
+ filelock==3.13.1
45
+ # via huggingface-hub
46
+ fonttools==4.49.0
47
+ # via matplotlib
48
+ fsspec==2024.2.0
49
+ # via
50
+ # gradio-client
51
+ # huggingface-hub
52
+ gradio==4.19.2
53
+ # via -r requirements.in
54
+ gradio-client==0.10.1
55
+ # via gradio
56
+ h11==0.14.0
57
+ # via
58
+ # httpcore
59
+ # uvicorn
60
+ httpcore==1.0.4
61
+ # via httpx
62
+ httpx==0.27.0
63
+ # via
64
+ # gradio
65
+ # gradio-client
66
+ huggingface-hub==0.21.3
67
+ # via
68
+ # gradio
69
+ # gradio-client
70
+ idna==3.6
71
+ # via
72
+ # anyio
73
+ # httpx
74
+ # requests
75
+ importlib-resources==6.1.2
76
+ # via gradio
77
+ jinja2==3.1.3
78
+ # via
79
+ # altair
80
+ # gradio
81
+ jsonschema==4.21.1
82
+ # via altair
83
+ jsonschema-specifications==2023.12.1
84
+ # via jsonschema
85
+ kiwisolver==1.4.5
86
+ # via matplotlib
87
+ markdown-it-py==3.0.0
88
+ # via rich
89
+ markupsafe==2.1.5
90
+ # via
91
+ # gradio
92
+ # jinja2
93
+ matplotlib==3.8.3
94
+ # via
95
+ # gradio
96
+ # seaborn
97
+ mdurl==0.1.2
98
+ # via markdown-it-py
99
+ mpmath==1.3.0
100
+ # via sympy
101
+ numpy==1.26.4
102
+ # via
103
+ # -r requirements.in
104
+ # altair
105
+ # contourpy
106
+ # gradio
107
+ # matplotlib
108
+ # pandas
109
+ # seaborn
110
+ orjson==3.9.15
111
+ # via gradio
112
+ packaging==23.2
113
+ # via
114
+ # altair
115
+ # gradio
116
+ # gradio-client
117
+ # huggingface-hub
118
+ # matplotlib
119
+ pandas==2.2.1
120
+ # via
121
+ # -r requirements.in
122
+ # altair
123
+ # gradio
124
+ # seaborn
125
+ pillow==10.2.0
126
+ # via
127
+ # gradio
128
+ # matplotlib
129
+ pydantic==2.6.3
130
+ # via
131
+ # fastapi
132
+ # gradio
133
+ pydantic-core==2.16.3
134
+ # via pydantic
135
+ pydub==0.25.1
136
+ # via gradio
137
+ pygments==2.17.2
138
+ # via rich
139
+ pyparsing==3.1.1
140
+ # via matplotlib
141
+ python-dateutil==2.9.0.post0
142
+ # via
143
+ # matplotlib
144
+ # pandas
145
+ python-multipart==0.0.9
146
+ # via gradio
147
+ pytz==2024.1
148
+ # via pandas
149
+ pyyaml==6.0.1
150
+ # via
151
+ # gradio
152
+ # huggingface-hub
153
+ referencing==0.33.0
154
+ # via
155
+ # jsonschema
156
+ # jsonschema-specifications
157
+ requests==2.31.0
158
+ # via huggingface-hub
159
+ rich==13.7.1
160
+ # via typer
161
+ rpds-py==0.18.0
162
+ # via
163
+ # jsonschema
164
+ # referencing
165
+ ruff==0.3.0
166
+ # via gradio
167
+ seaborn==0.13.2
168
+ # via -r requirements.in
169
+ semantic-version==2.10.0
170
+ # via gradio
171
+ shellingham==1.5.4
172
+ # via typer
173
+ six==1.16.0
174
+ # via python-dateutil
175
+ sniffio==1.3.1
176
+ # via
177
+ # anyio
178
+ # httpx
179
+ starlette==0.36.3
180
+ # via fastapi
181
+ sympy==1.12
182
+ # via -r requirements.in
183
+ tomlkit==0.12.0
184
+ # via gradio
185
+ toolz==0.12.1
186
+ # via altair
187
+ tqdm==4.66.2
188
+ # via huggingface-hub
189
+ typer[all]==0.9.0
190
+ # via gradio
191
+ typing-extensions==4.10.0
192
+ # via
193
+ # altair
194
+ # anyio
195
+ # fastapi
196
+ # gradio
197
+ # gradio-client
198
+ # huggingface-hub
199
+ # pydantic
200
+ # pydantic-core
201
+ # typer
202
+ # uvicorn
203
+ tzdata==2024.1
204
+ # via pandas
205
+ urllib3==2.2.1
206
+ # via requests
207
+ uvicorn==0.27.1
208
+ # via gradio
209
+ websockets==11.0.3
210
+ # via gradio-client