Spaces:
Runtime error
Runtime error
File size: 5,540 Bytes
d1b91e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from utils.commons.dataset_utils import BaseDataset, collate_1d, collate_2d
from utils.commons.hparams import hparams
from utils.commons.indexed_datasets import IndexedDataset
class EndlessDistributedSampler(DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.shuffle = shuffle
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = [i for _ in range(1000) for i in torch.randperm(
len(self.dataset), generator=g).tolist()]
else:
indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))]
indices = indices[:len(indices) // self.num_replicas * self.num_replicas]
indices = indices[self.rank::self.num_replicas]
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
class VocoderDataset(BaseDataset):
def __init__(self, prefix, shuffle=False):
super().__init__(shuffle)
self.hparams = hparams
self.prefix = prefix
self.data_dir = hparams['binary_data_dir']
self.is_infer = prefix == 'test'
self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size']
self.hop_size = hparams['hop_size']
self.indexed_ds = None
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if s > self.batch_max_frames]
print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.")
self.sizes = [s for idx, s in enumerate(self.sizes) if s > self.batch_max_frames]
def _get_item(self, index):
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
item = self.indexed_ds[index]
return item
def __getitem__(self, index):
index = self.avail_idxs[index]
item = self._get_item(index)
sample = {
"id": index,
"item_name": item['item_name'],
"mel": torch.FloatTensor(item['mel']),
"wav": torch.FloatTensor(item['wav'].astype(np.float32)),
"pitch": torch.LongTensor(item['pitch']),
"f0": torch.FloatTensor(item['f0'])
}
return sample
def collater(self, batch):
if len(batch) == 0:
return {}
y_batch, c_batch, p_batch, f0_batch = [], [], [], []
item_name = []
for idx in range(len(batch)):
item_name.append(batch[idx]['item_name'])
x, c = batch[idx]['wav'], batch[idx]['mel']
p, f0 = batch[idx]['pitch'], batch[idx]['f0']
self._assert_ready_for_upsampling(x, c, self.hop_size)
if len(c) > self.batch_max_frames:
# randomly pickup with the batch_max_steps length of the part
batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len(c) - 1
batch_max_steps = batch_max_frames * self.hop_size
interval_start = 0
interval_end = len(c) - batch_max_frames
start_frame = np.random.randint(interval_start, interval_end)
start_step = start_frame * self.hop_size
y = x[start_step: start_step + batch_max_steps]
c = c[start_frame: start_frame + batch_max_frames]
p = p[start_frame: start_frame + batch_max_frames]
f0 = f0[start_frame: start_frame + batch_max_frames]
self._assert_ready_for_upsampling(y, c, self.hop_size)
else:
print(f"Removed short sample from batch (length={len(x)}).")
continue
y_batch += [y.reshape(-1, 1)] # [(T, 1), (T, 1), ...]
c_batch += [c] # [(T' C), (T' C), ...]
p_batch += [p] # [(T' C), (T' C), ...]
f0_batch += [f0] # [(T' C), (T' C), ...]
# convert each batch to tensor, asuume that each item in batch has the same length
y_batch = collate_2d(y_batch, 0).transpose(2, 1) # (B, 1, T)
c_batch = collate_2d(c_batch, 0).transpose(2, 1) # (B, C, T')
p_batch = collate_1d(p_batch, 0) # (B, T')
f0_batch = collate_1d(f0_batch, 0) # (B, T')
# make input noise signal batch tensor
z_batch = torch.randn(y_batch.size()) # (B, 1, T)
return {
'z': z_batch,
'mels': c_batch,
'wavs': y_batch,
'pitches': p_batch,
'f0': f0_batch,
'item_name': item_name
}
@staticmethod
def _assert_ready_for_upsampling(x, c, hop_size):
"""Assert the audio and feature lengths are correctly adjusted for upsamping."""
assert len(x) == (len(c)) * hop_size
|