Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from typing import Sequence, Union | |
import torch | |
from mmdet.structures.bbox.transforms import get_box_tensor | |
from torch import Tensor | |
def make_divisible(x: float, | |
widen_factor: float = 1.0, | |
divisor: int = 8) -> int: | |
"""Make sure that x*widen_factor is divisible by divisor.""" | |
return math.ceil(x * widen_factor / divisor) * divisor | |
def make_round(x: float, deepen_factor: float = 1.0) -> int: | |
"""Make sure that x*deepen_factor becomes an integer not less than 1.""" | |
return max(round(x * deepen_factor), 1) if x > 1 else x | |
def gt_instances_preprocess(batch_gt_instances: Union[Tensor, Sequence], | |
batch_size: int) -> Tensor: | |
"""Split batch_gt_instances with batch size. | |
From [all_gt_bboxes, box_dim+2] to [batch_size, number_gt, box_dim+1]. | |
For horizontal box, box_dim=4, for rotated box, box_dim=5 | |
If some shape of single batch smaller than | |
gt bbox len, then using zeros to fill. | |
Args: | |
batch_gt_instances (Sequence[Tensor]): Ground truth | |
instances for whole batch, shape [all_gt_bboxes, box_dim+2] | |
batch_size (int): Batch size. | |
Returns: | |
Tensor: batch gt instances data, shape | |
[batch_size, number_gt, box_dim+1] | |
""" | |
if isinstance(batch_gt_instances, Sequence): | |
max_gt_bbox_len = max( | |
[len(gt_instances) for gt_instances in batch_gt_instances]) | |
# fill zeros with length box_dim+1 if some shape of | |
# single batch not equal max_gt_bbox_len | |
batch_instance_list = [] | |
for index, gt_instance in enumerate(batch_gt_instances): | |
bboxes = gt_instance.bboxes | |
labels = gt_instance.labels | |
box_dim = get_box_tensor(bboxes).size(-1) | |
batch_instance_list.append( | |
torch.cat((labels[:, None], bboxes), dim=-1)) | |
if bboxes.shape[0] >= max_gt_bbox_len: | |
continue | |
fill_tensor = bboxes.new_full( | |
[max_gt_bbox_len - bboxes.shape[0], box_dim + 1], 0) | |
batch_instance_list[index] = torch.cat( | |
(batch_instance_list[index], fill_tensor), dim=0) | |
return torch.stack(batch_instance_list) | |
else: | |
# faster version | |
# format of batch_gt_instances: [img_ind, cls_ind, (box)] | |
# For example horizontal box should be: | |
# [img_ind, cls_ind, x1, y1, x2, y2] | |
# Rotated box should be | |
# [img_ind, cls_ind, x, y, w, h, a] | |
# sqlit batch gt instance [all_gt_bboxes, box_dim+2] -> | |
# [batch_size, max_gt_bbox_len, box_dim+1] | |
assert isinstance(batch_gt_instances, Tensor) | |
box_dim = batch_gt_instances.size(-1) - 2 | |
if len(batch_gt_instances) > 0: | |
gt_images_indexes = batch_gt_instances[:, 0] | |
max_gt_bbox_len = gt_images_indexes.unique( | |
return_counts=True)[1].max() | |
# fill zeros with length box_dim+1 if some shape of | |
# single batch not equal max_gt_bbox_len | |
batch_instance = torch.zeros( | |
(batch_size, max_gt_bbox_len, box_dim + 1), | |
dtype=batch_gt_instances.dtype, | |
device=batch_gt_instances.device) | |
for i in range(batch_size): | |
match_indexes = gt_images_indexes == i | |
gt_num = match_indexes.sum() | |
if gt_num: | |
batch_instance[i, :gt_num] = batch_gt_instances[ | |
match_indexes, 1:] | |
else: | |
batch_instance = torch.zeros((batch_size, 0, box_dim + 1), | |
dtype=batch_gt_instances.dtype, | |
device=batch_gt_instances.device) | |
return batch_instance | |