Zeel commited on
Commit
a5b88b1
1 Parent(s): bdc81a4

upload first version

Browse files
Files changed (2) hide show
  1. app.py +128 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+ def parabola_fn(x):
8
+ return x**0.5
9
+
10
+
11
+ def circle_fn(x):
12
+ return (1 - x**2) ** 0.5
13
+
14
+
15
+ d_parabola_fn = jax.grad(parabola_fn)
16
+ d_circle_fn = jax.grad(circle_fn)
17
+
18
+
19
+ def loss_fn(params):
20
+ x1 = params["x1"]
21
+ x2 = params["x2"]
22
+ # parpendicular line to the tangent of the parabola: y = m1 * x + c1
23
+ m1 = -1 / d_parabola_fn(x1)
24
+ c1 = parabola_fn(x1) - m1 * x1
25
+
26
+ def perpendicular_parabola_fn(x):
27
+ return m1 * x + c1
28
+
29
+ # parpendicular line to the tangent of the circle: y = m2 * x + c2
30
+ m2 = -1 / d_circle_fn(x2)
31
+ c2 = circle_fn(x2) - m2 * x2
32
+
33
+ def perpendicular_circle_fn(x):
34
+ return m2 * x + c2
35
+
36
+ # x_star and y_star are the intersection of the two lines
37
+ x_star = (c2 - c1) / (m1 - m2)
38
+ y_star = m1 * x_star + c1
39
+
40
+ # three quantities should be equal to each other
41
+ # 1. distance between intersection and parabola
42
+ # 2. distance between intersection and circle
43
+ # 3. distance between intersection and x=0 line
44
+
45
+ d1 = (x_star - x1) ** 2 + (y_star - parabola_fn(x1)) ** 2
46
+ d2 = (x_star - x2) ** 2 + (y_star - circle_fn(x2)) ** 2
47
+ d3 = x_star**2
48
+
49
+ aux = {
50
+ "x_star": x_star,
51
+ "y_star": y_star,
52
+ "perpendicular_parabola_fn": perpendicular_parabola_fn,
53
+ "perpendicular_circle_fn": perpendicular_circle_fn,
54
+ "r": d1**0.5,
55
+ }
56
+ # final loss
57
+ loss = (d1 - d2) ** 2 + (d1 - d3) ** 2 + (d2 - d3) ** 2
58
+
59
+ return loss, aux
60
+
61
+
62
+ x = jnp.linspace(0, 1, 100)
63
+
64
+ st.title("Radius of the Circle: Optimization Playground")
65
+
66
+ col1, col2 = st.columns(2)
67
+
68
+ x1 = col1.slider("initial x1 (x intersection with parabola)", 0.0, 1.0, 0.5)
69
+ x2 = col1.slider("initial x2 (x intersection with the circle)", 0.0, 1.0, 0.5)
70
+ n_epochs = col2.slider("n_epochs", 0, 1000, 50)
71
+ lr = col2.slider("lr", 0.0, 1.0, value=0.1, step=0.01)
72
+
73
+ # submit button
74
+ submit = st.button("submit")
75
+
76
+ # when submit button is clicked run the following code
77
+ params = {"x1": x1, "x2": x2}
78
+ losses = []
79
+ value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
80
+
81
+ # initialize plot
82
+ fig, axes = plt.subplots(1, 2, figsize=(15, 5))
83
+ axes[0].set_xlim(0, 1)
84
+ axes[0].set_ylim(0, 1)
85
+
86
+ value, aux = loss_fn(params)
87
+ (pbola_plot,) = axes[0].plot(x, parabola_fn(x), color="red")
88
+ (pbola_perpendicular_plot,) = axes[0].plot(x, aux["perpendicular_parabola_fn"](x), color="red", linestyle="--")
89
+ (cicle_plot,) = axes[0].plot(x, circle_fn(x), color="blue")
90
+ (circle_perpendicular_plot,) = axes[0].plot(x, aux["perpendicular_circle_fn"](x), color="blue", linestyle="--")
91
+ x_star, y_star = aux["x_star"], aux["y_star"]
92
+ radius = aux["r"]
93
+ axes[0].add_patch(plt.Circle((x_star, y_star), radius, fill=False))
94
+
95
+ axes[1].set_xlim(0, n_epochs)
96
+ axes[1].set_ylim(0, value)
97
+ (loss_plot,) = axes[1].plot(losses, color="black")
98
+
99
+ pbar = st.progress(0)
100
+
101
+ with st.empty():
102
+ st.pyplot(fig)
103
+ if submit:
104
+ for i in range(n_epochs):
105
+ (value, _), grad = value_and_grad_fn(params)
106
+
107
+ params["x1"] -= lr * grad["x1"]
108
+ params["x2"] -= lr * grad["x2"]
109
+ losses.append(value)
110
+
111
+ _, aux = loss_fn(params)
112
+ print(params, grad, lr)
113
+ pbola_plot.set_data(x, parabola_fn(x))
114
+ pbola_perpendicular_plot.set_data(x, aux["perpendicular_parabola_fn"](x))
115
+ cicle_plot.set_data(x, circle_fn(x))
116
+ circle_perpendicular_plot.set_data(x, aux["perpendicular_circle_fn"](x))
117
+ x_star, y_star = aux["x_star"], aux["y_star"]
118
+ radius = aux["r"]
119
+ axes[0].add_patch(plt.Circle((x_star, y_star), radius, fill=False))
120
+
121
+ loss_plot.set_data(range(len(losses)), losses)
122
+
123
+ pbar.progress(i / n_epochs)
124
+
125
+ axes[0].set_title(f"x1: {params['x1']:.3f}, x2: {params['x2']:.3f} \n r: {radius:.4f}")
126
+ axes[1].set_title(f"epoch: {i}, loss: {value:.5f}")
127
+
128
+ st.pyplot(fig)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ matplotlib
3
+ numpy
4
+ jaxlib
5
+ jax