|
import streamlit as st |
|
import nibabel as nib |
|
import os.path |
|
import os |
|
from nilearn import plotting |
|
|
|
import torch |
|
from monai.transforms import ( |
|
EnsureChannelFirst, |
|
Compose, |
|
Resize, |
|
ScaleIntensity, |
|
LoadImage, |
|
) |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from statistics import mean |
|
|
|
from constants import CLASSES |
|
from model.download_model import load_model |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
st.set_page_config(page_title = "Alzheimer Classifier", page_icon = ":brain:", layout = "wide") |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
transforms = Compose([ |
|
ScaleIntensity(), |
|
EnsureChannelFirst(), |
|
Resize((96, 96, 96)), |
|
]) |
|
load_img = LoadImage(image_only=True) |
|
|
|
|
|
class_names = CLASSES |
|
|
|
|
|
if 'clicked_pp' not in st.session_state: |
|
st.session_state.clicked_pp = False |
|
|
|
if 'clicked_pred' not in st.session_state: |
|
st.session_state.clicked_pred = False |
|
|
|
def click_pp_true(): |
|
st.session_state.clicked_pp = True |
|
|
|
def click_pred_true(): |
|
st.session_state.clicked_pred = True |
|
|
|
def click_false(): |
|
st.session_state.clicked_pp = False |
|
st.session_state.clicked_pred = False |
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.title("Alzheimer Classifier Demo") |
|
img_path = st.selectbox( |
|
"Select Image", |
|
tuple(class_names), |
|
on_change= click_false, |
|
) |
|
col1, col2 = st.columns((1,1)) |
|
with col1: |
|
run_preprocess = st.button("Preprocess Image", on_click=click_pp_true) |
|
if st.session_state.clicked_pp: |
|
with col2: |
|
run_pred = st.button("Run Prediction", on_click= click_pred_true) |
|
|
|
with st.container(): |
|
if img_path != "": |
|
if st.session_state.clicked_pp: |
|
if st.session_state.clicked_pred == False: |
|
with st.container(): |
|
pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) |
|
|
|
bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image) |
|
|
|
st.sidebar.write("#") |
|
y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]])) |
|
x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]])) |
|
z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]])) |
|
|
|
plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True) |
|
st.pyplot() |
|
|
|
else: |
|
with st.container(): |
|
pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) |
|
|
|
bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image) |
|
|
|
st.sidebar.write("#") |
|
y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]])) |
|
x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]])) |
|
z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]])) |
|
|
|
img_array = load_img(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) |
|
new_data = transforms(img_array) |
|
new_data_tensor = torch.from_numpy(np.array([new_data])) |
|
|
|
with torch.no_grad(): |
|
output = model(new_data_tensor) |
|
|
|
probabilities = F.softmax(output, dim=1) |
|
probabilities_np = probabilities.numpy() |
|
probabilities_item = probabilities_np[0] |
|
probabilities_percentage = probabilities_item * 100 |
|
predicted_class_index = np.argmax(probabilities_np[0]) |
|
predicted_class_name = class_names[predicted_class_index] |
|
predicted_probability = probabilities_percentage[predicted_class_index] |
|
|
|
st.sidebar.write("#") |
|
if predicted_class_index == 0: |
|
color_name = "red" |
|
elif predicted_class_index == 1: |
|
color_name = "blue" |
|
elif predicted_class_index == 2: |
|
color_name = "green" |
|
|
|
if predicted_probability > 80: |
|
color_prob = "green" |
|
elif predicted_probability > 60: |
|
color_prob = "yellow" |
|
else: |
|
color_prob = "red" |
|
|
|
class_col, pred_col = st.columns((1,1)) |
|
|
|
with class_col: |
|
st.write(f"### Predicted Class: :{color_name}[{predicted_class_name}]") |
|
|
|
with pred_col: |
|
st.write(f"### Probability: :{color_prob}[{predicted_probability:.2f}%]") |
|
|
|
plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True) |
|
st.pyplot() |
|
|
|
else: |
|
raw_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="raw", filename = img_path + ".nii")) |
|
|
|
bounds_raw = plotting.find_cuts._get_auto_mask_bounds(raw_image) |
|
|
|
st.sidebar.write("#") |
|
y_value_raw = st.sidebar.slider('Move the slider to adjust the coronal cut', bounds_raw[1][0], bounds_raw[1][1], mean([bounds_raw[1][0], bounds_raw[1][1]])) |
|
x_value_raw = st.sidebar.slider('Move the slider to adjust the sagittal cut', bounds_raw[0][0], bounds_raw[0][1], mean([bounds_raw[0][0], bounds_raw[0][1]])) |
|
z_value_raw = st.sidebar.slider('Move the slider to adjust the axial cut', bounds_raw[2][0], bounds_raw[2][1], mean([bounds_raw[2][0], bounds_raw[2][1]])) |
|
|
|
plotting.plot_img(raw_image, cmap = "grey", cut_coords=(x_value_raw,y_value_raw,z_value_raw), black_bg=True) |
|
st.pyplot() |
|
|