File size: 4,030 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import math
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
def identity(x):
return x
class SortedSampler(Sampler):
"""
Samples elements sequentially, always in the same order.
Args:
data (`obj`: `Iterable`):
Iterable data.
sort_key (`obj`: `Callable`):
Specifies a function of one argument that is used to
extract a numerical comparison key from each list element.
Example:
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
"""
def __init__(self, data, sort_key=identity):
super().__init__(data)
self.data = data
self.sort_key = sort_key
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
zip_ = sorted(zip_, key=lambda r: r[1])
self.sorted_indexes = [item[0] for item in zip_]
def __iter__(self):
return iter(self.sorted_indexes)
def __len__(self):
return len(self.data)
class BucketBatchSampler(BatchSampler):
"""
`BucketBatchSampler` toggles between `sampler` batches and sorted batches.
Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between
random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice
versa.
Background:
``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like
``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar
size length to reduce the padding required for each batch while maintaining some noise
through bucketing.
**AllenNLP Implementation:**
https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py
**torchtext Implementation:**
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225
Args:
sampler (`obj`: `torch.data.utils.sampler.Sampler):
batch_size (`int`):
Size of mini-batch.
drop_last (`bool`, optional, defaults to `False`):
If `True` the sampler will drop the last batch if its size would be less than `batch_size`.
sort_key (`obj`: `Callable`, optional, defaults to `identity`):
Callable to specify a comparison key for sorting.
bucket_size_multiplier (`int`, optional, defaults to `100`):
Buckets are of size `batch_size * bucket_size_multiplier`.
Example:
>>> from torchnlp.random import set_seed
>>> set_seed(123)
>>>
>>> from torch.utils.data.sampler import SequentialSampler
>>> sampler = SequentialSampler(list(range(10)))
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False))
[[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]]
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(
self,
sampler,
batch_size,
drop_last: bool = False,
sort_key=identity,
bucket_size_multiplier=100,
):
super().__init__(sampler, batch_size, drop_last)
self.sort_key = sort_key
_bucket_size = batch_size * bucket_size_multiplier
if hasattr(sampler, "__len__"):
_bucket_size = min(_bucket_size, len(sampler))
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
def __iter__(self):
for bucket in self.bucket_sampler:
sorted_sampler = SortedSampler(bucket, self.sort_key)
for batch in SubsetRandomSampler(
list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))
):
yield [bucket[i] for i in batch]
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return math.ceil(len(self.sampler) / self.batch_size)
|