Pinwheel's picture
HF Demo
128757a
raw
history blame
789 Bytes
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import bisect
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra
method for querying the sizes of the image
"""
def get_idxs(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
def get_img_info(self, idx):
dataset_idx, sample_idx = self.get_idxs(idx)
return self.datasets[dataset_idx].get_img_info(sample_idx)