JustinLin610
update
10b0761
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
from . import BaseWrapperDataset
logger = logging.getLogger(__name__)
class SubsampleDataset(BaseWrapperDataset):
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
Args:
dataset (~torch.utils.data.Dataset): dataset to subsample
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
"""
def __init__(self, dataset, size_ratio, shuffle=False):
super().__init__(dataset)
assert size_ratio < 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
self.indices = np.random.choice(
list(range(len(self.dataset))), self.actual_size, replace=False
)
self.shuffle = shuffle
logger.info(
"subsampled dataset from {} to {} (ratio={})".format(
len(self.dataset), self.actual_size, size_ratio
)
)
def __getitem__(self, index):
return self.dataset[self.indices[index]]
def __len__(self):
return self.actual_size
def collater(self, samples):
return self.dataset.collater(samples)
@property
def sizes(self):
return self.dataset.sizes[self.indices]
@property
def name(self):
return self.dataset.name
def num_tokens(self, index):
return self.dataset.num_tokens(self.indices[index])
def size(self, index):
return self.dataset.size(self.indices[index])
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
def prefetch(self, indices):
self.dataset.prefetch(self.indices[indices])