Spaces:
Runtime error
Runtime error
"""Original sampling logic of MQTTS. | |
Copyright PolyAI Limited. | |
""" | |
import math | |
import random | |
import numpy as np | |
from torch.utils import data | |
def StandardSampler(dataset, shuffle, distributed=False, | |
world_size=None, rank=None): | |
if distributed: | |
return data.distributed.DistributedSampler( | |
dataset, shuffle=shuffle, num_replicas=world_size, rank=rank) | |
if shuffle: | |
return data.RandomSampler(dataset) | |
return data.SequentialSampler(dataset) | |
def RandomBucketSampler( | |
nbuckets, length, batch_size, drop_last, distributed=False, | |
world_size=None, rank=None): | |
if distributed: | |
return DistributedRandomBucketSampler( | |
nbuckets, length, batch_size, drop_last, world_size, rank) | |
return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last) | |
class SingleRandomBucketSampler(data.Sampler): | |
def __init__(self, nbuckets, length, batch_size, drop_last): | |
self.length = length | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
indices = np.argsort([-x for x in length]) | |
split = len(indices) // nbuckets | |
self.indices = [] | |
for i in range(nbuckets): | |
self.indices.append(indices[i*split:(i+1)*split]) | |
if nbuckets * split < len(length): | |
self.indices.append(indices[nbuckets*split:]) | |
def __iter__(self): | |
random.shuffle(self.indices) | |
for x in self.indices: | |
random.shuffle(x) | |
idxs = [i for x in self.indices for i in x] | |
batches, batch, sum_len, max_len = [], [], 0, 0 | |
for idx in idxs: | |
batch.append(idx) | |
sum_len += self.length[idx] | |
max_len = max(self.length[idx], max_len) | |
if max_len * len(batch) > self.batch_size: | |
batches.append(batch[:-1]) | |
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa | |
if len(batch) > 0 and not self.drop_last: | |
batches.append(batch) | |
random.shuffle(batches) | |
return iter(batches) | |
class DistributedRandomBucketSampler(data.Sampler): | |
def __init__(self, nbuckets, length, batch_size, | |
drop_last, num_replicas, rank, seed=1234): | |
if rank >= num_replicas or rank < 0: | |
raise ValueError( | |
"Invalid rank {}, rank should be in the interval" | |
" [0, {}]".format(rank, num_replicas - 1)) | |
indices = np.argsort(length) | |
split = len(indices) // nbuckets | |
self.length = length | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
self.indices = [] | |
for i in range(nbuckets): | |
self.indices.append(indices[i*split:(i+1)*split]) | |
if nbuckets * split < len(length): | |
self.indices.append(indices[nbuckets*split:]) | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.epoch = 0 | |
self.seed = seed | |
def __iter__(self): | |
# Deterministic shuffling | |
random.Random(self.epoch + self.seed).shuffle(self.indices) | |
for i, x in enumerate(self.indices): | |
seed = self.epoch + self.seed + i * 5 | |
random.Random(seed).shuffle(x) | |
indices = [i for x in self.indices for i in x] | |
# Batching | |
batches, batch, sum_len, max_len = [], [], 0, 0 | |
for idx in indices: | |
batch.append(idx) | |
sum_len += self.length[idx] | |
max_len = max(self.length[idx], max_len) | |
if max_len * len(batch) > self.batch_size: | |
batches.append(batch[:-1]) | |
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa | |
# Subsample | |
num_samples = math.ceil( | |
(len(batches) - self.num_replicas) / self.num_replicas) | |
total_size = num_samples * self.num_replicas | |
batches = batches[:total_size] | |
batches = batches[self.rank*num_samples: (self.rank+1)*num_samples] | |
assert len(batches) == num_samples | |
# Stochastic suffling | |
random.shuffle(batches) | |
return iter(batches) | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |