Thomas Lucchetta commited on
Commit
a3a57c6
·
unverified ·
1 Parent(s): 38c9612

Add files via upload

Browse files
Files changed (6) hide show
  1. README.md +9 -1
  2. app.py +169 -0
  3. constants.py +15 -0
  4. download_pictures.py +6 -0
  5. model/download_model.py +30 -0
  6. requirements.txt +9 -0
README.md CHANGED
@@ -1 +1,9 @@
1
- # Alzheimer-Classifier-Demo
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ emoji: 🧠
4
+ title: Alzheimer Classifier
5
+ sdk: streamlit
6
+ colorFrom: gray
7
+ colorTo: purple
8
+ ---
9
+ # MRI-classifier-streamlit
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import nibabel as nib
3
+ import os.path
4
+ import os
5
+ from nilearn import plotting
6
+
7
+ import torch
8
+ from monai.transforms import (
9
+ EnsureChannelFirst,
10
+ Compose,
11
+ Resize,
12
+ ScaleIntensity,
13
+ LoadImage,
14
+ )
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ from statistics import mean
18
+
19
+ from constants import CLASSES
20
+ from model.download_model import load_model
21
+ from download_pictures import download_images
22
+
23
+ #SET PAGE TITLE
24
+ st.set_page_config(page_title = "Alzheimer Classifier", page_icon = ":brain:", layout = "wide")
25
+
26
+ #LOAD MODEL
27
+ model = load_model()
28
+
29
+ #LOAD IMAGES
30
+ download_images()
31
+
32
+ #SET NIFTI FILE LOADING AND PROCESSING CONFIGURATIONS
33
+ transforms = Compose([
34
+ ScaleIntensity(),
35
+ EnsureChannelFirst(),
36
+ Resize((96, 96, 96)),
37
+ ])
38
+ load_img = LoadImage(image_only=True)
39
+
40
+ #SET CLASSES
41
+ class_names = CLASSES
42
+
43
+ #SET IMAGE PATH LIST FOR STREAMLIT'S SELECT BOX
44
+ filelist=[""]
45
+ for root, dirs, files in os.walk("images/raw"):
46
+ for file in files:
47
+ filename=file.split(".")[0]
48
+ filelist.append(filename)
49
+ filelist = tuple(filelist)
50
+
51
+ #SILENCE STREAMIT WARNING
52
+ st.set_option('deprecation.showPyplotGlobalUse', False)
53
+
54
+ #SET STREAMLIT SESSION STATES
55
+ if 'clicked_pp' not in st.session_state:
56
+ st.session_state.clicked_pp = False
57
+
58
+ if 'clicked_pred' not in st.session_state:
59
+ st.session_state.clicked_pred = False
60
+
61
+ def click_pp_true():
62
+ st.session_state.clicked_pp = True
63
+
64
+ def click_pred_true():
65
+ st.session_state.clicked_pred = True
66
+
67
+ def click_false():
68
+ st.session_state.clicked_pp = False
69
+ st.session_state.clicked_pred = False
70
+
71
+ ###########################################################
72
+ ###################### STREAMLIT APP ######################
73
+ ###########################################################
74
+
75
+ with st.sidebar:
76
+ st.title("Alzheimer Classifier Demo")
77
+ img_path = st.selectbox(
78
+ "Select Image",
79
+ filelist,
80
+ on_change= click_false,
81
+ )
82
+ col1, col2 = st.columns((1,1))
83
+ with col1:
84
+ run_preprocess = st.button("Preprocess Image", on_click=click_pp_true)
85
+ if st.session_state.clicked_pp:
86
+ with col2:
87
+ run_pred = st.button("Run Prediction", on_click= click_pred_true)
88
+
89
+ with st.container():
90
+ if img_path != "":
91
+ if st.session_state.clicked_pp:
92
+ if st.session_state.clicked_pred == False:
93
+ with st.container():
94
+ pred_image = nib.load(os.path.join("images/preprocessed", img_path + ".nii.gz"))
95
+
96
+ bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image)
97
+
98
+ st.sidebar.write("#")
99
+ y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]]))
100
+ x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]]))
101
+ z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]]))
102
+
103
+ plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True)
104
+ st.pyplot()
105
+
106
+ else:
107
+ with st.container():
108
+ pred_image = nib.load(os.path.join("images/preprocessed", img_path + ".nii.gz"))
109
+
110
+ bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image)
111
+
112
+ st.sidebar.write("#")
113
+ y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]]))
114
+ x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]]))
115
+ z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]]))
116
+
117
+ img_array = load_img(os.path.join("images/preprocessed", img_path + ".nii.gz"))
118
+ new_data = transforms(img_array)
119
+ new_data_tensor = torch.from_numpy(np.array([new_data]))
120
+
121
+ with torch.no_grad():
122
+ output = model(new_data_tensor)
123
+
124
+ probabilities = F.softmax(output, dim=1)
125
+ probabilities_np = probabilities.numpy()
126
+ probabilities_item = probabilities_np[0]
127
+ probabilities_percentage = probabilities_item * 100
128
+ predicted_class_index = np.argmax(probabilities_np[0])
129
+ predicted_class_name = class_names[predicted_class_index]
130
+ predicted_probability = probabilities_percentage[predicted_class_index]
131
+
132
+ st.sidebar.write("#")
133
+ if predicted_class_index == 0:
134
+ color_name = "red"
135
+ elif predicted_class_index == 1:
136
+ color_name = "blue"
137
+ elif predicted_class_index == 2:
138
+ color_name = "green"
139
+
140
+ if predicted_probability > 80:
141
+ color_prob = "green"
142
+ elif predicted_probability > 60:
143
+ color_prob = "yellow"
144
+ else:
145
+ color_prob = "red"
146
+
147
+ class_col, pred_col = st.columns((1,1))
148
+
149
+ with class_col:
150
+ st.write(f"### Predicted Class: :{color_name}[{predicted_class_name}]")
151
+
152
+ with pred_col:
153
+ st.write(f"### Probability: :{color_prob}[{predicted_probability:.2f}%]")
154
+
155
+ plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True)
156
+ st.pyplot()
157
+
158
+ else:
159
+ raw_image = nib.load(os.path.join("images/raw", img_path + ".nii"))
160
+
161
+ bounds_raw = plotting.find_cuts._get_auto_mask_bounds(raw_image)
162
+
163
+ st.sidebar.write("#")
164
+ y_value_raw = st.sidebar.slider('Move the slider to adjust the coronal cut', bounds_raw[1][0], bounds_raw[1][1], mean([bounds_raw[1][0], bounds_raw[1][1]]))
165
+ x_value_raw = st.sidebar.slider('Move the slider to adjust the sagittal cut', bounds_raw[0][0], bounds_raw[0][1], mean([bounds_raw[0][0], bounds_raw[0][1]]))
166
+ z_value_raw = st.sidebar.slider('Move the slider to adjust the axial cut', bounds_raw[2][0], bounds_raw[2][1], mean([bounds_raw[2][0], bounds_raw[2][1]]))
167
+
168
+ plotting.plot_img(raw_image, cmap = "grey", cut_coords=(x_value_raw,y_value_raw,z_value_raw), black_bg=True)
169
+ st.pyplot()
constants.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ # Root dir
8
+ ROOT_DIR = os.getcwd()
9
+
10
+ # Model checkpoints and repo names
11
+ MODEL_FILENAME = os.getenv("MODEL_FILENAME")
12
+ HF_MODEL_REPO_NAME = os.getenv("HF_MODEL_REPO_NAME")
13
+
14
+ # Other constants
15
+ CLASSES = ["Alzheimer's Desease", "Mild Cognitive Impairment", "Control"]
download_pictures.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from constants import HF_MODEL_REPO_NAME, ROOT_DIR
3
+ import os
4
+
5
+ def download_images():
6
+ snapshot_download(repo_id=HF_MODEL_REPO_NAME, repo_type="dataset", local_dir=os.path.join(ROOT_DIR, "images"))
model/download_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import os
3
+
4
+ import monai.networks.nets as nets
5
+ import torch
6
+
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from constants import ROOT_DIR, MODEL_FILENAME, HF_MODEL_REPO_NAME
10
+
11
+ def load_model():
12
+ """
13
+ Load pretrained model
14
+ """
15
+
16
+ model_path = os.path.join(ROOT_DIR, "model", MODEL_FILENAME)
17
+
18
+ # If model doesnt exist download from huggingface
19
+ if not os.path.exists(model_path):
20
+ hf_hub_download(HF_MODEL_REPO_NAME, MODEL_FILENAME, local_dir=os.path.join(ROOT_DIR, "model"))
21
+
22
+ model = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=3)
23
+ if torch.cuda.is_available():
24
+ checkpoint = torch.load(model_path)
25
+ else:
26
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
27
+ model.load_state_dict(checkpoint)
28
+ model.eval()
29
+
30
+ return model
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ huggingface-hub==0.20.1
2
+ matplotlib==3.8.2
3
+ monai==1.3.0
4
+ nibabel==5.2.0
5
+ nilearn==0.10.2
6
+ numpy==1.26.2
7
+ python-dotenv==1.0.0
8
+ streamlit==1.29.0
9
+ torch==2.1.1