Spaces:
Sleeping
Sleeping
import math | |
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler | |
def identity(x): | |
return x | |
class SortedSampler(Sampler): | |
""" | |
Samples elements sequentially, always in the same order. | |
Args: | |
data (`obj`: `Iterable`): | |
Iterable data. | |
sort_key (`obj`: `Callable`): | |
Specifies a function of one argument that is used to | |
extract a numerical comparison key from each list element. | |
Example: | |
>>> list(SortedSampler(range(10), sort_key=lambda i: -i)) | |
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0] | |
""" | |
def __init__(self, data, sort_key=identity): | |
super().__init__(data) | |
self.data = data | |
self.sort_key = sort_key | |
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] | |
zip_ = sorted(zip_, key=lambda r: r[1]) | |
self.sorted_indexes = [item[0] for item in zip_] | |
def __iter__(self): | |
return iter(self.sorted_indexes) | |
def __len__(self): | |
return len(self.data) | |
class BucketBatchSampler(BatchSampler): | |
""" | |
`BucketBatchSampler` toggles between `sampler` batches and sorted batches. | |
Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between | |
random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice | |
versa. | |
Background: | |
``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like | |
``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar | |
size length to reduce the padding required for each batch while maintaining some noise | |
through bucketing. | |
**AllenNLP Implementation:** | |
https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py | |
**torchtext Implementation:** | |
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 | |
Args: | |
sampler (`obj`: `torch.data.utils.sampler.Sampler): | |
batch_size (`int`): | |
Size of mini-batch. | |
drop_last (`bool`, optional, defaults to `False`): | |
If `True` the sampler will drop the last batch if its size would be less than `batch_size`. | |
sort_key (`obj`: `Callable`, optional, defaults to `identity`): | |
Callable to specify a comparison key for sorting. | |
bucket_size_multiplier (`int`, optional, defaults to `100`): | |
Buckets are of size `batch_size * bucket_size_multiplier`. | |
Example: | |
>>> from torchnlp.random import set_seed | |
>>> set_seed(123) | |
>>> | |
>>> from torch.utils.data.sampler import SequentialSampler | |
>>> sampler = SequentialSampler(list(range(10))) | |
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) | |
[[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] | |
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | |
""" | |
def __init__( | |
self, | |
sampler, | |
batch_size, | |
drop_last: bool = False, | |
sort_key=identity, | |
bucket_size_multiplier=100, | |
): | |
super().__init__(sampler, batch_size, drop_last) | |
self.sort_key = sort_key | |
_bucket_size = batch_size * bucket_size_multiplier | |
if hasattr(sampler, "__len__"): | |
_bucket_size = min(_bucket_size, len(sampler)) | |
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) | |
def __iter__(self): | |
for bucket in self.bucket_sampler: | |
sorted_sampler = SortedSampler(bucket, self.sort_key) | |
for batch in SubsetRandomSampler( | |
list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) | |
): | |
yield [bucket[i] for i in batch] | |
def __len__(self): | |
if self.drop_last: | |
return len(self.sampler) // self.batch_size | |
else: | |
return math.ceil(len(self.sampler) / self.batch_size) | |