|
import math |
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from torch import Tensor |
|
from torchvision.ops import batched_nms |
|
|
|
from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig |
|
|
|
|
|
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor: |
|
metrics = metrics.lower() |
|
EPS = 1e-9 |
|
dtype = bbox1.dtype |
|
bbox1 = bbox1.to(torch.float32) |
|
bbox2 = bbox2.to(torch.float32) |
|
|
|
|
|
if bbox1.ndim == 2 and bbox2.ndim == 2: |
|
bbox1 = bbox1.unsqueeze(1) |
|
bbox2 = bbox2.unsqueeze(0) |
|
elif bbox1.ndim == 3 and bbox2.ndim == 3: |
|
bbox1 = bbox1.unsqueeze(2) |
|
bbox2 = bbox2.unsqueeze(1) |
|
|
|
|
|
xmin_inter = torch.max(bbox1[..., 0], bbox2[..., 0]) |
|
ymin_inter = torch.max(bbox1[..., 1], bbox2[..., 1]) |
|
xmax_inter = torch.min(bbox1[..., 2], bbox2[..., 2]) |
|
ymax_inter = torch.min(bbox1[..., 3], bbox2[..., 3]) |
|
|
|
|
|
intersection_area = torch.clamp(xmax_inter - xmin_inter, min=0) * torch.clamp(ymax_inter - ymin_inter, min=0) |
|
|
|
|
|
area_bbox1 = (bbox1[..., 2] - bbox1[..., 0]) * (bbox1[..., 3] - bbox1[..., 1]) |
|
area_bbox2 = (bbox2[..., 2] - bbox2[..., 0]) * (bbox2[..., 3] - bbox2[..., 1]) |
|
|
|
|
|
union_area = area_bbox1 + area_bbox2 - intersection_area |
|
|
|
|
|
iou = intersection_area / (union_area + EPS) |
|
if metrics == "iou": |
|
return iou |
|
|
|
|
|
cx1 = (bbox1[..., 2] + bbox1[..., 0]) / 2 |
|
cy1 = (bbox1[..., 3] + bbox1[..., 1]) / 2 |
|
cx2 = (bbox2[..., 2] + bbox2[..., 0]) / 2 |
|
cy2 = (bbox2[..., 3] + bbox2[..., 1]) / 2 |
|
cent_dis = (cx1 - cx2) ** 2 + (cy1 - cy2) ** 2 |
|
|
|
|
|
c_x = torch.max(bbox1[..., 2], bbox2[..., 2]) - torch.min(bbox1[..., 0], bbox2[..., 0]) |
|
c_y = torch.max(bbox1[..., 3], bbox2[..., 3]) - torch.min(bbox1[..., 1], bbox2[..., 1]) |
|
diag_dis = c_x**2 + c_y**2 + EPS |
|
|
|
diou = iou - (cent_dis / diag_dis) |
|
if metrics == "diou": |
|
return diou |
|
|
|
|
|
arctan = torch.atan((bbox1[..., 2] - bbox1[..., 0]) / (bbox1[..., 3] - bbox1[..., 1] + EPS)) - torch.atan( |
|
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS) |
|
) |
|
v = (4 / (math.pi**2)) * (arctan**2) |
|
alpha = v / (v - iou + 1 + EPS) |
|
|
|
ciou = diou - alpha * v |
|
return ciou.to(dtype) |
|
|
|
|
|
def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"): |
|
data_type = bbox.dtype |
|
in_type, out_type = indicator.replace(" ", "").split("->") |
|
|
|
if in_type not in ["xyxy", "xywh", "xycwh"] or out_type not in ["xyxy", "xywh", "xycwh"]: |
|
raise ValueError("Invalid input or output format") |
|
|
|
if in_type == "xywh": |
|
x_min = bbox[..., 0] |
|
y_min = bbox[..., 1] |
|
x_max = bbox[..., 0] + bbox[..., 2] |
|
y_max = bbox[..., 1] + bbox[..., 3] |
|
elif in_type == "xyxy": |
|
x_min = bbox[..., 0] |
|
y_min = bbox[..., 1] |
|
x_max = bbox[..., 2] |
|
y_max = bbox[..., 3] |
|
elif in_type == "xycwh": |
|
x_min = bbox[..., 0] - bbox[..., 2] / 2 |
|
y_min = bbox[..., 1] - bbox[..., 3] / 2 |
|
x_max = bbox[..., 0] + bbox[..., 2] / 2 |
|
y_max = bbox[..., 1] + bbox[..., 3] / 2 |
|
|
|
if out_type == "xywh": |
|
bbox = torch.stack([x_min, y_min, x_max - x_min, y_max - y_min], dim=-1) |
|
elif out_type == "xyxy": |
|
bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1) |
|
elif out_type == "xycwh": |
|
bbox = torch.stack([(x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min], dim=-1) |
|
|
|
return bbox.to(dtype=data_type) |
|
|
|
|
|
def generate_anchors(image_size: List[int], strides: List[int], device): |
|
W, H = image_size |
|
anchors = [] |
|
scaler = [] |
|
for stride in strides: |
|
anchor_num = W // stride * H // stride |
|
scaler.append(torch.full((anchor_num,), stride, device=device)) |
|
shift = stride // 2 |
|
x = torch.arange(0, W, stride, device=device) + shift |
|
y = torch.arange(0, H, stride, device=device) + shift |
|
anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij") |
|
anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1) |
|
anchors.append(anchor) |
|
all_anchors = torch.cat(anchors, dim=0) |
|
all_scalers = torch.cat(scaler, dim=0) |
|
return all_anchors, all_scalers |
|
|
|
|
|
class AnchorBoxConverter: |
|
def __init__(self, model_cfg: ModelConfig, image_size: List[int], device: torch.device) -> None: |
|
self.reg_max = model_cfg.anchor.reg_max |
|
self.class_num = model_cfg.class_num |
|
self.strides = model_cfg.anchor.strides |
|
|
|
self.anchors, self.scaler = generate_anchors(image_size, self.strides, device) |
|
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device) |
|
|
|
def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor: |
|
""" |
|
args: |
|
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80 |
|
return: |
|
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY) |
|
""" |
|
preds = [] |
|
for pred in predicts: |
|
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) |
|
preds = torch.concat(preds, dim=1) |
|
|
|
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1) |
|
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4) |
|
if with_logits: |
|
preds_cls = preds_cls.sigmoid() |
|
|
|
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1) |
|
|
|
lt, rb = pred_LTRB.chunk(2, dim=-1) |
|
pred_minXY = self.anchors - lt |
|
pred_maxXY = self.anchors + rb |
|
preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1) |
|
|
|
predicts = torch.cat([preds_cls, preds_box], dim=-1) |
|
|
|
return predicts, preds_anc |
|
|
|
|
|
class BoxMatcher: |
|
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None: |
|
self.class_num = class_num |
|
self.anchors = anchors |
|
for attr_name in cfg: |
|
setattr(self, attr_name, cfg[attr_name]) |
|
|
|
def get_valid_matrix(self, target_bbox: Tensor): |
|
""" |
|
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor. |
|
|
|
Args: |
|
target_bbox [batch x targets x 4]: The bounding box of each targets. |
|
Returns: |
|
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors. |
|
""" |
|
Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3) |
|
anchors = self.anchors[None, None] |
|
anchors_x, anchors_y = anchors.unbind(dim=3) |
|
target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax) |
|
target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax) |
|
target_on_anchor = target_in_x & target_in_y |
|
return target_on_anchor |
|
|
|
def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor: |
|
""" |
|
Get the (predicted class' probabilities) corresponding to the target classes across all anchors |
|
|
|
Args: |
|
predict_cls [batch x class x anchors]: The predicted probabilities for each class across each anchor. |
|
target_cls [batch x targets]: The class index for each target. |
|
|
|
Returns: |
|
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`. |
|
""" |
|
target_cls = target_cls.expand(-1, -1, 8400) |
|
predict_cls = predict_cls.transpose(1, 2) |
|
cls_probabilities = torch.gather(predict_cls, 1, target_cls) |
|
return cls_probabilities |
|
|
|
def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor: |
|
""" |
|
Get the IoU between each target bounding box and each predicted bounding box. |
|
|
|
Args: |
|
predict_bbox [batch x predicts x 4]: Bounding box with [x1, y1, x2, y2]. |
|
target_bbox [batch x targets x 4]: Bounding box with [x1, y1, x2, y2]. |
|
Returns: |
|
[batch x targets x predicts]: The IoU scores between each target and predicted. |
|
""" |
|
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1) |
|
|
|
def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Filter the top-k suitability of targets for each anchor. |
|
|
|
Args: |
|
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors |
|
topk (int, optional): Number of top scores to retain per anchor. |
|
|
|
Returns: |
|
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor |
|
topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions. |
|
""" |
|
values, indices = target_matrix.topk(topk, dim=-1) |
|
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device) |
|
topk_targets.scatter_(dim=-1, index=indices, src=values) |
|
topk_masks = topk_targets > 0 |
|
return topk_targets, topk_masks |
|
|
|
def filter_duplicates(self, target_matrix: Tensor): |
|
""" |
|
Filter the maximum suitability target index of each anchor. |
|
|
|
Args: |
|
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors |
|
|
|
Returns: |
|
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors |
|
""" |
|
unique_indices = target_matrix.argmax(dim=1) |
|
return unique_indices[..., None] |
|
|
|
def __call__(self, target: Tensor, predict: Tensor) -> Tuple[Tensor, Tensor]: |
|
""" |
|
1. For each anchor prediction, find the highest suitability targets |
|
2. Select the targets |
|
2. Noramlize the class probilities of targets |
|
""" |
|
predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) |
|
target_cls, target_bbox = target.split([1, 4], dim=-1) |
|
target_cls = target_cls.long().clamp(0) |
|
|
|
|
|
grid_mask = self.get_valid_matrix(target_bbox) |
|
|
|
|
|
iou_mat = self.get_iou_matrix(predict_bbox, target_bbox) |
|
|
|
|
|
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls) |
|
|
|
target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"]) |
|
|
|
|
|
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk) |
|
|
|
|
|
unique_indices = self.filter_duplicates(topk_targets) |
|
|
|
|
|
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool() |
|
|
|
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4)) |
|
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1) |
|
align_cls = F.one_hot(align_cls, self.class_num) |
|
|
|
|
|
max_target = target_matrix.amax(dim=-1, keepdim=True) |
|
max_iou = iou_mat.amax(dim=-1, keepdim=True) |
|
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou |
|
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices) |
|
align_cls = align_cls * normalize_term * valid_mask[:, :, None] |
|
|
|
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool() |
|
|
|
|
|
def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig): |
|
|
|
cls_dist, bbox = predicts.split([80, 4], dim=-1) |
|
|
|
|
|
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True) |
|
valid_mask = cls_val > nms_cfg.min_confidence |
|
valid_cls = cls_idx[valid_mask].float() |
|
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4) |
|
|
|
batch_idx, *_ = torch.where(valid_mask) |
|
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou) |
|
predicts_nms = [] |
|
for idx in range(predicts.size(0)): |
|
instance_idx = nms_idx[idx == batch_idx[nms_idx]] |
|
|
|
predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1) |
|
|
|
predicts_nms.append(predict_nms) |
|
return predicts_nms |
|
|