import streamlit as st
import pickle
import time
import os
import pandas as pd
import plotly.express as px
from PIL import Image
from utils import load_data_pickle
# import gradcam
# from gradcam.utils import visualize_cam
# from gradcam import GradCAM, GradCAMpp
#add_indentation()
st.set_page_config(layout="wide")
# Chemin vers le dossier contenant les images et le modèle pré-entraîné
DATA_DIR = r"data/image_classification/images"
MODEL_PATH = r"pretrained_models/image_classification/resnet18_braintumor.pt"
gradcam_images_paths = ["images/meningioma_tumor.png", "images/no_tumor.png", "images/pituitary.png"]
# PREPROCESSING
# def preprocess(image):
# # Il faut que l'image' est une image PIL. Si 'image' est un tableau numpy, on le convertit en image PIL.
# if isinstance(image, np.ndarray):
# image = Image.fromarray(image)
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalisez l'image.
# ])
# # On applique les transformations définies sur l'image.
# image = transform(image)
# return image
# Chargement du modèle pré-entraîné
# def load_pretrained_model(num_classes=3):
# model = models.resnet18(pretrained=False)
# num_ftrs = model.fc.in_features
# model.fc = torch.nn.Linear(num_ftrs, num_classes)
# # Chargement des poids pré-entraînés tout en ignorant la dernière couche 'fc'
# state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
# state_dict.pop('fc.weight', None)
# state_dict.pop('fc.bias', None)
# model.load_state_dict(state_dict, strict=False)
# model.eval()
# return model
# model = load_pretrained_model(num_classes=3) #On a supprimés une des classes
# # PREDICTION
# def predict(image_preprocessed, model):
# # Si image_preprocessed est déjà un tensor PyTorch, on doit s'assurer qu'il soit de dimension 3 : [batch_size, channels, height, width]
# # La fonction unsqueeze(0) ajoute une dimension de batch_size au début pour le faire correspondre à cette attente
# if image_preprocessed.dim() == 3:
# image_preprocessed = image_preprocessed.unsqueeze(0)
# with torch.no_grad():
# output = model(image_preprocessed)
# _, predicted = torch.max(output, 1)
# return predicted, output
###################################### TITLE ####################################
st.markdown("# Image Classification 🖼️")
st.markdown("### What is Image classification ?")
st.info("""**Image classification** is a process in Machine Learning and Computer Vision where an algorithm is trained to recognize and categorize images into predefined classes. It involves analyzing the visual content of an image and assigning it to a specific label based on its features.""")
#unsafe_allow_html=True)
st.markdown(" ")
st.markdown("""State-of-the-art image classification models use **neural networks** to predict whether an image belongs to a specific class.
Each of the possible predicted classes are given a probability then the class with the highest value is assigned to the input image.""",
unsafe_allow_html=True)
image_ts = Image.open('images/cnn_example.png')
_, col, _ = st.columns([0.2,0.8,0.2])
with col:
st.image(image_ts,
caption="An example of an image classification model, with the 'backbone model' as the neural network.")
st.markdown(" ")
st.markdown("""Real-life applications of image classification includes:
- **Medical Imaging 👨⚕️**: Diagnose diseases and medical conditions from images such as X-rays, MRIs and CT scans to, for example, identify tumors and classify different types of cancers.
- **Autonomous Vehicules** 🏎️: Classify objects such as pedestrians, vehicles, traffic signs, lane markings, and obstacles, which is crucial for navigation and collision avoidance.
- **Satellite and Remote Sensing 🛰️**: Analyze satellite imagery to identify land use patterns, monitor vegetation health, assess environmental changes, and detect natural disasters such as wildfires and floods.
- **Quality Control 🛂**: Inspect products and identify defects to ensure compliance with quality standards during the manufacturying process.
""")
# st.markdown("""Real-life applications of Brain Tumor includes:
# - **Research and development💰**: The technologies and methodologies developed for brain tumor classification can advance research in neuroscience, oncology, and the development of new diagnostic tools and treatments.
# - **Healthcare👨⚕️**: Data derived from the classification and analysis of brain tumors can inform public health decisions, healthcare policies, and resource allocation, emphasizing areas with higher incidences of certain types of tumors.
# - **Insurance Industry 🏬**: Predict future demand for products to optimize inventory levels, reduce holding costs, and improve supply chain efficiency.
# """)
###################################### USE CASE #######################################
# BEGINNING OF USE CASE
st.divider()
st.markdown("# Brain Tumor Classification 🧠")
st.info("""In this use case, a **brain tumor classification** model is leveraged to accurately identify the presence of tumors in MRI scans of the brain.
This application can be a great resource for healthcare professionals to facilite early detection and consequently improve treatment outcomes for patients.""")
st.markdown(" ")
_, col, _ = st.columns([0.1,0.8,0.1])
with col:
st.image("images/brain_tumor.jpg")
st.markdown(" ")
st.markdown(" ")
### WHAT ARE BRAIN TUMORS ?
st.markdown(" ### What is a Brain Tumor ?")
st.markdown("""Before introducing the use case, let's give a short description on what a brain tumor is.
A brain tumor occurs when **abnormal cells form within the brain**. Two main types of tumors exist: **cancerous (malignant) tumors** and **benign tumors**.
- **Cancerous tumors** are malignant tumors that have the ability to invade nearby tissues and spread to other parts of the body through a process called metastasis.
- **Benign tumors** can become quite large but will not invade nearby tissue or spread to other parts of the body. They can still cause serious health problems depending on their size, location and rate of growth.
""", unsafe_allow_html=True)
st.markdown(" ")
st.markdown(" ")
st.markdown("### About the data 📋")
st.markdown("""You were provided with a large dataset which contains **anonymized patient MRI scans** categorized into three distinct classes: **pituitary tumor** (in most cases benign), **meningioma tumor** (cancerous) and **no tumor**.
This dataset will serve as the foundation for training our classification model, offering a comprehensive view of varied tumor presentations within the brain.""")
_, col, _ = st.columns([0.15,0.7,0.15])
with col:
st.image("images/tumors_types_class.png")
# see_data = st.checkbox('**See the data**', key="image_class\seedata")
# if see_data:
# st.warning("You can view here a few examples of the MRI training data.")
# # image selection
# images = os.listdir(DATA_DIR)
# selected_image1 = st.selectbox("Choose an image to visualize 🔎 :", images, key="selectionbox_key_2")
# # show image
# image_path = os.path.join(DATA_DIR, selected_image1)
# image = Image.open(image_path)
# st.image(image, caption="Image selected", width=450)
# st.info("""**Note**: This dataset will serve as the foundation for training our classification model, offering a comprehensive view of varied tumor presentations within the brain.
# By analyzing these images, the model learns to discern the subtle differences between each class, thereby enabling the precise identification of tumor types.""")
st.markdown(" ")
st.markdown(" ")
st.markdown("### Train the algorithm ⚙️")
st.markdown("""**Training an AI model** means feeding it data that contains multiple examples/images each type of tumor to be detected.
By analyzing the provided MRI images, the model learns to discern the subtle differences between each classes, thereby enabling the precise identification of tumor types.""")
### CONDITION ##
# Initialisation de l'état du modèle
if 'model_train' not in st.session_state:
st.session_state['model_train'] = False
run_model = st.button("Train the model")
if run_model:
# Simuler l'entraînement du modèle
st.session_state.model_train = True
with st.spinner('Training the model...'):
time.sleep(2)
st.success("The model has been trained.")
else:
# Afficher le statut
st.info("The model hasn't been trained yet.")
# Afficher les résultats
if st.session_state.model_train:
st.markdown(" ")
st.markdown(" ")
st.markdown("### See the results ☑️")
tab1, tab2 = st.tabs(["Performance", "Explainability"])
with tab1:
#st.subheader("Performance")
st.info("""**Evaluating a model's performance** helps provide a quantitative measurement of it's ability to make accurate predictions.
In this use case, the performance of the brain tumor classification model was measured by comparing the patient's true diagnosis with the class predicted by the trained model.""")
class_accuracy_path = "data/image_classification/class_accuracies.pkl"
# Charger les données depuis le fichier Pickle
try:
with open(class_accuracy_path, 'rb') as file:
class_accuracy = pickle.load(file)
except Exception as e:
st.error(f"Erreur lors du chargement du fichier : {e}")
class_accuracy = {}
if not isinstance(class_accuracy, dict):
st.error(f"Expected a dictionary, but got: {type(class_accuracy)}")
else:
# Conversion des données en DataFrame
df_accuracy = pd.DataFrame(list(class_accuracy.items()), columns=['Tumor Type', 'Accuracy'])
df_accuracy['Accuracy'] = ((df_accuracy['Accuracy'] * 100).round()).astype(int)
# Générer le graphique à barres avec Plotly
fig = px.bar(df_accuracy, x='Tumor Type', y='Accuracy',
text='Accuracy', color='Tumor Type',
title="Model Performance",
labels={'Accuracy': 'Accuracy (%)', 'Tumor Type': 'Tumor Type'})
fig.update_traces(texttemplate='%{text}%', textposition='outside')
# Afficher le graphique dans Streamlit
st.plotly_chart(fig, use_container_width=True)
st.markdown("""The model's accuracy was evaluated across two types of tumors (pituitary and meningioma) and no tumor type.
This evaluation is vital for determining if the model performs consistently across different tumor classifications, or if it encounters difficulties in accurately distinguishing between these two types of tumors.""",
unsafe_allow_html=True)
st.markdown(" ")
st.markdown("""**Interpretation**:
Our model demonstrates high accuracy in predicting cancerous type tumors (meningioma) as well as 'healthy' brain scans (no tumor) with a 98% accuracy for both.
It is observed that the model's performance is lower for pituitary type tumors, as it is around 81%.
This discrepancy may indicate that the model finds it more challenging to distinguish pituitary tumors from other tumor
types, possibly due to their unique characteristics or lower representation in the training data.
""", unsafe_allow_html=True)
with tab2:
#st.subheader("Model Explainability with Grad-CAM")
st.info("""**Explainability in AI** refers to the ability to **understand and interpret how AI systems make predictions** and how to quantify the impact of the provided data on its results.
In the case of image classification, explainability can be measured by analyzing which of the image's pixel had the most impact on the model's output.""")
st.markdown(" ")
st.markdown("""The following images show the output of image classification explainability applied on three images used during training.
Pixels that are colored in 'red' had a larger impact on the model's output and thus its ability to distinguish different tumor types (or none).
""", unsafe_allow_html=True)
st.markdown(" ")
gradcam_images_paths = ["images/meningioma_tumor.png", "images/no_tumor.png", "images/pituitary.png"]
class_names = ["Meningioma Tumor", "No Tumor", "Pituitary Tumor"]
for path, class_name in zip(gradcam_images_paths, class_names):
st.image(path, caption=f"Explainability for {class_name}")
# st.markdown("""
# Interpretation:
# ### Meningioma Tumors
# **Meningiomas** are tumors that originate from the meninges, the layers of tissue
# that envelop the brain and spinal cord. Although they are most often benign
# (noncancerous) and grow slowly, their location can cause significant issues by
# exerting pressure on the brain or spinal cord. Meningiomas can occur at various
# places around the brain and spinal cord and are more common in women than in men.
# ### Pituitary Tumors
# **Pituitary** are growths that develop in the pituitary gland, a small gland located at the
# base of the brain, behind the nose, and between the ears. Despite their critical location,
# the majority of pituitary tumors are benign and grow slowly. This gland regulates many of the
# hormones that control various body functions, so even a small tumor can affect hormone production,
# leading to a variety of symptoms.""", unsafe_allow_html=True)
#################################################
st.markdown(" ")
st.markdown(" ")
st.markdown("### Classify MRI scans 🆕")
st.info("**Note**: The brain tumor classification model can classify new MRI images only if it has been previously trained.")
st.markdown("""Here, you are provided the MRI scans of nine new patients.
Select an image and press 'run the model' to classify the MRI as either a pituitary tumor, a meningioma tumor or no tumor.""")
# Définition des catégories de tumeurs
categories = ["pituitary tumor", "no tumor", "meningioma tumor"]
# Selection des images
images = os.listdir(DATA_DIR)
selected_image2 = st.selectbox("Choose an image", images, key="selectionbox_key_1")
# show image
image_path = os.path.join(DATA_DIR, selected_image2)
image = Image.open(image_path)
st.markdown("#### You've selected the following image.")
st.image(image, caption="Image selected", width=300)
if st.button('**Make predictions**', key='another_action_button'):
results_path = r"data/image_classification"
df_results = load_data_pickle(results_path, "results.pkl")
predicted_category = df_results.loc[df_results["image"]==selected_image2,"class"].to_numpy()
# # Prétraitement et prédiction
# image_preprocessed = preprocess(image)
# predicted_tensor, _ = predict(image_preprocessed, model)
# predicted_idx = predicted_tensor.item()
# predicted_category = categories[predicted_idx]
# Affichage de la prédiction avec la catégorie prédite
if predicted_category == "pituitary":
st.warning(f"**Results**: Pituitary tumor was detected. ")
elif predicted_category == "no tumor":
st.success(f"**Results**: No tumor was detected.")
elif predicted_category == "meningnoma":
st.error(f"**Results**: Meningioma was detected.")
# image_path = os.path.join(DATA_DIR, selected_image2)
# image = Image.open(image_path)
# st.image(image, caption="Image selected", width=450)