Spaces:
Runtime error
Runtime error
File size: 1,766 Bytes
a3fdf79 1bdb2db c5bcab5 1bdb2db e007ae4 1bdb2db 654009f e007ae4 a3fdf79 fbe9715 a3fdf79 d3a71f8 a3fdf79 19c2016 e007ae4 19c2016 a3fdf79 3272bb6 d3a71f8 c5bcab5 d3a71f8 3272bb6 d3a71f8 a3fdf79 04125fd b5bb6ce 04125fd c5bcab5 017e8c0 04125fd 3181393 d3a71f8 e6273c9 a3fdf79 fbe9715 1bdb2db d3a71f8 fbe9715 a3fdf79 fbe9715 a3fdf79 d734b3a d3a71f8 d734b3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import streamlit as st
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error
from sklearn.neighbors import KNeighborsRegressor
st.subheader("K nearest neighbor (KNN) Regressor")
st_col = st.columns(1)[0]
K = st.slider(
"Number of nearest neighbors (K)", min_value=1, max_value=10, value=5, step=1
)
option = st.selectbox(
"Select Distance Metric", ("L1(Manhattan)", "L2(Euclidean Distance)")
)
X, y = make_regression(n_samples=100, n_features=1, noise=0.3, random_state=42)
ntrain = 30
x_train = X[:ntrain]
y_train = y[:ntrain]
x_test = X[ntrain:]
y_test = y[ntrain:]
if str(option) == "L1(Manhattan)":
metric = "manhattan"
else:
metric = "euclidean"
knn = KNeighborsRegressor(n_neighbors=K, metric=metric)
knn.fit(x_train, y_train)
y_pred = knn.predict(x_test)
plt.figure()
plt.plot(y_test[:30], "C0s", label="True Points (Test)")
plt.plot(y_pred[:30], "C1*", label="Predictions (Test)")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend(loc="upper left")
plt.ylim(-90,90)
sns.despine(right=True, top=True)
with st_col:
st.pyplot(plt)
error = mean_squared_error(y_test, y_pred)
st.write("The mean squared error is %.2f" % error)
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
subheader {alignment: center;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
st.markdown(
"""
The above plot shows the True values and Predictions for 30 points in the test set.
It can be observed that the optimal value of K is 3 for the given dataset.
""",
unsafe_allow_html=True,
)
|