Spaces:
Runtime error
Runtime error
Upload vocoder/vocoder_dataset.py with huggingface_hub
Browse files- vocoder/vocoder_dataset.py +84 -0
vocoder/vocoder_dataset.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from pathlib import Path
|
3 |
+
from vocoder import audio
|
4 |
+
import vocoder.hparams as hp
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class VocoderDataset(Dataset):
|
10 |
+
def __init__(self, metadata_fpath: Path, mel_dir: Path, wav_dir: Path):
|
11 |
+
print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, wav_dir))
|
12 |
+
|
13 |
+
with metadata_fpath.open("r") as metadata_file:
|
14 |
+
metadata = [line.split("|") for line in metadata_file]
|
15 |
+
|
16 |
+
gta_fnames = [x[1] for x in metadata if int(x[4])]
|
17 |
+
gta_fpaths = [mel_dir.joinpath(fname) for fname in gta_fnames]
|
18 |
+
wav_fnames = [x[0] for x in metadata if int(x[4])]
|
19 |
+
wav_fpaths = [wav_dir.joinpath(fname) for fname in wav_fnames]
|
20 |
+
self.samples_fpaths = list(zip(gta_fpaths, wav_fpaths))
|
21 |
+
|
22 |
+
print("Found %d samples" % len(self.samples_fpaths))
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
mel_path, wav_path = self.samples_fpaths[index]
|
26 |
+
|
27 |
+
# Load the mel spectrogram and adjust its range to [-1, 1]
|
28 |
+
mel = np.load(mel_path).T.astype(np.float32) / hp.mel_max_abs_value
|
29 |
+
|
30 |
+
# Load the wav
|
31 |
+
wav = np.load(wav_path)
|
32 |
+
if hp.apply_preemphasis:
|
33 |
+
wav = audio.pre_emphasis(wav)
|
34 |
+
wav = np.clip(wav, -1, 1)
|
35 |
+
|
36 |
+
# Fix for missing padding # TODO: settle on whether this is any useful
|
37 |
+
r_pad = (len(wav) // hp.hop_length + 1) * hp.hop_length - len(wav)
|
38 |
+
wav = np.pad(wav, (0, r_pad), mode='constant')
|
39 |
+
assert len(wav) >= mel.shape[1] * hp.hop_length
|
40 |
+
wav = wav[:mel.shape[1] * hp.hop_length]
|
41 |
+
assert len(wav) % hp.hop_length == 0
|
42 |
+
|
43 |
+
# Quantize the wav
|
44 |
+
if hp.voc_mode == 'RAW':
|
45 |
+
if hp.mu_law:
|
46 |
+
quant = audio.encode_mu_law(wav, mu=2 ** hp.bits)
|
47 |
+
else:
|
48 |
+
quant = audio.float_2_label(wav, bits=hp.bits)
|
49 |
+
elif hp.voc_mode == 'MOL':
|
50 |
+
quant = audio.float_2_label(wav, bits=16)
|
51 |
+
|
52 |
+
return mel.astype(np.float32), quant.astype(np.int64)
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.samples_fpaths)
|
56 |
+
|
57 |
+
|
58 |
+
def collate_vocoder(batch):
|
59 |
+
mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad
|
60 |
+
max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch]
|
61 |
+
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
62 |
+
sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets]
|
63 |
+
|
64 |
+
mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
|
65 |
+
|
66 |
+
labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)]
|
67 |
+
|
68 |
+
mels = np.stack(mels).astype(np.float32)
|
69 |
+
labels = np.stack(labels).astype(np.int64)
|
70 |
+
|
71 |
+
mels = torch.tensor(mels)
|
72 |
+
labels = torch.tensor(labels).long()
|
73 |
+
|
74 |
+
x = labels[:, :hp.voc_seq_len]
|
75 |
+
y = labels[:, 1:]
|
76 |
+
|
77 |
+
bits = 16 if hp.voc_mode == 'MOL' else hp.bits
|
78 |
+
|
79 |
+
x = audio.label_2_float(x.float(), bits)
|
80 |
+
|
81 |
+
if hp.voc_mode == 'MOL' :
|
82 |
+
y = audio.label_2_float(y.float(), bits)
|
83 |
+
|
84 |
+
return x, y, mels
|