File size: 7,428 Bytes
b683920
cf41825
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99d3d67
cf41825
 
b683920
 
 
 
 
 
99d3d67
b683920
 
 
 
 
 
 
99d3d67
b683920
 
99d3d67
b683920
99d3d67
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf41825
 
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf41825
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99d3d67
b683920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a6ccf
b683920
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import logging
from typing import List

import albumentations as A
import streamlit as st
import torch
from albumentations import pytorch

from src.model_architecture import Net

anchors = torch.tensor(
    [
        [[0.2800, 0.2200], [0.3800, 0.4800], [0.9000, 0.7800]],
        [[0.0700, 0.1500], [0.1500, 0.1100], [0.1400, 0.2900]],
        [[0.0200, 0.0300], [0.0400, 0.0700], [0.0800, 0.0600]],
    ]
)

transforms = A.Compose(
    [
        A.Resize(always_apply=False, p=1, height=192, width=192, interpolation=1),
        A.Normalize(),
        pytorch.transforms.ToTensorV2(),
    ]
)


def cells_to_bboxes(
    predictions: torch.Tensor, tensor_anchors: torch.Tensor, s: int, is_preds: bool = True
) -> List[List]:
    """
    Scale the predictions coming from the model_files to
    be relative to the entire image such that they for example later
    can be plotted or.
    Args:
        predictions: tensor of size (N, 3, S, S, num_classes+5)
        tensor_anchors: the anchors used for the predictions
        s: the number of cells the image is divided in on the width (and height)
        is_preds: whether the input is predictions or the true bounding boxes
    Returns:
        converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
                      object score, bounding box coordinates
    """
    batch_size = predictions.shape[0]
    num_anchors = len(tensor_anchors)
    box_predictions = predictions[..., 1:5]
    if is_preds:
        tensor_anchors = tensor_anchors.reshape(1, len(tensor_anchors), 1, 1, 2)
        box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
        box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * tensor_anchors
        scores = torch.sigmoid(predictions[..., 0:1])
        best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
    else:
        scores = predictions[..., 0:1]
        best_class = predictions[..., 5:6]

    cell_indices = torch.arange(s).repeat(predictions.shape[0], 3, s, 1).unsqueeze(-1).to(predictions.device)
    x = 1 / s * (box_predictions[..., 0:1] + cell_indices)
    y = 1 / s * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
    w_h = 1 / s * box_predictions[..., 2:4]
    converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(batch_size, num_anchors * s * s, 6)
    return converted_bboxes.tolist()


def non_max_suppression(
    bboxes: List[List], iou_threshold: float, threshold: float, box_format: str = 'corners'
) -> List[List]:
    """
    Apply nms to the bboxes.

    Video explanation of this function:
    https://youtu.be/YDkjWEN8jNA
    Does Non Max Suppression given bboxes
    Args:
        bboxes (list): list of lists containing all bboxes with each bboxes
        specified as [class_pred, prob_score, x1, y1, x2, y2]
        iou_threshold (float): threshold where predicted bboxes is correct
        threshold (float): threshold to remove predicted bboxes (independent of IoU)
        box_format (str): 'midpoint' or 'corners' used to specify bboxes
    Returns:
        list: bboxes after performing NMS given a specific IoU threshold
    """

    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < iou_threshold
        ]

        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms


def intersection_over_union(
    boxes_preds: torch.Tensor, boxes_labels: torch.Tensor, box_format: str = 'midpoint'
) -> torch.Tensor:
    """
    Calculate iou.

    Video explanation of this function:
    https://youtu.be/XXYG5ZWtjj0
    This function calculates intersection over union (iou) given pred boxes
    and target boxes.
    Args:
        boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
        boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
    Returns:
        tensor: Intersection over union for all examples
    """

    if box_format == 'midpoint':
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == 'corners':
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)


def predict(
    model: torch.nn.Module, image: torch.Tensor, iou_threshold: float = 1.0, threshold: float = 0.05
) -> List[List]:
    """
    Apply the model_files to the predictions and to postprocessing
    Args:
        model: a trained pytorch model_files.
        image: image as a torch tensor
        iou_threshold: a threshold for intersection_over_union function
        threshold: a threshold for bbox probability

    Returns:
        predicted bboxes

    """
    # apply model_files. add a dimension to imitate a batch size of 1
    logits = model(image[None, :])
    logging.info('predicted')

    # postprocess. In fact, we could remove indexing with idx here, as there is a single image.
    # But I prefer to keep it so that this code could be easier changed for cases with batch size > 1
    bboxes: List[List] = [[] for _ in range(1)]
    idx = 0
    for i in range(3):
        S = logits[i].shape[2]
        # it could be better to initialize anchors inside the function, but I don't want to do it for every prediction.
        anchor = anchors[i] * S
        boxes_scale_i = cells_to_bboxes(logits[i], anchor, s=S, is_preds=True)
        for idx, (box) in enumerate(boxes_scale_i):
            bboxes[idx] += box
    logging.info('Starting nms')
    nms_boxes = non_max_suppression(
        bboxes[idx],
        iou_threshold=iou_threshold,
        threshold=threshold,
        box_format='midpoint',
    )

    return nms_boxes


@st.cache_data
def get_model():
    model_name = 'model_files/best_model.pth'

    model = Net()
    model.load_state_dict(torch.load(model_name))
    model.eval()

    return model