|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import numpy as np |
|
import torch.utils.data |
|
from fairseq.data import data_utils |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EpochListening: |
|
"""Mixin for receiving updates whenever the epoch increments.""" |
|
|
|
@property |
|
def can_reuse_epoch_itr_across_epochs(self): |
|
""" |
|
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for |
|
this dataset across epochs. |
|
|
|
This needs to return ``False`` if the sample sizes can change across |
|
epochs, in which case we may need to regenerate batches at each epoch. |
|
If your dataset relies in ``set_epoch`` then you should consider setting |
|
this to ``False``. |
|
""" |
|
return True |
|
|
|
def set_epoch(self, epoch): |
|
"""Will receive the updated epoch number at the beginning of the epoch.""" |
|
pass |
|
|
|
|
|
class FairseqDataset(torch.utils.data.Dataset, EpochListening): |
|
"""A dataset that provides helpers for batching.""" |
|
|
|
def __getitem__(self, index): |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
raise NotImplementedError |
|
|
|
def collater(self, samples): |
|
"""Merge a list of samples to form a mini-batch. |
|
|
|
Args: |
|
samples (List[dict]): samples to collate |
|
|
|
Returns: |
|
dict: a mini-batch suitable for forwarding with a Model |
|
""" |
|
raise NotImplementedError |
|
|
|
def num_tokens(self, index): |
|
"""Return the number of tokens in a sample. This value is used to |
|
enforce ``--max-tokens`` during batching.""" |
|
raise NotImplementedError |
|
|
|
def num_tokens_vec(self, indices): |
|
"""Return the number of tokens for a set of positions defined by indices. |
|
This value is used to enforce ``--max-tokens`` during batching.""" |
|
raise NotImplementedError |
|
|
|
def size(self, index): |
|
"""Return an example's size as a float or tuple. This value is used when |
|
filtering a dataset with ``--max-positions``.""" |
|
raise NotImplementedError |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
return np.arange(len(self), dtype=np.int64) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
"""Whether this dataset supports prefetching.""" |
|
return False |
|
|
|
def attr(self, attr: str, index: int): |
|
return getattr(self, attr, None) |
|
|
|
def prefetch(self, indices): |
|
"""Prefetch the data required for this epoch.""" |
|
raise NotImplementedError |
|
|
|
def get_batch_shapes(self): |
|
""" |
|
Return a list of valid batch shapes, for example:: |
|
|
|
[(8, 512), (16, 256), (32, 128)] |
|
|
|
The first dimension of each tuple is the batch size and can be ``None`` |
|
to automatically infer the max batch size based on ``--max-tokens``. |
|
The second dimension of each tuple is the max supported length as given |
|
by :func:`fairseq.data.FairseqDataset.num_tokens`. |
|
|
|
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` |
|
to restrict batch shapes. This is useful on TPUs to avoid too many |
|
dynamic shapes (and recompilations). |
|
""" |
|
return None |
|
|
|
def batch_by_size( |
|
self, |
|
indices, |
|
max_tokens=None, |
|
max_sentences=None, |
|
required_batch_size_multiple=1, |
|
): |
|
""" |
|
Given an ordered set of indices, return batches according to |
|
*max_tokens*, *max_sentences* and *required_batch_size_multiple*. |
|
""" |
|
from fairseq.data import data_utils |
|
|
|
fixed_shapes = self.get_batch_shapes() |
|
if fixed_shapes is not None: |
|
|
|
def adjust_bsz(bsz, num_tokens): |
|
if bsz is None: |
|
assert max_tokens is not None, "Must specify --max-tokens" |
|
bsz = max_tokens // num_tokens |
|
if max_sentences is not None: |
|
bsz = min(bsz, max_sentences) |
|
elif ( |
|
bsz >= required_batch_size_multiple |
|
and bsz % required_batch_size_multiple != 0 |
|
): |
|
bsz -= bsz % required_batch_size_multiple |
|
return bsz |
|
|
|
fixed_shapes = np.array( |
|
[ |
|
[adjust_bsz(bsz, num_tokens), num_tokens] |
|
for (bsz, num_tokens) in fixed_shapes |
|
] |
|
) |
|
|
|
try: |
|
num_tokens_vec = self.num_tokens_vec(indices).astype("int64") |
|
except NotImplementedError: |
|
num_tokens_vec = None |
|
|
|
return data_utils.batch_by_size( |
|
indices, |
|
num_tokens_fn=self.num_tokens, |
|
num_tokens_vec=num_tokens_vec, |
|
max_tokens=max_tokens, |
|
max_sentences=max_sentences, |
|
required_batch_size_multiple=required_batch_size_multiple, |
|
fixed_shapes=fixed_shapes, |
|
) |
|
|
|
def filter_indices_by_size(self, indices, max_sizes): |
|
""" |
|
Filter a list of sample indices. Remove those that are longer than |
|
specified in *max_sizes*. |
|
|
|
WARNING: don't update, override method in child classes |
|
|
|
Args: |
|
indices (np.array): original array of sample indices |
|
max_sizes (int or list[int] or tuple[int]): max sample size, |
|
can be defined separately for src and tgt (then list or tuple) |
|
|
|
Returns: |
|
np.array: filtered sample array |
|
list: list of removed indices |
|
""" |
|
if isinstance(max_sizes, float) or isinstance(max_sizes, int): |
|
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): |
|
ignored = indices[self.sizes[indices] > max_sizes].tolist() |
|
indices = indices[self.sizes[indices] <= max_sizes] |
|
elif ( |
|
hasattr(self, "sizes") |
|
and isinstance(self.sizes, list) |
|
and len(self.sizes) == 1 |
|
): |
|
ignored = indices[self.sizes[0][indices] > max_sizes].tolist() |
|
indices = indices[self.sizes[0][indices] <= max_sizes] |
|
else: |
|
indices, ignored = data_utils._filter_by_size_dynamic( |
|
indices, self.size, max_sizes |
|
) |
|
else: |
|
indices, ignored = data_utils._filter_by_size_dynamic( |
|
indices, self.size, max_sizes |
|
) |
|
return indices, ignored |
|
|
|
@property |
|
def supports_fetch_outside_dataloader(self): |
|
"""Whether this dataset supports fetching outside the workers of the dataloader.""" |
|
return True |
|
|
|
|
|
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): |
|
""" |
|
For datasets that need to be read sequentially, usually because the data is |
|
being streamed or otherwise can't be manipulated on a single machine. |
|
""" |
|
|
|
def __iter__(self): |
|
raise NotImplementedError |
|
|