Spaces:
Runtime error
Runtime error
Upload glaucoma_detection_model.py
Browse files- glaucoma_detection_model.py +283 -0
glaucoma_detection_model.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
from fastai.data.all import *
|
8 |
+
from fastai.vision.all import *
|
9 |
+
import cv2
|
10 |
+
import os
|
11 |
+
from pathlib import Path
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
# Load your CSV file into a pandas DataFrame
|
15 |
+
path_csv_combined = Path('D:\\Documents\\Machine Learning - Glaucoma\\combined_csv.csv')
|
16 |
+
|
17 |
+
# Define the image path with images from all three databases combined
|
18 |
+
path_image_combined = Path('D:\\Documents\\Machine Learning - Glaucoma\\combined images')
|
19 |
+
|
20 |
+
# Load dataframe
|
21 |
+
combined_df = pd.read_csv(path_csv_combined)
|
22 |
+
|
23 |
+
|
24 |
+
# In[2]:
|
25 |
+
|
26 |
+
|
27 |
+
from sklearn.model_selection import train_test_split
|
28 |
+
|
29 |
+
train_df, test_df = train_test_split(combined_df, test_size=0.15, random_state=42, stratify=combined_df['label'])
|
30 |
+
train_df, val_df = train_test_split(train_df, test_size=0.15, random_state=42, stratify=train_df['label'])
|
31 |
+
|
32 |
+
|
33 |
+
# Display the sizes of the datasets
|
34 |
+
print(f"Training set size: {len(train_df)} samples")
|
35 |
+
print(f"Validation set size: {len(val_df)} samples")
|
36 |
+
print(f"Test set size: {len(test_df)} samples")
|
37 |
+
|
38 |
+
|
39 |
+
# In[3]:
|
40 |
+
|
41 |
+
|
42 |
+
print(combined_df['label'].value_counts())
|
43 |
+
|
44 |
+
|
45 |
+
# In[4]:
|
46 |
+
|
47 |
+
|
48 |
+
import matplotlib.pyplot as plt
|
49 |
+
|
50 |
+
combined_df['label'].value_counts().plot(kind='bar')
|
51 |
+
plt.title('Class distribution')
|
52 |
+
plt.xlabel('Class')
|
53 |
+
plt.ylabel('Count')
|
54 |
+
plt.show()
|
55 |
+
|
56 |
+
|
57 |
+
# In[5]:
|
58 |
+
|
59 |
+
|
60 |
+
print(train_df['label'].value_counts())
|
61 |
+
|
62 |
+
|
63 |
+
# In[6]:
|
64 |
+
|
65 |
+
|
66 |
+
import matplotlib.pyplot as plt
|
67 |
+
|
68 |
+
train_df['label'].value_counts().plot(kind='bar')
|
69 |
+
plt.title('Class distribution')
|
70 |
+
plt.xlabel('Class')
|
71 |
+
plt.ylabel('Count')
|
72 |
+
plt.show()
|
73 |
+
|
74 |
+
|
75 |
+
# In[7]:
|
76 |
+
|
77 |
+
|
78 |
+
import cv2
|
79 |
+
import numpy as np
|
80 |
+
from skimage import filters
|
81 |
+
from PIL import Image
|
82 |
+
|
83 |
+
# Define how to get the labels
|
84 |
+
def get_y(row):
|
85 |
+
return row['label'] # adjust this depending on how your csv is structured
|
86 |
+
|
87 |
+
# Define the transformations
|
88 |
+
def custom_transform(image_path):
|
89 |
+
image = cv2.imread(str(image_path)) # Read the image file.
|
90 |
+
if image is None:
|
91 |
+
return None
|
92 |
+
|
93 |
+
# Convert the image from BGR to RGB
|
94 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
95 |
+
|
96 |
+
# Apply filters and transformations
|
97 |
+
# Gaussian filter
|
98 |
+
image = cv2.GaussianBlur(image, (5, 5), 0)
|
99 |
+
|
100 |
+
# Histogram Equalization
|
101 |
+
img_yuv = cv2.cvtColor(image, cv2.COLOR_RGB2YUV)
|
102 |
+
img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0])
|
103 |
+
image = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
|
104 |
+
|
105 |
+
# Median filter
|
106 |
+
image = cv2.medianBlur(image, 3)
|
107 |
+
|
108 |
+
# Bypass filter (leaving the image unchanged)
|
109 |
+
# (add any specific implementation if needed)
|
110 |
+
|
111 |
+
# Sharpening filter
|
112 |
+
kernel = np.array([[0, -1, 0],
|
113 |
+
[-1, 5,-1],
|
114 |
+
[0, -1, 0]])
|
115 |
+
image = cv2.filter2D(image, -1, kernel)
|
116 |
+
|
117 |
+
# Resize the image to a target size of 224x224 pixels.
|
118 |
+
image = cv2.resize(image, (224, 224))
|
119 |
+
return image
|
120 |
+
|
121 |
+
from albumentations import (
|
122 |
+
Compose, Rotate, RandomBrightnessContrast, OpticalDistortion, IAAPerspective
|
123 |
+
)
|
124 |
+
import albumentations.pytorch as A
|
125 |
+
|
126 |
+
def additional_augmentations(image):
|
127 |
+
transform = Compose([
|
128 |
+
Rotate(limit=10, p=0.75), # max_rotate=10.0
|
129 |
+
RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.75), # max_lighting=0.2, p_lighting=0.75
|
130 |
+
OpticalDistortion(distort_limit=0.2, shift_limit=0.2, p=0.75), # max_warp=0.2, p_affine=0.75
|
131 |
+
# No flipping is performed as do_flip and flip_vert are both set to False
|
132 |
+
], p=1) # p=1 ensures that the augmentations are always applied
|
133 |
+
|
134 |
+
augmented_image = transform(image=image)['image']
|
135 |
+
return augmented_image
|
136 |
+
|
137 |
+
def get_x(row, is_test=False):
|
138 |
+
image_path = path_image_combined / (row['id_code'])
|
139 |
+
transformed_image = custom_transform(image_path)
|
140 |
+
|
141 |
+
# Check the label of the current image and apply augmentations if it belongs to the minority class
|
142 |
+
if not is_test and row['label'] == 1:
|
143 |
+
transformed_image = additional_augmentations(transformed_image)
|
144 |
+
|
145 |
+
return Image.fromarray(transformed_image)
|
146 |
+
|
147 |
+
# Define a DataBlock
|
148 |
+
dblock = DataBlock(
|
149 |
+
blocks=(ImageBlock(cls=PILImage), CategoryBlock),
|
150 |
+
get_x=get_x,
|
151 |
+
get_y=get_y,
|
152 |
+
item_tfms=None,
|
153 |
+
batch_tfms=None)
|
154 |
+
|
155 |
+
# Create a DataLoader for training data
|
156 |
+
dls = dblock.dataloaders(train_df,bs = 128)
|
157 |
+
|
158 |
+
|
159 |
+
# In[8]:
|
160 |
+
|
161 |
+
|
162 |
+
#Print the first few rows of the 'df_train' DataFrame
|
163 |
+
print(train_df.head())
|
164 |
+
|
165 |
+
|
166 |
+
# In[9]:
|
167 |
+
|
168 |
+
|
169 |
+
# Extract all the rows from the training dataset where the 'label' column has a value of 0,
|
170 |
+
# which represents the majority class in this context.
|
171 |
+
majority_class = train_df[train_df['label'] == 0]
|
172 |
+
|
173 |
+
# Extract all the rows from the training dataset where the 'label' column has a value of 1,
|
174 |
+
# which represents the minority class in this context.
|
175 |
+
minority_class = train_df[train_df['label'] == 1]
|
176 |
+
|
177 |
+
|
178 |
+
# In[10]:
|
179 |
+
|
180 |
+
|
181 |
+
# Oversample the minority class to have the same number of samples as the majority class
|
182 |
+
oversampled_minority_class = minority_class.sample(n=len(majority_class), replace=True, random_state=42)
|
183 |
+
|
184 |
+
# Concatenate the oversampled minority class DataFrame with the majority class DataFrame to create a balanced dataset
|
185 |
+
oversampled_train_df = pd.concat([majority_class, oversampled_minority_class], axis=0)
|
186 |
+
|
187 |
+
# Shuffle the oversampled DataFrame to ensure a random distribution of classes
|
188 |
+
oversampled_train_df = oversampled_train_df.sample(frac=1, random_state=42).reset_index(drop=True)
|
189 |
+
|
190 |
+
# Create a DataLoader using the balanced DataFrame and a batch size of 128
|
191 |
+
dls = dblock.dataloaders(oversampled_train_df, bs=128)
|
192 |
+
|
193 |
+
|
194 |
+
# In[11]:
|
195 |
+
|
196 |
+
|
197 |
+
#Display a batch of data from the training dataloader
|
198 |
+
dls.show_batch()
|
199 |
+
|
200 |
+
|
201 |
+
# In[12]:
|
202 |
+
|
203 |
+
|
204 |
+
from fastai.metrics import AccumMetric
|
205 |
+
from sklearn.metrics import roc_auc_score
|
206 |
+
|
207 |
+
def custom_roc_auc_score(preds, targs):
|
208 |
+
# preds are assumed to be from a binary classification model with n_out=2
|
209 |
+
# taking the probability of the positive class (usually the second column)
|
210 |
+
probs = preds[:, 1]
|
211 |
+
return roc_auc_score(targs, probs)
|
212 |
+
|
213 |
+
# Now use this custom metric in your learner
|
214 |
+
learn = cnn_learner(dls, resnet50,
|
215 |
+
n_out=2, # For binary classification
|
216 |
+
loss_func=CrossEntropyLossFlat(),
|
217 |
+
metrics=[
|
218 |
+
accuracy,
|
219 |
+
Precision(average='binary'),
|
220 |
+
Recall(average='binary'),
|
221 |
+
F1Score(average='binary'),
|
222 |
+
AccumMetric(custom_roc_auc_score, flatten=False) # Custom ROC AUC
|
223 |
+
],
|
224 |
+
cbs=[
|
225 |
+
EarlyStoppingCallback(monitor='valid_loss', patience=3),
|
226 |
+
SaveModelCallback(monitor='valid_loss', fname='best_model')
|
227 |
+
]
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
# In[13]:
|
232 |
+
|
233 |
+
|
234 |
+
# Train the model
|
235 |
+
# Monitor the loss during training; it should typically decrease over epochs
|
236 |
+
learn.fit_one_cycle(10, 5e-02)
|
237 |
+
|
238 |
+
|
239 |
+
# In[14]:
|
240 |
+
|
241 |
+
|
242 |
+
interp = ClassificationInterpretation.from_learner(learn)
|
243 |
+
interp.plot_confusion_matrix(figsize=(8,8))
|
244 |
+
|
245 |
+
|
246 |
+
# In[15]:
|
247 |
+
|
248 |
+
|
249 |
+
from sklearn.metrics import accuracy_score, precision_score, classification_report
|
250 |
+
|
251 |
+
# Assuming dls is your DataLoaders object
|
252 |
+
test_dl = dls.test_dl(test_df, with_labels=True)
|
253 |
+
|
254 |
+
# Modify the get_x function in the DataLoader to indicate it's for testing
|
255 |
+
test_dl.dataset.get_x = partial(get_x, is_test=True)
|
256 |
+
|
257 |
+
# Get predictions and targets
|
258 |
+
preds, targs = learn.get_preds(dl=test_dl)
|
259 |
+
|
260 |
+
# Get the prediction indices
|
261 |
+
preds_argmax = preds.argmax(dim=-1)
|
262 |
+
|
263 |
+
# Calculate and print accuracy, precision, and other metrics
|
264 |
+
accuracy = accuracy_score(targs.numpy(), preds_argmax.numpy())
|
265 |
+
print(f'Accuracy: {accuracy * 100:.2f}%')
|
266 |
+
|
267 |
+
precision = precision_score(targs.numpy(), preds_argmax.numpy())
|
268 |
+
print(f'Precision: {precision * 100:.2f}%')
|
269 |
+
|
270 |
+
report = classification_report(targs.numpy(), preds_argmax.numpy())
|
271 |
+
print(report)
|
272 |
+
|
273 |
+
|
274 |
+
# In[16]:
|
275 |
+
|
276 |
+
|
277 |
+
import os
|
278 |
+
from fastai.vision.all import *
|
279 |
+
|
280 |
+
# Export the model to the directory
|
281 |
+
model_export_path = 'D:/Documents/Machine Learning - Glaucoma/your_model.pkl'
|
282 |
+
learn.export(model_export_path)
|
283 |
+
|