DmitriiKhizbullin
commited on
Commit
•
f036ad4
1
Parent(s):
9d09bdd
Partial docstrings
Browse files- README.md +26 -4
- app.py +34 -2
- environment.yml +0 -2
- labelmap.py +2 -0
- requirements.txt +2 -0
- train.py +69 -15
README.md
CHANGED
@@ -1,8 +1,30 @@
|
|
1 |
-
#
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
Create conda environment from YAML:
|
6 |
```bash
|
7 |
-
mamba env create -n
|
8 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diabetic Retinopathy Detection with AI
|
2 |
|
3 |
+
## Setup
|
4 |
+
|
5 |
+
### Gradio app environment
|
6 |
+
|
7 |
+
TODO
|
8 |
+
|
9 |
+
### Training environment
|
10 |
|
11 |
Create conda environment from YAML:
|
12 |
```bash
|
13 |
+
mamba env create -n retinopathy_train -f environment.yml
|
14 |
+
```
|
15 |
+
|
16 |
+
Download the data from [Kaggle](https://www.kaggle.com/competitions/diabetic-retinopathy-detection/data) or use kaggle API:
|
17 |
+
|
18 |
+
```bash
|
19 |
+
pip install kaggle
|
20 |
+
kaggle competitions download -c diabetic-retinopathy-detection
|
21 |
+
mkdir retinopathy_data/
|
22 |
+
unzip diabetic-retinopathy-detection.zip -d retinopathy_data/
|
23 |
+
```
|
24 |
+
|
25 |
+
Launch training:
|
26 |
+
```bash
|
27 |
+
conda activate retinopathy_train
|
28 |
+
python train.py
|
29 |
+
```
|
30 |
+
The trained model will be put into `lightning_logs/`.
|
app.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
from typing import Tuple, Optional, Dict, List
|
6 |
import glob
|
7 |
from collections import defaultdict
|
8 |
|
@@ -13,7 +13,10 @@ from labelmap import DR_LABELMAP
|
|
13 |
|
14 |
|
15 |
class App:
|
|
|
|
|
16 |
def __init__(self) -> None:
|
|
|
17 |
|
18 |
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
|
19 |
|
@@ -66,9 +69,19 @@ class App:
|
|
66 |
self.ui = ui
|
67 |
|
68 |
def launch(self) -> None:
|
|
|
69 |
self.ui.queue().launch(share=True)
|
70 |
|
71 |
-
def predict(self, image: Optional[np.ndarray]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
if image is None:
|
73 |
return dict()
|
74 |
cls_name, prob, probs = self._infer(image)
|
@@ -79,6 +92,19 @@ class App:
|
|
79 |
return probs_dict
|
80 |
|
81 |
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
assert isinstance(self.model, ResNetForImageClassification)
|
83 |
|
84 |
inputs = self.image_processor(image_chw, return_tensors="pt")
|
@@ -98,6 +124,11 @@ class App:
|
|
98 |
|
99 |
@staticmethod
|
100 |
def _load_example_lists() -> Dict[int, List[str]]:
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
|
103 |
|
@@ -115,6 +146,7 @@ class App:
|
|
115 |
|
116 |
|
117 |
def main():
|
|
|
118 |
app = App()
|
119 |
app.launch()
|
120 |
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
from typing import Tuple, Optional, Dict, List, Dict
|
6 |
import glob
|
7 |
from collections import defaultdict
|
8 |
|
|
|
13 |
|
14 |
|
15 |
class App:
|
16 |
+
""" Demonstration of the Diabetic Retinopathy model as a Gradio app. """
|
17 |
+
|
18 |
def __init__(self) -> None:
|
19 |
+
""" Constructor. """
|
20 |
|
21 |
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
|
22 |
|
|
|
69 |
self.ui = ui
|
70 |
|
71 |
def launch(self) -> None:
|
72 |
+
""" Launch the application, blocking. """
|
73 |
self.ui.queue().launch(share=True)
|
74 |
|
75 |
+
def predict(self, image: Optional[np.ndarray]) -> Dict[str, float]:
|
76 |
+
""" Gradio callback for pricessing of an image.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
image (Optional[np.ndarray]): Provided image.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
Dict[str, float]: Label-compatible dict.
|
83 |
+
"""
|
84 |
+
|
85 |
if image is None:
|
86 |
return dict()
|
87 |
cls_name, prob, probs = self._infer(image)
|
|
|
92 |
return probs_dict
|
93 |
|
94 |
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
|
95 |
+
""" Low-level method to perform neural network inference.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
image_chw (np.ndarray): Provided image.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Tuple[str, float, np.ndarray]:
|
102 |
+
- Most probable class name
|
103 |
+
- Probability of the most probable class name.
|
104 |
+
- Probablilities of all classes in the order of
|
105 |
+
being listed in the label map.
|
106 |
+
"""
|
107 |
+
|
108 |
assert isinstance(self.model, ResNetForImageClassification)
|
109 |
|
110 |
inputs = self.image_processor(image_chw, return_tensors="pt")
|
|
|
124 |
|
125 |
@staticmethod
|
126 |
def _load_example_lists() -> Dict[int, List[str]]:
|
127 |
+
""" Load example retina images from disk.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Dict[int, List[str]]: Dictionary of cls_id -> list of images paths.
|
131 |
+
"""
|
132 |
|
133 |
example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
|
134 |
|
|
|
146 |
|
147 |
|
148 |
def main():
|
149 |
+
""" App entry point. """
|
150 |
app = App()
|
151 |
app.launch()
|
152 |
|
environment.yml
CHANGED
@@ -69,7 +69,6 @@ dependencies:
|
|
69 |
- parso=0.8.3=pyhd3eb1b0_0
|
70 |
- pexpect=4.8.0=pyhd3eb1b0_3
|
71 |
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
72 |
-
- pip=23.3.1=py310h06a4308_0
|
73 |
- platformdirs=3.10.0=py310h06a4308_0
|
74 |
- prometheus_client=0.14.1=py310h06a4308_0
|
75 |
- prompt-toolkit=3.0.36=py310h06a4308_0
|
@@ -104,7 +103,6 @@ dependencies:
|
|
104 |
- tornado=6.3.3=py310h5eee18b_0
|
105 |
- webencodings=0.5.1=py310h06a4308_1
|
106 |
- wheel=0.41.2=py310h06a4308_0
|
107 |
-
- xz=5.4.5=h5eee18b_0
|
108 |
- y-py=0.5.9=py310h52d8a92_0
|
109 |
- yaml=0.2.5=h7b6447c_0
|
110 |
- ypy-websocket=0.8.2=py310h06a4308_0
|
|
|
69 |
- parso=0.8.3=pyhd3eb1b0_0
|
70 |
- pexpect=4.8.0=pyhd3eb1b0_3
|
71 |
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
|
|
72 |
- platformdirs=3.10.0=py310h06a4308_0
|
73 |
- prometheus_client=0.14.1=py310h06a4308_0
|
74 |
- prompt-toolkit=3.0.36=py310h06a4308_0
|
|
|
103 |
- tornado=6.3.3=py310h5eee18b_0
|
104 |
- webencodings=0.5.1=py310h06a4308_1
|
105 |
- wheel=0.41.2=py310h06a4308_0
|
|
|
106 |
- y-py=0.5.9=py310h52d8a92_0
|
107 |
- yaml=0.2.5=h7b6447c_0
|
108 |
- ypy-websocket=0.8.2=py310h06a4308_0
|
labelmap.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
DR_LABELMAP = {
|
2 |
0: 'No DR',
|
3 |
1: 'Mild',
|
|
|
1 |
+
""" Mapping of class IDs to lables. """
|
2 |
+
|
3 |
DR_LABELMAP = {
|
4 |
0: 'No DR',
|
5 |
1: 'Mild',
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
2 |
+
torch==2.1.2+cpu
|
train.py
CHANGED
@@ -49,7 +49,15 @@ DataRecord = Tuple[Image.Image, int]
|
|
49 |
|
50 |
|
51 |
class RetinopathyDataset(data.Dataset[DataRecord]):
|
|
|
|
|
52 |
def __init__(self, data_path: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
super().__init__()
|
54 |
|
55 |
self.data_path = data_path
|
@@ -88,21 +96,25 @@ class RetinopathyDataset(data.Dataset[DataRecord]):
|
|
88 |
return img_path
|
89 |
|
90 |
|
|
|
91 |
class Purpose(Enum):
|
92 |
Train = 0
|
93 |
Val = 1
|
94 |
|
95 |
-
|
96 |
FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
|
97 |
Callable[..., torch.Tensor]]
|
98 |
|
|
|
99 |
TensorRecord = Tuple[torch.Tensor, torch.Tensor]
|
100 |
|
101 |
-
def normalize(arr: np.ndarray) -> np.ndarray:
|
102 |
-
return arr / np.sum(arr)
|
103 |
-
|
104 |
|
105 |
class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
|
|
|
|
|
|
|
|
|
106 |
def __init__(self, dataset: RetinopathyDataset,
|
107 |
indices: np.ndarray,
|
108 |
purpose: Purpose,
|
@@ -111,7 +123,24 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
111 |
stratify_classes: bool = False,
|
112 |
use_log_frequencies: bool = False,
|
113 |
):
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
self.dataset = dataset
|
116 |
self.indices = indices
|
117 |
self.purpose = purpose
|
@@ -124,22 +153,26 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
124 |
self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
|
125 |
self.frequencies: Optional[Dict[int, float]] = None
|
126 |
if self.stratify_classes:
|
127 |
-
self.
|
128 |
if self.use_log_frequencies:
|
129 |
-
self.
|
130 |
|
131 |
-
def
|
132 |
assert self.per_class_indices is not None
|
133 |
counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
|
134 |
counts = np.array(list(counts_dict.values()))
|
135 |
-
counts_nrm =
|
136 |
temperature = 50.0 # > 1 to even-out frequencies
|
137 |
-
freqs =
|
138 |
self.frequencies = {k: freq.item() for k, freq
|
139 |
in zip(self.per_class_indices.keys(), freqs)}
|
140 |
print(self.frequencies)
|
141 |
|
142 |
-
|
|
|
|
|
|
|
|
|
143 |
buckets = defaultdict(list)
|
144 |
for index in self.indices:
|
145 |
label = self.dataset.get_label_at(index)
|
@@ -191,6 +224,14 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
191 |
seed: int = 54,
|
192 |
) -> Tuple['Split', 'Split']:
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
prng = RandomState(seed)
|
195 |
|
196 |
num_train = int(len(all_data) * train_fraction)
|
@@ -204,7 +245,8 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
204 |
return train_data, val_data
|
205 |
|
206 |
|
207 |
-
def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
|
|
|
208 |
labels = []
|
209 |
for _, label in dataset:
|
210 |
if isinstance(label, torch.Tensor):
|
@@ -261,7 +303,16 @@ class Metrics:
|
|
261 |
return self
|
262 |
|
263 |
|
264 |
-
def worker_init_fn(worker_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
state = np.random.get_state()
|
266 |
assert isinstance(state, tuple)
|
267 |
assert isinstance(state[1], np.ndarray)
|
@@ -274,6 +325,7 @@ def worker_init_fn(worker_id):
|
|
274 |
|
275 |
|
276 |
class ViTLightningModule(L.LightningModule):
|
|
|
277 |
def __init__(self, debug: bool) -> None:
|
278 |
super().__init__()
|
279 |
|
@@ -443,6 +495,7 @@ class ViTLightningModule(L.LightningModule):
|
|
443 |
return loss
|
444 |
|
445 |
def _dump_train_images(self) -> None:
|
|
|
446 |
img_batch, label_batch = next(iter(self._train_dataloader))
|
447 |
for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
|
448 |
img_np = img.cpu().numpy()
|
@@ -494,18 +547,19 @@ class ViTLightningModule(L.LightningModule):
|
|
494 |
|
495 |
|
496 |
def main():
|
|
|
497 |
|
498 |
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
|
499 |
parser.add_argument('--tag', action='store', type=str,
|
500 |
help='Extra suffix to put on the artefact dir name')
|
501 |
-
parser.add_argument('--debug', action='store_true'
|
|
|
502 |
parser.add_argument('--convert-checkpoint', action='store', type=str,
|
503 |
help='Convert a checkpoint from training to pickle-independent '
|
504 |
'predictor-compatible directory')
|
505 |
|
506 |
args = parser.parse_args()
|
507 |
|
508 |
-
|
509 |
torch.set_float32_matmul_precision('high') # for V100/A100
|
510 |
|
511 |
if args.convert_checkpoint is not None:
|
|
|
49 |
|
50 |
|
51 |
class RetinopathyDataset(data.Dataset[DataRecord]):
|
52 |
+
""" A class to access the pre-downloaded Diabetic Retinopathy dataset. """
|
53 |
+
|
54 |
def __init__(self, data_path: str) -> None:
|
55 |
+
""" Constructor.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
data_path (str): path to the dataset, ex: "retinopathy_data"
|
59 |
+
containing "trainLabels.csv" and "train/".
|
60 |
+
"""
|
61 |
super().__init__()
|
62 |
|
63 |
self.data_path = data_path
|
|
|
96 |
return img_path
|
97 |
|
98 |
|
99 |
+
""" Purpose of a split: training or validation. """
|
100 |
class Purpose(Enum):
|
101 |
Train = 0
|
102 |
Val = 1
|
103 |
|
104 |
+
""" Augmentation transformations for an image and a label. """
|
105 |
FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
|
106 |
Callable[..., torch.Tensor]]
|
107 |
|
108 |
+
""" Feature (image) and target (label) tensors. """
|
109 |
TensorRecord = Tuple[torch.Tensor, torch.Tensor]
|
110 |
|
|
|
|
|
|
|
111 |
|
112 |
class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
113 |
+
""" Split is a class that keep a view on a part of a dataset.
|
114 |
+
Split is used to hold the imormation about which samples go to training
|
115 |
+
and which to validation without a need to put these groups of files into
|
116 |
+
separate folders.
|
117 |
+
"""
|
118 |
def __init__(self, dataset: RetinopathyDataset,
|
119 |
indices: np.ndarray,
|
120 |
purpose: Purpose,
|
|
|
123 |
stratify_classes: bool = False,
|
124 |
use_log_frequencies: bool = False,
|
125 |
):
|
126 |
+
""" Constructor.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
dataset (RetinopathyDataset): The dataset on which the Split "views".
|
130 |
+
indices (np.ndarray): Externally provided indices of samples that
|
131 |
+
are "viewed" on.
|
132 |
+
purpose (Purpose): Either train or val, to be able to replicate
|
133 |
+
the data for train split for effecient workers utilization.
|
134 |
+
transforms (FeatureAndTargetTransforms): Functors of feature and
|
135 |
+
target transforms.
|
136 |
+
oversample_factor (int, optional): Expand the training dataset by
|
137 |
+
replication to avoid dataloader stalls on epoch ends. Defaults to 1.
|
138 |
+
stratify_classes (bool, optional): Whether to apply stratified sampling.
|
139 |
+
Defaults to False.
|
140 |
+
use_log_frequencies (bool, optional): If stratify_classes=True,
|
141 |
+
whether to use logarithmic sampling strategy. If False, apply
|
142 |
+
regular even sampling. Defaults to False.
|
143 |
+
"""
|
144 |
self.dataset = dataset
|
145 |
self.indices = indices
|
146 |
self.purpose = purpose
|
|
|
153 |
self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
|
154 |
self.frequencies: Optional[Dict[int, float]] = None
|
155 |
if self.stratify_classes:
|
156 |
+
self._bucketize_indices()
|
157 |
if self.use_log_frequencies:
|
158 |
+
self._calc_frequencies()
|
159 |
|
160 |
+
def _calc_frequencies(self):
|
161 |
assert self.per_class_indices is not None
|
162 |
counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
|
163 |
counts = np.array(list(counts_dict.values()))
|
164 |
+
counts_nrm = self._normalize(counts)
|
165 |
temperature = 50.0 # > 1 to even-out frequencies
|
166 |
+
freqs = self._normalize(np.log1p(counts_nrm * temperature))
|
167 |
self.frequencies = {k: freq.item() for k, freq
|
168 |
in zip(self.per_class_indices.keys(), freqs)}
|
169 |
print(self.frequencies)
|
170 |
|
171 |
+
@staticmethod
|
172 |
+
def _normalize(arr: np.ndarray) -> np.ndarray:
|
173 |
+
return arr / np.sum(arr)
|
174 |
+
|
175 |
+
def _bucketize_indices(self):
|
176 |
buckets = defaultdict(list)
|
177 |
for index in self.indices:
|
178 |
label = self.dataset.get_label_at(index)
|
|
|
224 |
seed: int = 54,
|
225 |
) -> Tuple['Split', 'Split']:
|
226 |
|
227 |
+
""" Prepare train and val splits deterministically.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
Tuple[Split, Split]:
|
231 |
+
- Train split
|
232 |
+
- Val split
|
233 |
+
"""
|
234 |
+
|
235 |
prng = RandomState(seed)
|
236 |
|
237 |
num_train = int(len(all_data) * train_fraction)
|
|
|
245 |
return train_data, val_data
|
246 |
|
247 |
|
248 |
+
def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
|
249 |
+
split_name: str) -> None:
|
250 |
labels = []
|
251 |
for _, label in dataset:
|
252 |
if isinstance(label, torch.Tensor):
|
|
|
303 |
return self
|
304 |
|
305 |
|
306 |
+
def worker_init_fn(worker_id: int) -> None:
|
307 |
+
""" Initialize workers in a way that they draw different
|
308 |
+
random samples and do not repeat identical pseudorandom
|
309 |
+
sequences of each other, which may be the case with Fork
|
310 |
+
multiprocessing.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
worker_id (int): id of a preprocessing worker process launched
|
314 |
+
by one DDP training process.
|
315 |
+
"""
|
316 |
state = np.random.get_state()
|
317 |
assert isinstance(state, tuple)
|
318 |
assert isinstance(state[1], np.ndarray)
|
|
|
325 |
|
326 |
|
327 |
class ViTLightningModule(L.LightningModule):
|
328 |
+
""" Lightning Module that implements neural network training hooks. """
|
329 |
def __init__(self, debug: bool) -> None:
|
330 |
super().__init__()
|
331 |
|
|
|
495 |
return loss
|
496 |
|
497 |
def _dump_train_images(self) -> None:
|
498 |
+
""" Save augmented images to disk for inspection. """
|
499 |
img_batch, label_batch = next(iter(self._train_dataloader))
|
500 |
for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
|
501 |
img_np = img.cpu().numpy()
|
|
|
547 |
|
548 |
|
549 |
def main():
|
550 |
+
""" Neural network trainer entry point. """
|
551 |
|
552 |
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
|
553 |
parser.add_argument('--tag', action='store', type=str,
|
554 |
help='Extra suffix to put on the artefact dir name')
|
555 |
+
parser.add_argument('--debug', action='store_true',
|
556 |
+
help="Dummy training cycle for testing purposes")
|
557 |
parser.add_argument('--convert-checkpoint', action='store', type=str,
|
558 |
help='Convert a checkpoint from training to pickle-independent '
|
559 |
'predictor-compatible directory')
|
560 |
|
561 |
args = parser.parse_args()
|
562 |
|
|
|
563 |
torch.set_float32_matmul_precision('high') # for V100/A100
|
564 |
|
565 |
if args.convert_checkpoint is not None:
|