Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset | |
from torch.utils.data.dataset import Dataset as torchDataset | |
import bisect | |
from functools import wraps | |
class ConcatDataset(torchConcatDataset): | |
def __init__(self, datasets): | |
super(ConcatDataset, self).__init__(datasets) | |
if hasattr(self.datasets[0], "input_dim"): | |
self._input_dim = self.datasets[0].input_dim | |
self.input_dim = self.datasets[0].input_dim | |
def pull_item(self, idx): | |
if idx < 0: | |
if -idx > len(self): | |
raise ValueError( | |
"absolute value of index should not exceed dataset length" | |
) | |
idx = len(self) + idx | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
return self.datasets[dataset_idx].pull_item(sample_idx) | |
class MixConcatDataset(torchConcatDataset): | |
def __init__(self, datasets): | |
super(MixConcatDataset, self).__init__(datasets) | |
if hasattr(self.datasets[0], "input_dim"): | |
self._input_dim = self.datasets[0].input_dim | |
self.input_dim = self.datasets[0].input_dim | |
def __getitem__(self, index): | |
if not isinstance(index, int): | |
idx = index[1] | |
if idx < 0: | |
if -idx > len(self): | |
raise ValueError( | |
"absolute value of index should not exceed dataset length" | |
) | |
idx = len(self) + idx | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
if not isinstance(index, int): | |
index = (index[0], sample_idx, index[2]) | |
return self.datasets[dataset_idx][index] | |
class Dataset(torchDataset): | |
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`, | |
that enables on the fly resizing of the ``input_dim``. | |
Args: | |
input_dimension (tuple): (width,height) tuple with default dimensions of the network | |
""" | |
def __init__(self, input_dimension, mosaic=True): | |
super().__init__() | |
self.__input_dim = input_dimension[:2] | |
self.enable_mosaic = mosaic | |
def input_dim(self): | |
""" | |
Dimension that can be used by transforms to set the correct image size, etc. | |
This allows transforms to have a single source of truth | |
for the input dimension of the network. | |
Return: | |
list: Tuple containing the current width,height | |
""" | |
if hasattr(self, "_input_dim"): | |
return self._input_dim | |
return self.__input_dim | |
def resize_getitem(getitem_fn): | |
""" | |
Decorator method that needs to be used around the ``__getitem__`` method. |br| | |
This decorator enables the on the fly resizing of | |
the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class. | |
Example: | |
>>> class CustomSet(ln.data.Dataset): | |
... def __len__(self): | |
... return 10 | |
... @ln.data.Dataset.resize_getitem | |
... def __getitem__(self, index): | |
... # Should return (image, anno) but here we return input_dim | |
... return self.input_dim | |
>>> data = CustomSet((200,200)) | |
>>> data[0] | |
(200, 200) | |
>>> data[(480,320), 0] | |
(480, 320) | |
""" | |
def wrapper(self, index): | |
if not isinstance(index, int): | |
has_dim = True | |
self._input_dim = index[0] | |
self.enable_mosaic = index[2] | |
index = index[1] | |
else: | |
has_dim = False | |
ret_val = getitem_fn(self, index) | |
if has_dim: | |
del self._input_dim | |
return ret_val | |
return wrapper | |