OFA-Image_Caption / fairseq /fairseq /data /resampling_dataset.py
JustinLin610
update
8437114
raw history blame
No virus
4.32 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
from fairseq.data import BaseWrapperDataset, plasma_utils
logger = logging.getLogger(__name__)
class ResamplingDataset(BaseWrapperDataset):
"""Randomly samples from a given dataset at each epoch.
Sampling is done with or without replacement, depending on the "replace"
parameter.
Optionally, the epoch size can be rescaled. This is potentially desirable
to increase per-epoch coverage of the base dataset (since sampling with
replacement means that many items in the dataset will be left out). In the
case of sampling without replacement, size_ratio should be strictly less
than 1.
Args:
dataset (~torch.utils.data.Dataset): dataset on which to sample.
weights (List[float]): list of probability weights
(default: None, which corresponds to uniform sampling).
replace (bool): sampling mode; True for "with replacement", or False
for "without replacement" (default: True)
size_ratio (float): the ratio to subsample to; must be positive
(default: 1.0).
batch_by_size (bool): whether or not to batch by sequence length
(default: True).
seed (int): RNG seed to use (default: 0).
epoch (int): starting epoch number (default: 1).
"""
def __init__(
self,
dataset,
weights=None,
replace=True,
size_ratio=1.0,
batch_by_size=True,
seed=0,
epoch=1,
):
super().__init__(dataset)
if weights is None:
self.weights = None
else:
assert len(weights) == len(dataset)
weights_arr = np.array(weights, dtype=np.float64)
weights_arr /= weights_arr.sum()
self.weights = plasma_utils.PlasmaArray(weights_arr)
self.replace = replace
assert size_ratio > 0.0
if not self.replace:
assert size_ratio < 1.0
self.size_ratio = float(size_ratio)
self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
self.batch_by_size = batch_by_size
self.seed = seed
self._cur_epoch = None
self._cur_indices = None
self.set_epoch(epoch)
def __getitem__(self, index):
return self.dataset[self._cur_indices.array[index]]
def __len__(self):
return self.actual_size
@property
def sizes(self):
if isinstance(self.dataset.sizes, list):
return [s[self._cur_indices.array] for s in self.dataset.sizes]
return self.dataset.sizes[self._cur_indices.array]
def num_tokens(self, index):
return self.dataset.num_tokens(self._cur_indices.array[index])
def size(self, index):
return self.dataset.size(self._cur_indices.array[index])
def ordered_indices(self):
if self.batch_by_size:
order = [
np.arange(len(self)),
self.sizes,
] # No need to handle `self.shuffle == True`
return np.lexsort(order)
else:
return np.arange(len(self))
def prefetch(self, indices):
self.dataset.prefetch(self._cur_indices.array[indices])
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch):
logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
super().set_epoch(epoch)
if epoch == self._cur_epoch:
return
self._cur_epoch = epoch
# Generate a weighted sample of indices as a function of the
# random seed and the current epoch.
rng = np.random.RandomState(
[
42, # magic number
self.seed % (2 ** 32), # global seed
self._cur_epoch, # epoch index
]
)
self._cur_indices = plasma_utils.PlasmaArray(
rng.choice(
len(self.dataset),
self.actual_size,
replace=self.replace,
p=(None if self.weights is None else self.weights.array),
)
)