digit-draw-detect / src /ml_utils.py
Andrey
Erlemar patch 1 (#18)
40a6ccf unverified
raw
history blame
7.43 kB
import logging
from typing import List
import albumentations as A
import streamlit as st
import torch
from albumentations import pytorch
from src.model_architecture import Net
anchors = torch.tensor(
[
[[0.2800, 0.2200], [0.3800, 0.4800], [0.9000, 0.7800]],
[[0.0700, 0.1500], [0.1500, 0.1100], [0.1400, 0.2900]],
[[0.0200, 0.0300], [0.0400, 0.0700], [0.0800, 0.0600]],
]
)
transforms = A.Compose(
[
A.Resize(always_apply=False, p=1, height=192, width=192, interpolation=1),
A.Normalize(),
pytorch.transforms.ToTensorV2(),
]
)
def cells_to_bboxes(
predictions: torch.Tensor, tensor_anchors: torch.Tensor, s: int, is_preds: bool = True
) -> List[List]:
"""
Scale the predictions coming from the model_files to
be relative to the entire image such that they for example later
can be plotted or.
Args:
predictions: tensor of size (N, 3, S, S, num_classes+5)
tensor_anchors: the anchors used for the predictions
s: the number of cells the image is divided in on the width (and height)
is_preds: whether the input is predictions or the true bounding boxes
Returns:
converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
object score, bounding box coordinates
"""
batch_size = predictions.shape[0]
num_anchors = len(tensor_anchors)
box_predictions = predictions[..., 1:5]
if is_preds:
tensor_anchors = tensor_anchors.reshape(1, len(tensor_anchors), 1, 1, 2)
box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * tensor_anchors
scores = torch.sigmoid(predictions[..., 0:1])
best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
else:
scores = predictions[..., 0:1]
best_class = predictions[..., 5:6]
cell_indices = torch.arange(s).repeat(predictions.shape[0], 3, s, 1).unsqueeze(-1).to(predictions.device)
x = 1 / s * (box_predictions[..., 0:1] + cell_indices)
y = 1 / s * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
w_h = 1 / s * box_predictions[..., 2:4]
converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(batch_size, num_anchors * s * s, 6)
return converted_bboxes.tolist()
def non_max_suppression(
bboxes: List[List], iou_threshold: float, threshold: float, box_format: str = 'corners'
) -> List[List]:
"""
Apply nms to the bboxes.
Video explanation of this function:
https://youtu.be/YDkjWEN8jNA
Does Non Max Suppression given bboxes
Args:
bboxes (list): list of lists containing all bboxes with each bboxes
specified as [class_pred, prob_score, x1, y1, x2, y2]
iou_threshold (float): threshold where predicted bboxes is correct
threshold (float): threshold to remove predicted bboxes (independent of IoU)
box_format (str): 'midpoint' or 'corners' used to specify bboxes
Returns:
list: bboxes after performing NMS given a specific IoU threshold
"""
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_after_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
bboxes = [
box
for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(
torch.tensor(chosen_box[2:]),
torch.tensor(box[2:]),
box_format=box_format,
)
< iou_threshold
]
bboxes_after_nms.append(chosen_box)
return bboxes_after_nms
def intersection_over_union(
boxes_preds: torch.Tensor, boxes_labels: torch.Tensor, box_format: str = 'midpoint'
) -> torch.Tensor:
"""
Calculate iou.
Video explanation of this function:
https://youtu.be/XXYG5ZWtjj0
This function calculates intersection over union (iou) given pred boxes
and target boxes.
Args:
boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
Returns:
tensor: Intersection over union for all examples
"""
if box_format == 'midpoint':
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
if box_format == 'corners':
box1_x1 = boxes_preds[..., 0:1]
box1_y1 = boxes_preds[..., 1:2]
box1_x2 = boxes_preds[..., 2:3]
box1_y2 = boxes_preds[..., 3:4]
box2_x1 = boxes_labels[..., 0:1]
box2_y1 = boxes_labels[..., 1:2]
box2_x2 = boxes_labels[..., 2:3]
box2_y2 = boxes_labels[..., 3:4]
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection / (box1_area + box2_area - intersection + 1e-6)
def predict(
model: torch.nn.Module, image: torch.Tensor, iou_threshold: float = 1.0, threshold: float = 0.05
) -> List[List]:
"""
Apply the model_files to the predictions and to postprocessing
Args:
model: a trained pytorch model_files.
image: image as a torch tensor
iou_threshold: a threshold for intersection_over_union function
threshold: a threshold for bbox probability
Returns:
predicted bboxes
"""
# apply model_files. add a dimension to imitate a batch size of 1
logits = model(image[None, :])
logging.info('predicted')
# postprocess. In fact, we could remove indexing with idx here, as there is a single image.
# But I prefer to keep it so that this code could be easier changed for cases with batch size > 1
bboxes: List[List] = [[] for _ in range(1)]
idx = 0
for i in range(3):
S = logits[i].shape[2]
# it could be better to initialize anchors inside the function, but I don't want to do it for every prediction.
anchor = anchors[i] * S
boxes_scale_i = cells_to_bboxes(logits[i], anchor, s=S, is_preds=True)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
logging.info('Starting nms')
nms_boxes = non_max_suppression(
bboxes[idx],
iou_threshold=iou_threshold,
threshold=threshold,
box_format='midpoint',
)
return nms_boxes
@st.cache_data
def get_model():
model_name = 'model_files/best_model.pth'
model = Net()
model.load_state_dict(torch.load(model_name))
model.eval()
return model