bytetrack / yolox /data /datasets /datasets_wrapper.py
AK391
all files
7734d5b
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
from torch.utils.data.dataset import Dataset as torchDataset
import bisect
from functools import wraps
class ConcatDataset(torchConcatDataset):
def __init__(self, datasets):
super(ConcatDataset, self).__init__(datasets)
if hasattr(self.datasets[0], "input_dim"):
self._input_dim = self.datasets[0].input_dim
self.input_dim = self.datasets[0].input_dim
def pull_item(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].pull_item(sample_idx)
class MixConcatDataset(torchConcatDataset):
def __init__(self, datasets):
super(MixConcatDataset, self).__init__(datasets)
if hasattr(self.datasets[0], "input_dim"):
self._input_dim = self.datasets[0].input_dim
self.input_dim = self.datasets[0].input_dim
def __getitem__(self, index):
if not isinstance(index, int):
idx = index[1]
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
if not isinstance(index, int):
index = (index[0], sample_idx, index[2])
return self.datasets[dataset_idx][index]
class Dataset(torchDataset):
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`,
that enables on the fly resizing of the ``input_dim``.
Args:
input_dimension (tuple): (width,height) tuple with default dimensions of the network
"""
def __init__(self, input_dimension, mosaic=True):
super().__init__()
self.__input_dim = input_dimension[:2]
self.enable_mosaic = mosaic
@property
def input_dim(self):
"""
Dimension that can be used by transforms to set the correct image size, etc.
This allows transforms to have a single source of truth
for the input dimension of the network.
Return:
list: Tuple containing the current width,height
"""
if hasattr(self, "_input_dim"):
return self._input_dim
return self.__input_dim
@staticmethod
def resize_getitem(getitem_fn):
"""
Decorator method that needs to be used around the ``__getitem__`` method. |br|
This decorator enables the on the fly resizing of
the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class.
Example:
>>> class CustomSet(ln.data.Dataset):
... def __len__(self):
... return 10
... @ln.data.Dataset.resize_getitem
... def __getitem__(self, index):
... # Should return (image, anno) but here we return input_dim
... return self.input_dim
>>> data = CustomSet((200,200))
>>> data[0]
(200, 200)
>>> data[(480,320), 0]
(480, 320)
"""
@wraps(getitem_fn)
def wrapper(self, index):
if not isinstance(index, int):
has_dim = True
self._input_dim = index[0]
self.enable_mosaic = index[2]
index = index[1]
else:
has_dim = False
ret_val = getitem_fn(self, index)
if has_dim:
del self._input_dim
return ret_val
return wrapper