import streamlit as st import nibabel as nib import os.path import os from nilearn import plotting import torch from monai.transforms import ( EnsureChannelFirst, Compose, Resize, ScaleIntensity, LoadImage, ) import torch.nn.functional as F import numpy as np from statistics import mean from constants import CLASSES from model.download_model import load_model from huggingface_hub import hf_hub_download #SET PAGE TITLE st.set_page_config(page_title = "Alzheimer Classifier", page_icon = ":brain:", layout = "wide") #LOAD MODEL model = load_model() #SET NIFTI FILE LOADING AND PROCESSING CONFIGURATIONS transforms = Compose([ ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96)), ]) load_img = LoadImage(image_only=True) #SET CLASSES class_names = CLASSES #SET STREAMLIT SESSION STATES if 'clicked_pp' not in st.session_state: st.session_state.clicked_pp = False if 'clicked_pred' not in st.session_state: st.session_state.clicked_pred = False def click_pp_true(): st.session_state.clicked_pp = True def click_pred_true(): st.session_state.clicked_pred = True def click_false(): st.session_state.clicked_pp = False st.session_state.clicked_pred = False ########################################################### ###################### STREAMLIT APP ###################### ########################################################### with st.sidebar: st.title("Alzheimer Classifier Demo") img_path = st.selectbox( "Select Image", tuple(class_names), on_change= click_false, ) col1, col2 = st.columns((1,1)) with col1: run_preprocess = st.button("Preprocess Image", on_click=click_pp_true) if st.session_state.clicked_pp: with col2: run_pred = st.button("Run Prediction", on_click= click_pred_true) with st.container(): if img_path != "": if st.session_state.clicked_pp: if st.session_state.clicked_pred == False: with st.container(): pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image) st.sidebar.write("#") y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]])) x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]])) z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]])) plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True) st.pyplot() else: with st.container(): pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image) st.sidebar.write("#") y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]])) x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]])) z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]])) img_array = load_img(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) new_data = transforms(img_array) new_data_tensor = torch.from_numpy(np.array([new_data])) with torch.no_grad(): output = model(new_data_tensor) probabilities = F.softmax(output, dim=1) probabilities_np = probabilities.numpy() probabilities_item = probabilities_np[0] probabilities_percentage = probabilities_item * 100 predicted_class_index = np.argmax(probabilities_np[0]) predicted_class_name = class_names[predicted_class_index] predicted_probability = probabilities_percentage[predicted_class_index] st.sidebar.write("#") if predicted_class_index == 0: color_name = "red" elif predicted_class_index == 1: color_name = "blue" elif predicted_class_index == 2: color_name = "green" if predicted_probability > 80: color_prob = "green" elif predicted_probability > 60: color_prob = "yellow" else: color_prob = "red" class_col, pred_col = st.columns((1,1)) with class_col: st.write(f"### Predicted Class: :{color_name}[{predicted_class_name}]") with pred_col: st.write(f"### Probability: :{color_prob}[{predicted_probability:.2f}%]") plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True) st.pyplot() else: raw_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="raw", filename = img_path + ".nii")) bounds_raw = plotting.find_cuts._get_auto_mask_bounds(raw_image) st.sidebar.write("#") y_value_raw = st.sidebar.slider('Move the slider to adjust the coronal cut', bounds_raw[1][0], bounds_raw[1][1], mean([bounds_raw[1][0], bounds_raw[1][1]])) x_value_raw = st.sidebar.slider('Move the slider to adjust the sagittal cut', bounds_raw[0][0], bounds_raw[0][1], mean([bounds_raw[0][0], bounds_raw[0][1]])) z_value_raw = st.sidebar.slider('Move the slider to adjust the axial cut', bounds_raw[2][0], bounds_raw[2][1], mean([bounds_raw[2][0], bounds_raw[2][1]])) plotting.plot_img(raw_image, cmap = "grey", cut_coords=(x_value_raw,y_value_raw,z_value_raw), black_bg=True) st.pyplot()