Florian Lux
add initial infrastructure
b3fa29f
raw history blame
No virus
11.5 kB
import os
import statistics
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio
class FastSpeechDataset(Dataset):
def __init__(self,
path_to_transcript_dict,
acoustic_checkpoint_path,
cache_dir,
lang,
loading_processes=40,
min_len_in_seconds=1,
max_len_in_seconds=20,
cut_silence=False,
reduction_factor=1,
device=torch.device("cpu"),
rebuild_cache=False,
ctc_selection=True,
save_imgs=False):
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache:
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
cut_silences=cut_silence,
rebuild_cache=rebuild_cache,
device=device)
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
# we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech.
if not isinstance(datapoints, tuple): # check for backwards compatibility
print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.")
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
cut_silences=cut_silence,
rebuild_cache=True)
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
dataset = datapoints[0]
norm_waves = datapoints[1]
# build cache
print("... building dataset cache ...")
self.datapoints = list()
self.ctc_losses = list()
acoustic_model = Aligner()
acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
# ==========================================
# actual creation of datapoints starts here
# ==========================================
acoustic_model = acoustic_model.to(device)
dio = Dio(reduction_factor=reduction_factor, fs=16000)
energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000)
dc = DurationCalculator(reduction_factor=reduction_factor)
vis_dir = os.path.join(cache_dir, "duration_vis")
os.makedirs(vis_dir, exist_ok=True)
pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device)
for index in tqdm(range(len(dataset))):
norm_wave = norm_waves[index]
norm_wave_length = torch.LongTensor([len(norm_wave)])
if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection:
continue
text = dataset[index][0]
melspec = dataset[index][2]
melspec_length = dataset[index][3]
alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device),
tokens=text.to(device),
save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None,
return_ctc=True)
cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
last_vec = None
for phoneme_index, vec in enumerate(text):
if last_vec is not None:
if last_vec.numpy().tolist() == vec.numpy().tolist():
# we found a case of repeating phonemes!
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
dur_1 = cached_duration[phoneme_index - 1]
dur_2 = cached_duration[phoneme_index]
total_dur = dur_1 + dur_2
new_dur_1 = int((total_dur / 5) * 3)
new_dur_2 = total_dur - new_dur_1
cached_duration[phoneme_index - 1] = new_dur_1
cached_duration[phoneme_index] = new_dur_2
last_vec = vec
cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=cached_duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
cached_pitch = dio(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=cached_duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
try:
prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu()
except RuntimeError:
# if there is an audio without any voiced segments whatsoever we have to skip it.
continue
self.datapoints.append([dataset[index][0],
dataset[index][1],
dataset[index][2],
dataset[index][3],
cached_duration.cpu(),
cached_energy,
cached_pitch,
prosodic_condition])
self.ctc_losses.append(ctc_loss)
# =============================
# done with datapoint creation
# =============================
if ctc_selection:
# now we can filter out some bad datapoints based on the CTC scores we collected
mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses)
std_dev = statistics.stdev(self.ctc_losses)
threshold = mean_ctc + std_dev
for index in range(len(self.ctc_losses), 0, -1):
if self.ctc_losses[index - 1] > threshold:
self.datapoints.pop(index - 1)
print(
f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}")
# save to cache
if len(self.datapoints) > 0:
torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt"))
else:
import sys
print("No datapoints were prepared! Exiting...")
sys.exit()
else:
# just load the datapoints from cache
self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu')
self.cache_dir = cache_dir
self.language_id = get_language_id(lang)
print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
def __getitem__(self, index):
return self.datapoints[index][0], \
self.datapoints[index][1], \
self.datapoints[index][2], \
self.datapoints[index][3], \
self.datapoints[index][4], \
self.datapoints[index][5], \
self.datapoints[index][6], \
self.datapoints[index][7], \
self.language_id
def __len__(self):
return len(self.datapoints)
def remove_samples(self, list_of_samples_to_remove):
for remove_id in sorted(list_of_samples_to_remove, reverse=True):
self.datapoints.pop(remove_id)
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
print("Dataset updated!")
def fix_repeating_phones(self):
"""
The viterbi decoding of the durations cannot
handle repetitions. This is now solved heuristically,
but if you have a cache from before March 2022,
use this method to postprocess those cases.
"""
for datapoint_index in tqdm(list(range(len(self.datapoints)))):
last_vec = None
for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]):
if last_vec is not None:
if last_vec.numpy().tolist() == vec.numpy().tolist():
# we found a case of repeating phonemes!
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1]
dur_2 = self.datapoints[datapoint_index][4][phoneme_index]
total_dur = dur_1 + dur_2
new_dur_1 = int((total_dur / 5) * 3)
new_dur_2 = total_dur - new_dur_1
self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1
self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2
print("fix applied")
last_vec = vec
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
print("Dataset updated!")