Spaces:
Runtime error
Runtime error
import logging | |
from typing import List | |
import albumentations as A | |
import streamlit as st | |
import torch | |
from albumentations import pytorch | |
from src.model_architecture import Net | |
anchors = torch.tensor( | |
[ | |
[[0.2800, 0.2200], [0.3800, 0.4800], [0.9000, 0.7800]], | |
[[0.0700, 0.1500], [0.1500, 0.1100], [0.1400, 0.2900]], | |
[[0.0200, 0.0300], [0.0400, 0.0700], [0.0800, 0.0600]], | |
] | |
) | |
transforms = A.Compose( | |
[ | |
A.Resize(always_apply=False, p=1, height=192, width=192, interpolation=1), | |
A.Normalize(), | |
pytorch.transforms.ToTensorV2(), | |
] | |
) | |
def cells_to_bboxes( | |
predictions: torch.Tensor, tensor_anchors: torch.Tensor, s: int, is_preds: bool = True | |
) -> List[List]: | |
""" | |
Scale the predictions coming from the model_files to | |
be relative to the entire image such that they for example later | |
can be plotted or. | |
Args: | |
predictions: tensor of size (N, 3, S, S, num_classes+5) | |
tensor_anchors: the anchors used for the predictions | |
s: the number of cells the image is divided in on the width (and height) | |
is_preds: whether the input is predictions or the true bounding boxes | |
Returns: | |
converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index, | |
object score, bounding box coordinates | |
""" | |
batch_size = predictions.shape[0] | |
num_anchors = len(tensor_anchors) | |
box_predictions = predictions[..., 1:5] | |
if is_preds: | |
tensor_anchors = tensor_anchors.reshape(1, len(tensor_anchors), 1, 1, 2) | |
box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2]) | |
box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * tensor_anchors | |
scores = torch.sigmoid(predictions[..., 0:1]) | |
best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1) | |
else: | |
scores = predictions[..., 0:1] | |
best_class = predictions[..., 5:6] | |
cell_indices = torch.arange(s).repeat(predictions.shape[0], 3, s, 1).unsqueeze(-1).to(predictions.device) | |
x = 1 / s * (box_predictions[..., 0:1] + cell_indices) | |
y = 1 / s * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4)) | |
w_h = 1 / s * box_predictions[..., 2:4] | |
converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(batch_size, num_anchors * s * s, 6) | |
return converted_bboxes.tolist() | |
def non_max_suppression( | |
bboxes: List[List], iou_threshold: float, threshold: float, box_format: str = 'corners' | |
) -> List[List]: | |
""" | |
Apply nms to the bboxes. | |
Video explanation of this function: | |
https://youtu.be/YDkjWEN8jNA | |
Does Non Max Suppression given bboxes | |
Args: | |
bboxes (list): list of lists containing all bboxes with each bboxes | |
specified as [class_pred, prob_score, x1, y1, x2, y2] | |
iou_threshold (float): threshold where predicted bboxes is correct | |
threshold (float): threshold to remove predicted bboxes (independent of IoU) | |
box_format (str): 'midpoint' or 'corners' used to specify bboxes | |
Returns: | |
list: bboxes after performing NMS given a specific IoU threshold | |
""" | |
bboxes = [box for box in bboxes if box[1] > threshold] | |
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True) | |
bboxes_after_nms = [] | |
while bboxes: | |
chosen_box = bboxes.pop(0) | |
bboxes = [ | |
box | |
for box in bboxes | |
if box[0] != chosen_box[0] | |
or intersection_over_union( | |
torch.tensor(chosen_box[2:]), | |
torch.tensor(box[2:]), | |
box_format=box_format, | |
) | |
< iou_threshold | |
] | |
bboxes_after_nms.append(chosen_box) | |
return bboxes_after_nms | |
def intersection_over_union( | |
boxes_preds: torch.Tensor, boxes_labels: torch.Tensor, box_format: str = 'midpoint' | |
) -> torch.Tensor: | |
""" | |
Calculate iou. | |
Video explanation of this function: | |
https://youtu.be/XXYG5ZWtjj0 | |
This function calculates intersection over union (iou) given pred boxes | |
and target boxes. | |
Args: | |
boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4) | |
boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4) | |
box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2) | |
Returns: | |
tensor: Intersection over union for all examples | |
""" | |
if box_format == 'midpoint': | |
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2 | |
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2 | |
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2 | |
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2 | |
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2 | |
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2 | |
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2 | |
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2 | |
if box_format == 'corners': | |
box1_x1 = boxes_preds[..., 0:1] | |
box1_y1 = boxes_preds[..., 1:2] | |
box1_x2 = boxes_preds[..., 2:3] | |
box1_y2 = boxes_preds[..., 3:4] | |
box2_x1 = boxes_labels[..., 0:1] | |
box2_y1 = boxes_labels[..., 1:2] | |
box2_x2 = boxes_labels[..., 2:3] | |
box2_y2 = boxes_labels[..., 3:4] | |
x1 = torch.max(box1_x1, box2_x1) | |
y1 = torch.max(box1_y1, box2_y1) | |
x2 = torch.min(box1_x2, box2_x2) | |
y2 = torch.min(box1_y2, box2_y2) | |
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) | |
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1)) | |
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1)) | |
return intersection / (box1_area + box2_area - intersection + 1e-6) | |
def predict( | |
model: torch.nn.Module, image: torch.Tensor, iou_threshold: float = 1.0, threshold: float = 0.05 | |
) -> List[List]: | |
""" | |
Apply the model_files to the predictions and to postprocessing | |
Args: | |
model: a trained pytorch model_files. | |
image: image as a torch tensor | |
iou_threshold: a threshold for intersection_over_union function | |
threshold: a threshold for bbox probability | |
Returns: | |
predicted bboxes | |
""" | |
# apply model_files. add a dimension to imitate a batch size of 1 | |
logits = model(image[None, :]) | |
logging.info('predicted') | |
# postprocess. In fact, we could remove indexing with idx here, as there is a single image. | |
# But I prefer to keep it so that this code could be easier changed for cases with batch size > 1 | |
bboxes: List[List] = [[] for _ in range(1)] | |
idx = 0 | |
for i in range(3): | |
S = logits[i].shape[2] | |
# it could be better to initialize anchors inside the function, but I don't want to do it for every prediction. | |
anchor = anchors[i] * S | |
boxes_scale_i = cells_to_bboxes(logits[i], anchor, s=S, is_preds=True) | |
for idx, (box) in enumerate(boxes_scale_i): | |
bboxes[idx] += box | |
logging.info('Starting nms') | |
nms_boxes = non_max_suppression( | |
bboxes[idx], | |
iou_threshold=iou_threshold, | |
threshold=threshold, | |
box_format='midpoint', | |
) | |
return nms_boxes | |
def get_model(): | |
model_name = 'model_files/best_model.pth' | |
model = Net() | |
model.load_state_dict(torch.load(model_name)) | |
model.eval() | |
return model | |