|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch.nn.functional as F |
|
from fairseq.data import BaseWrapperDataset |
|
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes |
|
|
|
|
|
class BucketPadLengthDataset(BaseWrapperDataset): |
|
""" |
|
Bucket and pad item lengths to the nearest bucket size. This can be used to |
|
reduce the number of unique batch shapes, which is important on TPUs since |
|
each new batch shape requires a recompilation. |
|
|
|
Args: |
|
dataset (FairseqDatset): dataset to bucket |
|
sizes (List[int]): all item sizes |
|
num_buckets (int): number of buckets to create |
|
pad_idx (int): padding symbol |
|
left_pad (bool): if True, pad on the left; otherwise right pad |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
sizes, |
|
num_buckets, |
|
pad_idx, |
|
left_pad, |
|
tensor_key=None, |
|
): |
|
super().__init__(dataset) |
|
self.pad_idx = pad_idx |
|
self.left_pad = left_pad |
|
|
|
assert num_buckets > 0 |
|
self.buckets = get_buckets(sizes, num_buckets) |
|
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) |
|
self._tensor_key = tensor_key |
|
|
|
def _set_tensor(self, item, val): |
|
if self._tensor_key is None: |
|
return val |
|
item[self._tensor_key] = val |
|
return item |
|
|
|
def _get_tensor(self, item): |
|
if self._tensor_key is None: |
|
return item |
|
return item[self._tensor_key] |
|
|
|
def _pad(self, tensor, bucket_size, dim=-1): |
|
num_pad = bucket_size - tensor.size(dim) |
|
return F.pad( |
|
tensor, |
|
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), |
|
value=self.pad_idx, |
|
) |
|
|
|
def __getitem__(self, index): |
|
item = self.dataset[index] |
|
bucket_size = self._bucketed_sizes[index] |
|
tensor = self._get_tensor(item) |
|
padded = self._pad(tensor, bucket_size) |
|
return self._set_tensor(item, padded) |
|
|
|
@property |
|
def sizes(self): |
|
return self._bucketed_sizes |
|
|
|
def num_tokens(self, index): |
|
return self._bucketed_sizes[index] |
|
|
|
def size(self, index): |
|
return self._bucketed_sizes[index] |
|
|