import streamlit as st import numpy as np import jax.numpy as jnp import jax import matplotlib.pyplot as plt # Set random key seed=321 key = jax.random.PRNGKey(seed) st.title('Fitting simple models with JAX') st.header('A quadratric regression example') st.markdown('*\"Parametrised models are simply functions that depend on inputs and trainable parameters. There is no fundamental difference between the two, except that trainable parameters are shared across training samples whereas the input varies from sample to sample.\"* [(Yann LeCun, Deep learning course)](https://atcold.github.io/pytorch-Deep-Learning/en/week02/02-1/#Parametrised-models)') st.latex(r'''h(\boldsymbol x, \boldsymbol w)= \sum_{k=1}^{K}\boldsymbol w_{k} \phi_{k}(\boldsymbol x)''') # Sidebar inputs number_of_observations = st.sidebar.slider('Number of observations', min_value=50, max_value=150, value=100) noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0) cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss')) # Generate random data np.random.seed(2) w = jnp.array([3.0, -20.0, 32.0]) # coefficients X = np.column_stack((np.ones(number_of_observations), np.random.random(number_of_observations))) X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations) y = jnp.array(np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \ + additional_noise) # Plot the data fig, ax = plt.subplots(dpi=320) ax.set_xlim((0,1)) ax.set_ylim((-5,26)) ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick') st.pyplot(fig) st.subheader('Train a model') st.markdown('*\"A Gradient Based Method is a method/algorithm that finds the minima of a function, assuming that one can easily compute the gradient of that function. It assumes that the function is continuous and differentiable almost everywhere (it need not be differentiable everywhere).\"* [(Yann LeCun, Deep learning course)](https://atcold.github.io/pytorch-Deep-Learning/en/week02/02-1/#Parametrised-models)') st.markdown('Using gradient descent we find the minima of the loss adjusting the weights in each step given the following formula:') st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{w})}{\partial \bf{w}}''') # Fitting by the respective cost_function w = jnp.array(np.random.random(3)) learning_rate = 0.05 NUM_ITER = 1000 fig, ax = plt.subplots(dpi=120) ax.set_xlim((0,1)) ax.set_ylim((-5,26)) ax.scatter(X[:,1], y, c='#e76254' ,edgecolors='firebrick') if cost_function == 'RMSE-Loss': def loss(w): return 1/X.shape[0] * jax.numpy.linalg.norm(jnp.dot(X, w) - y)**2 st.write('You selected the RMSE loss function.') st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''') st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''') st.latex(r'''\ell(X, y, w)= \frac{1}{m}\sum_1^m (\hat{y}_i - y_i)^2''') progress_bar = st.progress(0) status_text = st.empty() grad_loss = jax.grad(loss) # Perform gradient descent for i in range(NUM_ITER): # Update progress bar. progress_bar.progress(i + 1) # Update parameters. w -= learning_rate * grad_loss(w) # Update status text. if (i+1)%100==0: status_text.text( 'Trained loss is: %s' % loss(w)) # add a line to the plot plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), '--') print(f"Trained loss at epoch #{i}:", loss(w)) plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), '--', label='Final line') status_text.text('Done!') else: st.write("You selected the Huber loss function.") st.latex(r''' \ell_{H} = \begin{cases} (y^{(i)}-\hat{y}^{(i)})^2 & \text{for }\quad |y^{(i)}-\hat{y}^{(i)}|\leq \delta \\ 2\delta|y^{(i)}-\hat{y}^{(i)}| - \delta^2 & \text{otherwise} \end{cases}''') st.markdown('The training loop:') code = '''NUM_ITER = 1000 # initialize parameters w = np.array([3., -2., -8.]) for i in range(NUM_ITER): # update parameters w -= learning_rte * grad_loss(w)''' st.code(code, language='python')