Spaces:
Runtime error
Runtime error
import time | |
import os | |
import random | |
import numpy as np | |
import torch | |
import torch.utils.data | |
import commons | |
from mel_processing import spectrogram_torch, spec_to_mel_torch | |
from utils import load_wav_to_torch, load_filepaths_and_text, transform | |
# import h5py | |
"""Multi speaker version""" | |
class TextAudioSpeakerLoader(torch.utils.data.Dataset): | |
""" | |
1) loads audio, speaker_id, text pairs | |
2) normalizes text and converts them to sequences of integers | |
3) computes spectrograms from audio files. | |
""" | |
def __init__(self, audiopaths, hparams): | |
self.audiopaths = load_filepaths_and_text(audiopaths) | |
self.max_wav_value = hparams.data.max_wav_value | |
self.sampling_rate = hparams.data.sampling_rate | |
self.filter_length = hparams.data.filter_length | |
self.hop_length = hparams.data.hop_length | |
self.win_length = hparams.data.win_length | |
self.sampling_rate = hparams.data.sampling_rate | |
self.use_sr = hparams.train.use_sr | |
self.spec_len = hparams.train.max_speclen | |
self.spk_map = hparams.spk | |
random.seed(1234) | |
random.shuffle(self.audiopaths) | |
def get_audio(self, filename): | |
audio, sampling_rate = load_wav_to_torch(filename) | |
if sampling_rate != self.sampling_rate: | |
raise ValueError("{} SR doesn't match target {} SR".format( | |
sampling_rate, self.sampling_rate)) | |
audio_norm = audio / self.max_wav_value | |
audio_norm = audio_norm.unsqueeze(0) | |
spec_filename = filename.replace(".wav", ".spec.pt") | |
if os.path.exists(spec_filename): | |
spec = torch.load(spec_filename) | |
else: | |
spec = spectrogram_torch(audio_norm, self.filter_length, | |
self.sampling_rate, self.hop_length, self.win_length, | |
center=False) | |
spec = torch.squeeze(spec, 0) | |
torch.save(spec, spec_filename) | |
spk = filename.split(os.sep)[-2] | |
spk = torch.LongTensor([self.spk_map[spk]]) | |
c = torch.load(filename + ".soft.pt").squeeze(0) | |
c = torch.repeat_interleave(c, repeats=3, dim=1) | |
f0 = np.load(filename + ".f0.npy") | |
f0 = torch.FloatTensor(f0) | |
lmin = min(c.size(-1), spec.size(-1), f0.shape[0]) | |
assert abs(c.size(-1) - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape, filename) | |
assert abs(lmin - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape) | |
assert abs(lmin - c.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape) | |
spec, c, f0 = spec[:, :lmin], c[:, :lmin], f0[:lmin] | |
audio_norm = audio_norm[:, :lmin * self.hop_length] | |
_spec, _c, _audio_norm, _f0 = spec, c, audio_norm, f0 | |
while spec.size(-1) < self.spec_len: | |
spec = torch.cat((spec, _spec), -1) | |
c = torch.cat((c, _c), -1) | |
f0 = torch.cat((f0, _f0), -1) | |
audio_norm = torch.cat((audio_norm, _audio_norm), -1) | |
start = random.randint(0, spec.size(-1) - self.spec_len) | |
end = start + self.spec_len | |
spec = spec[:, start:end] | |
c = c[:, start:end] | |
f0 = f0[start:end] | |
audio_norm = audio_norm[:, start * self.hop_length:end * self.hop_length] | |
return c, f0, spec, audio_norm, spk | |
def __getitem__(self, index): | |
return self.get_audio(self.audiopaths[index][0]) | |
def __len__(self): | |
return len(self.audiopaths) | |
class EvalDataLoader(torch.utils.data.Dataset): | |
""" | |
1) loads audio, speaker_id, text pairs | |
2) normalizes text and converts them to sequences of integers | |
3) computes spectrograms from audio files. | |
""" | |
def __init__(self, audiopaths, hparams): | |
self.audiopaths = load_filepaths_and_text(audiopaths) | |
self.max_wav_value = hparams.data.max_wav_value | |
self.sampling_rate = hparams.data.sampling_rate | |
self.filter_length = hparams.data.filter_length | |
self.hop_length = hparams.data.hop_length | |
self.win_length = hparams.data.win_length | |
self.sampling_rate = hparams.data.sampling_rate | |
self.use_sr = hparams.train.use_sr | |
self.audiopaths = self.audiopaths[:5] | |
self.spk_map = hparams.spk | |
def get_audio(self, filename): | |
audio, sampling_rate = load_wav_to_torch(filename) | |
if sampling_rate != self.sampling_rate: | |
raise ValueError("{} SR doesn't match target {} SR".format( | |
sampling_rate, self.sampling_rate)) | |
audio_norm = audio / self.max_wav_value | |
audio_norm = audio_norm.unsqueeze(0) | |
spec_filename = filename.replace(".wav", ".spec.pt") | |
if os.path.exists(spec_filename): | |
spec = torch.load(spec_filename) | |
else: | |
spec = spectrogram_torch(audio_norm, self.filter_length, | |
self.sampling_rate, self.hop_length, self.win_length, | |
center=False) | |
spec = torch.squeeze(spec, 0) | |
torch.save(spec, spec_filename) | |
spk = filename.split(os.sep)[-2] | |
spk = torch.LongTensor([self.spk_map[spk]]) | |
c = torch.load(filename + ".soft.pt").squeeze(0) | |
c = torch.repeat_interleave(c, repeats=3, dim=1) | |
f0 = np.load(filename + ".f0.npy") | |
f0 = torch.FloatTensor(f0) | |
lmin = min(c.size(-1), spec.size(-1), f0.shape[0]) | |
assert abs(c.size(-1) - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape) | |
assert abs(f0.shape[0] - spec.shape[-1]) < 4, (c.size(-1), spec.size(-1), f0.shape) | |
spec, c, f0 = spec[:, :lmin], c[:, :lmin], f0[:lmin] | |
audio_norm = audio_norm[:, :lmin * self.hop_length] | |
return c, f0, spec, audio_norm, spk | |
def __getitem__(self, index): | |
return self.get_audio(self.audiopaths[index][0]) | |
def __len__(self): | |
return len(self.audiopaths) | |