eriquesouza's picture
app v1
e831f85
raw history blame
No virus
11.5 kB
import os
import statistics
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio
class FastSpeechDataset(Dataset):
def __init__(self,
path_to_transcript_dict,
acoustic_checkpoint_path,
cache_dir,
lang,
loading_processes=40,
min_len_in_seconds=1,
max_len_in_seconds=20,
cut_silence=False,
reduction_factor=1,
device=torch.device("cpu"),
rebuild_cache=False,
ctc_selection=True,
save_imgs=False):
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache:
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
cut_silences=cut_silence,
rebuild_cache=rebuild_cache,
device=device)
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
# we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech.
if not isinstance(datapoints, tuple): # check for backwards compatibility
print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.")
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
cut_silences=cut_silence,
rebuild_cache=True)
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
dataset = datapoints[0]
norm_waves = datapoints[1]
# build cache
print("... building dataset cache ...")
self.datapoints = list()
self.ctc_losses = list()
acoustic_model = Aligner()
acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
# ==========================================
# actual creation of datapoints starts here
# ==========================================
acoustic_model = acoustic_model.to(device)
dio = Dio(reduction_factor=reduction_factor, fs=16000)
energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000)
dc = DurationCalculator(reduction_factor=reduction_factor)
vis_dir = os.path.join(cache_dir, "duration_vis")
os.makedirs(vis_dir, exist_ok=True)
pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device)
for index in tqdm(range(len(dataset))):
norm_wave = norm_waves[index]
norm_wave_length = torch.LongTensor([len(norm_wave)])
if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection:
continue
text = dataset[index][0]
melspec = dataset[index][2]
melspec_length = dataset[index][3]
alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device),
tokens=text.to(device),
save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None,
return_ctc=True)
cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
last_vec = None
for phoneme_index, vec in enumerate(text):
if last_vec is not None:
if last_vec.numpy().tolist() == vec.numpy().tolist():
# we found a case of repeating phonemes!
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
dur_1 = cached_duration[phoneme_index - 1]
dur_2 = cached_duration[phoneme_index]
total_dur = dur_1 + dur_2
new_dur_1 = int((total_dur / 5) * 3)
new_dur_2 = total_dur - new_dur_1
cached_duration[phoneme_index - 1] = new_dur_1
cached_duration[phoneme_index] = new_dur_2
last_vec = vec
cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=cached_duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
cached_pitch = dio(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=cached_duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
try:
prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu()
except RuntimeError:
# if there is an audio without any voiced segments whatsoever we have to skip it.
continue
self.datapoints.append([dataset[index][0],
dataset[index][1],
dataset[index][2],
dataset[index][3],
cached_duration.cpu(),
cached_energy,
cached_pitch,
prosodic_condition])
self.ctc_losses.append(ctc_loss)
# =============================
# done with datapoint creation
# =============================
if ctc_selection:
# now we can filter out some bad datapoints based on the CTC scores we collected
mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses)
std_dev = statistics.stdev(self.ctc_losses)
threshold = mean_ctc + std_dev
for index in range(len(self.ctc_losses), 0, -1):
if self.ctc_losses[index - 1] > threshold:
self.datapoints.pop(index - 1)
print(
f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}")
# save to cache
if len(self.datapoints) > 0:
torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt"))
else:
import sys
print("No datapoints were prepared! Exiting...")
sys.exit()
else:
# just load the datapoints from cache
self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu')
self.cache_dir = cache_dir
self.language_id = get_language_id(lang)
print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
def __getitem__(self, index):
return self.datapoints[index][0], \
self.datapoints[index][1], \
self.datapoints[index][2], \
self.datapoints[index][3], \
self.datapoints[index][4], \
self.datapoints[index][5], \
self.datapoints[index][6], \
self.datapoints[index][7], \
self.language_id
def __len__(self):
return len(self.datapoints)
def remove_samples(self, list_of_samples_to_remove):
for remove_id in sorted(list_of_samples_to_remove, reverse=True):
self.datapoints.pop(remove_id)
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
print("Dataset updated!")
def fix_repeating_phones(self):
"""
The viterbi decoding of the durations cannot
handle repetitions. This is now solved heuristically,
but if you have a cache from before March 2022,
use this method to postprocess those cases.
"""
for datapoint_index in tqdm(list(range(len(self.datapoints)))):
last_vec = None
for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]):
if last_vec is not None:
if last_vec.numpy().tolist() == vec.numpy().tolist():
# we found a case of repeating phonemes!
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1]
dur_2 = self.datapoints[datapoint_index][4][phoneme_index]
total_dur = dur_1 + dur_2
new_dur_1 = int((total_dur / 5) * 3)
new_dur_2 = total_dur - new_dur_1
self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1
self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2
print("fix applied")
last_vec = vec
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
print("Dataset updated!")