Mushroom-IAAE / app.py
mateo-sanro's picture
Update app.py
667cf4b verified
raw
history blame contribute delete
No virus
4.25 kB
import json
import os
import joblib
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from matplotlib.colors import ListedColormap
from tensorflow.keras.models import load_model
from utils.utils import input_predict
directory = "utils"
model_1_path = os.path.join(directory, "nnmodel_117d.keras")
model_2_path = os.path.join(directory, "nnmodel_2d.keras")
pca_adim_path = os.path.join(directory, "pca_2d.pkl")
json_file_path = os.path.join(directory, "columns.json")
mushroom_attributes_path = os.path.join(directory, "mushroom_attributes.json")
prediction_grid_path = os.path.join(directory, "pred_grid.npz")
model_1 = load_model(model_1_path)
model_2 = load_model(model_2_path)
pca_2d_loaded = joblib.load(pca_adim_path)
with open(json_file_path, 'r') as json_file:
df_columns = json.load(json_file)
with open(mushroom_attributes_path, 'r') as json_file:
loaded_dicts = json.load(json_file)
with np.load(prediction_grid_path) as data:
xx = data['xx']
yy = data['yy']
Z = data['Z']
example_input = {
'cap-shape': 'convex',
'cap-surface': 'smooth',
'cap-color': 'white',
'bruises': 'bruises',
'odor': 'pungent',
'gill-attachment': 'free',
'gill-spacing': 'close',
'gill-size': 'narrow',
'gill-color': 'black',
'stalk-shape': 'enlarging',
'stalk-root': 'equal',
'stalk-surface-above-ring': 'smooth',
'stalk-surface-below-ring': 'smooth',
'stalk-color-above-ring': 'white',
'stalk-color-below-ring': 'white',
'veil-type': 'partial',
'veil-color': 'white',
'ring-number': 'one',
'ring-type': 'pendant',
'spore-print-color': 'black',
'population': 'several',
'habitat': 'grasses'
}
st.title('Mushroom Classification')
input_data = {}
keys = list(example_input.keys())
for i in range(0, len(keys), 4):
cols = st.columns(4)
for col, key in zip(cols, keys[i:i + 4]):
options = list(loaded_dicts.get(key.replace('-', '_') + '_dict', {}).values())
input_data[key] = col.selectbox(
label=key.replace('_', ' ').title(),
options=options,
index=options.index(example_input[key]) if example_input[key] in options else 0,
)
boundary_color1 = 'palegreen'
boundary_color2 = 'lightcoral'
custom_cmap = ListedColormap([boundary_color1, boundary_color2])
submitted = st.button('Submit')
if submitted:
mushroom_input = np.array([list(input_data.values())])
predict = input_predict(input_data, df_columns=df_columns, model_1=model_1, model_2=model_2, pca=pca_2d_loaded)
print(str(predict[1][0]))
if str(predict[1][0]) == 'poisonous':
header_col = 'red'
else:
header_col = 'green'
st.header(f':{header_col}[{str(predict[1][0]).upper()}]', anchor=False, divider=header_col)
if not predict[2]:
st.markdown(
'<span style="color:red; font-weight:bold;">&#9888; Potential Misleading Plot</span>',
unsafe_allow_html=True
)
st.markdown(
'<span style="color:red; font-size:small;">The plotted data may be misleading because the Principal Component Analysis (PCA) used for dimensionality reduction only explains 30% of the variance in the original dataset.</span>',
unsafe_allow_html=True
)
vec_2d_imput = predict[0]
train_size = 50
fig, ax = plt.subplots(figsize=(9, 7))
ax.contourf(xx, yy, Z, np.linspace(0, 1, 3), alpha=0.3, cmap=custom_cmap)
prediction_test = ax.scatter(vec_2d_imput[:, 0], vec_2d_imput[:, 1], marker='x', c='k', s=200,
label=f'Prediction')
train_edible_marker = ax.scatter([], [], c=boundary_color1, label='Edible', s=train_size, marker='o')
train_poisonus_marker = ax.scatter([], [], c=boundary_color2, label='Poisonus', s=train_size, marker='o')
ax.legend(handles=[train_edible_marker, train_poisonus_marker, prediction_test])
ax.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
st.pyplot(fig)
st.markdown(
'<div style="text-align: center; font-style: italic; font-size:small;">- 2D plot of the model -</div>',
unsafe_allow_html=True
)
# streamlit run C:/Users/teoto/PycharmProjects/huggingface_app/app.py