DhrubaAdhikary1991 commited on
Commit
abb203b
1 Parent(s): ce67e64

push changes

Browse files
Files changed (5) hide show
  1. app.py +86 -0
  2. config.py +185 -0
  3. model.py +176 -0
  4. requirements.txt +8 -0
  5. utils.py +215 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import albumentations as A
3
+ import cv2
4
+
5
+ from albumentations.pytorch import ToTensorV2
6
+ import numpy as np
7
+ import config
8
+ from PIL import Image
9
+ from pytorch_grad_cam import GradCAM
10
+ from pytorch_grad_cam.utils.image import show_cam_on_image
11
+ from model import YOLOv3
12
+ import gradio as gr
13
+ import os
14
+ import matplotlib.pyplot as plt
15
+ from utils import *
16
+
17
+
18
+ model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
19
+ model.load_state_dict(torch.load("custom_yolo_v3.pt", map_location=torch.device('cpu')), strict=False)
20
+
21
+ IMAGE_SIZE = config.IMAGE_SIZE
22
+ test_transforms = A.Compose(
23
+ [
24
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
25
+ A.PadIfNeeded(
26
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
27
+ ),
28
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
29
+ ToTensorV2(),
30
+ ])
31
+
32
+ anchors = (
33
+ torch.tensor(config.ANCHORS)
34
+ * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1,3,2)
35
+ ).to(config.DEVICE)
36
+
37
+ def inference(input_img, transparency = 0.5, thresh=0.8, iou_thresh=0.3):
38
+
39
+ x = test_transforms(image=input_img)['image'].unsqueeze(0).to(config.DEVICE)
40
+ with torch.no_grad():
41
+ out = model(x)
42
+ bboxes = [[] for _ in range(x.shape[0])]
43
+ for i in range(3):
44
+ batch_size, A, S, _, _ = out[i].shape
45
+ anchor = anchors[i]
46
+ boxes_scale_i = cells_to_bboxes(
47
+ out[i], anchor, S=S, is_preds=True
48
+ )
49
+ for idx, (box) in enumerate(boxes_scale_i):
50
+ bboxes[idx] += box
51
+
52
+ model.train()
53
+
54
+ nms_boxes = non_max_suppression(
55
+ bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
56
+ )
57
+ visualization = plot_image(x[0].permute(1,2,0).detach().cpu(), nms_boxes)
58
+
59
+ return visualization
60
+
61
+ title = "Object Detection using YOLOv3"
62
+ description = "A simple Gradio interface to show object detection on images"
63
+ examples = [['./images/006294.jpg'],
64
+ ['./images/005898.jpg'],
65
+ ['./images/003785.jpg'],
66
+ ['./images/001624.jpg'],
67
+ ['./images/006796.jpg'],
68
+ ['./images/003388.jpg'],
69
+ ['./images/002216.jpg'],
70
+ ['./images/000341.jpg'],
71
+ ['./images/006818.jpg'],
72
+ ]
73
+
74
+
75
+ demo = gr.Interface(
76
+ inference,
77
+ inputs = [gr.Image(label="Input Image", type='numpy'),
78
+ gr.Slider(0, 1, value = 0.75, label="Threshold"),
79
+ gr.Slider(0, 1, value = 0.75, label="IoU Threshold"),
80
+ gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM")],
81
+ outputs = [gr.Image(label="Output").style(width=600, height=600)],
82
+ title = title,
83
+ description = description,
84
+ examples = examples,
85
+ )
86
+ demo.launch()
config.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+ import torch
4
+
5
+ from albumentations.pytorch import ToTensorV2
6
+ # from utils import seed_everything
7
+
8
+ DATASET = 'PASCAL_VOC'
9
+ #DATASET = '/kaggle/input/pascal-voc-dataset-used-in-yolov3-video/PASCAL_VOC'
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ # seed_everything() # If you want deterministic behavior
12
+ NUM_WORKERS = 2
13
+ BATCH_SIZE = 32
14
+ IMAGE_SIZE = 416
15
+ NUM_CLASSES = 20
16
+ LEARNING_RATE = 1e-3
17
+ WEIGHT_DECAY = 1e-4
18
+ NUM_EPOCHS = 100
19
+ CONF_THRESHOLD = 0.05
20
+ MAP_IOU_THRESH = 0.5
21
+ NMS_IOU_THRESH = 0.45
22
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
23
+ PIN_MEMORY = True
24
+ LOAD_MODEL = False
25
+ SAVE_MODEL = True
26
+ CHECKPOINT_FILE = "checkpoint.pth.tar"
27
+ IMG_DIR = DATASET + "/images/"
28
+ LABEL_DIR = DATASET + "/labels/"
29
+
30
+ ANCHORS = [
31
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
32
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
33
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
34
+ ] # Note these have been rescaled to be between [0, 1]
35
+
36
+ means = [0.485, 0.456, 0.406]
37
+
38
+ scale = 1.1
39
+ train_transforms = A.Compose(
40
+ [
41
+ A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
42
+ A.PadIfNeeded(
43
+ min_height=int(IMAGE_SIZE * scale),
44
+ min_width=int(IMAGE_SIZE * scale),
45
+ border_mode=cv2.BORDER_CONSTANT,
46
+ ),
47
+ A.Rotate(limit = 10, interpolation=1, border_mode=4),
48
+ A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
49
+ A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
50
+ A.OneOf(
51
+ [
52
+ A.ShiftScaleRotate(
53
+ rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
54
+ ),
55
+ # A.Affine(shear=15, p=0.5, mode="constant"),
56
+ ],
57
+ p=1.0,
58
+ ),
59
+ A.HorizontalFlip(p=0.5),
60
+ A.Blur(p=0.1),
61
+ A.CLAHE(p=0.1),
62
+ A.Posterize(p=0.1),
63
+ A.ToGray(p=0.1),
64
+ A.ChannelShuffle(p=0.05),
65
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
66
+ ToTensorV2(),
67
+ ],
68
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
69
+ )
70
+ test_transforms = A.Compose(
71
+ [
72
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
73
+ A.PadIfNeeded(
74
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
75
+ ),
76
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
77
+ ToTensorV2(),
78
+ ],
79
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
80
+ )
81
+
82
+ PASCAL_CLASSES = [
83
+ "aeroplane",
84
+ "bicycle",
85
+ "bird",
86
+ "boat",
87
+ "bottle",
88
+ "bus",
89
+ "car",
90
+ "cat",
91
+ "chair",
92
+ "cow",
93
+ "diningtable",
94
+ "dog",
95
+ "horse",
96
+ "motorbike",
97
+ "person",
98
+ "pottedplant",
99
+ "sheep",
100
+ "sofa",
101
+ "train",
102
+ "tvmonitor"
103
+ ]
104
+
105
+ COCO_LABELS = ['person',
106
+ 'bicycle',
107
+ 'car',
108
+ 'motorcycle',
109
+ 'airplane',
110
+ 'bus',
111
+ 'train',
112
+ 'truck',
113
+ 'boat',
114
+ 'traffic light',
115
+ 'fire hydrant',
116
+ 'stop sign',
117
+ 'parking meter',
118
+ 'bench',
119
+ 'bird',
120
+ 'cat',
121
+ 'dog',
122
+ 'horse',
123
+ 'sheep',
124
+ 'cow',
125
+ 'elephant',
126
+ 'bear',
127
+ 'zebra',
128
+ 'giraffe',
129
+ 'backpack',
130
+ 'umbrella',
131
+ 'handbag',
132
+ 'tie',
133
+ 'suitcase',
134
+ 'frisbee',
135
+ 'skis',
136
+ 'snowboard',
137
+ 'sports ball',
138
+ 'kite',
139
+ 'baseball bat',
140
+ 'baseball glove',
141
+ 'skateboard',
142
+ 'surfboard',
143
+ 'tennis racket',
144
+ 'bottle',
145
+ 'wine glass',
146
+ 'cup',
147
+ 'fork',
148
+ 'knife',
149
+ 'spoon',
150
+ 'bowl',
151
+ 'banana',
152
+ 'apple',
153
+ 'sandwich',
154
+ 'orange',
155
+ 'broccoli',
156
+ 'carrot',
157
+ 'hot dog',
158
+ 'pizza',
159
+ 'donut',
160
+ 'cake',
161
+ 'chair',
162
+ 'couch',
163
+ 'potted plant',
164
+ 'bed',
165
+ 'dining table',
166
+ 'toilet',
167
+ 'tv',
168
+ 'laptop',
169
+ 'mouse',
170
+ 'remote',
171
+ 'keyboard',
172
+ 'cell phone',
173
+ 'microwave',
174
+ 'oven',
175
+ 'toaster',
176
+ 'sink',
177
+ 'refrigerator',
178
+ 'book',
179
+ 'clock',
180
+ 'vase',
181
+ 'scissors',
182
+ 'teddy bear',
183
+ 'hair drier',
184
+ 'toothbrush'
185
+ ]
model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of YOLOv3 architecture
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ """
9
+ Information about architecture config:
10
+ Tuple is structured by (filters, kernel_size, stride)
11
+ Every conv is a same convolution.
12
+ List is structured by "B" indicating a residual block followed by the number of repeats
13
+ "S" is for scale prediction block and computing the yolo loss
14
+ "U" is for upsampling the feature map and concatenating with a previous layer
15
+ """
16
+ config = [
17
+ (32, 3, 1),
18
+ (64, 3, 2),
19
+ ["B", 1],
20
+ (128, 3, 2),
21
+ ["B", 2],
22
+ (256, 3, 2),
23
+ ["B", 8],
24
+ (512, 3, 2),
25
+ ["B", 8],
26
+ (1024, 3, 2),
27
+ ["B", 4], # To this point is Darknet-53
28
+ (512, 1, 1),
29
+ (1024, 3, 1),
30
+ "S",
31
+ (256, 1, 1),
32
+ "U",
33
+ (256, 1, 1),
34
+ (512, 3, 1),
35
+ "S",
36
+ (128, 1, 1),
37
+ "U",
38
+ (128, 1, 1),
39
+ (256, 3, 1),
40
+ "S",
41
+ ]
42
+
43
+
44
+ class CNNBlock(nn.Module):
45
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
46
+ super().__init__()
47
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
48
+ self.bn = nn.BatchNorm2d(out_channels)
49
+ self.leaky = nn.LeakyReLU(0.1)
50
+ self.use_bn_act = bn_act
51
+
52
+ def forward(self, x):
53
+ if self.use_bn_act:
54
+ return self.leaky(self.bn(self.conv(x)))
55
+ else:
56
+ return self.conv(x)
57
+
58
+
59
+ class ResidualBlock(nn.Module):
60
+ def __init__(self, channels, use_residual=True, num_repeats=1):
61
+ super().__init__()
62
+ self.layers = nn.ModuleList()
63
+ for repeat in range(num_repeats):
64
+ self.layers += [
65
+ nn.Sequential(
66
+ CNNBlock(channels, channels // 2, kernel_size=1),
67
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
68
+ )
69
+ ]
70
+
71
+ self.use_residual = use_residual
72
+ self.num_repeats = num_repeats
73
+
74
+ def forward(self, x):
75
+ for layer in self.layers:
76
+ if self.use_residual:
77
+ x = x + layer(x)
78
+ else:
79
+ x = layer(x)
80
+
81
+ return x
82
+
83
+
84
+ class ScalePrediction(nn.Module):
85
+ def __init__(self, in_channels, num_classes):
86
+ super().__init__()
87
+ self.pred = nn.Sequential(
88
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
89
+ CNNBlock(
90
+ 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
91
+ ),
92
+ )
93
+ self.num_classes = num_classes
94
+
95
+ def forward(self, x):
96
+ return (
97
+ self.pred(x)
98
+ .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
99
+ .permute(0, 1, 3, 4, 2)
100
+ )
101
+
102
+
103
+ class YOLOv3(nn.Module):
104
+ def __init__(self, in_channels=3, num_classes=80):
105
+ super().__init__()
106
+ self.num_classes = num_classes
107
+ self.in_channels = in_channels
108
+ self.layers = self._create_conv_layers()
109
+
110
+ def forward(self, x):
111
+ outputs = [] # for each scale
112
+ route_connections = []
113
+ for layer in self.layers:
114
+ if isinstance(layer, ScalePrediction):
115
+ outputs.append(layer(x))
116
+ continue
117
+
118
+ x = layer(x)
119
+
120
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
121
+ route_connections.append(x)
122
+
123
+ elif isinstance(layer, nn.Upsample):
124
+ x = torch.cat([x, route_connections[-1]], dim=1)
125
+ route_connections.pop()
126
+
127
+ return outputs
128
+
129
+ def _create_conv_layers(self):
130
+ layers = nn.ModuleList()
131
+ in_channels = self.in_channels
132
+
133
+ for module in config:
134
+ if isinstance(module, tuple):
135
+ out_channels, kernel_size, stride = module
136
+ layers.append(
137
+ CNNBlock(
138
+ in_channels,
139
+ out_channels,
140
+ kernel_size=kernel_size,
141
+ stride=stride,
142
+ padding=1 if kernel_size == 3 else 0,
143
+ )
144
+ )
145
+ in_channels = out_channels
146
+
147
+ elif isinstance(module, list):
148
+ num_repeats = module[1]
149
+ layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))
150
+
151
+ elif isinstance(module, str):
152
+ if module == "S":
153
+ layers += [
154
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
155
+ CNNBlock(in_channels, in_channels // 2, kernel_size=1),
156
+ ScalePrediction(in_channels // 2, num_classes=self.num_classes),
157
+ ]
158
+ in_channels = in_channels // 2
159
+
160
+ elif module == "U":
161
+ layers.append(nn.Upsample(scale_factor=2),)
162
+ in_channels = in_channels * 3
163
+
164
+ return layers
165
+
166
+
167
+ if __name__ == "__main__":
168
+ num_classes = 20
169
+ IMAGE_SIZE = 416
170
+ model = YOLOv3(num_classes=num_classes)
171
+ x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
172
+ out = model(x)
173
+ assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
174
+ assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
175
+ assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
176
+ print("Success!")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/albu/albumentations
2
+ torch
3
+ torchvision
4
+ torch-lr-finder
5
+ pytorch-lightning
6
+ grad-cam
7
+ pillow
8
+ numpy
utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import torch
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ from collections import Counter
12
+
13
+
14
+
15
+ def iou_width_height(boxes1, boxes2):
16
+ """
17
+ Parameters:
18
+ boxes1 (tensor): width and height of the first bounding boxes
19
+ boxes2 (tensor): width and height of the second bounding boxes
20
+ Returns:
21
+ tensor: Intersection over union of the corresponding boxes
22
+ """
23
+ intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
24
+ boxes1[..., 1], boxes2[..., 1]
25
+ )
26
+ union = (
27
+ boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
28
+ )
29
+ return intersection / union
30
+
31
+
32
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
33
+ """
34
+ Video explanation of this function:
35
+ https://youtu.be/XXYG5ZWtjj0
36
+
37
+ This function calculates intersection over union (iou) given pred boxes
38
+ and target boxes.
39
+
40
+ Parameters:
41
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
42
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
43
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
44
+
45
+ Returns:
46
+ tensor: Intersection over union for all examples
47
+ """
48
+
49
+ if box_format == "midpoint":
50
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
51
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
52
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
53
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
54
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
55
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
56
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
57
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
58
+
59
+ if box_format == "corners":
60
+ box1_x1 = boxes_preds[..., 0:1]
61
+ box1_y1 = boxes_preds[..., 1:2]
62
+ box1_x2 = boxes_preds[..., 2:3]
63
+ box1_y2 = boxes_preds[..., 3:4]
64
+ box2_x1 = boxes_labels[..., 0:1]
65
+ box2_y1 = boxes_labels[..., 1:2]
66
+ box2_x2 = boxes_labels[..., 2:3]
67
+ box2_y2 = boxes_labels[..., 3:4]
68
+
69
+ x1 = torch.max(box1_x1, box2_x1)
70
+ y1 = torch.max(box1_y1, box2_y1)
71
+ x2 = torch.min(box1_x2, box2_x2)
72
+ y2 = torch.min(box1_y2, box2_y2)
73
+
74
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
75
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
76
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
77
+
78
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
79
+
80
+
81
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
82
+ """
83
+ Video explanation of this function:
84
+ https://youtu.be/YDkjWEN8jNA
85
+
86
+ Does Non Max Suppression given bboxes
87
+
88
+ Parameters:
89
+ bboxes (list): list of lists containing all bboxes with each bboxes
90
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
91
+ iou_threshold (float): threshold where predicted bboxes is correct
92
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
93
+ box_format (str): "midpoint" or "corners" used to specify bboxes
94
+
95
+ Returns:
96
+ list: bboxes after performing NMS given a specific IoU threshold
97
+ """
98
+
99
+ assert type(bboxes) == list
100
+
101
+ bboxes = [box for box in bboxes if box[1] > threshold]
102
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
103
+ bboxes_after_nms = []
104
+
105
+ while bboxes:
106
+ chosen_box = bboxes.pop(0)
107
+
108
+ bboxes = [
109
+ box
110
+ for box in bboxes
111
+ if box[0] != chosen_box[0]
112
+ or intersection_over_union(
113
+ torch.tensor(chosen_box[2:]),
114
+ torch.tensor(box[2:]),
115
+ box_format=box_format,
116
+ )
117
+ < iou_threshold
118
+ ]
119
+
120
+ bboxes_after_nms.append(chosen_box)
121
+
122
+ return bboxes_after_nms
123
+
124
+
125
+
126
+
127
+ def plot_image(image, boxes):
128
+ """Plots predicted bounding boxes on the image"""
129
+ cmap = plt.get_cmap("tab20b")
130
+ class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES
131
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
132
+ im = np.array(image)
133
+ height, width, _ = im.shape
134
+
135
+ # Create figure and axes
136
+ fig, ax = plt.subplots(1)
137
+ # Display the image
138
+ ax.imshow(im)
139
+
140
+ # box[0] is x midpoint, box[2] is width
141
+ # box[1] is y midpoint, box[3] is height
142
+
143
+ # Create a Rectangle patch
144
+ for box in boxes:
145
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
146
+ class_pred = box[0]
147
+ box = box[2:]
148
+ upper_left_x = box[0] - box[2] / 2
149
+ upper_left_y = box[1] - box[3] / 2
150
+ rect = patches.Rectangle(
151
+ (upper_left_x * width, upper_left_y * height),
152
+ box[2] * width,
153
+ box[3] * height,
154
+ linewidth=2,
155
+ edgecolor=colors[int(class_pred)],
156
+ facecolor="none",
157
+ )
158
+ # Add the patch to the Axes
159
+ ax.add_patch(rect)
160
+ plt.text(
161
+ upper_left_x * width,
162
+ upper_left_y * height,
163
+ s=class_labels[int(class_pred)],
164
+ color="white",
165
+ verticalalignment="top",
166
+ bbox={"color": colors[int(class_pred)], "pad": 0},
167
+ )
168
+
169
+ buffer = BytesIO()
170
+ plt.axis('off')
171
+ plt.savefig(buffer,format='png', bbox_inches='tight', pad_inches=0)
172
+ visualization = Image.open(buffer)
173
+
174
+ return visualization
175
+
176
+
177
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
178
+ """
179
+ Scales the predictions coming from the model to
180
+ be relative to the entire image such that they for example later
181
+ can be plotted or.
182
+ INPUT:
183
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
184
+ anchors: the anchors used for the predictions
185
+ S: the number of cells the image is divided in on the width (and height)
186
+ is_preds: whether the input is predictions or the true bounding boxes
187
+ OUTPUT:
188
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
189
+ object score, bounding box coordinates
190
+ """
191
+ BATCH_SIZE = predictions.shape[0]
192
+ num_anchors = len(anchors)
193
+ box_predictions = predictions[..., 1:5]
194
+ if is_preds:
195
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
196
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
197
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
198
+ scores = torch.sigmoid(predictions[..., 0:1])
199
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
200
+ else:
201
+ scores = predictions[..., 0:1]
202
+ best_class = predictions[..., 5:6]
203
+
204
+ cell_indices = (
205
+ torch.arange(S)
206
+ .repeat(predictions.shape[0], 3, S, 1)
207
+ .unsqueeze(-1)
208
+ .to(predictions.device)
209
+ )
210
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
211
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
212
+ w_h = 1 / S * box_predictions[..., 2:4]
213
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
214
+ return converted_bboxes.tolist()
215
+