alkzar90 commited on
Commit
d119a57
1 Parent(s): 2667960

Update data generation

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  import jax.numpy as jnp
3
  import jax
4
  import matplotlib.pyplot as plt
@@ -21,15 +22,14 @@ noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise',
21
  cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
22
 
23
  # Generate random data
24
- X = jnp.column_stack((jnp.ones(number_of_observations),
25
- jax.random.uniform(key, shape=(number_of_observations,), minval=0., maxval=1.)))
26
- w = jnp.array([3.0, -20.0, 32.0]) # coefficients
27
  X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column
28
- additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])
29
- y = jnp.dot(X, w) + noise_standard_deviation * jax.random.normal(key, shape=[number_of_observations,]) \
30
- + additional_noise
31
-
32
 
 
 
 
 
33
  # Plot the data
34
  fig, ax = plt.subplots(dpi=320)
35
  ax.set_xlim((0,1))
 
1
  import streamlit as st
2
+ import numpy as np
3
  import jax.numpy as jnp
4
  import jax
5
  import matplotlib.pyplot as plt
 
22
  cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
23
 
24
  # Generate random data
25
+ X = np.column_stack((np.ones(number_of_observations),
26
+ np.random.random(number_of_observations)))
 
27
  X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column
 
 
 
 
28
 
29
+ additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
30
+ y = jnp.array(np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
31
+ + additional_noise)
32
+
33
  # Plot the data
34
  fig, ax = plt.subplots(dpi=320)
35
  ax.set_xlim((0,1))