Spaces:
Sleeping
Sleeping
Lennard Schober
commited on
Commit
·
16a2246
1
Parent(s):
c188d2e
Fix latex formatting
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ glob_c = 7.5
|
|
13 |
|
14 |
n_x, n_t = 10, 10
|
15 |
|
|
|
16 |
def clear_npz():
|
17 |
current_directory = os.getcwd() # Get the current working directory
|
18 |
for filename in os.listdir(current_directory):
|
@@ -93,7 +94,6 @@ def plot_heat_equation(m, approx_type):
|
|
93 |
|
94 |
# Layout for the Plotly plot without controls
|
95 |
layout = go.Layout(
|
96 |
-
title=f"Heat Equation Approximation | Kernel = {approx_type} | m = {m}",
|
97 |
scene=dict(
|
98 |
camera=dict(
|
99 |
eye=dict(x=0, y=-2, z=0), # Front view
|
@@ -105,28 +105,6 @@ def plot_heat_equation(m, approx_type):
|
|
105 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
106 |
)
|
107 |
|
108 |
-
# Config to remove modebar buttons except the save image button
|
109 |
-
config = {
|
110 |
-
"modeBarButtonsToRemove": [
|
111 |
-
"pan",
|
112 |
-
"resetCameraLastSave",
|
113 |
-
"hoverClosest3d",
|
114 |
-
"hoverCompareCartesian",
|
115 |
-
"zoomIn",
|
116 |
-
"zoomOut",
|
117 |
-
"select2d",
|
118 |
-
"lasso2d",
|
119 |
-
"zoomIn2d",
|
120 |
-
"zoomOut2d",
|
121 |
-
"sendDataToCloud",
|
122 |
-
"zoom3d",
|
123 |
-
"orbitRotation",
|
124 |
-
"tableRotation",
|
125 |
-
],
|
126 |
-
"displayModeBar": True, # Keep the modebar visible
|
127 |
-
"displaylogo": False, # Hide the Plotly logo
|
128 |
-
}
|
129 |
-
|
130 |
# Create the figure
|
131 |
fig = go.Figure(data=traces, layout=layout)
|
132 |
|
@@ -148,7 +126,8 @@ def plot_heat_equation(m, approx_type):
|
|
148 |
"tableRotation",
|
149 |
"toImage",
|
150 |
"resetCameraDefault3d",
|
151 |
-
]
|
|
|
152 |
)
|
153 |
|
154 |
return fig
|
@@ -198,7 +177,6 @@ def plot_errors(m, approx_type):
|
|
198 |
|
199 |
# Layout for the Plotly plot without controls
|
200 |
layout = go.Layout(
|
201 |
-
title=f"Heat Equation Approximation Error | Kernel = {approx_type} | m = {m}",
|
202 |
scene=dict(
|
203 |
camera=dict(
|
204 |
eye=dict(x=0, y=-2, z=0), # Front view
|
@@ -210,28 +188,6 @@ def plot_errors(m, approx_type):
|
|
210 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
211 |
)
|
212 |
|
213 |
-
# Config to remove modebar buttons except the save image button
|
214 |
-
config = {
|
215 |
-
"modeBarButtonsToRemove": [
|
216 |
-
"pan",
|
217 |
-
"resetCameraLastSave",
|
218 |
-
"hoverClosest3d",
|
219 |
-
"hoverCompareCartesian",
|
220 |
-
"zoomIn",
|
221 |
-
"zoomOut",
|
222 |
-
"select2d",
|
223 |
-
"lasso2d",
|
224 |
-
"zoomIn2d",
|
225 |
-
"zoomOut2d",
|
226 |
-
"sendDataToCloud",
|
227 |
-
"zoom3d",
|
228 |
-
"orbitRotation",
|
229 |
-
"tableRotation",
|
230 |
-
],
|
231 |
-
"displayModeBar": True, # Keep the modebar visible
|
232 |
-
"displaylogo": False, # Hide the Plotly logo
|
233 |
-
}
|
234 |
-
|
235 |
# Create the figure
|
236 |
fig = go.Figure(data=traces, layout=layout)
|
237 |
|
@@ -253,7 +209,8 @@ def plot_errors(m, approx_type):
|
|
253 |
"tableRotation",
|
254 |
"toImage",
|
255 |
"resetCameraDefault3d",
|
256 |
-
]
|
|
|
257 |
)
|
258 |
|
259 |
return fig
|
@@ -340,7 +297,9 @@ def train_coefficients(m, kernel):
|
|
340 |
Phi = design_matrix(a_train, theta, kernel)
|
341 |
alpha = learn_coefficients(Phi, u_train)
|
342 |
# Validate and animate results
|
343 |
-
u_real = np.array(
|
|
|
|
|
344 |
a_test = np.c_[np.meshgrid(x, t)[0].ravel(), np.meshgrid(x, t)[1].ravel()]
|
345 |
u_approx = approximate_solution(a_test, alpha, theta, kernel).reshape(n_t, n_x)
|
346 |
|
@@ -415,7 +374,8 @@ def plot_function(k, a, b, c):
|
|
415 |
"tableRotation",
|
416 |
"toImage",
|
417 |
"resetCameraDefault3d",
|
418 |
-
]
|
|
|
419 |
)
|
420 |
|
421 |
return fig
|
@@ -434,36 +394,46 @@ def plot_all(m, kernel):
|
|
434 |
gr.update(visible=True, value=error_fig),
|
435 |
)
|
436 |
|
|
|
437 |
def change_quality(quality):
|
438 |
global n_x, n_t
|
439 |
-
|
440 |
if quality == "Low":
|
441 |
n_x, n_t = 10, 10
|
442 |
elif quality == "Mid":
|
443 |
n_x, n_t = 20, 20
|
444 |
elif quality == "High":
|
445 |
n_x, n_t = 40, 40
|
446 |
-
|
447 |
|
448 |
# Gradio interface
|
449 |
def create_gradio_ui():
|
450 |
global glob_k, glob_a, glob_b, glob_c
|
451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
# Get the initial available files
|
453 |
with gr.Blocks() as demo:
|
454 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
455 |
|
456 |
# Function parameter inputs
|
457 |
-
gr.Markdown(
|
458 |
-
|
459 |
-
## Function: $$u(x, t)\\coloneqq\\exp(-\\textcolor{magenta}{k}(\\textcolor{cyan}{a}\\pi)^2t)\\sin(\\textcolor{cyan}{a}\\pi x)+0.5\\exp(-\\textcolor{magenta}{k}(\\textcolor{lime}{b}\\pi)^2t)\\sin(\\textcolor{lime}{b}\\pi x)+0.25\\exp(-\\textcolor{magenta}{k}(\\textcolor{orange}{c}\\pi)^2t)\\sin(\\textcolor{orange}{c}\\pi x)$$
|
460 |
-
|
461 |
-
Adjust the values for <span style='color: magenta;'>k</span>, <span style='color: cyan;'>a</span>, <span style='color: lime;'>b</span> and <span style='color: orange;'>c</span> with the sliders below.
|
462 |
-
"""
|
463 |
-
)
|
464 |
|
465 |
with gr.Row():
|
466 |
-
with gr.Column():
|
467 |
k_slider = gr.Slider(
|
468 |
minimum=0.0001, maximum=0.1, step=0.0001, value=glob_k, label="k"
|
469 |
)
|
@@ -476,8 +446,8 @@ def create_gradio_ui():
|
|
476 |
c_slider = gr.Slider(
|
477 |
minimum=-10, maximum=10, step=0.1, value=glob_c, label="c"
|
478 |
)
|
479 |
-
|
480 |
-
|
481 |
|
482 |
k_slider.change(
|
483 |
fn=plot_function,
|
@@ -506,7 +476,9 @@ def create_gradio_ui():
|
|
506 |
quality_dropdown = gr.Dropdown(
|
507 |
label="Choose Quality", choices=["Low", "Mid", "High"], value="Low"
|
508 |
)
|
509 |
-
quality_dropdown.change(
|
|
|
|
|
510 |
kernel_dropdown = gr.Dropdown(
|
511 |
label="Choose Kernel", choices=["SINE", "GFF"], value="SINE"
|
512 |
)
|
@@ -530,9 +502,9 @@ def create_gradio_ui():
|
|
530 |
approx_button = gr.Button("Plot Approximation")
|
531 |
|
532 |
with gr.Row():
|
533 |
-
with gr.Column(
|
534 |
approx_plot = gr.Plot(visible=False)
|
535 |
-
with gr.Column(
|
536 |
error_plot = gr.Plot(visible=False)
|
537 |
|
538 |
approx_button.click(
|
|
|
13 |
|
14 |
n_x, n_t = 10, 10
|
15 |
|
16 |
+
|
17 |
def clear_npz():
|
18 |
current_directory = os.getcwd() # Get the current working directory
|
19 |
for filename in os.listdir(current_directory):
|
|
|
94 |
|
95 |
# Layout for the Plotly plot without controls
|
96 |
layout = go.Layout(
|
|
|
97 |
scene=dict(
|
98 |
camera=dict(
|
99 |
eye=dict(x=0, y=-2, z=0), # Front view
|
|
|
105 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
106 |
)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
# Create the figure
|
109 |
fig = go.Figure(data=traces, layout=layout)
|
110 |
|
|
|
126 |
"tableRotation",
|
127 |
"toImage",
|
128 |
"resetCameraDefault3d",
|
129 |
+
],
|
130 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
131 |
)
|
132 |
|
133 |
return fig
|
|
|
177 |
|
178 |
# Layout for the Plotly plot without controls
|
179 |
layout = go.Layout(
|
|
|
180 |
scene=dict(
|
181 |
camera=dict(
|
182 |
eye=dict(x=0, y=-2, z=0), # Front view
|
|
|
188 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
189 |
)
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
# Create the figure
|
192 |
fig = go.Figure(data=traces, layout=layout)
|
193 |
|
|
|
209 |
"tableRotation",
|
210 |
"toImage",
|
211 |
"resetCameraDefault3d",
|
212 |
+
],
|
213 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
214 |
)
|
215 |
|
216 |
return fig
|
|
|
297 |
Phi = design_matrix(a_train, theta, kernel)
|
298 |
alpha = learn_coefficients(Phi, u_train)
|
299 |
# Validate and animate results
|
300 |
+
u_real = np.array(
|
301 |
+
[complex_heat_eq_solution(x, t_i, glob_k, glob_a, glob_b, glob_c) for t_i in t]
|
302 |
+
)
|
303 |
a_test = np.c_[np.meshgrid(x, t)[0].ravel(), np.meshgrid(x, t)[1].ravel()]
|
304 |
u_approx = approximate_solution(a_test, alpha, theta, kernel).reshape(n_t, n_x)
|
305 |
|
|
|
374 |
"tableRotation",
|
375 |
"toImage",
|
376 |
"resetCameraDefault3d",
|
377 |
+
],
|
378 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
379 |
)
|
380 |
|
381 |
return fig
|
|
|
394 |
gr.update(visible=True, value=error_fig),
|
395 |
)
|
396 |
|
397 |
+
|
398 |
def change_quality(quality):
|
399 |
global n_x, n_t
|
400 |
+
|
401 |
if quality == "Low":
|
402 |
n_x, n_t = 10, 10
|
403 |
elif quality == "Mid":
|
404 |
n_x, n_t = 20, 20
|
405 |
elif quality == "High":
|
406 |
n_x, n_t = 40, 40
|
407 |
+
|
408 |
|
409 |
# Gradio interface
|
410 |
def create_gradio_ui():
|
411 |
global glob_k, glob_a, glob_b, glob_c
|
412 |
|
413 |
+
markdown_content = r"""
|
414 |
+
## Function:
|
415 |
+
$$
|
416 |
+
\begin{alignat*}{5}
|
417 |
+
u(x, t)
|
418 |
+
\coloneqq &\exp(-\textcolor{magenta}{k}(&\textcolor{cyan}{a}&\pi)^2t)\sin(&\textcolor{cyan}{a}&\pi x) \\
|
419 |
+
+ &\exp(-\textcolor{magenta}{k}(&\textcolor{lime}{b}&\pi)^2t)\sin(&\textcolor{lime}{b}&\pi x) \\
|
420 |
+
+ &\exp(-\textcolor{magenta}{k}(&\textcolor{orange}{c}&\pi)^2t)\sin(&\textcolor{orange}{c}&\pi x)
|
421 |
+
\end{alignat*}
|
422 |
+
$$
|
423 |
+
|
424 |
+
Adjust the values for <span style='color: magenta;'>k</span>, <span style='color: cyan;'>a</span>, <span style='color: lime;'>b</span> and <span style='color: orange;'>c</span> with the sliders below.
|
425 |
+
"""
|
426 |
+
|
427 |
# Get the initial available files
|
428 |
with gr.Blocks() as demo:
|
429 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
430 |
|
431 |
# Function parameter inputs
|
432 |
+
gr.Markdown(markdown_content)
|
433 |
+
|
|
|
|
|
|
|
|
|
|
|
434 |
|
435 |
with gr.Row():
|
436 |
+
with gr.Column(min_width=500):
|
437 |
k_slider = gr.Slider(
|
438 |
minimum=0.0001, maximum=0.1, step=0.0001, value=glob_k, label="k"
|
439 |
)
|
|
|
446 |
c_slider = gr.Slider(
|
447 |
minimum=-10, maximum=10, step=0.1, value=glob_c, label="c"
|
448 |
)
|
449 |
+
with gr.Column(min_width=500):
|
450 |
+
plot_output = gr.Plot()
|
451 |
|
452 |
k_slider.change(
|
453 |
fn=plot_function,
|
|
|
476 |
quality_dropdown = gr.Dropdown(
|
477 |
label="Choose Quality", choices=["Low", "Mid", "High"], value="Low"
|
478 |
)
|
479 |
+
quality_dropdown.change(
|
480 |
+
fn=change_quality, inputs=quality_dropdown, outputs=None
|
481 |
+
)
|
482 |
kernel_dropdown = gr.Dropdown(
|
483 |
label="Choose Kernel", choices=["SINE", "GFF"], value="SINE"
|
484 |
)
|
|
|
502 |
approx_button = gr.Button("Plot Approximation")
|
503 |
|
504 |
with gr.Row():
|
505 |
+
with gr.Column(min_width=500):
|
506 |
approx_plot = gr.Plot(visible=False)
|
507 |
+
with gr.Column(min_width=500):
|
508 |
error_plot = gr.Plot(visible=False)
|
509 |
|
510 |
approx_button.click(
|