|
import math |
|
|
|
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler |
|
|
|
|
|
def identity(x): |
|
return x |
|
|
|
|
|
class SortedSampler(Sampler): |
|
""" |
|
Samples elements sequentially, always in the same order. |
|
|
|
Args: |
|
data (`obj`: `Iterable`): |
|
Iterable data. |
|
sort_key (`obj`: `Callable`): |
|
Specifies a function of one argument that is used to |
|
extract a numerical comparison key from each list element. |
|
|
|
Example: |
|
>>> list(SortedSampler(range(10), sort_key=lambda i: -i)) |
|
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0] |
|
|
|
""" |
|
|
|
def __init__(self, data, sort_key=identity): |
|
super().__init__(data) |
|
self.data = data |
|
self.sort_key = sort_key |
|
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] |
|
zip_ = sorted(zip_, key=lambda r: r[1]) |
|
self.sorted_indexes = [item[0] for item in zip_] |
|
|
|
def __iter__(self): |
|
return iter(self.sorted_indexes) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
|
|
class BucketBatchSampler(BatchSampler): |
|
""" |
|
`BucketBatchSampler` toggles between `sampler` batches and sorted batches. |
|
Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between |
|
random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice |
|
versa. |
|
Background: |
|
``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like |
|
``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar |
|
size length to reduce the padding required for each batch while maintaining some noise |
|
through bucketing. |
|
**AllenNLP Implementation:** |
|
https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py |
|
**torchtext Implementation:** |
|
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 |
|
|
|
Args: |
|
sampler (`obj`: `torch.data.utils.sampler.Sampler): |
|
batch_size (`int`): |
|
Size of mini-batch. |
|
drop_last (`bool`, optional, defaults to `False`): |
|
If `True` the sampler will drop the last batch if its size would be less than `batch_size`. |
|
sort_key (`obj`: `Callable`, optional, defaults to `identity`): |
|
Callable to specify a comparison key for sorting. |
|
bucket_size_multiplier (`int`, optional, defaults to `100`): |
|
Buckets are of size `batch_size * bucket_size_multiplier`. |
|
Example: |
|
>>> from torchnlp.random import set_seed |
|
>>> set_seed(123) |
|
>>> |
|
>>> from torch.utils.data.sampler import SequentialSampler |
|
>>> sampler = SequentialSampler(list(range(10))) |
|
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) |
|
[[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] |
|
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
sampler, |
|
batch_size, |
|
drop_last: bool = False, |
|
sort_key=identity, |
|
bucket_size_multiplier=100, |
|
): |
|
super().__init__(sampler, batch_size, drop_last) |
|
self.sort_key = sort_key |
|
_bucket_size = batch_size * bucket_size_multiplier |
|
if hasattr(sampler, "__len__"): |
|
_bucket_size = min(_bucket_size, len(sampler)) |
|
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) |
|
|
|
def __iter__(self): |
|
for bucket in self.bucket_sampler: |
|
sorted_sampler = SortedSampler(bucket, self.sort_key) |
|
for batch in SubsetRandomSampler( |
|
list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) |
|
): |
|
yield [bucket[i] for i in batch] |
|
|
|
def __len__(self): |
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
else: |
|
return math.ceil(len(self.sampler) / self.batch_size) |
|
|