|
import pandas as pd
|
|
import os
|
|
import random
|
|
import ast
|
|
import numpy as np
|
|
import torch
|
|
from einops import repeat, rearrange
|
|
import librosa
|
|
|
|
from torch.utils.data import Dataset
|
|
import torchaudio
|
|
|
|
|
|
def log_f0(f0, f0_min=librosa.note_to_hz('C2'), scales=4):
|
|
f0[f0 < f0_min] = 0.0
|
|
f0_log = torch.zeros_like(f0)
|
|
f0_log[f0 != 0] = 12*np.log2(f0[f0 != 0]/f0_min) + 1
|
|
|
|
|
|
f0_log /= (scales*12)
|
|
return f0_log
|
|
|
|
|
|
class VCData(Dataset):
|
|
def __init__(self,
|
|
data_dir, meta_dir, subset, prompt_dir,
|
|
seg_length=1.92, speaker_length=4,
|
|
sr=24000, content_sr=50, speaker_sr=16000,
|
|
plugin_mode=False
|
|
):
|
|
self.datadir = data_dir
|
|
meta = pd.read_csv(meta_dir)
|
|
self.meta = meta[meta['subset'] == subset]
|
|
self.subset = subset
|
|
self.prompts = pd.read_csv(prompt_dir)
|
|
self.seg_len = seg_length
|
|
self.speaker_length = speaker_length
|
|
self.sr = sr
|
|
self.content_sr = content_sr
|
|
self.speaker_sr = speaker_sr
|
|
self.plugin_mode = plugin_mode
|
|
|
|
def get_audio_content(self, audio_path, content_path, f0_path):
|
|
audio_path = self.datadir + audio_path
|
|
audio, sr = torchaudio.load(audio_path)
|
|
assert sr == self.sr
|
|
|
|
|
|
content = torch.load(self.datadir + content_path)
|
|
|
|
total_length = content.shape[1]
|
|
if int(total_length - int(self.content_sr * self.seg_len)) > 0:
|
|
start = np.random.randint(0, int(total_length - self.content_sr * self.seg_len) + 1)
|
|
else:
|
|
start = 0
|
|
end = min(start + int(self.seg_len * self.content_sr), content.shape[1])
|
|
|
|
|
|
content_clip = repeat(content[:, -1, :], "b c-> b t c", t=int(self.content_sr * self.seg_len)).clone()
|
|
content_clip[:, :end - start, :] = content[:, start: end, :]
|
|
|
|
audio_clip = torch.zeros(int(self.seg_len * self.sr))
|
|
|
|
|
|
audio_start = round(start * self.sr / self.content_sr)
|
|
audio_end = round(end * self.sr / self.content_sr)
|
|
|
|
|
|
|
|
|
|
audio_clip[:audio_end - audio_start] = audio[0, audio_start: audio_end].clone()
|
|
|
|
if f0_path:
|
|
f0 = torch.load(self.datadir + f0_path).float()
|
|
f0_clip = torch.zeros(int(self.content_sr * self.seg_len))
|
|
f0_clip[:end-start] = f0[start:end]
|
|
f0_clip = log_f0(f0_clip)
|
|
f0_clip = f0_clip.unsqueeze(-1)
|
|
else:
|
|
f0_clip = None
|
|
|
|
return audio_clip, content_clip[0], f0_clip
|
|
|
|
def get_speaker(self, speaker_path):
|
|
audio_path = self.datadir + speaker_path
|
|
audio, sr = torchaudio.load(audio_path)
|
|
assert sr == self.speaker_sr
|
|
|
|
|
|
|
|
|
|
audio_clip = torch.zeros(self.speaker_length * self.speaker_sr)
|
|
|
|
total_length = audio.shape[1]
|
|
if int(total_length - self.speaker_sr * self.speaker_length) > 0:
|
|
start = np.random.randint(0, int(total_length - self.speaker_sr * self.speaker_length) + 1)
|
|
else:
|
|
start = 0
|
|
end = min(start + self.speaker_sr * self.speaker_length, total_length)
|
|
|
|
audio_clip[:end-start] = audio[0, start: end]
|
|
|
|
return audio_clip
|
|
|
|
def __getitem__(self, index):
|
|
row = self.meta.iloc[index]
|
|
|
|
if self.plugin_mode:
|
|
audio_clip, content_clip, f0_clip = [''], [''], ['']
|
|
else:
|
|
|
|
audio_path = row['audio_path']
|
|
content_path = row['content_path']
|
|
f0_path = row['f0_path']
|
|
audio_clip, content_clip, f0_clip = self.get_audio_content(audio_path, content_path, f0_path)
|
|
|
|
|
|
if self.subset == 'train':
|
|
speaker = row['speaker']
|
|
else:
|
|
speaker = row['speaker_val']
|
|
|
|
speaker_row = self.meta[self.meta['speaker'] == speaker].sample(1)
|
|
speaker_path = speaker_row.iloc[0]['speaker_path']
|
|
speaker_clip = self.get_speaker(speaker_path)
|
|
|
|
|
|
|
|
|
|
|
|
prompts = self.prompts[self.prompts['ID'] == speaker]['prompts'].iloc[0]
|
|
prompts = ast.literal_eval(prompts)
|
|
prompt = random.choice(prompts)
|
|
|
|
return audio_clip, content_clip, f0_clip, speaker_clip, prompt
|
|
|
|
def __len__(self):
|
|
return len(self.meta)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from tqdm import tqdm
|
|
data = VCData('../../features/', '../../data/meta_val.csv', 'val', '../../data/speaker_gender.csv')
|
|
for i in tqdm(range(len(data))):
|
|
x = data[i]
|
|
|