alkzar90 commited on
Commit
cc456a2
1 Parent(s): e49b76e

Change numpy code by jax.numpy

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -1,7 +1,12 @@
1
  import streamlit as st
2
- import numpy as np
 
3
  import matplotlib.pyplot as plt
4
 
 
 
 
 
5
  st.title('Fitting simple models with JAX')
6
  st.header('A quadratric regression example')
7
 
@@ -15,19 +20,18 @@ number_of_observations = st.sidebar.slider('Number of observations', min_value=5
15
  noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
16
  cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
17
 
18
- np.random.seed(2)
19
-
20
  X = np.column_stack((np.ones(number_of_observations),
21
  np.random.random(number_of_observations)))
22
 
23
  w = np.array([3.0, -20.0, 32.0]) # coefficients
24
-
25
  X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
26
- additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
27
- y = np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
28
  + additional_noise
29
 
30
 
 
31
  fig, ax = plt.subplots(dpi=320)
32
  ax.set_xlim((0,1))
33
  ax.set_ylim((-5,26))
@@ -46,6 +50,10 @@ st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{
46
 
47
  # Fitting by the respective cost_function
48
  if cost_function == 'RMSE-Loss':
 
 
 
 
49
  st.write('You selected the RMSE loss function.')
50
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
51
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
 
1
  import streamlit as st
2
+ import jax.numpy as jnp
3
+ import jax
4
  import matplotlib.pyplot as plt
5
 
6
+ # Set random key
7
+ seed=321
8
+ key = jax.random.PRNGKey(seed)
9
+
10
  st.title('Fitting simple models with JAX')
11
  st.header('A quadratric regression example')
12
 
 
20
  noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
21
  cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
22
 
23
+ # Generate random data
 
24
  X = np.column_stack((np.ones(number_of_observations),
25
  np.random.random(number_of_observations)))
26
 
27
  w = np.array([3.0, -20.0, 32.0]) # coefficients
 
28
  X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
29
+ additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])
30
+ y = jnp.dot(X, w) + noise_standard_deviation * jax.random.normal(key, shape=[number_of_observations,]) \
31
  + additional_noise
32
 
33
 
34
+ # Plot the data
35
  fig, ax = plt.subplots(dpi=320)
36
  ax.set_xlim((0,1))
37
  ax.set_ylim((-5,26))
 
50
 
51
  # Fitting by the respective cost_function
52
  if cost_function == 'RMSE-Loss':
53
+
54
+ def loss(w):
55
+ return 1/X.shape[0] * jax.numpy.linalg.norm(jnp.dot(X, w) - y)**2
56
+
57
  st.write('You selected the RMSE loss function.')
58
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
59
  st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')