Thomas Lucchetta
fix warnings
7abca29 unverified
import streamlit as st
import nibabel as nib
import os.path
import os
from nilearn import plotting
import torch
from monai.transforms import (
EnsureChannelFirst,
Compose,
Resize,
ScaleIntensity,
LoadImage,
)
import torch.nn.functional as F
import numpy as np
from statistics import mean
from constants import CLASSES
from model.download_model import load_model
from huggingface_hub import hf_hub_download
#SET PAGE TITLE
st.set_page_config(page_title = "Alzheimer Classifier", page_icon = ":brain:", layout = "wide")
#LOAD MODEL
model = load_model()
#SET NIFTI FILE LOADING AND PROCESSING CONFIGURATIONS
transforms = Compose([
ScaleIntensity(),
EnsureChannelFirst(),
Resize((96, 96, 96)),
])
load_img = LoadImage(image_only=True)
#SET CLASSES
class_names = CLASSES
#SET STREAMLIT SESSION STATES
if 'clicked_pp' not in st.session_state:
st.session_state.clicked_pp = False
if 'clicked_pred' not in st.session_state:
st.session_state.clicked_pred = False
def click_pp_true():
st.session_state.clicked_pp = True
def click_pred_true():
st.session_state.clicked_pred = True
def click_false():
st.session_state.clicked_pp = False
st.session_state.clicked_pred = False
###########################################################
###################### STREAMLIT APP ######################
###########################################################
with st.sidebar:
st.title("Alzheimer Classifier Demo")
img_path = st.selectbox(
"Select Image",
tuple(class_names),
on_change= click_false,
)
col1, col2 = st.columns((1,1))
with col1:
run_preprocess = st.button("Preprocess Image", on_click=click_pp_true)
if st.session_state.clicked_pp:
with col2:
run_pred = st.button("Run Prediction", on_click= click_pred_true)
with st.container():
if img_path != "":
if st.session_state.clicked_pp:
if st.session_state.clicked_pred == False:
with st.container():
pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz"))
bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image)
st.sidebar.write("#")
y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]]))
x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]]))
z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]]))
plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True)
st.pyplot()
else:
with st.container():
pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz"))
bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image)
st.sidebar.write("#")
y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]]))
x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]]))
z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]]))
img_array = load_img(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz"))
new_data = transforms(img_array)
new_data_tensor = torch.from_numpy(np.array([new_data]))
with torch.no_grad():
output = model(new_data_tensor)
probabilities = F.softmax(output, dim=1)
probabilities_np = probabilities.numpy()
probabilities_item = probabilities_np[0]
probabilities_percentage = probabilities_item * 100
predicted_class_index = np.argmax(probabilities_np[0])
predicted_class_name = class_names[predicted_class_index]
predicted_probability = probabilities_percentage[predicted_class_index]
st.sidebar.write("#")
if predicted_class_index == 0:
color_name = "red"
elif predicted_class_index == 1:
color_name = "blue"
elif predicted_class_index == 2:
color_name = "green"
if predicted_probability > 80:
color_prob = "green"
elif predicted_probability > 60:
color_prob = "yellow"
else:
color_prob = "red"
class_col, pred_col = st.columns((1,1))
with class_col:
st.write(f"### Predicted Class: :{color_name}[{predicted_class_name}]")
with pred_col:
st.write(f"### Probability: :{color_prob}[{predicted_probability:.2f}%]")
plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True)
st.pyplot()
else:
raw_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="raw", filename = img_path + ".nii"))
bounds_raw = plotting.find_cuts._get_auto_mask_bounds(raw_image)
st.sidebar.write("#")
y_value_raw = st.sidebar.slider('Move the slider to adjust the coronal cut', bounds_raw[1][0], bounds_raw[1][1], mean([bounds_raw[1][0], bounds_raw[1][1]]))
x_value_raw = st.sidebar.slider('Move the slider to adjust the sagittal cut', bounds_raw[0][0], bounds_raw[0][1], mean([bounds_raw[0][0], bounds_raw[0][1]]))
z_value_raw = st.sidebar.slider('Move the slider to adjust the axial cut', bounds_raw[2][0], bounds_raw[2][1], mean([bounds_raw[2][0], bounds_raw[2][1]]))
plotting.plot_img(raw_image, cmap = "grey", cut_coords=(x_value_raw,y_value_raw,z_value_raw), black_bg=True)
st.pyplot()