Higobeatz's picture
freevc plugin
0dabde8
raw
history blame contribute delete
5.2 kB
import pandas as pd
import os
import random
import ast
import numpy as np
import torch
from einops import repeat, rearrange
import librosa
from torch.utils.data import Dataset
import torchaudio
def log_f0(f0, f0_min=librosa.note_to_hz('C2'), scales=4):
f0[f0 < f0_min] = 0.0
f0_log = torch.zeros_like(f0)
f0_log[f0 != 0] = 12*np.log2(f0[f0 != 0]/f0_min) + 1
# f0_mel_min = 12*np.log2(f0_min/f0_min) + 1
# f0_mel_max = 12*np.log2(f0_max/f0_min) + 1
f0_log /= (scales*12)
return f0_log
class VCData(Dataset):
def __init__(self,
data_dir, meta_dir, subset, prompt_dir,
seg_length=1.92, speaker_length=4,
sr=24000, content_sr=50, speaker_sr=16000,
plugin_mode=False
):
self.datadir = data_dir
meta = pd.read_csv(meta_dir)
self.meta = meta[meta['subset'] == subset]
self.subset = subset
self.prompts = pd.read_csv(prompt_dir)
self.seg_len = seg_length
self.speaker_length = speaker_length
self.sr = sr
self.content_sr = content_sr
self.speaker_sr = speaker_sr
self.plugin_mode = plugin_mode
def get_audio_content(self, audio_path, content_path, f0_path):
audio_path = self.datadir + audio_path
audio, sr = torchaudio.load(audio_path)
assert sr == self.sr
# 1, T, C
content = torch.load(self.datadir + content_path)
total_length = content.shape[1]
if int(total_length - int(self.content_sr * self.seg_len)) > 0:
start = np.random.randint(0, int(total_length - self.content_sr * self.seg_len) + 1)
else:
start = 0
end = min(start + int(self.seg_len * self.content_sr), content.shape[1])
# use last frame for padding
content_clip = repeat(content[:, -1, :], "b c-> b t c", t=int(self.content_sr * self.seg_len)).clone()
content_clip[:, :end - start, :] = content[:, start: end, :]
audio_clip = torch.zeros(int(self.seg_len * self.sr))
# print(start)
# print(end)
audio_start = round(start * self.sr / self.content_sr)
audio_end = round(end * self.sr / self.content_sr)
# print(audio_start)
# print(audio_end)
# print(audio.shape)
audio_clip[:audio_end - audio_start] = audio[0, audio_start: audio_end].clone()
if f0_path:
f0 = torch.load(self.datadir + f0_path).float()
f0_clip = torch.zeros(int(self.content_sr * self.seg_len))
f0_clip[:end-start] = f0[start:end]
f0_clip = log_f0(f0_clip)
f0_clip = f0_clip.unsqueeze(-1)
else:
f0_clip = None
return audio_clip, content_clip[0], f0_clip
def get_speaker(self, speaker_path):
audio_path = self.datadir + speaker_path
audio, sr = torchaudio.load(audio_path)
assert sr == self.speaker_sr
# if sr != self.speaker_sr:
# resampler = torchaudio.transforms.Resample(sr, self.speaker_sr, dtype=audio.dtype)
# audio = resampler(audio)
audio_clip = torch.zeros(self.speaker_length * self.speaker_sr)
total_length = audio.shape[1]
if int(total_length - self.speaker_sr * self.speaker_length) > 0:
start = np.random.randint(0, int(total_length - self.speaker_sr * self.speaker_length) + 1)
else:
start = 0
end = min(start + self.speaker_sr * self.speaker_length, total_length)
audio_clip[:end-start] = audio[0, start: end]
return audio_clip
def __getitem__(self, index):
row = self.meta.iloc[index]
if self.plugin_mode:
audio_clip, content_clip, f0_clip = [''], [''], ['']
else:
# load current audio
audio_path = row['audio_path']
content_path = row['content_path']
f0_path = row['f0_path']
audio_clip, content_clip, f0_clip = self.get_audio_content(audio_path, content_path, f0_path)
# get speaker
if self.subset == 'train':
speaker = row['speaker']
else:
speaker = row['speaker_val']
speaker_row = self.meta[self.meta['speaker'] == speaker].sample(1)
speaker_path = speaker_row.iloc[0]['speaker_path']
speaker_clip = self.get_speaker(speaker_path)
# print(speaker_clip.shape)
# print(speaker_path)
# print(speaker)
# get prompt
prompts = self.prompts[self.prompts['ID'] == speaker]['prompts'].iloc[0]
prompts = ast.literal_eval(prompts)
prompt = random.choice(prompts)
return audio_clip, content_clip, f0_clip, speaker_clip, prompt
def __len__(self):
return len(self.meta)
if __name__ == '__main__':
from tqdm import tqdm
data = VCData('../../features/', '../../data/meta_val.csv', 'val', '../../data/speaker_gender.csv')
for i in tqdm(range(len(data))):
x = data[i]
# print(x[-1])