Spaces:
Running
on
Zero
Running
on
Zero
| # coding: utf-8 | |
| import os | |
| import os.path as osp | |
| import time | |
| import random | |
| import numpy as np | |
| import random | |
| import soundfile as sf | |
| import librosa | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from torch.utils.data import DataLoader | |
| import math | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| from torch.utils.data.distributed import DistributedSampler | |
| np.random.seed(114514) | |
| random.seed(114514) | |
| SPECT_PARAMS = { | |
| "n_fft": 2048, | |
| "win_length": 1200, | |
| "hop_length": 300, | |
| } | |
| MEL_PARAMS = { | |
| "n_mels": 80, | |
| } | |
| to_mel = torchaudio.transforms.MelSpectrogram( | |
| n_mels=MEL_PARAMS['n_mels'], **SPECT_PARAMS) | |
| mean, std = -4, 4 | |
| def preprocess(wave): | |
| # wave = wave.unsqueeze(0) | |
| wave_tensor = torch.from_numpy(wave).float() | |
| mel_tensor = to_mel(wave_tensor) | |
| mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std | |
| return mel_tensor | |
| class PseudoDataset(torch.utils.data.Dataset): | |
| def __init__(self, | |
| list_path, | |
| sr=24000, | |
| range=(1, 30), # length of the audio duration in seconds | |
| ): | |
| self.data_list = [] # read your list path here | |
| self.sr = sr | |
| self.duration_range = range | |
| def __len__(self): | |
| # return len(self.data_list) | |
| return 100 # return a fixed number for testing | |
| def __getitem__(self, idx): | |
| # replace this with your own data loading | |
| # wave, sr = librosa.load(self.data_list[idx], sr=self.sr) | |
| wave = np.random.randn(self.sr * random.randint(*self.duration_range)).clamp(-1, 1) | |
| mel = preprocess(wave) | |
| return wave, mel | |
| def collate(batch): | |
| # batch[0] = wave, mel, text, f0, speakerid | |
| batch_size = len(batch) | |
| # sort by mel length | |
| lengths = [b[1].shape[1] for b in batch] | |
| batch_indexes = np.argsort(lengths)[::-1] | |
| batch = [batch[bid] for bid in batch_indexes] | |
| nmels = batch[0][1].size(0) | |
| max_mel_length = max([b[1].shape[1] for b in batch]) | |
| max_wave_length = max([b[0].size(0) for b in batch]) | |
| mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10 | |
| waves = torch.zeros((batch_size, max_wave_length)).float() | |
| mel_lengths = torch.zeros(batch_size).long() | |
| wave_lengths = torch.zeros(batch_size).long() | |
| for bid, (wave, mel) in enumerate(batch): | |
| mel_size = mel.size(1) | |
| mels[bid, :, :mel_size] = mel | |
| waves[bid, : wave.size(0)] = wave | |
| mel_lengths[bid] = mel_size | |
| wave_lengths[bid] = wave.size(0) | |
| return waves, mels, wave_lengths, mel_lengths | |
| def build_dataloader( | |
| rank=0, | |
| world_size=1, | |
| batch_size=32, | |
| num_workers=0, | |
| prefetch_factor=16, | |
| ): | |
| dataset = PseudoDataset() # replace this with your own dataset | |
| collate_fn = collate | |
| sampler = torch.utils.data.distributed.DistributedSampler( | |
| dataset, | |
| num_replicas=world_size, | |
| rank=rank, | |
| shuffle=True, | |
| seed=114514, | |
| ) | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| sampler=sampler, | |
| num_workers=num_workers, | |
| drop_last=True, | |
| collate_fn=collate_fn, | |
| pin_memory=True, | |
| prefetch_factor=prefetch_factor, | |
| # shuffle=True, | |
| ) | |
| return data_loader | |