Spaces:
Running
Running
import os | |
import statistics | |
import torch | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id | |
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor | |
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner | |
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset | |
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator | |
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator | |
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio | |
class FastSpeechDataset(Dataset): | |
def __init__(self, | |
path_to_transcript_dict, | |
acoustic_checkpoint_path, | |
cache_dir, | |
lang, | |
loading_processes=40, | |
min_len_in_seconds=1, | |
max_len_in_seconds=20, | |
cut_silence=False, | |
reduction_factor=1, | |
device=torch.device("cpu"), | |
rebuild_cache=False, | |
ctc_selection=True, | |
save_imgs=False): | |
self.cache_dir = cache_dir | |
os.makedirs(cache_dir, exist_ok=True) | |
if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache: | |
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: | |
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict, | |
cache_dir=cache_dir, | |
lang=lang, | |
loading_processes=loading_processes, | |
min_len_in_seconds=min_len_in_seconds, | |
max_len_in_seconds=max_len_in_seconds, | |
cut_silences=cut_silence, | |
rebuild_cache=rebuild_cache, | |
device=device) | |
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu') | |
# we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech. | |
if not isinstance(datapoints, tuple): # check for backwards compatibility | |
print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.") | |
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict, | |
cache_dir=cache_dir, | |
lang=lang, | |
loading_processes=loading_processes, | |
min_len_in_seconds=min_len_in_seconds, | |
max_len_in_seconds=max_len_in_seconds, | |
cut_silences=cut_silence, | |
rebuild_cache=True) | |
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu') | |
dataset = datapoints[0] | |
norm_waves = datapoints[1] | |
# build cache | |
print("... building dataset cache ...") | |
self.datapoints = list() | |
self.ctc_losses = list() | |
acoustic_model = Aligner() | |
acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"]) | |
# ========================================== | |
# actual creation of datapoints starts here | |
# ========================================== | |
acoustic_model = acoustic_model.to(device) | |
dio = Dio(reduction_factor=reduction_factor, fs=16000) | |
energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000) | |
dc = DurationCalculator(reduction_factor=reduction_factor) | |
vis_dir = os.path.join(cache_dir, "duration_vis") | |
os.makedirs(vis_dir, exist_ok=True) | |
pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device) | |
for index in tqdm(range(len(dataset))): | |
norm_wave = norm_waves[index] | |
norm_wave_length = torch.LongTensor([len(norm_wave)]) | |
if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection: | |
continue | |
text = dataset[index][0] | |
melspec = dataset[index][2] | |
melspec_length = dataset[index][3] | |
alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device), | |
tokens=text.to(device), | |
save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None, | |
return_ctc=True) | |
cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu() | |
last_vec = None | |
for phoneme_index, vec in enumerate(text): | |
if last_vec is not None: | |
if last_vec.numpy().tolist() == vec.numpy().tolist(): | |
# we found a case of repeating phonemes! | |
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest) | |
dur_1 = cached_duration[phoneme_index - 1] | |
dur_2 = cached_duration[phoneme_index] | |
total_dur = dur_1 + dur_2 | |
new_dur_1 = int((total_dur / 5) * 3) | |
new_dur_2 = total_dur - new_dur_1 | |
cached_duration[phoneme_index - 1] = new_dur_1 | |
cached_duration[phoneme_index] = new_dur_2 | |
last_vec = vec | |
cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0), | |
input_waves_lengths=norm_wave_length, | |
feats_lengths=melspec_length, | |
durations=cached_duration.unsqueeze(0), | |
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu() | |
cached_pitch = dio(input_waves=norm_wave.unsqueeze(0), | |
input_waves_lengths=norm_wave_length, | |
feats_lengths=melspec_length, | |
durations=cached_duration.unsqueeze(0), | |
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu() | |
try: | |
prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu() | |
except RuntimeError: | |
# if there is an audio without any voiced segments whatsoever we have to skip it. | |
continue | |
self.datapoints.append([dataset[index][0], | |
dataset[index][1], | |
dataset[index][2], | |
dataset[index][3], | |
cached_duration.cpu(), | |
cached_energy, | |
cached_pitch, | |
prosodic_condition]) | |
self.ctc_losses.append(ctc_loss) | |
# ============================= | |
# done with datapoint creation | |
# ============================= | |
if ctc_selection: | |
# now we can filter out some bad datapoints based on the CTC scores we collected | |
mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses) | |
std_dev = statistics.stdev(self.ctc_losses) | |
threshold = mean_ctc + std_dev | |
for index in range(len(self.ctc_losses), 0, -1): | |
if self.ctc_losses[index - 1] > threshold: | |
self.datapoints.pop(index - 1) | |
print( | |
f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}") | |
# save to cache | |
if len(self.datapoints) > 0: | |
torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt")) | |
else: | |
import sys | |
print("No datapoints were prepared! Exiting...") | |
sys.exit() | |
else: | |
# just load the datapoints from cache | |
self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu') | |
self.cache_dir = cache_dir | |
self.language_id = get_language_id(lang) | |
print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.") | |
def __getitem__(self, index): | |
return self.datapoints[index][0], \ | |
self.datapoints[index][1], \ | |
self.datapoints[index][2], \ | |
self.datapoints[index][3], \ | |
self.datapoints[index][4], \ | |
self.datapoints[index][5], \ | |
self.datapoints[index][6], \ | |
self.datapoints[index][7], \ | |
self.language_id | |
def __len__(self): | |
return len(self.datapoints) | |
def remove_samples(self, list_of_samples_to_remove): | |
for remove_id in sorted(list_of_samples_to_remove, reverse=True): | |
self.datapoints.pop(remove_id) | |
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt")) | |
print("Dataset updated!") | |
def fix_repeating_phones(self): | |
""" | |
The viterbi decoding of the durations cannot | |
handle repetitions. This is now solved heuristically, | |
but if you have a cache from before March 2022, | |
use this method to postprocess those cases. | |
""" | |
for datapoint_index in tqdm(list(range(len(self.datapoints)))): | |
last_vec = None | |
for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]): | |
if last_vec is not None: | |
if last_vec.numpy().tolist() == vec.numpy().tolist(): | |
# we found a case of repeating phonemes! | |
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest) | |
dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1] | |
dur_2 = self.datapoints[datapoint_index][4][phoneme_index] | |
total_dur = dur_1 + dur_2 | |
new_dur_1 = int((total_dur / 5) * 3) | |
new_dur_2 = total_dur - new_dur_1 | |
self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1 | |
self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2 | |
print("fix applied") | |
last_vec = vec | |
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt")) | |
print("Dataset updated!") | |