alkzar90 commited on
Commit
0833d64
1 Parent(s): d975533

Fix jax.numpy typos errors

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -21,9 +21,8 @@ noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise',
21
  cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
22
 
23
  # Generate random data
24
- X = np.column_stack((np.ones(number_of_observations),
25
- np.random.random(number_of_observations)))
26
-
27
  w = np.array([3.0, -20.0, 32.0]) # coefficients
28
  X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
29
  additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])
 
21
  cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
22
 
23
  # Generate random data
24
+ X = np.column_stack((jnp.ones(number_of_observations),
25
+ jax.random.uniform(key, shape=(number_of_observations,), minval=0., maxval=1.))
 
26
  w = np.array([3.0, -20.0, 32.0]) # coefficients
27
  X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
28
  additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])