darshanjani commited on
Commit
3a0062c
1 Parent(s): 2beb6cf

utils function for inference

Browse files
Utilities/callbacks.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+
3
+ from . import config
4
+ from .utils import (
5
+ check_class_accuracy,
6
+ get_evaluation_bboxes,
7
+ mean_average_precision,
8
+ plot_couple_examples,
9
+ )
10
+
11
+ class PlotTestExamplesCallback(pl.Callback):
12
+ def __init__(self, every_n_epochs: int = 1) -> None:
13
+ super().__init__()
14
+ self.every_n_epochs = every_n_epochs
15
+
16
+ def on_train_epoch_end(self, trainer:pl.Trainer, pl_module:pl.LightningModule) -> None:
17
+ if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
18
+ plot_couple_examples(
19
+ model=pl_module,
20
+ loader=trainer.datamodule.test_dataloader(),
21
+ thresh=0.6,
22
+ iou_thresh=0.5,
23
+ anchors=pl_module.scaled_anchors
24
+ )
25
+
26
+ class CheckClassAccuracyCallback(pl.Callback):
27
+ def __init__(
28
+ self, train_every_n_epochs: int = 1, test_every_n_epochs: int = 3
29
+ ) -> None:
30
+ super().__init__()
31
+ self.train_every_n_epochs = train_every_n_epochs
32
+ self.test_every_n_epochs = test_every_n_epochs
33
+
34
+ def on_train_epoch_end(
35
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
36
+ ) -> None:
37
+ if (trainer.current_epoch + 1) % self.train_every_n_epochs == 0:
38
+ print("+++ TRAIN ACCURACIES")
39
+ class_acc, no_obj_acc, obj_acc = check_class_accuracy(
40
+ model=pl_module,
41
+ loader=trainer.datamodule.train_dataloader(),
42
+ threshold=config.CONF_THRESHOLD,
43
+ )
44
+ pl_module.log_dict(
45
+ {
46
+ "train_class_acc": class_acc,
47
+ "train_no_obj_acc": no_obj_acc,
48
+ "train_obj_acc": obj_acc,
49
+ },
50
+ logger=True,
51
+ )
52
+
53
+ if (trainer.current_epoch + 1) % self.test_every_n_epochs == 0:
54
+ print("+++ TEST ACCURACIES")
55
+ class_acc, no_obj_acc, obj_acc = check_class_accuracy(
56
+ model=pl_module,
57
+ loader=trainer.datamodule.test_dataloader(),
58
+ threshold=config.CONF_THRESHOLD,
59
+ )
60
+ pl_module.log_dict(
61
+ {
62
+ "test_class_acc": class_acc,
63
+ "test_no_obj_acc": no_obj_acc,
64
+ "test_obj_acc": obj_acc,
65
+ },
66
+ logger=True,
67
+ )
68
+ class MAPCallback(pl.Callback):
69
+ def __init__(self, every_n_epochs: int = 3) -> None:
70
+ super().__init__()
71
+ self.every_n_epochs = every_n_epochs
72
+
73
+ def on_train_epoch_end(
74
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
75
+ ) -> None:
76
+ if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
77
+ pred_boxes, true_boxes = get_evaluation_bboxes(
78
+ loader=trainer.datamodule.test_dataloader(),
79
+ model=pl_module,
80
+ iou_threshold=config.NMS_IOU_THRESH,
81
+ anchors=config.ANCHORS,
82
+ threshold=config.CONF_THRESHOLD,
83
+ device=config.DEVICE,
84
+ )
85
+
86
+ map_val = mean_average_precision(
87
+ pred_boxes=pred_boxes,
88
+ true_boxes=true_boxes,
89
+ iou_threshold=config.MAP_IOU_THRESH,
90
+ box_format="midpoint",
91
+ num_classes=config.NUM_CLASSES,
92
+ )
93
+ print("+++ MAP: ", map_val.item())
94
+ pl_module.log(
95
+ "MAP",
96
+ map_val.item(),
97
+ logger=True,
98
+ )
99
+ pl_module.train()
Utilities/config.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ MAIN_DIR = "/kaggle/working/"
6
+ # DATASET = os.path.join(MAIN_DIR, "../data/PASCAL_VOC")
7
+ DATASET = "/kaggle/input/pascal-voc-dataset-used-in-yolov3-video/PASCAL_VOC"
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ # DEVICE = "mps"
10
+ # seed_everything() # If you want deterministic behavior
11
+ NUM_WORKERS = 2
12
+ BATCH_SIZE = 40
13
+ IMAGE_SIZE = 416
14
+ INPUT_RESOLUTIONS = [416, 544]
15
+ INPUT_RESOLUTIONS_CUM_PROBS = [50, 100]
16
+ NUM_CLASSES = 20
17
+ LEARNING_RATE = 1e-5
18
+ WEIGHT_DECAY = 1e-4
19
+ NUM_EPOCHS = 40
20
+ CONF_THRESHOLD = 0.05
21
+ MAP_IOU_THRESH = 0.5
22
+ NMS_IOU_THRESH = 0.45
23
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
24
+ PIN_MEMORY = True
25
+ LOAD_MODEL = False
26
+ SAVE_MODEL = True
27
+ CHECKPOINT_PATH = os.path.join(MAIN_DIR, "Store/checkpoints/")
28
+ IMG_DIR = DATASET + "/images/"
29
+ LABEL_DIR = DATASET + "/labels/"
30
+ TRAIN_MOSAIC_PERCENTAGE = 0.5
31
+ TEST_MOSAIC_PERCENTAGE = 0.00
32
+ MODEL_STATE_DICT_PATH = os.path.join(MAIN_DIR, "Store/checkpoints/yolov3.pth")
33
+
34
+ ANCHORS = [
35
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
36
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
37
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
38
+ ] # Note these have been rescaled to be between [0, 1]
39
+
40
+ means = [0.485, 0.456, 0.406]
41
+
42
+ scale = 1.1
43
+
44
+ PASCAL_CLASSES = [
45
+ "aeroplane",
46
+ "bicycle",
47
+ "bird",
48
+ "boat",
49
+ "bottle",
50
+ "bus",
51
+ "car",
52
+ "cat",
53
+ "chair",
54
+ "cow",
55
+ "diningtable",
56
+ "dog",
57
+ "horse",
58
+ "motorbike",
59
+ "person",
60
+ "pottedplant",
61
+ "sheep",
62
+ "sofa",
63
+ "train",
64
+ "tvmonitor",
65
+ ]
66
+
67
+ COCO_LABELS = [
68
+ "person",
69
+ "bicycle",
70
+ "car",
71
+ "motorcycle",
72
+ "airplane",
73
+ "bus",
74
+ "train",
75
+ "truck",
76
+ "boat",
77
+ "traffic light",
78
+ "fire hydrant",
79
+ "stop sign",
80
+ "parking meter",
81
+ "bench",
82
+ "bird",
83
+ "cat",
84
+ "dog",
85
+ "horse",
86
+ "sheep",
87
+ "cow",
88
+ "elephant",
89
+ "bear",
90
+ "zebra",
91
+ "giraffe",
92
+ "backpack",
93
+ "umbrella",
94
+ "handbag",
95
+ "tie",
96
+ "suitcase",
97
+ "frisbee",
98
+ "skis",
99
+ "snowboard",
100
+ "sports ball",
101
+ "kite",
102
+ "baseball bat",
103
+ "baseball glove",
104
+ "skateboard",
105
+ "surfboard",
106
+ "tennis racket",
107
+ "bottle",
108
+ "wine glass",
109
+ "cup",
110
+ "fork",
111
+ "knife",
112
+ "spoon",
113
+ "bowl",
114
+ "banana",
115
+ "apple",
116
+ "sandwich",
117
+ "orange",
118
+ "broccoli",
119
+ "carrot",
120
+ "hot dog",
121
+ "pizza",
122
+ "donut",
123
+ "cake",
124
+ "chair",
125
+ "couch",
126
+ "potted plant",
127
+ "bed",
128
+ "dining table",
129
+ "toilet",
130
+ "tv",
131
+ "laptop",
132
+ "mouse",
133
+ "remote",
134
+ "keyboard",
135
+ "cell phone",
136
+ "microwave",
137
+ "oven",
138
+ "toaster",
139
+ "sink",
140
+ "refrigerator",
141
+ "book",
142
+ "clock",
143
+ "vase",
144
+ "scissors",
145
+ "teddy bear",
146
+ "hair drier",
147
+ "toothbrush",
148
+ ]
Utilities/dataset.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from PIL import Image, ImageFile
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from torchvision.transforms import Resize
11
+
12
+ from . import config, transforms
13
+ from .utils import cells_to_bboxes
14
+ from .utils import iou_width_height as iou
15
+ from .utils import non_max_suppression as nms
16
+ from .utils import plot_image, xyxy2xywhn, xywhn2xyxy
17
+
18
+
19
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
20
+
21
+ class YOLODataset(Dataset):
22
+ def __init__(
23
+ self,
24
+ csv_file,
25
+ img_dir,
26
+ label_dir,
27
+ anchors,
28
+ image_size=416,
29
+ S=[13, 26, 52],
30
+ C=20,
31
+ transform=None,
32
+ mosaic_percentage=0.67,
33
+ ):
34
+ self.annotations = pd.read_csv(csv_file)
35
+ self.img_dir = img_dir
36
+ self.label_dir = label_dir
37
+ self.image_size = image_size
38
+ self.mosaic_border = [image_size // 2, image_size // 2]
39
+ self.transform = transform
40
+ self.S = S
41
+ self.anchors = torch.tensor(
42
+ anchors[0] + anchors[1] + anchors[2]
43
+ ) # for all 3 scales
44
+ self.num_anchors = self.anchors.shape[0]
45
+ self.num_anchors_per_scale = self.num_anchors // 3
46
+ self.C = C
47
+ self.ignore_iou_thresh = 0.5
48
+ self.mosaic_percentage = mosaic_percentage
49
+
50
+ def __len__(self):
51
+ return len(self.annotations)
52
+
53
+ def load_mosaic(self, index):
54
+ # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
55
+ labels4 = []
56
+ s = self.image_size
57
+ yc, xc = (
58
+ int(random.uniform(x, 2 * s - x)) for x in self.mosaic_border
59
+ ) # mosaic center x, y
60
+ indices = [index] + random.choices(
61
+ range(len(self)), k=3
62
+ ) # 3 additional image indices
63
+ random.shuffle(indices)
64
+ for i, index in enumerate(indices):
65
+ # Load image
66
+ label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
67
+ bboxes = np.roll(
68
+ np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1
69
+ ).tolist()
70
+ img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
71
+ img = np.array(Image.open(img_path).convert("RGB"))
72
+
73
+ h, w = img.shape[0], img.shape[1]
74
+ labels = np.array(bboxes)
75
+
76
+ # place img in img4
77
+ if i == 0: # top left
78
+ img4 = np.full(
79
+ (s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8
80
+ ) # base image with 4 tiles
81
+ x1a, y1a, x2a, y2a = (
82
+ max(xc - w, 0),
83
+ max(yc - h, 0),
84
+ xc,
85
+ yc,
86
+ ) # xmin, ymin, xmax, ymax (large image)
87
+ x1b, y1b, x2b, y2b = (
88
+ w - (x2a - x1a),
89
+ h - (y2a - y1a),
90
+ w,
91
+ h,
92
+ ) # xmin, ymin, xmax, ymax (small image)
93
+ elif i == 1: # top right
94
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
95
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
96
+ elif i == 2: # bottom left
97
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
98
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
99
+ elif i == 3: # bottom right
100
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
101
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
102
+
103
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
104
+ padw = x1a - x1b
105
+ padh = y1a - y1b
106
+
107
+ # Labels
108
+ if labels.size:
109
+ labels[:, :-1] = xywhn2xyxy(
110
+ labels[:, :-1], w, h, padw, padh
111
+ ) # normalized xywh to pixel xyxy format
112
+ labels4.append(labels)
113
+
114
+ # Concat/clip labels
115
+ labels4 = np.concatenate(labels4, 0)
116
+ for x in (labels4[:, :-1],):
117
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
118
+ # img4, labels4 = replicate(img4, labels4) # replicate
119
+ labels4[:, :-1] = xyxy2xywhn(labels4[:, :-1], 2 * s, 2 * s)
120
+ labels4[:, :-1] = np.clip(labels4[:, :-1], 0, 1)
121
+ labels4 = labels4[labels4[:, 2] > 0]
122
+ labels4 = labels4[labels4[:, 3] > 0]
123
+ return img4, labels4
124
+
125
+ def load_single_img(self, index):
126
+ label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
127
+ bboxes = np.roll(
128
+ np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1
129
+ ).tolist()
130
+ img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
131
+ image = np.array(Image.open(img_path).convert("RGB"))
132
+ return image, bboxes
133
+
134
+ def __getitem__(self, index):
135
+ if random.random() < self.mosaic_percentage:
136
+ image, bboxes = self.load_mosaic(index)
137
+ else:
138
+ image, bboxes = self.load_single_img(index)
139
+
140
+ if self.transform:
141
+ augmentations = self.transform(image=image, bboxes=bboxes)
142
+ image = augmentations["image"]
143
+ bboxes = augmentations["bboxes"]
144
+
145
+ # e.g. = (3, 13, 13, 6), (3, 26, 26, 6), (3, 52, 52, 6) || 6 = [x, y, w, h, obj, class] for each anchor box
146
+ targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]
147
+ for box in bboxes:
148
+ iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)
149
+ anchor_indices = iou_anchors.argsort(descending=True, dim=0)
150
+ x, y, width, height, class_label = box
151
+ has_anchor = [False] * 3 # each scale should have one anchor
152
+ for anchor_idx in anchor_indices:
153
+ scale_idx = anchor_idx // self.num_anchors_per_scale
154
+ anchor_on_scale = anchor_idx % self.num_anchors_per_scale
155
+ S = self.S[scale_idx]
156
+ i, j = int(S * y), int(S * x) # which cell
157
+ anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
158
+ if not anchor_taken and not has_anchor[scale_idx]:
159
+ targets[scale_idx][anchor_on_scale, i, j, 0] = 1
160
+ x_cell, y_cell = S * x - j, S * y - i # both between [0,1]
161
+ width_cell, height_cell = (
162
+ width * S,
163
+ height * S,
164
+ ) # can be greater than 1 since it's relative to cell
165
+ box_coordinates = torch.tensor(
166
+ [x_cell, y_cell, width_cell, height_cell]
167
+ )
168
+ targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
169
+ targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
170
+ has_anchor[scale_idx] = True
171
+
172
+ elif (
173
+ not anchor_taken
174
+ and iou_anchors[anchor_idx] > self.ignore_iou_thresh
175
+ ):
176
+ targets[scale_idx][
177
+ anchor_on_scale, i, j, 0
178
+ ] = -1 # ignore prediction
179
+
180
+ return image, tuple(targets)
181
+
182
+ class YOLODataModule(pl.LightningDataModule):
183
+ def __init__(self, train_csv_path, test_csv_path):
184
+ super().__init__()
185
+ self.train_csv_path = train_csv_path
186
+ self.test_csv_path = test_csv_path
187
+ self.train_dataset = None
188
+ self.eval_dataset = None
189
+ self.test_dataset = None
190
+
191
+ def setup(self, stage=None):
192
+ self.train_dataset = YOLODataset(
193
+ self.train_csv_path,
194
+ transform=transforms.train_transforms,
195
+ S=[
196
+ config.IMAGE_SIZE // 32,
197
+ config.IMAGE_SIZE // 16,
198
+ config.IMAGE_SIZE // 8
199
+ ],
200
+ img_dir=config.IMG_DIR,
201
+ label_dir=config.LABEL_DIR,
202
+ anchors=config.ANCHORS,
203
+ mosaic_percentage=config.TRAIN_MOSAIC_PERCENTAGE
204
+ )
205
+
206
+ self.eval_dataset = YOLODataset(
207
+ self.train_csv_path,
208
+ transform=transforms.test_transforms,
209
+ S=[
210
+ config.IMAGE_SIZE // 32,
211
+ config.IMAGE_SIZE // 16,
212
+ config.IMAGE_SIZE // 8
213
+ ],
214
+ img_dir=config.IMG_DIR,
215
+ label_dir=config.LABEL_DIR,
216
+ anchors=config.ANCHORS,
217
+ mosaic_percentage=config.TRAIN_MOSAIC_PERCENTAGE # should be 0?
218
+ )
219
+
220
+ self.test_dataset = YOLODataset(
221
+ self.test_csv_path,
222
+ transform=transforms.test_transforms,
223
+ S=[
224
+ config.IMAGE_SIZE // 32,
225
+ config.IMAGE_SIZE // 16,
226
+ config.IMAGE_SIZE // 8
227
+ ],
228
+ img_dir=config.IMG_DIR,
229
+ label_dir=config.LABEL_DIR,
230
+ anchors=config.ANCHORS,
231
+ mosaic_percentage=config.TEST_MOSAIC_PERCENTAGE
232
+ )
233
+
234
+ def train_dataloader(self):
235
+ return DataLoader(
236
+ dataset=self.train_dataset,
237
+ batch_size=config.BATCH_SIZE,
238
+ shuffle=True,
239
+ num_workers=config.NUM_WORKERS,
240
+ pin_memory=config.PIN_MEMORY,
241
+ drop_last=False
242
+ )
243
+ def val_dataloader(self):
244
+ return DataLoader(
245
+ dataset=self.eval_dataset,
246
+ batch_size=config.BATCH_SIZE,
247
+ shuffle=False,
248
+ num_workers=config.NUM_WORKERS,
249
+ pin_memory=config.PIN_MEMORY,
250
+ drop_last=False
251
+ )
252
+ def test_dataloader(self):
253
+ return DataLoader(
254
+ dataset=self.test_dataset,
255
+ batch_size=config.BATCH_SIZE,
256
+ shuffle=False,
257
+ num_workers=config.NUM_WORKERS,
258
+ pin_memory=config.PIN_MEMORY,
259
+ drop_last=False
260
+ )
261
+
262
+ def test():
263
+
264
+ anchors = config.ANCHORS
265
+
266
+ transform = config.test_transforms
267
+
268
+ dataset = YOLODataset(
269
+ "../data/PASCAL_VOC/2examples.csv",
270
+ "../data/PASCAL_VOC/images",
271
+ "../data/PASCAL_VOC/labels",
272
+ S=[13, 26, 52],
273
+ anchors=anchors,
274
+ transform=transform
275
+ )
276
+
277
+ S = [13, 26, 52]
278
+ scaled_anchors = torch.tensor(anchors) / (
279
+ 1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
280
+ )
281
+
282
+ loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
283
+ for x, y in loader:
284
+ boxes = []
285
+
286
+ for i in range(y[0].shape[1]):
287
+ anchor = scaled_anchors[i]
288
+ print(anchor.shape)
289
+ print(y[i].shape)
290
+ boxes += cells_to_bboxes(
291
+ y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
292
+ )[0]
293
+ boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint")
294
+ print(boxes)
295
+ plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes)
296
+
297
+ if __name__ == "__main__":
298
+ test()
Utilities/loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
3
+ the difference from what I can tell is I use CrossEntropy for the classes
4
+ instead of BinaryCrossEntropy.
5
+ """
6
+
7
+ import random
8
+
9
+ import pytorch_lightning as pl
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from .utils import intersection_over_union
14
+
15
+ class YoloLoss(pl.LightningModule):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.mse = nn.MSELoss()
19
+ self.bce = nn.BCEWithLogitsLoss()
20
+ self.entropy = nn.CrossEntropyLoss()
21
+ self.sigmoid = nn.Sigmoid()
22
+
23
+ # constants for the loss function
24
+ self.lambda_class = 1
25
+ self.lambda_noobj = 5
26
+ self.lambda_obj = 1
27
+ self.lambda_box = 1
28
+
29
+ def forward(self, predictions, target, anchors):
30
+ # Check where obj and noobj (we ignore if target == -1)
31
+ obj = target[..., 0] == 1
32
+ noobj = target[..., 0] == 0
33
+
34
+ # ======================= #
35
+ # FOR NO OBJECT LOSS #
36
+ # ======================= #
37
+
38
+ no_object_loss = self.bce(
39
+ (predictions[..., 0:1][noobj]),
40
+ (target[..., 0:1][noobj])
41
+ )
42
+
43
+ # ==================== #
44
+ # FOR OBJECT LOSS #
45
+ # ==================== #
46
+
47
+ anchors = anchors.reshape(1, 3, 1, 1, 2)
48
+
49
+ box_preds = torch.cat(
50
+ [
51
+ self.sigmoid(predictions[..., 1:3]),
52
+ torch.exp(predictions[..., 3:5]) * anchors,
53
+ ],
54
+ dim=-1,
55
+ )
56
+
57
+ ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
58
+
59
+ object_loss = self.mse(
60
+ self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj]
61
+ )
62
+
63
+ # ======================== #
64
+ # FOR BOX COORDINATES #
65
+ # ======================== #
66
+
67
+ predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates
68
+ target[..., 3:5] = torch.log(
69
+ (1e-16 + target[..., 3:5] / anchors)
70
+ ) # width, height coordinates
71
+ box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])
72
+
73
+ # ================== #
74
+ # FOR CLASS LOSS #
75
+ # ================== #
76
+
77
+ class_loss = self.entropy(
78
+ (predictions[..., 5:][obj]),
79
+ (target[..., 5][obj].long()),
80
+ )
81
+
82
+ # print("__________________________________")
83
+ # print(self.lambda_box * box_loss)
84
+ # print(self.lambda_obj * object_loss)
85
+ # print(self.lambda_noobj * no_object_loss)
86
+ # print(self.lambda_class * class_loss)
87
+ # print("\n")
88
+
89
+ return (
90
+ self.lambda_box * box_loss
91
+ + self.lambda_obj * object_loss
92
+ + self.lambda_noobj * no_object_loss
93
+ + self.lambda_class * class_loss
94
+ )
95
+
96
+
Utilities/model.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of YOLOv3 architecture
3
+ """
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from torch.optim.lr_scheduler import OneCycleLR
10
+
11
+
12
+ from . import config
13
+ from .loss import YoloLoss
14
+
15
+ model_config = [
16
+ (32, 3, 1),
17
+ (64, 3, 2),
18
+ ["B", 1],
19
+ (128, 3, 2),
20
+ ["B", 2],
21
+ (256, 3, 2),
22
+ ["B", 8],
23
+ (512, 3, 2),
24
+ ["B", 8],
25
+ (1024, 3, 2),
26
+ ["B", 4], # darknet 53 ends here
27
+
28
+ (512, 1, 1),
29
+ (1024, 3, 1),
30
+ "S",
31
+
32
+ (256, 1, 1),
33
+ "U",
34
+ (256, 1, 1),
35
+ (512, 3, 1),
36
+ "S",
37
+
38
+ (128, 1, 1),
39
+ "U",
40
+ (128, 1, 1),
41
+ (256, 3, 1),
42
+ "S"
43
+ ]
44
+
45
+ class CNNBlock(pl.LightningModule):
46
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
47
+ super().__init__()
48
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
49
+ self.bn = nn.BatchNorm2d(out_channels)
50
+ self.leaky = nn.LeakyReLU(0.1)
51
+ self.use_bn_act = bn_act
52
+
53
+ def forward(self, x):
54
+ if self.use_bn_act:
55
+ return self.leaky(self.bn((self.conv(x))))
56
+ else:
57
+ return self.conv(x)
58
+
59
+ class ResidualBlock(pl.LightningModule):
60
+ def __init__(self, channels, use_residual=True, num_repeats=1):
61
+ super().__init__()
62
+ self.layers = nn.ModuleList()
63
+ for repeat in range(num_repeats):
64
+ self.layers += [
65
+ nn.Sequential(
66
+ CNNBlock(channels, channels//2, kernel_size=1),
67
+ CNNBlock(channels//2, channels, kernel_size=3, padding=1)
68
+ )
69
+ ]
70
+ self.use_residual = use_residual
71
+ self.num_repeats = num_repeats
72
+
73
+ def forward(self, x):
74
+ for layer in self.layers:
75
+ if self.use_residual:
76
+ x = x + layer(x)
77
+ else:
78
+ x = layer(x)
79
+
80
+ return x
81
+
82
+ class ScalePrediction(pl.LightningModule):
83
+ def __init__(self, in_channels, num_classes):
84
+ super().__init__()
85
+ self.pred = nn.Sequential(
86
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
87
+ CNNBlock(2 * in_channels, (num_classes + 5) * 3, kernel_size=1, bn_act=False)
88
+ )
89
+ self.num_classes = num_classes
90
+
91
+ def forward(self, x):
92
+ return (
93
+ self.pred(x).
94
+ reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]).
95
+ permute(0, 1, 3, 4, 2)
96
+ )
97
+
98
+ class YOLOv3(pl.LightningModule):
99
+ def __init__(self, in_channels=3, num_classes=20):
100
+ super().__init__()
101
+ self.num_classes = num_classes
102
+ self.in_channels = in_channels
103
+ self.layers = self._create_conv_layers()
104
+
105
+ self.scaled_anchors = (
106
+ torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) # ?
107
+ ).to(config.DEVICE)
108
+
109
+ self.learning_rate = config.LEARNING_RATE
110
+ self.weight_decay = config.WEIGHT_DECAY
111
+ self.best_lr = 1e-3 ## ?
112
+
113
+ def forward(self, x): # ?
114
+ outputs = [] # for each scale
115
+ route_connections = []
116
+ for layer in self.layers:
117
+ if isinstance(layer, ScalePrediction):
118
+ outputs.append(layer(x))
119
+ continue
120
+
121
+ x = layer(x)
122
+
123
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
124
+ route_connections.append(x)
125
+ elif isinstance(layer, nn.Upsample):
126
+ x = torch.cat([x, route_connections[-1]], dim=1)
127
+ route_connections.pop()
128
+
129
+ return outputs
130
+
131
+ def _create_conv_layers(self):
132
+ layers = nn.ModuleList()
133
+ in_channels = self.in_channels
134
+
135
+ for module in model_config:
136
+ if isinstance(module, tuple):
137
+ out_channels, kernel_size, stride = module
138
+ layers.append(
139
+ CNNBlock(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1 if kernel_size==3 else 0)
140
+ )
141
+ in_channels = out_channels
142
+
143
+ elif isinstance(module, list):
144
+ num_repeats = module[1]
145
+ layers.append(
146
+ ResidualBlock(in_channels, num_repeats=num_repeats)
147
+ )
148
+ elif isinstance(module, str):
149
+ if module == "S":
150
+ layers += [
151
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
152
+ CNNBlock(in_channels, in_channels//2, kernel_size=1),
153
+ ScalePrediction(in_channels//2, num_classes=self.num_classes)
154
+ ]
155
+ in_channels = in_channels // 2
156
+
157
+ elif module == "U":
158
+ layers.append(nn.Upsample(scale_factor=2))
159
+ in_channels = in_channels * 3
160
+
161
+ return layers
162
+
163
+
164
+ def yololoss(self):
165
+ return YoloLoss()
166
+
167
+ def training_step(self, batch, batch_idx):
168
+ x, y = batch
169
+ y0, y1, y2 = y[0], y[1], y[2]
170
+ out = self.forward(x)
171
+ # print(out[0].shape, y0.shape)
172
+
173
+ loss = ( # ?
174
+ self.yololoss()(out[0], y0, self.scaled_anchors[0])
175
+ + self.yololoss()(out[1], y1, self.scaled_anchors[1])
176
+ + self.yololoss()(out[2], y2, self.scaled_anchors[2])
177
+ )
178
+
179
+ self.log(
180
+ "train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
181
+ )
182
+ return loss
183
+
184
+ def test_step(self, batch, batch_idx):
185
+ x, y = batch
186
+ y0, y1, y2 = y[0], y[1], y[2]
187
+ out = self.forward(x)
188
+
189
+ loss = (
190
+ self.yololoss()(out[0], y0, self.scaled_anchors[0])
191
+ + self.yololoss()(out[1], y1, self.scaled_anchors[1])
192
+ + self.yololoss()(out[2], y2, self.scaled_anchors[2])
193
+ )
194
+
195
+ self.log(
196
+ "test_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
197
+ )
198
+
199
+ return loss
200
+
201
+ def on_train_epoch_end(self) -> None:
202
+ print(
203
+ f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['train_loss_epoch']}"
204
+ )
205
+
206
+ def on_test_epoch_end(self) -> None:
207
+ print(
208
+ f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['test_loss_epoch']}"
209
+ )
210
+
211
+ def configure_optimizers(self):
212
+ optimizer = optim.Adam(
213
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
214
+ )
215
+
216
+ scheduler = OneCycleLR(
217
+ optimizer,
218
+ max_lr=self.best_lr,
219
+ steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
220
+ epochs=config.NUM_EPOCHS,
221
+ pct_start=8 / config.NUM_EPOCHS,
222
+ div_factor=100,
223
+ three_phase=False,
224
+ final_div_factor=100,
225
+ anneal_strategy="linear"
226
+ )
227
+
228
+ return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
229
+
230
+ def on_train_end(self) -> None:
231
+ torch.save(self.state_dict(), config.MODEL_STATE_DICT_PATH)
232
+
233
+ if __name__ == "main":
234
+ num_classes = 20
235
+ IMAGE_SIZE = 416
236
+ model = YOLOv3(num_classes=num_classes)
237
+ x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
238
+ out = model(x)
239
+ assert model(x)[0].shape == (
240
+ 2,
241
+ 3,
242
+ IMAGE_SIZE // 32,
243
+ IMAGE_SIZE // 32,
244
+ num_classes + 5
245
+ )
246
+ assert model(x)[1].shape == (
247
+ 2,
248
+ 3,
249
+ IMAGE_SIZE // 16,
250
+ IMAGE_SIZE // 16,
251
+ num_classes + 5
252
+ )
253
+ assert model(x)[2].shape == (
254
+ 2,
255
+ 3,
256
+ IMAGE_SIZE // 8,
257
+ IMAGE_SIZE // 8,
258
+ num_classes + 5
259
+ )
260
+ print("Image size compatibility check passed!")
261
+
262
+
263
+
264
+
265
+
266
+
267
+
268
+
269
+
270
+
Utilities/runtime_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytorch_lightning as pl
3
+ import torch
4
+ from pytorch_grad_cam import GradCAM
5
+ from pytorch_grad_cam.utils.image import show_cam_on_image
6
+
7
+ from Utilities.transforms import test_transforms
8
+ # from Utilities.config import S
9
+ from Utilities.utils import cells_to_bboxes, non_max_suppression, plot_image
10
+
11
+
12
+ def plot_bboxes(
13
+ input_img,
14
+ model,
15
+ thresh=0.6,
16
+ iou_thresh=0.5,
17
+ anchors=None,
18
+ ):
19
+ input_img = test_transforms(image=input_img)["image"]
20
+ input_img = input_img.unsqueeze(0)
21
+ model.eval()
22
+ with torch.no_grad():
23
+ out = model(input_img)
24
+
25
+ for i in range(3):
26
+ batch_size, A, S, _, _ = out[i].shape
27
+ anchor = anchors[i]
28
+ boxes_scale_i = cells_to_bboxes(out[i], anchor, S=S, is_preds=True)
29
+ bboxes = boxes_scale_i[0]
30
+
31
+ nms_boxes = non_max_suppression(
32
+ bboxes,
33
+ iou_threshold=iou_thresh,
34
+ threshold=thresh,
35
+ box_formet="midpoint",
36
+ )
37
+ fig = plot_image(input_img[0].permute(1, 2, 0).detach().cpu(), nms_boxes)
38
+ return fig, input_img
39
+
40
+ def return_top_objectness_class_preds(model, input_img, gradcam_output_stream):
41
+ out = model(input_img)[gradcam_output_stream]
42
+
43
+ # 1. get objectness score
44
+ objectness_scores = out[..., 0]
45
+
46
+ # 2. get index of highest objectness score
47
+ max_obj_arg = torch.argmax(objectness_scores)
48
+
49
+ max_obj_arg_onehot = torch.zeros(objectness_scores.flatten().shape[0])
50
+ max_obj_arg_onehot[max_obj_arg] = 1
51
+
52
+ max_obj_arg_onehot = max_obj_arg_onehot.reshape_as(objectness_scores).int()
53
+
54
+ selected_elements = out[max_obj_arg_onehot == 1]
55
+ selected_elements = selected_elements[:, 5:]
56
+
57
+ return selected_elements
58
+
59
+ class TopObjectnessClassPreds(pl.LightningModule):
60
+ def __init__(self, model, gradcam_output_stream):
61
+ super().__init__()
62
+ self.model = model
63
+ self.gradcam_output_stream = gradcam_output_stream
64
+
65
+ def forward(self, x):
66
+ return return_top_objectness_class_preds(self.model, x, self.gradcam_output_stream)
67
+
68
+ def generate_gradcam_output(org_img, model, input_img, gradcam_output_stream: int = 0):
69
+ TopObjectnessClassPredsObj = TopObjectnessClassPreds(model, gradcam_output_stream)
70
+ gradcam_model_layer = [15, 22, 29]
71
+ cam = GradCAM(
72
+ model=TopObjectnessClassPredsObj,
73
+ target_layers=[
74
+ TopObjectnessClassPredsObj.model.layers[
75
+ gradcam_model_layer[gradcam_output_stream]
76
+ ]
77
+ ],
78
+ )
79
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
80
+ grayscale_cam = np.sum(grayscale_cam, axis=-1)
81
+ grayscale_cam = grayscale_cam[0, :]
82
+
83
+ visualization = show_cam_on_image(
84
+ org_img / 255,
85
+ grayscale_cam,
86
+ use_rgb=True,
87
+ image_weight=0.5,
88
+ )
89
+ return visualization
Utilities/transforms.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+ from albumentations.pytorch import ToTensorV2
4
+
5
+ from .config import IMAGE_SIZE, scale
6
+
7
+ # train_transforms = A.Compose(
8
+ # [
9
+ # A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
10
+ # A.PadIfNeeded(
11
+ # min_height=int(IMAGE_SIZE * scale),
12
+ # min_width=int(IMAGE_SIZE * scale),
13
+ # border_mode=cv2.BORDER_CONSTANT,
14
+ # ),
15
+ # A.Rotate(limit=10, interpolation=1, border_mode=4),
16
+ # A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
17
+ # A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
18
+ # A.OneOf(
19
+ # [
20
+ # A.ShiftScaleRotate(
21
+ # rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
22
+ # ),
23
+ # # A.Affine(shear=15, p=0.5, mode="constant"),
24
+ # ],
25
+ # p=1.0,
26
+ # ),
27
+ # A.HorizontalFlip(p=0.5),
28
+ # A.Blur(p=0.1),
29
+ # A.CLAHE(p=0.1),
30
+ # A.Posterize(p=0.1),
31
+ # A.ToGray(p=0.1),
32
+ # A.ChannelShuffle(p=0.05),
33
+ # A.Normalize(
34
+ # mean=[0, 0, 0],
35
+ # std=[1, 1, 1],
36
+ # max_pixel_value=255,
37
+ # ),
38
+ # ToTensorV2(),
39
+ # ],
40
+ # bbox_params=A.BboxParams(
41
+ # format="yolo",
42
+ # min_visibility=0.4,
43
+ # label_fields=[],
44
+ # ),
45
+ # )
46
+
47
+ test_transforms = A.Compose(
48
+ [
49
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
50
+ A.PadIfNeeded(
51
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
52
+ ),
53
+ A.Normalize(
54
+ mean=[0, 0, 0],
55
+ std=[1, 1, 1],
56
+ max_pixel_value=255,
57
+ ),
58
+ ToTensorV2(),
59
+ ],
60
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
61
+ )
62
+
63
+ resize_transforms = A.Compose(
64
+ [
65
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
66
+ A.PadIfNeeded(
67
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
68
+ ),
69
+ ]
70
+ )
Utilities/utils.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import Counter
4
+
5
+ import matplotlib.patches as patches
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from . import config
12
+
13
+ def iou_width_height(boxes1, boxes2):
14
+ """
15
+ Parameters:
16
+ boxes1 (tensor): width and height of the first bounding boxes
17
+ boxes2 (tensor): width and height of the second bounding boxes
18
+ Returns:
19
+ tensor: Intersection over union of the corresponding boxes
20
+ """
21
+ intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
22
+ boxes1[..., 1], boxes2[..., 1]
23
+ )
24
+ union = (
25
+ boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
26
+ )
27
+ return intersection / union
28
+
29
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
30
+ """
31
+ Video explanation of this function:
32
+ https://youtu.be/XXYG5ZWtjj0
33
+
34
+ This function calculates intersection over union (iou) given pred boxes
35
+ and target boxes.
36
+
37
+ Parameters:
38
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
39
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
40
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
41
+
42
+ Returns:
43
+ tensor: Intersection over union for all examples
44
+ """
45
+
46
+ if box_format == "midpoint":
47
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
48
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
49
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
50
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
51
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
52
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
53
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
54
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
55
+
56
+ if box_format == "corners":
57
+ box1_x1 = boxes_preds[..., 0:1]
58
+ box1_y1 = boxes_preds[..., 1:2]
59
+ box1_x2 = boxes_preds[..., 2:3]
60
+ box1_y2 = boxes_preds[..., 3:4]
61
+ box2_x1 = boxes_labels[..., 0:1]
62
+ box2_y1 = boxes_labels[..., 1:2]
63
+ box2_x2 = boxes_labels[..., 2:3]
64
+ box2_y2 = boxes_labels[..., 3:4]
65
+
66
+ x1 = torch.max(box1_x1, box2_x1)
67
+ y1 = torch.max(box1_y1, box2_y1)
68
+ x2 = torch.min(box1_x2, box2_x2)
69
+ y2 = torch.min(box1_y2, box2_y2)
70
+
71
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
72
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
73
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
74
+
75
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
76
+
77
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
78
+ """
79
+ Video explanation of this function:
80
+ https://youtu.be/YDkjWEN8jNA
81
+
82
+ Does Non Max Suppression given bboxes
83
+
84
+ Parameters:
85
+ bboxes (list): list of lists containing all bboxes with each bboxes
86
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
87
+ iou_threshold (float): threshold where predicted bboxes is correct
88
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
89
+ box_format (str): "midpoint" or "corners" used to specify bboxes
90
+
91
+ Returns:
92
+ list: bboxes after performing NMS given a specific IoU threshold
93
+ """
94
+
95
+ assert type(bboxes) == list
96
+
97
+ bboxes = [box for box in bboxes if box[1] > threshold]
98
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
99
+ bboxes_after_nms = []
100
+
101
+ while bboxes:
102
+ chosen_box = bboxes.pop(0)
103
+
104
+ bboxes = [
105
+ box
106
+ for box in bboxes
107
+ if box[0] != chosen_box[0]
108
+ or intersection_over_union(
109
+ torch.tensor(chosen_box[2:]),
110
+ torch.tensor(box[2:]),
111
+ box_format=box_format,
112
+ )
113
+ < iou_threshold
114
+ ]
115
+
116
+ bboxes_after_nms.append(chosen_box)
117
+
118
+ return bboxes_after_nms
119
+
120
+ def mean_average_precision(
121
+ pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
122
+ ):
123
+ """
124
+ Video explanation of this function:
125
+ https://youtu.be/FppOzcDvaDI
126
+
127
+ This function calculates mean average precision (mAP)
128
+
129
+ Parameters:
130
+ pred_boxes (list): list of lists containing all bboxes with each bboxes
131
+ specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
132
+ true_boxes (list): Similar as pred_boxes except all the correct ones
133
+ iou_threshold (float): threshold where predicted bboxes is correct
134
+ box_format (str): "midpoint" or "corners" used to specify bboxes
135
+ num_classes (int): number of classes
136
+
137
+ Returns:
138
+ float: mAP value across all classes given a specific IoU threshold
139
+ """
140
+
141
+ # list storing all AP for respective classes
142
+ average_precisions = []
143
+
144
+ # used for numerical stability later on
145
+ epsilon = 1e-6
146
+
147
+ for c in range(num_classes):
148
+ detections = []
149
+ ground_truths = []
150
+
151
+ # Go through all predictions and targets,
152
+ # and only add the ones that belong to the
153
+ # current class c
154
+ for detection in pred_boxes:
155
+ if detection[1] == c:
156
+ detections.append(detection)
157
+
158
+ for true_box in true_boxes:
159
+ if true_box[1] == c:
160
+ ground_truths.append(true_box)
161
+
162
+ # find the amount of bboxes for each training example
163
+ # Counter here finds how many ground truth bboxes we get
164
+ # for each training example, so let's say img 0 has 3,
165
+ # img 1 has 5 then we will obtain a dictionary with:
166
+ # amount_bboxes = {0:3, 1:5}
167
+ amount_bboxes = Counter([gt[0] for gt in ground_truths])
168
+
169
+ # We then go through each key, val in this dictionary
170
+ # and convert to the following (w.r.t same example):
171
+ # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
172
+ for key, val in amount_bboxes.items():
173
+ amount_bboxes[key] = torch.zeros(val)
174
+
175
+ # sort by box probabilities which is index 2
176
+ detections.sort(key=lambda x: x[2], reverse=True)
177
+ TP = torch.zeros((len(detections)))
178
+ FP = torch.zeros((len(detections)))
179
+ total_true_bboxes = len(ground_truths)
180
+
181
+ # If none exists for this class then we can safely skip
182
+ if total_true_bboxes == 0:
183
+ continue
184
+
185
+ for detection_idx, detection in enumerate(detections):
186
+ # Only take out the ground_truths that have the same
187
+ # training idx as detection
188
+ ground_truth_img = [
189
+ bbox for bbox in ground_truths if bbox[0] == detection[0]
190
+ ]
191
+
192
+ num_gts = len(ground_truth_img)
193
+ best_iou = 0
194
+
195
+ for idx, gt in enumerate(ground_truth_img):
196
+ iou = intersection_over_union(
197
+ torch.tensor(detection[3:]),
198
+ torch.tensor(gt[3:]),
199
+ box_format=box_format,
200
+ )
201
+
202
+ if iou > best_iou:
203
+ best_iou = iou
204
+ best_gt_idx = idx
205
+
206
+ if best_iou > iou_threshold:
207
+ # only detect ground truth detection once
208
+ if amount_bboxes[detection[0]][best_gt_idx] == 0:
209
+ # true positive and add this bounding box to seen
210
+ TP[detection_idx] = 1
211
+ amount_bboxes[detection[0]][best_gt_idx] = 1
212
+ else:
213
+ FP[detection_idx] = 1
214
+
215
+ # if IOU is lower then the detection is a false positive
216
+ else:
217
+ FP[detection_idx] = 1
218
+
219
+ TP_cumsum = torch.cumsum(TP, dim=0)
220
+ FP_cumsum = torch.cumsum(FP, dim=0)
221
+ recalls = TP_cumsum / (total_true_bboxes + epsilon)
222
+ precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
223
+ precisions = torch.cat((torch.tensor([1]), precisions))
224
+ recalls = torch.cat((torch.tensor([0]), recalls))
225
+ # torch.trapz for numerical integration
226
+ average_precisions.append(torch.trapz(precisions, recalls))
227
+
228
+ return sum(average_precisions) / len(average_precisions)
229
+
230
+ def plot_image(image, boxes):
231
+ """Plots predicted bounding boxes on the image"""
232
+ cmap = plt.get_cmap("tab20b")
233
+ class_labels = (
234
+ config.COCO_LABELS if config.DATASET == "COCO" else config.PASCAL_CLASSES
235
+ )
236
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
237
+ im = np.array(image)
238
+ height, width, _ = im.shape
239
+
240
+ # Create figure and axes
241
+ fig, ax = plt.subplots(1)
242
+ # Display the image
243
+ ax.imshow(im)
244
+
245
+ # box[0] is x midpoint, box[2] is width
246
+ # box[1] is y midpoint, box[3] is height
247
+
248
+ # Create a Rectangle patch
249
+ for box in boxes:
250
+ assert (
251
+ len(box) == 6
252
+ ), "box should contain class pred, confidence, x, y, width, height"
253
+ class_pred = box[0]
254
+ box = box[2:]
255
+ upper_left_x = box[0] - box[2] / 2
256
+ upper_left_y = box[1] - box[3] / 2
257
+ rect = patches.Rectangle(
258
+ (upper_left_x * width, upper_left_y * height),
259
+ box[2] * width,
260
+ box[3] * height,
261
+ linewidth=2,
262
+ edgecolor=colors[int(class_pred)],
263
+ facecolor="none",
264
+ )
265
+ # Add the patch to the Axes
266
+ ax.add_patch(rect)
267
+ plt.text(
268
+ upper_left_x * width,
269
+ upper_left_y * height,
270
+ s=class_labels[int(class_pred)],
271
+ color="white",
272
+ verticalalignment="top",
273
+ bbox={"color": colors[int(class_pred)], "pad": 0},
274
+ )
275
+
276
+ plt.show()
277
+
278
+ def get_evaluation_bboxes(
279
+ loader,
280
+ model,
281
+ iou_threshold,
282
+ anchors,
283
+ threshold,
284
+ box_format="midpoint",
285
+ device="cuda",
286
+ ):
287
+ # make sure model is in eval before get bboxes
288
+ model.eval()
289
+ train_idx = 0
290
+ all_pred_boxes = []
291
+ all_true_boxes = []
292
+ for batch_idx, (x, labels) in enumerate(tqdm(loader)):
293
+ x = x.to(device)
294
+
295
+ with torch.no_grad():
296
+ predictions = model(x)
297
+
298
+ batch_size = x.shape[0]
299
+ bboxes = [[] for _ in range(batch_size)]
300
+ for i in range(3):
301
+ S = predictions[i].shape[2]
302
+ anchor = torch.tensor([*anchors[i]]).to(device) * S
303
+ boxes_scale_i = cells_to_bboxes(predictions[i], anchor, S=S, is_preds=True)
304
+ for idx, (box) in enumerate(boxes_scale_i):
305
+ bboxes[idx] += box
306
+
307
+ # we just want one bbox for each label, not one for each scale
308
+ true_bboxes = cells_to_bboxes(labels[2], anchor, S=S, is_preds=False)
309
+
310
+ for idx in range(batch_size):
311
+ nms_boxes = non_max_suppression(
312
+ bboxes[idx],
313
+ iou_threshold=iou_threshold,
314
+ threshold=threshold,
315
+ box_format=box_format,
316
+ )
317
+
318
+ for nms_box in nms_boxes:
319
+ all_pred_boxes.append([train_idx] + nms_box)
320
+
321
+ for box in true_bboxes[idx]:
322
+ if box[1] > threshold:
323
+ all_true_boxes.append([train_idx] + box)
324
+
325
+ train_idx += 1
326
+
327
+ model.train()
328
+ return all_pred_boxes, all_true_boxes
329
+
330
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
331
+ """
332
+ Scales the predictions coming from the model to
333
+ be relative to the entire image such that they for example later
334
+ can be plotted or.
335
+ INPUT:
336
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
337
+ anchors: the anchors used for the predictions
338
+ S: the number of cells the image is divided in on the width (and height)
339
+ is_preds: whether the input is predictions or the true bounding boxes
340
+ OUTPUT:
341
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
342
+ object score, bounding box coordinates
343
+ """
344
+ BATCH_SIZE = predictions.shape[0]
345
+ num_anchors = len(anchors)
346
+ box_predictions = predictions[..., 1:5]
347
+ if is_preds:
348
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
349
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
350
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
351
+ scores = torch.sigmoid(predictions[..., 0:1])
352
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
353
+ else:
354
+ scores = predictions[..., 0:1]
355
+ best_class = predictions[..., 5:6]
356
+
357
+ cell_indices = (
358
+ torch.arange(S)
359
+ .repeat(predictions.shape[0], 3, S, 1)
360
+ .unsqueeze(-1)
361
+ .to(predictions.device)
362
+ )
363
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
364
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
365
+ w_h = 1 / S * box_predictions[..., 2:4]
366
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(
367
+ BATCH_SIZE, num_anchors * S * S, 6
368
+ )
369
+ return converted_bboxes.tolist()
370
+
371
+ def check_class_accuracy(model, loader, threshold):
372
+ model.eval()
373
+ tot_class_preds, correct_class = 0, 0
374
+ tot_noobj, correct_noobj = 0, 0
375
+ tot_obj, correct_obj = 0, 0
376
+
377
+ for idx, (x, y) in enumerate(tqdm(loader)):
378
+ x = x.to(config.DEVICE)
379
+ with torch.no_grad():
380
+ out = model(x)
381
+
382
+ for i in range(3):
383
+ y[i] = y[i].to(config.DEVICE)
384
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
385
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
386
+
387
+ correct_class += torch.sum(
388
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
389
+ )
390
+ tot_class_preds += torch.sum(obj)
391
+
392
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
393
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
394
+ tot_obj += torch.sum(obj)
395
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
396
+ tot_noobj += torch.sum(noobj)
397
+
398
+ class_acc = (correct_class / (tot_class_preds + 1e-16)) * 100
399
+ no_obj_acc = (correct_noobj / (tot_noobj + 1e-16)) * 100
400
+ obj_acc = (correct_obj / (tot_obj + 1e-16)) * 100
401
+
402
+ print(f"Class accuracy is: {class_acc:2f}%")
403
+ print(f"No obj accuracy is: {no_obj_acc:2f}%")
404
+ print(f"Obj accuracy is: {obj_acc:2f}%")
405
+ model.train()
406
+ return class_acc, no_obj_acc, obj_acc
407
+
408
+ def get_mean_std(loader):
409
+ # var[X] = E[X**2] - E[X]**2
410
+ channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
411
+
412
+ for data, _ in tqdm(loader):
413
+ channels_sum += torch.mean(data, dim=[0, 2, 3])
414
+ channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3])
415
+ num_batches += 1
416
+
417
+ mean = channels_sum / num_batches
418
+ std = (channels_sqrd_sum / num_batches - mean**2) ** 0.5
419
+
420
+ return mean, std
421
+
422
+ def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
423
+ print("=> Saving checkpoint")
424
+ checkpoint = {
425
+ "state_dict": model.state_dict(),
426
+ "optimizer": optimizer.state_dict(),
427
+ }
428
+ torch.save(checkpoint, filename)
429
+
430
+ def load_checkpoint(checkpoint_file, model, optimizer, lr):
431
+ print("=> Loading checkpoint")
432
+ checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
433
+ model.load_state_dict(checkpoint["state_dict"])
434
+ optimizer.load_state_dict(checkpoint["optimizer"])
435
+
436
+ # If we don't do this then it will just have learning rate of old checkpoint
437
+ # and it will lead to many hours of debugging \:
438
+ for param_group in optimizer.param_groups:
439
+ param_group["lr"] = lr
440
+
441
+ def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
442
+ model.eval()
443
+ x, y = next(iter(loader))
444
+ x = x.to(config.DEVICE)
445
+
446
+ with torch.no_grad():
447
+ out = model(x)
448
+ bboxes = [[] for _ in range(x.shape[0])]
449
+ for i in range(3): # should not be hard coded
450
+ batch_size, A, S, _, _ = out[i].shape
451
+ anchor = anchors[i]
452
+ boxes_scale_i = cells_to_bboxes(out[i], anchor, S=S, is_preds=True)
453
+ for idx, (box) in enumerate(boxes_scale_i):
454
+ bboxes[idx] += box
455
+
456
+ model.train() #correct indetation?
457
+
458
+ for i in range(batch_size // 4):
459
+ nms_boxes = non_max_suppression(
460
+ bboxes[i],
461
+ iou_threshold=iou_thresh,
462
+ threshold=thresh,
463
+ box_format="midpoint",
464
+ )
465
+ plot_image(x[i].permute(1, 2, 0).detach().cpu(), nms_boxes)
466
+
467
+ def seed_everything(seed=42):
468
+ os.environ["PYTHONHASHSEED"] = str(seed)
469
+ random.seed(seed)
470
+ np.random.seed(seed)
471
+ torch.manual_seed(seed)
472
+ torch.cuda.manual_seed(seed)
473
+ torch.cuda.manual_seed_all(seed)
474
+ torch.backends.cudnn.deterministic = True
475
+ torch.backends.cudnn.benchmark = False
476
+
477
+
478
+ def clip_coords(boxes, img_shape):
479
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
480
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
481
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
482
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
483
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
484
+
485
+
486
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
487
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
488
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
489
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
490
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
491
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
492
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
493
+ return y
494
+
495
+
496
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
497
+ # Convert normalized segments into pixel segments, shape (n,2)
498
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
499
+ y[..., 0] = w * x[..., 0] + padw # top left x
500
+ y[..., 1] = h * x[..., 1] + padh # top left y
501
+ return y
502
+
503
+
504
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
505
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
506
+ if clip:
507
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
508
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
509
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
510
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
511
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
512
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
513
+ return y
514
+
515
+ def clip_boxes(boxes, shape):
516
+ # Clip boxes (xyxy) to image shape (height, width)
517
+ if isinstance(boxes, torch.Tensor): # faster individually
518
+ boxes[..., 0].clamp_(0, shape[1]) # x1
519
+ boxes[..., 1].clamp_(0, shape[0]) # y1
520
+ boxes[..., 2].clamp_(0, shape[1]) # x2
521
+ boxes[..., 3].clamp_(0, shape[0]) # y2
522
+ else: # np.array (faster grouped)
523
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
524
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2