Spaces:
Sleeping
Sleeping
import json | |
import os | |
import joblib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import streamlit as st | |
from matplotlib.colors import ListedColormap | |
from tensorflow.keras.models import load_model | |
from utils.utils import input_predict | |
directory = "utils" | |
model_1_path = os.path.join(directory, "nnmodel_117d.keras") | |
model_2_path = os.path.join(directory, "nnmodel_2d.keras") | |
pca_adim_path = os.path.join(directory, "pca_2d.pkl") | |
json_file_path = os.path.join(directory, "columns.json") | |
mushroom_attributes_path = os.path.join(directory, "mushroom_attributes.json") | |
prediction_grid_path = os.path.join(directory, "pred_grid.npz") | |
model_1 = load_model(model_1_path) | |
model_2 = load_model(model_2_path) | |
pca_2d_loaded = joblib.load(pca_adim_path) | |
with open(json_file_path, 'r') as json_file: | |
df_columns = json.load(json_file) | |
with open(mushroom_attributes_path, 'r') as json_file: | |
loaded_dicts = json.load(json_file) | |
with np.load(prediction_grid_path) as data: | |
xx = data['xx'] | |
yy = data['yy'] | |
Z = data['Z'] | |
example_input = { | |
'cap-shape': 'convex', | |
'cap-surface': 'smooth', | |
'cap-color': 'white', | |
'bruises': 'bruises', | |
'odor': 'pungent', | |
'gill-attachment': 'free', | |
'gill-spacing': 'close', | |
'gill-size': 'narrow', | |
'gill-color': 'black', | |
'stalk-shape': 'enlarging', | |
'stalk-root': 'equal', | |
'stalk-surface-above-ring': 'smooth', | |
'stalk-surface-below-ring': 'smooth', | |
'stalk-color-above-ring': 'white', | |
'stalk-color-below-ring': 'white', | |
'veil-type': 'partial', | |
'veil-color': 'white', | |
'ring-number': 'one', | |
'ring-type': 'pendant', | |
'spore-print-color': 'black', | |
'population': 'several', | |
'habitat': 'grasses' | |
} | |
st.title('Mushroom Classification') | |
input_data = {} | |
keys = list(example_input.keys()) | |
for i in range(0, len(keys), 4): | |
cols = st.columns(4) | |
for col, key in zip(cols, keys[i:i + 4]): | |
options = list(loaded_dicts.get(key.replace('-', '_') + '_dict', {}).values()) | |
input_data[key] = col.selectbox( | |
label=key.replace('_', ' ').title(), | |
options=options, | |
index=options.index(example_input[key]) if example_input[key] in options else 0, | |
) | |
boundary_color1 = 'palegreen' | |
boundary_color2 = 'lightcoral' | |
custom_cmap = ListedColormap([boundary_color1, boundary_color2]) | |
submitted = st.button('Submit') | |
if submitted: | |
mushroom_input = np.array([list(input_data.values())]) | |
predict = input_predict(input_data, df_columns=df_columns, model_1=model_1, model_2=model_2, pca=pca_2d_loaded) | |
print(str(predict[1][0])) | |
if str(predict[1][0]) == 'poisonous': | |
header_col = 'red' | |
else: | |
header_col = 'green' | |
st.header(f':{header_col}[{str(predict[1][0]).upper()}]', anchor=False, divider=header_col) | |
if not predict[2]: | |
st.markdown( | |
'<span style="color:red; font-weight:bold;">⚠ Potential Misleading Plot</span>', | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
'<span style="color:red; font-size:small;">The plotted data may be misleading because the Principal Component Analysis (PCA) used for dimensionality reduction only explains 30% of the variance in the original dataset.</span>', | |
unsafe_allow_html=True | |
) | |
vec_2d_imput = predict[0] | |
train_size = 50 | |
fig, ax = plt.subplots(figsize=(9, 7)) | |
ax.contourf(xx, yy, Z, np.linspace(0, 1, 3), alpha=0.3, cmap=custom_cmap) | |
prediction_test = ax.scatter(vec_2d_imput[:, 0], vec_2d_imput[:, 1], marker='x', c='k', s=200, | |
label=f'Prediction') | |
train_edible_marker = ax.scatter([], [], c=boundary_color1, label='Edible', s=train_size, marker='o') | |
train_poisonus_marker = ax.scatter([], [], c=boundary_color2, label='Poisonus', s=train_size, marker='o') | |
ax.legend(handles=[train_edible_marker, train_poisonus_marker, prediction_test]) | |
ax.axis('off') | |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
st.pyplot(fig) | |
st.markdown( | |
'<div style="text-align: center; font-style: italic; font-size:small;">- 2D plot of the model -</div>', | |
unsafe_allow_html=True | |
) | |
# streamlit run C:/Users/teoto/PycharmProjects/huggingface_app/app.py | |