Spaces:
Runtime error
Runtime error
Update data generation
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
import jax.numpy as jnp
|
3 |
import jax
|
4 |
import matplotlib.pyplot as plt
|
@@ -21,15 +22,14 @@ noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise',
|
|
21 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
22 |
|
23 |
# Generate random data
|
24 |
-
X =
|
25 |
-
|
26 |
-
w = jnp.array([3.0, -20.0, 32.0]) # coefficients
|
27 |
X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
28 |
-
additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])
|
29 |
-
y = jnp.dot(X, w) + noise_standard_deviation * jax.random.normal(key, shape=[number_of_observations,]) \
|
30 |
-
+ additional_noise
|
31 |
-
|
32 |
|
|
|
|
|
|
|
|
|
33 |
# Plot the data
|
34 |
fig, ax = plt.subplots(dpi=320)
|
35 |
ax.set_xlim((0,1))
|
|
|
1 |
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
import jax.numpy as jnp
|
4 |
import jax
|
5 |
import matplotlib.pyplot as plt
|
|
|
22 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
23 |
|
24 |
# Generate random data
|
25 |
+
X = np.column_stack((np.ones(number_of_observations),
|
26 |
+
np.random.random(number_of_observations)))
|
|
|
27 |
X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
|
30 |
+
y = jnp.array(np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
|
31 |
+
+ additional_noise)
|
32 |
+
|
33 |
# Plot the data
|
34 |
fig, ax = plt.subplots(dpi=320)
|
35 |
ax.set_xlim((0,1))
|