NchourupouoM's picture
Update app.py
77e8f60 verified
import streamlit as st
import torch
from torchvision import transforms, models
from PIL import Image
import torch.nn.functional as F
import pandas as pd
import pathlib
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
me = ['Acceuil','Prediction']
p = st.sidebar.selectbox('Menu', me)
# style
def load_css(file_path):
with open(file=file_path) as f:
st.html(f"<style>{f.read()}</style>")
css_path = pathlib.Path("css_file.css")
load_css(css_path)
@st.cache_resource
def charger_modele():
try:
model = torch.load('Brain_tumor_effmodel.pth', map_location=torch.device('cpu'), weights_only=False) # Charge sur CPU par défaut
model.eval() # Mettre le modèle en mode évaluation
return model
except Exception as e: # Gestion des erreurs
st.error(f"Erreur lors du chargement du modèle : {e}") # Affichage de l'erreur dans Streamlit
return None # Ou gérer l'erreur comme vous le souhaitez
with st.spinner('Chargement du model pytorch..'):
modele_charge = charger_modele()
# L'image doit être comme celles utilisées de l'entraînement du model
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5),0.5)])
if p == 'Acceuil':
st.title('SN DE TRANSFERT LEARNING')
st.subheader("Nom & Prenom: Nchourupouo Mohamed")
st.subheader("Application de prediction de tumeur cerebrale")
c1, c2 = st.columns(2)
with c1:
st.write("Bienvenue sur cette plateforme innovante qui utilise l'intelligence artificielle pour fournir des prédictions éclairées sur les tumeurs de cerveau.Elle offre une évaluation rapide et précise du risque de malignité. Alors vous souhaitez essayer notre Application ? cliquer sur le bouton prediction dans le sidebar.")
with c2:
st.write("Près pour un test ? Deroulez le menu `Acceuil` ➤ `Prediction` ➤ `Browse files` et Uploader un fichier image comme celui presente ci-dessous")
st.image("mri_healthy (1).jpeg", caption="Image chargée", use_container_width=True)
elif p=='Prediction':
st.title('TRANSFERT LEARNING')
st.write("Pour pouvoir faire une prediction, veuillez uploader une image en upuyant le bouton `Browse files`.")
upload_file = st.sidebar.file_uploader('Choisissez une image',type=['jpg','jpeg','png'])
if upload_file:
image = Image.open(upload_file).convert("RGB")
st.image(image, caption="Image chargée", use_container_width=True)
st.write("Cliquez sur le bouton `Resultat` pour voir vos resultats, puis scroller vers le bat pour voir les `probabilités`.")
# Prétraitement de l'image
img_tensor = transform(image).unsqueeze(0)
classes_p = {'HEALTHY': 0, "Brain Tumor":1}
def output_proba():
output = modele_charge(img_tensor)
probabilities = F.softmax(output, dim=1)
return output, probabilities
if st.sidebar.checkbox("Resultat"):
# Prédiction
with torch.no_grad():
output,_ =output_proba()
predicted_class = torch.argmax(output, dim=1).item()
for key, value in classes_p.items():
if value == predicted_class:
if key == "HEALTHY":
st.write(f"Je suis heureux de vous annoncer que les résultats sont très encourageants. Il n'y a aucune indication de cancer du cerveau. La prédiction du modèle est la suivante ➤ <strong>{key}</strong>", unsafe_allow_html=True)
st.balloons()
else:
st.write(f"Je suis profondément désolé de devoir vous annoncer que les résultats ne sont pas ceux que nous espérions. la prediction du model est la suivante ➤ <strong>{key}</strong>", unsafe_allow_html=True)
_,probabilities = output_proba()
df = pd.DataFrame({"Classe": classes_p.keys(),"Probabilité": probabilities.tolist()[0]})
couleurs = ['skyblue', 'lightcoral']
fig, ax = plt.subplots(figsize=(8, 6)) # Créer une figure et un axe
ax.bar(df['Classe'], df['Probabilité'], color=couleurs)
# Ajouter des étiquettes et un titre
ax.set_xlabel('Classe')
ax.set_ylabel('Probabilité')
ax.set_title('Probabilites de prediction')
# Personnaliser l'apparence du graphique (facultatif)
ax.grid(axis='y', linestyle='--', alpha=0.7) # Ajouter une grille en arrière-plan
ax.tick_params(axis='x', rotation=45) # Faire pivoter les étiquettes de l'axe x si elles sont trop longues
# Afficher le graphique dans Streamlit
st.pyplot(fig)
print(df.columns)