alkzar90's picture
Fix typo
3ab53fa
raw history blame
No virus
1.23 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
st.title('Fitting simple models with JAX')
st.header('A quadratric regression example')
st.markdown('**This is a simple text** to specify the goal of this simple data app\n.')
st.latex('h(\boldsymbol x, \boldsymbol w)= \sum_{k=1}^{K}\boldsymbol w_{k} \phi_{k}(\boldsymbol x)')
number_of_observations = st.sidebar.slider('Number of observations', min_value=50, max_value=150, value=50)
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=0.25)
np.random.seed(2)
X = np.column_stack((np.ones(number_of_observations),
np.random.random(number_of_observations)))
w = np.array([3.0, -20.0, 32.0]) # coefficients
X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
y = np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
+ additional_noise
fig, ax = plt.subplots(dpi=320)
ax.set_xlim((0,1))
ax.set_ylim((-5,20))
ax.scatter(X[:,1], y, c='r', edgecolors='black')
st.pyplot(fig)
st.write(X[:5, :])