Spaces:
Sleeping
Sleeping
import torch | |
class Sampler(object): | |
"""Base class for all Samplers. | |
Every Sampler subclass has to provide an __iter__ method, providing a way | |
to iterate over indices of dataset elements, and a __len__ method that | |
returns the length of the returned iterators. | |
""" | |
def __init__(self, data_source): | |
pass | |
def __iter__(self): | |
raise NotImplementedError | |
def __len__(self): | |
raise NotImplementedError | |
class SequentialSampler(Sampler): | |
"""Samples elements sequentially, always in the same order. | |
Arguments: | |
data_source (Dataset): dataset to sample from | |
""" | |
def __init__(self, data_source): | |
self.data_source = data_source | |
def __iter__(self): | |
return iter(range(len(self.data_source))) | |
def __len__(self): | |
return len(self.data_source) | |
class RandomSampler(Sampler): | |
"""Samples elements randomly, without replacement. | |
Arguments: | |
data_source (Dataset): dataset to sample from | |
""" | |
def __init__(self, data_source): | |
self.data_source = data_source | |
def __iter__(self): | |
return iter(torch.randperm(len(self.data_source)).long()) | |
def __len__(self): | |
return len(self.data_source) | |
class SubsetRandomSampler(Sampler): | |
"""Samples elements randomly from a given list of indices, without replacement. | |
Arguments: | |
indices (list): a list of indices | |
""" | |
def __init__(self, indices): | |
self.indices = indices | |
def __iter__(self): | |
return (self.indices[i] for i in torch.randperm(len(self.indices))) | |
def __len__(self): | |
return len(self.indices) | |
class WeightedRandomSampler(Sampler): | |
"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights). | |
Arguments: | |
weights (list) : a list of weights, not necessary summing up to one | |
num_samples (int): number of samples to draw | |
replacement (bool): if ``True``, samples are drawn with replacement. | |
If not, they are drawn without replacement, which means that when a | |
sample index is drawn for a row, it cannot be drawn again for that row. | |
""" | |
def __init__(self, weights, num_samples, replacement=True): | |
self.weights = torch.DoubleTensor(weights) | |
self.num_samples = num_samples | |
self.replacement = replacement | |
def __iter__(self): | |
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) | |
def __len__(self): | |
return self.num_samples | |
class BatchSampler(object): | |
"""Wraps another sampler to yield a mini-batch of indices. | |
Args: | |
sampler (Sampler): Base sampler. | |
batch_size (int): Size of mini-batch. | |
drop_last (bool): If ``True``, the sampler will drop the last batch if | |
its size would be less than ``batch_size`` | |
Example: | |
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] | |
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | |
""" | |
def __init__(self, sampler, batch_size, drop_last): | |
self.sampler = sampler | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
def __iter__(self): | |
batch = [] | |
for idx in self.sampler: | |
batch.append(idx) | |
if len(batch) == self.batch_size: | |
yield batch | |
batch = [] | |
if len(batch) > 0 and not self.drop_last: | |
yield batch | |
def __len__(self): | |
if self.drop_last: | |
return len(self.sampler) // self.batch_size | |
else: | |
return (len(self.sampler) + self.batch_size - 1) // self.batch_size | |