| """ |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| import importlib.metadata |
| from torch import Tensor |
|
|
| if '0.15.2' in importlib.metadata.version('torchvision'): |
| import torchvision |
| torchvision.disable_beta_transforms_warning() |
|
|
| from torchvision.datapoints import BoundingBox as BoundingBoxes |
| from torchvision.datapoints import BoundingBoxFormat, Mask, Image, Video |
| from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes |
| _boxes_keys = ['format', 'spatial_size'] |
|
|
| elif '0.17' > importlib.metadata.version('torchvision') >= '0.16': |
| import torchvision |
| torchvision.disable_beta_transforms_warning() |
|
|
| from torchvision.transforms.v2 import SanitizeBoundingBoxes |
| from torchvision.tv_tensors import ( |
| BoundingBoxes, BoundingBoxFormat, Mask, Image, Video) |
| _boxes_keys = ['format', 'canvas_size'] |
|
|
| elif importlib.metadata.version('torchvision') >= '0.17': |
| import torchvision |
| from torchvision.transforms.v2 import SanitizeBoundingBoxes |
| from torchvision.tv_tensors import ( |
| BoundingBoxes, BoundingBoxFormat, Mask, Image, Video) |
| _boxes_keys = ['format', 'canvas_size'] |
|
|
| else: |
| raise RuntimeError('Please make sure torchvision version >= 0.15.2') |
|
|
|
|
|
|
| def convert_to_tv_tensor(tensor: Tensor, key: str, box_format='xyxy', spatial_size=None) -> Tensor: |
| """ |
| Args: |
| tensor (Tensor): input tensor |
| key (str): transform to key |
| |
| Return: |
| Dict[str, TV_Tensor] |
| """ |
| assert key in ('boxes', 'masks', ), "Only support 'boxes' and 'masks'" |
|
|
| if key == 'boxes': |
| box_format = getattr(BoundingBoxFormat, box_format.upper()) |
| _kwargs = dict(zip(_boxes_keys, [box_format, spatial_size])) |
| return BoundingBoxes(tensor, **_kwargs) |
|
|
| if key == 'masks': |
| return Mask(tensor) |
|
|