Spaces:
Sleeping
Sleeping
from typing import Sequence | |
import torch | |
import torch.distributed as torch_dist | |
from mmengine.dist import get_dist_info, get_default_group, get_comm_device | |
from torch._C._distributed_c10d import ReduceOp | |
from torch.utils.data import Sampler, BatchSampler | |
from mmdet.datasets.samplers.batch_sampler import AspectRatioBatchSampler | |
from mmdet.registry import DATA_SAMPLERS | |
class VideoSegAspectRatioBatchSampler(AspectRatioBatchSampler): | |
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or. | |
>= 1) into a same batch. | |
Args: | |
sampler (Sampler): Base sampler. | |
batch_size (int): Size of mini-batch. | |
drop_last (bool): If ``True``, the sampler will drop the last batch if | |
its size would be less than ``batch_size``. | |
""" | |
def __iter__(self) -> Sequence[int]: | |
for idx in self.sampler: | |
# hard code to solve TrackImgSampler | |
video_idx = idx | |
# video_idx | |
data_info = self.sampler.dataset.get_data_info(video_idx) | |
# data_info {video_id, images, video_length} | |
if 'images' in data_info: | |
img_data_info = data_info['images'][0] | |
else: | |
img_data_info = data_info | |
width, height = img_data_info['width'], img_data_info['height'] | |
bucket_id = 0 if width < height else 1 | |
bucket = self._aspect_ratio_buckets[bucket_id] | |
bucket.append(idx) | |
# yield a batch of indices in the same aspect ratio group | |
if len(bucket) == self.batch_size: | |
yield bucket[:] | |
del bucket[:] | |
# yield the rest data and reset the bucket | |
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ | |
1] | |
self._aspect_ratio_buckets = [[] for _ in range(2)] | |
while len(left_data) > 0: | |
if len(left_data) <= self.batch_size: | |
if not self.drop_last: | |
yield left_data[:] | |
left_data = [] | |
else: | |
yield left_data[:self.batch_size] | |
left_data = left_data[self.batch_size:] | |
class MultiDataAspectRatioBatchSampler(BatchSampler): | |
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or. | |
>= 1) into a same batch for multi-source datasets. | |
Args: | |
sampler (Sampler): Base sampler. | |
batch_size (Sequence(int)): Size of mini-batch for multi-source | |
datasets. | |
num_datasets(int): Number of multi-source datasets. | |
drop_last (bool): If ``True``, the sampler will drop the last batch if | |
its size would be less than ``batch_size``. | |
""" | |
def __init__(self, | |
sampler: Sampler, | |
batch_size: Sequence[int], | |
num_datasets: int, | |
drop_last: bool = True) -> None: | |
if not isinstance(sampler, Sampler): | |
raise TypeError('sampler should be an instance of ``Sampler``, ' | |
f'but got {sampler}') | |
self.sampler = sampler | |
if isinstance(batch_size, int): | |
self.batch_size = [batch_size] * num_datasets | |
else: | |
self.batch_size = batch_size | |
self.num_datasets = num_datasets | |
self.drop_last = drop_last | |
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets | |
self._buckets = [[] for _ in range(2 * self.num_datasets)] | |
def __iter__(self) -> Sequence[int]: | |
num_batch = torch.tensor(len(self), device='cpu') | |
rank, world_size = get_dist_info() | |
if world_size > 1: | |
group = get_default_group() | |
backend_device = get_comm_device(group) | |
num_batch = num_batch.to(device=backend_device) | |
torch_dist.all_reduce(num_batch, op=ReduceOp.MIN, group=group) | |
num_batch = num_batch.to('cpu').item() | |
for idx in self.sampler: | |
data_info = self.sampler.dataset.get_data_info(idx) | |
width, height = data_info.get('width', 0), data_info.get('height', 0) | |
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx) | |
aspect_ratio_bucket_id = 0 if width < height else 1 | |
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id | |
bucket = self._buckets[bucket_id] | |
bucket.append(idx) | |
# yield a batch of indices in the same aspect ratio group | |
if len(bucket) == self.batch_size[dataset_source_idx]: | |
yield bucket[:] | |
num_batch -= 1 | |
if num_batch <= 0: | |
return | |
del bucket[:] | |
# yield the rest data and reset the bucket | |
for i in range(self.num_datasets): | |
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1] | |
while len(left_data) > 0: | |
if len(left_data) < self.batch_size[i]: | |
if not self.drop_last: | |
yield left_data[:] | |
num_batch -= 1 | |
if num_batch <= 0: | |
return | |
left_data = [] | |
else: | |
yield left_data[:self.batch_size[i]] | |
num_batch -= 1 | |
if num_batch <= 0: | |
return | |
left_data = left_data[self.batch_size[i]:] | |
self._buckets = [[] for _ in range(2 * self.num_datasets)] | |
def __len__(self) -> int: | |
sizes = [0 for _ in range(self.num_datasets)] | |
for idx in self.sampler: | |
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx) | |
sizes[dataset_source_idx] += 1 | |
if self.drop_last: | |
lens = 0 | |
for i in range(self.num_datasets): | |
lens += sizes[i] // self.batch_size[i] | |
return lens | |
else: | |
lens = 0 | |
for i in range(self.num_datasets): | |
lens += (sizes[i] + self.batch_size[i] - 1) // self.batch_size[i] | |
return lens | |