keithhon commited on
Commit
e7f3680
1 Parent(s): 4dd64e8

Upload vocoder/vocoder_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocoder/vocoder_dataset.py +84 -0
vocoder/vocoder_dataset.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from vocoder import audio
4
+ import vocoder.hparams as hp
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ class VocoderDataset(Dataset):
10
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, wav_dir: Path):
11
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, wav_dir))
12
+
13
+ with metadata_fpath.open("r") as metadata_file:
14
+ metadata = [line.split("|") for line in metadata_file]
15
+
16
+ gta_fnames = [x[1] for x in metadata if int(x[4])]
17
+ gta_fpaths = [mel_dir.joinpath(fname) for fname in gta_fnames]
18
+ wav_fnames = [x[0] for x in metadata if int(x[4])]
19
+ wav_fpaths = [wav_dir.joinpath(fname) for fname in wav_fnames]
20
+ self.samples_fpaths = list(zip(gta_fpaths, wav_fpaths))
21
+
22
+ print("Found %d samples" % len(self.samples_fpaths))
23
+
24
+ def __getitem__(self, index):
25
+ mel_path, wav_path = self.samples_fpaths[index]
26
+
27
+ # Load the mel spectrogram and adjust its range to [-1, 1]
28
+ mel = np.load(mel_path).T.astype(np.float32) / hp.mel_max_abs_value
29
+
30
+ # Load the wav
31
+ wav = np.load(wav_path)
32
+ if hp.apply_preemphasis:
33
+ wav = audio.pre_emphasis(wav)
34
+ wav = np.clip(wav, -1, 1)
35
+
36
+ # Fix for missing padding # TODO: settle on whether this is any useful
37
+ r_pad = (len(wav) // hp.hop_length + 1) * hp.hop_length - len(wav)
38
+ wav = np.pad(wav, (0, r_pad), mode='constant')
39
+ assert len(wav) >= mel.shape[1] * hp.hop_length
40
+ wav = wav[:mel.shape[1] * hp.hop_length]
41
+ assert len(wav) % hp.hop_length == 0
42
+
43
+ # Quantize the wav
44
+ if hp.voc_mode == 'RAW':
45
+ if hp.mu_law:
46
+ quant = audio.encode_mu_law(wav, mu=2 ** hp.bits)
47
+ else:
48
+ quant = audio.float_2_label(wav, bits=hp.bits)
49
+ elif hp.voc_mode == 'MOL':
50
+ quant = audio.float_2_label(wav, bits=16)
51
+
52
+ return mel.astype(np.float32), quant.astype(np.int64)
53
+
54
+ def __len__(self):
55
+ return len(self.samples_fpaths)
56
+
57
+
58
+ def collate_vocoder(batch):
59
+ mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad
60
+ max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch]
61
+ mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
62
+ sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets]
63
+
64
+ mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
65
+
66
+ labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)]
67
+
68
+ mels = np.stack(mels).astype(np.float32)
69
+ labels = np.stack(labels).astype(np.int64)
70
+
71
+ mels = torch.tensor(mels)
72
+ labels = torch.tensor(labels).long()
73
+
74
+ x = labels[:, :hp.voc_seq_len]
75
+ y = labels[:, 1:]
76
+
77
+ bits = 16 if hp.voc_mode == 'MOL' else hp.bits
78
+
79
+ x = audio.label_2_float(x.float(), bits)
80
+
81
+ if hp.voc_mode == 'MOL' :
82
+ y = audio.label_2_float(y.float(), bits)
83
+
84
+ return x, y, mels