Ethium commited on
Commit
75192d4
1 Parent(s): e3d027a

Upload glaucoma_detection_model.py

Browse files
Files changed (1) hide show
  1. glaucoma_detection_model.py +283 -0
glaucoma_detection_model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ from fastai.data.all import *
8
+ from fastai.vision.all import *
9
+ import cv2
10
+ import os
11
+ from pathlib import Path
12
+ import pandas as pd
13
+
14
+ # Load your CSV file into a pandas DataFrame
15
+ path_csv_combined = Path('D:\\Documents\\Machine Learning - Glaucoma\\combined_csv.csv')
16
+
17
+ # Define the image path with images from all three databases combined
18
+ path_image_combined = Path('D:\\Documents\\Machine Learning - Glaucoma\\combined images')
19
+
20
+ # Load dataframe
21
+ combined_df = pd.read_csv(path_csv_combined)
22
+
23
+
24
+ # In[2]:
25
+
26
+
27
+ from sklearn.model_selection import train_test_split
28
+
29
+ train_df, test_df = train_test_split(combined_df, test_size=0.15, random_state=42, stratify=combined_df['label'])
30
+ train_df, val_df = train_test_split(train_df, test_size=0.15, random_state=42, stratify=train_df['label'])
31
+
32
+
33
+ # Display the sizes of the datasets
34
+ print(f"Training set size: {len(train_df)} samples")
35
+ print(f"Validation set size: {len(val_df)} samples")
36
+ print(f"Test set size: {len(test_df)} samples")
37
+
38
+
39
+ # In[3]:
40
+
41
+
42
+ print(combined_df['label'].value_counts())
43
+
44
+
45
+ # In[4]:
46
+
47
+
48
+ import matplotlib.pyplot as plt
49
+
50
+ combined_df['label'].value_counts().plot(kind='bar')
51
+ plt.title('Class distribution')
52
+ plt.xlabel('Class')
53
+ plt.ylabel('Count')
54
+ plt.show()
55
+
56
+
57
+ # In[5]:
58
+
59
+
60
+ print(train_df['label'].value_counts())
61
+
62
+
63
+ # In[6]:
64
+
65
+
66
+ import matplotlib.pyplot as plt
67
+
68
+ train_df['label'].value_counts().plot(kind='bar')
69
+ plt.title('Class distribution')
70
+ plt.xlabel('Class')
71
+ plt.ylabel('Count')
72
+ plt.show()
73
+
74
+
75
+ # In[7]:
76
+
77
+
78
+ import cv2
79
+ import numpy as np
80
+ from skimage import filters
81
+ from PIL import Image
82
+
83
+ # Define how to get the labels
84
+ def get_y(row):
85
+ return row['label'] # adjust this depending on how your csv is structured
86
+
87
+ # Define the transformations
88
+ def custom_transform(image_path):
89
+ image = cv2.imread(str(image_path)) # Read the image file.
90
+ if image is None:
91
+ return None
92
+
93
+ # Convert the image from BGR to RGB
94
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
95
+
96
+ # Apply filters and transformations
97
+ # Gaussian filter
98
+ image = cv2.GaussianBlur(image, (5, 5), 0)
99
+
100
+ # Histogram Equalization
101
+ img_yuv = cv2.cvtColor(image, cv2.COLOR_RGB2YUV)
102
+ img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0])
103
+ image = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
104
+
105
+ # Median filter
106
+ image = cv2.medianBlur(image, 3)
107
+
108
+ # Bypass filter (leaving the image unchanged)
109
+ # (add any specific implementation if needed)
110
+
111
+ # Sharpening filter
112
+ kernel = np.array([[0, -1, 0],
113
+ [-1, 5,-1],
114
+ [0, -1, 0]])
115
+ image = cv2.filter2D(image, -1, kernel)
116
+
117
+ # Resize the image to a target size of 224x224 pixels.
118
+ image = cv2.resize(image, (224, 224))
119
+ return image
120
+
121
+ from albumentations import (
122
+ Compose, Rotate, RandomBrightnessContrast, OpticalDistortion, IAAPerspective
123
+ )
124
+ import albumentations.pytorch as A
125
+
126
+ def additional_augmentations(image):
127
+ transform = Compose([
128
+ Rotate(limit=10, p=0.75), # max_rotate=10.0
129
+ RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.75), # max_lighting=0.2, p_lighting=0.75
130
+ OpticalDistortion(distort_limit=0.2, shift_limit=0.2, p=0.75), # max_warp=0.2, p_affine=0.75
131
+ # No flipping is performed as do_flip and flip_vert are both set to False
132
+ ], p=1) # p=1 ensures that the augmentations are always applied
133
+
134
+ augmented_image = transform(image=image)['image']
135
+ return augmented_image
136
+
137
+ def get_x(row, is_test=False):
138
+ image_path = path_image_combined / (row['id_code'])
139
+ transformed_image = custom_transform(image_path)
140
+
141
+ # Check the label of the current image and apply augmentations if it belongs to the minority class
142
+ if not is_test and row['label'] == 1:
143
+ transformed_image = additional_augmentations(transformed_image)
144
+
145
+ return Image.fromarray(transformed_image)
146
+
147
+ # Define a DataBlock
148
+ dblock = DataBlock(
149
+ blocks=(ImageBlock(cls=PILImage), CategoryBlock),
150
+ get_x=get_x,
151
+ get_y=get_y,
152
+ item_tfms=None,
153
+ batch_tfms=None)
154
+
155
+ # Create a DataLoader for training data
156
+ dls = dblock.dataloaders(train_df,bs = 128)
157
+
158
+
159
+ # In[8]:
160
+
161
+
162
+ #Print the first few rows of the 'df_train' DataFrame
163
+ print(train_df.head())
164
+
165
+
166
+ # In[9]:
167
+
168
+
169
+ # Extract all the rows from the training dataset where the 'label' column has a value of 0,
170
+ # which represents the majority class in this context.
171
+ majority_class = train_df[train_df['label'] == 0]
172
+
173
+ # Extract all the rows from the training dataset where the 'label' column has a value of 1,
174
+ # which represents the minority class in this context.
175
+ minority_class = train_df[train_df['label'] == 1]
176
+
177
+
178
+ # In[10]:
179
+
180
+
181
+ # Oversample the minority class to have the same number of samples as the majority class
182
+ oversampled_minority_class = minority_class.sample(n=len(majority_class), replace=True, random_state=42)
183
+
184
+ # Concatenate the oversampled minority class DataFrame with the majority class DataFrame to create a balanced dataset
185
+ oversampled_train_df = pd.concat([majority_class, oversampled_minority_class], axis=0)
186
+
187
+ # Shuffle the oversampled DataFrame to ensure a random distribution of classes
188
+ oversampled_train_df = oversampled_train_df.sample(frac=1, random_state=42).reset_index(drop=True)
189
+
190
+ # Create a DataLoader using the balanced DataFrame and a batch size of 128
191
+ dls = dblock.dataloaders(oversampled_train_df, bs=128)
192
+
193
+
194
+ # In[11]:
195
+
196
+
197
+ #Display a batch of data from the training dataloader
198
+ dls.show_batch()
199
+
200
+
201
+ # In[12]:
202
+
203
+
204
+ from fastai.metrics import AccumMetric
205
+ from sklearn.metrics import roc_auc_score
206
+
207
+ def custom_roc_auc_score(preds, targs):
208
+ # preds are assumed to be from a binary classification model with n_out=2
209
+ # taking the probability of the positive class (usually the second column)
210
+ probs = preds[:, 1]
211
+ return roc_auc_score(targs, probs)
212
+
213
+ # Now use this custom metric in your learner
214
+ learn = cnn_learner(dls, resnet50,
215
+ n_out=2, # For binary classification
216
+ loss_func=CrossEntropyLossFlat(),
217
+ metrics=[
218
+ accuracy,
219
+ Precision(average='binary'),
220
+ Recall(average='binary'),
221
+ F1Score(average='binary'),
222
+ AccumMetric(custom_roc_auc_score, flatten=False) # Custom ROC AUC
223
+ ],
224
+ cbs=[
225
+ EarlyStoppingCallback(monitor='valid_loss', patience=3),
226
+ SaveModelCallback(monitor='valid_loss', fname='best_model')
227
+ ]
228
+ )
229
+
230
+
231
+ # In[13]:
232
+
233
+
234
+ # Train the model
235
+ # Monitor the loss during training; it should typically decrease over epochs
236
+ learn.fit_one_cycle(10, 5e-02)
237
+
238
+
239
+ # In[14]:
240
+
241
+
242
+ interp = ClassificationInterpretation.from_learner(learn)
243
+ interp.plot_confusion_matrix(figsize=(8,8))
244
+
245
+
246
+ # In[15]:
247
+
248
+
249
+ from sklearn.metrics import accuracy_score, precision_score, classification_report
250
+
251
+ # Assuming dls is your DataLoaders object
252
+ test_dl = dls.test_dl(test_df, with_labels=True)
253
+
254
+ # Modify the get_x function in the DataLoader to indicate it's for testing
255
+ test_dl.dataset.get_x = partial(get_x, is_test=True)
256
+
257
+ # Get predictions and targets
258
+ preds, targs = learn.get_preds(dl=test_dl)
259
+
260
+ # Get the prediction indices
261
+ preds_argmax = preds.argmax(dim=-1)
262
+
263
+ # Calculate and print accuracy, precision, and other metrics
264
+ accuracy = accuracy_score(targs.numpy(), preds_argmax.numpy())
265
+ print(f'Accuracy: {accuracy * 100:.2f}%')
266
+
267
+ precision = precision_score(targs.numpy(), preds_argmax.numpy())
268
+ print(f'Precision: {precision * 100:.2f}%')
269
+
270
+ report = classification_report(targs.numpy(), preds_argmax.numpy())
271
+ print(report)
272
+
273
+
274
+ # In[16]:
275
+
276
+
277
+ import os
278
+ from fastai.vision.all import *
279
+
280
+ # Export the model to the directory
281
+ model_export_path = 'D:/Documents/Machine Learning - Glaucoma/your_model.pkl'
282
+ learn.export(model_export_path)
283
+