Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import List | |
| import albumentations as A | |
| import streamlit as st | |
| import torch | |
| from albumentations import pytorch | |
| from src.model_architecture import Net | |
| anchors = torch.tensor( | |
| [ | |
| [[0.2800, 0.2200], [0.3800, 0.4800], [0.9000, 0.7800]], | |
| [[0.0700, 0.1500], [0.1500, 0.1100], [0.1400, 0.2900]], | |
| [[0.0200, 0.0300], [0.0400, 0.0700], [0.0800, 0.0600]], | |
| ] | |
| ) | |
| transforms = A.Compose( | |
| [ | |
| A.Resize(always_apply=False, p=1, height=192, width=192, interpolation=1), | |
| A.Normalize(), | |
| pytorch.transforms.ToTensorV2(), | |
| ] | |
| ) | |
| def cells_to_bboxes( | |
| predictions: torch.Tensor, tensor_anchors: torch.Tensor, s: int, is_preds: bool = True | |
| ) -> List[List]: | |
| """ | |
| Scale the predictions coming from the model_files to | |
| be relative to the entire image such that they for example later | |
| can be plotted or. | |
| Args: | |
| predictions: tensor of size (N, 3, S, S, num_classes+5) | |
| tensor_anchors: the anchors used for the predictions | |
| s: the number of cells the image is divided in on the width (and height) | |
| is_preds: whether the input is predictions or the true bounding boxes | |
| Returns: | |
| converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index, | |
| object score, bounding box coordinates | |
| """ | |
| batch_size = predictions.shape[0] | |
| num_anchors = len(tensor_anchors) | |
| box_predictions = predictions[..., 1:5] | |
| if is_preds: | |
| tensor_anchors = tensor_anchors.reshape(1, len(tensor_anchors), 1, 1, 2) | |
| box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2]) | |
| box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * tensor_anchors | |
| scores = torch.sigmoid(predictions[..., 0:1]) | |
| best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1) | |
| else: | |
| scores = predictions[..., 0:1] | |
| best_class = predictions[..., 5:6] | |
| cell_indices = torch.arange(s).repeat(predictions.shape[0], 3, s, 1).unsqueeze(-1).to(predictions.device) | |
| x = 1 / s * (box_predictions[..., 0:1] + cell_indices) | |
| y = 1 / s * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4)) | |
| w_h = 1 / s * box_predictions[..., 2:4] | |
| converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(batch_size, num_anchors * s * s, 6) | |
| return converted_bboxes.tolist() | |
| def non_max_suppression( | |
| bboxes: List[List], iou_threshold: float, threshold: float, box_format: str = 'corners' | |
| ) -> List[List]: | |
| """ | |
| Apply nms to the bboxes. | |
| Video explanation of this function: | |
| https://youtu.be/YDkjWEN8jNA | |
| Does Non Max Suppression given bboxes | |
| Args: | |
| bboxes (list): list of lists containing all bboxes with each bboxes | |
| specified as [class_pred, prob_score, x1, y1, x2, y2] | |
| iou_threshold (float): threshold where predicted bboxes is correct | |
| threshold (float): threshold to remove predicted bboxes (independent of IoU) | |
| box_format (str): 'midpoint' or 'corners' used to specify bboxes | |
| Returns: | |
| list: bboxes after performing NMS given a specific IoU threshold | |
| """ | |
| bboxes = [box for box in bboxes if box[1] > threshold] | |
| bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True) | |
| bboxes_after_nms = [] | |
| while bboxes: | |
| chosen_box = bboxes.pop(0) | |
| bboxes = [ | |
| box | |
| for box in bboxes | |
| if box[0] != chosen_box[0] | |
| or intersection_over_union( | |
| torch.tensor(chosen_box[2:]), | |
| torch.tensor(box[2:]), | |
| box_format=box_format, | |
| ) | |
| < iou_threshold | |
| ] | |
| bboxes_after_nms.append(chosen_box) | |
| return bboxes_after_nms | |
| def intersection_over_union( | |
| boxes_preds: torch.Tensor, boxes_labels: torch.Tensor, box_format: str = 'midpoint' | |
| ) -> torch.Tensor: | |
| """ | |
| Calculate iou. | |
| Video explanation of this function: | |
| https://youtu.be/XXYG5ZWtjj0 | |
| This function calculates intersection over union (iou) given pred boxes | |
| and target boxes. | |
| Args: | |
| boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4) | |
| boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4) | |
| box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2) | |
| Returns: | |
| tensor: Intersection over union for all examples | |
| """ | |
| if box_format == 'midpoint': | |
| box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2 | |
| box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2 | |
| box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2 | |
| box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2 | |
| box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2 | |
| box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2 | |
| box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2 | |
| box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2 | |
| if box_format == 'corners': | |
| box1_x1 = boxes_preds[..., 0:1] | |
| box1_y1 = boxes_preds[..., 1:2] | |
| box1_x2 = boxes_preds[..., 2:3] | |
| box1_y2 = boxes_preds[..., 3:4] | |
| box2_x1 = boxes_labels[..., 0:1] | |
| box2_y1 = boxes_labels[..., 1:2] | |
| box2_x2 = boxes_labels[..., 2:3] | |
| box2_y2 = boxes_labels[..., 3:4] | |
| x1 = torch.max(box1_x1, box2_x1) | |
| y1 = torch.max(box1_y1, box2_y1) | |
| x2 = torch.min(box1_x2, box2_x2) | |
| y2 = torch.min(box1_y2, box2_y2) | |
| intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) | |
| box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1)) | |
| box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1)) | |
| return intersection / (box1_area + box2_area - intersection + 1e-6) | |
| def predict( | |
| model: torch.nn.Module, image: torch.Tensor, iou_threshold: float = 1.0, threshold: float = 0.05 | |
| ) -> List[List]: | |
| """ | |
| Apply the model_files to the predictions and to postprocessing | |
| Args: | |
| model: a trained pytorch model_files. | |
| image: image as a torch tensor | |
| iou_threshold: a threshold for intersection_over_union function | |
| threshold: a threshold for bbox probability | |
| Returns: | |
| predicted bboxes | |
| """ | |
| # apply model_files. add a dimension to imitate a batch size of 1 | |
| logits = model(image[None, :]) | |
| logging.info('predicted') | |
| # postprocess. In fact, we could remove indexing with idx here, as there is a single image. | |
| # But I prefer to keep it so that this code could be easier changed for cases with batch size > 1 | |
| bboxes: List[List] = [[] for _ in range(1)] | |
| idx = 0 | |
| for i in range(3): | |
| S = logits[i].shape[2] | |
| # it could be better to initialize anchors inside the function, but I don't want to do it for every prediction. | |
| anchor = anchors[i] * S | |
| boxes_scale_i = cells_to_bboxes(logits[i], anchor, s=S, is_preds=True) | |
| for idx, (box) in enumerate(boxes_scale_i): | |
| bboxes[idx] += box | |
| logging.info('Starting nms') | |
| nms_boxes = non_max_suppression( | |
| bboxes[idx], | |
| iou_threshold=iou_threshold, | |
| threshold=threshold, | |
| box_format='midpoint', | |
| ) | |
| return nms_boxes | |
| def get_model(): | |
| model_name = 'model_files/best_model.pth' | |
| model = Net() | |
| model.load_state_dict(torch.load(model_name)) | |
| model.eval() | |
| return model | |