Spaces:
Running
on
Zero
Running
on
Zero
# coding: utf-8 | |
import os | |
import os.path as osp | |
import time | |
import random | |
import numpy as np | |
import random | |
import soundfile as sf | |
import librosa | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchaudio | |
from torch.utils.data import DataLoader | |
import math | |
import logging | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
from torch.utils.data.distributed import DistributedSampler | |
np.random.seed(114514) | |
random.seed(114514) | |
SPECT_PARAMS = { | |
"n_fft": 2048, | |
"win_length": 1200, | |
"hop_length": 300, | |
} | |
MEL_PARAMS = { | |
"n_mels": 80, | |
} | |
to_mel = torchaudio.transforms.MelSpectrogram( | |
n_mels=MEL_PARAMS['n_mels'], **SPECT_PARAMS) | |
mean, std = -4, 4 | |
def preprocess(wave): | |
# wave = wave.unsqueeze(0) | |
wave_tensor = torch.from_numpy(wave).float() | |
mel_tensor = to_mel(wave_tensor) | |
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std | |
return mel_tensor | |
class PseudoDataset(torch.utils.data.Dataset): | |
def __init__(self, | |
list_path, | |
sr=24000, | |
range=(1, 30), # length of the audio duration in seconds | |
): | |
self.data_list = [] # read your list path here | |
self.sr = sr | |
self.duration_range = range | |
def __len__(self): | |
# return len(self.data_list) | |
return 100 # return a fixed number for testing | |
def __getitem__(self, idx): | |
# replace this with your own data loading | |
# wave, sr = librosa.load(self.data_list[idx], sr=self.sr) | |
wave = np.random.randn(self.sr * random.randint(*self.duration_range)).clamp(-1, 1) | |
mel = preprocess(wave) | |
return wave, mel | |
def collate(batch): | |
# batch[0] = wave, mel, text, f0, speakerid | |
batch_size = len(batch) | |
# sort by mel length | |
lengths = [b[1].shape[1] for b in batch] | |
batch_indexes = np.argsort(lengths)[::-1] | |
batch = [batch[bid] for bid in batch_indexes] | |
nmels = batch[0][1].size(0) | |
max_mel_length = max([b[1].shape[1] for b in batch]) | |
max_wave_length = max([b[0].size(0) for b in batch]) | |
mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10 | |
waves = torch.zeros((batch_size, max_wave_length)).float() | |
mel_lengths = torch.zeros(batch_size).long() | |
wave_lengths = torch.zeros(batch_size).long() | |
for bid, (wave, mel) in enumerate(batch): | |
mel_size = mel.size(1) | |
mels[bid, :, :mel_size] = mel | |
waves[bid, : wave.size(0)] = wave | |
mel_lengths[bid] = mel_size | |
wave_lengths[bid] = wave.size(0) | |
return waves, mels, wave_lengths, mel_lengths | |
def build_dataloader( | |
rank=0, | |
world_size=1, | |
batch_size=32, | |
num_workers=0, | |
prefetch_factor=16, | |
): | |
dataset = PseudoDataset() # replace this with your own dataset | |
collate_fn = collate | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
dataset, | |
num_replicas=world_size, | |
rank=rank, | |
shuffle=True, | |
seed=114514, | |
) | |
data_loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
sampler=sampler, | |
num_workers=num_workers, | |
drop_last=True, | |
collate_fn=collate_fn, | |
pin_memory=True, | |
prefetch_factor=prefetch_factor, | |
# shuffle=True, | |
) | |
return data_loader | |