D-FINE / src /data /_misc.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import importlib.metadata
from torch import Tensor
if "0.15.2" in importlib.metadata.version("torchvision"):
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.datapoints import BoundingBox as BoundingBoxes
from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video
from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes
_boxes_keys = ["format", "spatial_size"]
elif "0.17" > importlib.metadata.version("torchvision") >= "0.16":
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.transforms.v2 import SanitizeBoundingBoxes
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
_boxes_keys = ["format", "canvas_size"]
elif importlib.metadata.version("torchvision") >= "0.17":
import torchvision
from torchvision.transforms.v2 import SanitizeBoundingBoxes
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
_boxes_keys = ["format", "canvas_size"]
else:
raise RuntimeError("Please make sure torchvision version >= 0.15.2")
def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor:
"""
Args:
tensor (Tensor): input tensor
key (str): transform to key
Return:
Dict[str, TV_Tensor]
"""
assert key in (
"boxes",
"masks",
), "Only support 'boxes' and 'masks'"
if key == "boxes":
box_format = getattr(BoundingBoxFormat, box_format.upper())
_kwargs = dict(zip(_boxes_keys, [box_format, spatial_size]))
return BoundingBoxes(tensor, **_kwargs)
if key == "masks":
return Mask(tensor)