File size: 4,533 Bytes
1fcb538
d119a57
cc456a2
 
9131bec
1fcb538
cc456a2
 
 
 
931207e
 
ba93ad8
61a9e50
 
 
617fdc7
1f08540
 
4c6f386
 
cc7b71f
96f9a87
cc456a2
80f8995
 
 
d119a57
 
9e06d30
617fdc7
d119a57
 
 
 
cc456a2
96f9a87
9131bec
bea8b09
 
96f9a87
9131bec
16e4f76
e741bd4
 
 
 
 
 
 
 
cd42de6
 
 
 
 
 
 
 
 
 
e741bd4
16e4f76
b5beaf9
 
 
 
16e4f76
cc456a2
 
 
 
16e4f76
 
 
 
b5beaf9
 
4555e57
 
 
b5beaf9
4555e57
3cff897
fdfc503
3cff897
 
 
 
4555e57
 
 
 
 
fdfc503
4555e57
 
8869fec
4555e57
cd42de6
 
 
 
 
8241a76
4555e57
b5beaf9
 
7bc22ec
16e4f76
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt

# Set random key
seed=321
key = jax.random.PRNGKey(seed)

st.title('Fitting simple models with JAX')
st.header('A quadratric regression example')

st.markdown('*\"Parametrised models are simply functions that depend on inputs and trainable parameters. There is no fundamental difference between the two, except that trainable parameters are shared across training samples whereas the input varies from sample to sample.\"* [(Yann LeCun, Deep learning course)](https://atcold.github.io/pytorch-Deep-Learning/en/week02/02-1/#Parametrised-models)')

st.latex(r'''h(\boldsymbol x, \boldsymbol w)= \sum_{k=1}^{K}\boldsymbol w_{k} \phi_{k}(\boldsymbol x)''')


# Sidebar inputs
number_of_observations = st.sidebar.slider('Number of observations', min_value=50, max_value=150, value=100)
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))

# Generate random data
np.random.seed(2)

w = jnp.array([3.0, -20.0, 32.0])  # coefficients                                    
X = np.column_stack((np.ones(number_of_observations), 
                     np.random.random(number_of_observations)))
X = jnp.column_stack((X, X[:,1] ** 2))   # add x**2 column

additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
y = jnp.array(np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
        + additional_noise)
        
# Plot the data
fig, ax = plt.subplots(dpi=320)
ax.set_xlim((0,1))
ax.set_ylim((-5,26))
ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick')

st.pyplot(fig)

st.subheader('Train a model')

st.markdown('*\"A Gradient Based Method is a method/algorithm that finds the minima of a function, assuming that one can easily compute the gradient of that function. It assumes that the function is continuous and differentiable almost everywhere (it need not be differentiable everywhere).\"* [(Yann LeCun, Deep learning course)](https://atcold.github.io/pytorch-Deep-Learning/en/week02/02-1/#Parametrised-models)')

st.markdown('Using gradient descent we find the minima of the loss adjusting the weights in each step given the following formula:')

st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{w})}{\partial \bf{w}}''')

st.markdown('The training loop:')

code = '''NUM_ITER = 1000
# initialize parameters
w = np.array([3., -2., -8.])
for i in range(NUM_ITER):
     # update parameters
     w -= learning_rte * grad_loss(w)'''
     
st.code(code, language='python')

# Fitting by the respective cost_function
w = jnp.array(np.random.random(3))
learning_rate = 0.05
NUM_ITER = 1000

if cost_function == 'RMSE-Loss':

     def loss(w):
         return 1/X.shape[0] * jax.numpy.linalg.norm(jnp.dot(X, w) - y)**2
         
     st.write('You selected the RMSE loss function.')
     st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
     st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
     st.latex(r'''\ell(X, y, w)= \frac{1}{m}\sum_1^m (\hat{y}_i - y_i)^2''')
     
     
     progress_bar = st.progress(0)
     status_text = st.empty()
     grad_loss = jax.grad(loss)
    
     # Perform gradient descent
     progress_counter = 0
     for i in range(1,NUM_ITER+1):
         if i %10==0:
             # Update progress bar.
             progress_counter += 1
             progress_bar.progress(progress_counter)

         # Update parameters.
         w -= learning_rate * grad_loss(w)

         # Update status text.
         if (i)%100==0:
             # report the loss at the current epoch
             status_text.text(
                 'Trained loss at epoch %s is %s' % (i, loss(w)))
     # Plot the final line
     fig, ax = plt.subplots(dpi=120)
     ax.set_xlim((0,1))
     ax.set_ylim((-5,26))
     ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick')
     ax.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), 'k-', label='Final line')
     st.pyplot(fig)
     status_text.text('Done!')


else:
     st.write("You selected the Huber loss function.")
     st.latex(r'''
\ell_{H} = 
\begin{cases} 
      (y^{(i)}-\hat{y}^{(i)})^2 & \text{for }\quad |y^{(i)}-\hat{y}^{(i)}|\leq \delta \\
      2\delta|y^{(i)}-\hat{y}^{(i)}| - \delta^2 & \text{otherwise}
\end{cases}''')