|
|
|
|
|
|
|
|
|
|
|
import hashlib |
|
import logging |
|
import math |
|
|
|
import numpy as np |
|
from fairseq.data import SampledMultiDataset |
|
|
|
from .sampled_multi_dataset import CollateFormat, default_virtual_size_func |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SampledMultiEpochDataset(SampledMultiDataset): |
|
"""Samples from multiple sub-datasets according to sampling ratios |
|
using virtual epoch sizes to speed up dataloading. |
|
Args: |
|
datasets ( |
|
List[~torch.utils.data.Dataset] |
|
or OrderedDict[str, ~torch.utils.data.Dataset] |
|
): datasets |
|
sampling_ratios (List[float]): list of probability of each dataset to be sampled |
|
(default: None, which corresponds to concating all dataset together). |
|
seed (int): RNG seed to use (default: 2). |
|
epoch (int): starting epoch number (default: 1). |
|
eval_key (str, optional): a key used at evaluation time that causes |
|
this instance to pass-through batches from *datasets[eval_key]*. |
|
collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or |
|
CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures |
|
the collater to output batches of data mixed from all sub-datasets, |
|
and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys |
|
of sub-datasets. |
|
Note that not all sub-datasets will present in a single batch in both formats. |
|
virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). |
|
split (str): the split of the data, e.g. 'train', 'valid' or 'test'. |
|
virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by |
|
this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering |
|
can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded. |
|
shared_collater (bool): whether or not to all sub-datasets have the same collater. |
|
shard_epoch (int): the real epoch number for shard selection. |
|
shuffle (bool): whether or not to shuffle data (default: True). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
datasets, |
|
sampling_ratios=None, |
|
seed=2, |
|
epoch=1, |
|
eval_key=None, |
|
collate_format=CollateFormat.single, |
|
virtual_size=default_virtual_size_func, |
|
split="", |
|
virtual_epoch_size=None, |
|
shared_collater=False, |
|
shard_epoch=1, |
|
shuffle=True, |
|
): |
|
self.virtual_epoch_size = virtual_epoch_size |
|
self._current_epoch_start_index = None |
|
self._random_global_indices = None |
|
self.shard_epoch = shard_epoch if shard_epoch is not None else 1 |
|
self.load_next_shard = None |
|
self._epoch_sizes = None |
|
super().__init__( |
|
datasets=datasets, |
|
sampling_ratios=sampling_ratios, |
|
seed=seed, |
|
epoch=epoch, |
|
eval_key=eval_key, |
|
collate_format=collate_format, |
|
virtual_size=virtual_size, |
|
split=split, |
|
shared_collater=shared_collater, |
|
shuffle=shuffle, |
|
) |
|
|
|
def _setup(self, epoch): |
|
self.virtual_epoch_size = ( |
|
self.virtual_epoch_size |
|
if self.virtual_epoch_size is not None |
|
else self.virtual_size |
|
) |
|
if self.virtual_epoch_size > self.virtual_size: |
|
logger.warning( |
|
f"virtual epoch size {self.virtual_epoch_size} " |
|
f"is greater than virtual dataset size {self.virtual_size}" |
|
) |
|
self.virtual_epoch_size = self.virtual_size |
|
self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size) |
|
self._current_epoch_start_index = self._get_epoch_start_index(epoch) |
|
logger.info( |
|
f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}" |
|
) |
|
|
|
def _map_epoch_index_to_global(self, index): |
|
index = self._current_epoch_start_index + index |
|
|
|
return self._random_global_indices[index] |
|
|
|
@property |
|
def sizes(self): |
|
if self._epoch_sizes is not None: |
|
return self._epoch_sizes |
|
_sizes = super().sizes |
|
indices = self._random_global_indices[ |
|
self._current_epoch_start_index : self._current_epoch_start_index |
|
+ len(self) |
|
] |
|
self._epoch_sizes = _sizes[indices] |
|
|
|
del self._sizes |
|
self._sizes = None |
|
return self._epoch_sizes |
|
|
|
def _get_dataset_and_index(self, index): |
|
i = self._map_epoch_index_to_global(index) |
|
return super()._get_dataset_and_index(i) |
|
|
|
def __len__(self): |
|
return ( |
|
self.virtual_epoch_size |
|
if self._current_epoch_start_index + self.virtual_epoch_size |
|
< self.virtual_size |
|
else self.virtual_size - self._current_epoch_start_index |
|
) |
|
|
|
def set_epoch(self, epoch): |
|
if self._current_epoch_start_index is None: |
|
|
|
self._setup(epoch) |
|
self._next_virtual_epoch(epoch) |
|
else: |
|
|
|
if epoch == self._cur_epoch: |
|
|
|
return |
|
self._next_virtual_epoch(epoch) |
|
|
|
def _get_epoch_start_index(self, epoch): |
|
assert epoch >= 1 |
|
return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size |
|
|
|
def _next_global_indices(self, epoch): |
|
rng = np.random.RandomState( |
|
[ |
|
int( |
|
hashlib.sha1( |
|
str(self.__class__.__name__).encode("utf-8") |
|
).hexdigest(), |
|
16, |
|
) |
|
% (2 ** 32), |
|
self.seed % (2 ** 32), |
|
epoch, |
|
] |
|
) |
|
del self._random_global_indices |
|
self._random_global_indices = rng.choice( |
|
self.virtual_size, self.virtual_size, replace=False |
|
) |
|
if self.load_next_shard is None: |
|
self.load_next_shard = False |
|
else: |
|
|
|
self.shard_epoch += 1 |
|
self.load_next_shard = True |
|
logger.info( |
|
"to load next epoch/shard in next load_dataset: " |
|
f"epoch={epoch}/shard_epoch={self.shard_epoch}" |
|
) |
|
|
|
def _next_virtual_epoch(self, epoch): |
|
index = self._get_epoch_start_index(epoch) |
|
if index == 0 or self._random_global_indices is None: |
|
|
|
|
|
logger.info( |
|
"establishing a new set of global virtual indices for " |
|
f"epoch={epoch}/shard_epoch={self.shard_epoch}" |
|
) |
|
super().set_epoch(epoch) |
|
self._next_global_indices(epoch) |
|
else: |
|
self._cur_epoch = epoch |
|
|
|
|
|
self._clean_if_not_none( |
|
[ |
|
self._epoch_sizes, |
|
] |
|
) |
|
self._epoch_sizes = None |
|
self._current_epoch_start_index = index |
|
|