Spaces:
Runtime error
Runtime error
File size: 789 Bytes
128757a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import bisect
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra
method for querying the sizes of the image
"""
def get_idxs(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
def get_img_info(self, idx):
dataset_idx, sample_idx = self.get_idxs(idx)
return self.datasets[dataset_idx].get_img_info(sample_idx)
|