""" This module is meant to find potentially problematic samples in the data you are using. There are two types: The alignment scorer and the TTS scorer. The alignment scorer can help you find mispronunciations or errors in the labels. The TTS scorer can help you find outliers in the audio part of text-audio pairs. """ import torch import torch.multiprocessing from tqdm import tqdm from Architectures.ToucanTTS.ToucanTTS import ToucanTTS from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor from Utility.corpus_preparation import prepare_tts_corpus class TTSScorer: def __init__(self, path_to_model, device, ): self.device = device self.path_to_score = dict() self.path_to_id = dict() self.nans = list() self.nan_indexes = list() self.tts = ToucanTTS() checkpoint = torch.load(path_to_model, map_location='cpu') weights = checkpoint["model"] self.tts.load_state_dict(weights) self.tts.to(self.device) self.nans_removed = False self.current_dset = None self.ap = CodecAudioPreprocessor(input_sr=-1, device=device) self.spec_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device) def score(self, path_to_toucantts_dataset, lang_id): """ call this to update the path_to_score dict with scores for this dataset """ dataset = prepare_tts_corpus(dict(), path_to_toucantts_dataset, lang_id) self.current_dset = dataset self.nans = list() self.nan_indexes = list() self.path_to_score = dict() self.path_to_id = dict() _ = dataset[0] for index in tqdm(range(len(dataset.datapoints))): datapoint = dataset.datapoints[index] text_tensors = datapoint[0].to(self.device).unsqueeze(0).float() text_lengths = datapoint[1].squeeze().to(self.device).unsqueeze(0) speech_indexes = datapoint[2] speech_lengths = datapoint[3].squeeze().to(self.device).unsqueeze(0) gold_durations = datapoint[4].to(self.device).unsqueeze(0) gold_pitch = datapoint[6].to(self.device).unsqueeze(0) # mind the switched order gold_energy = datapoint[5].to(self.device).unsqueeze(0) # mind the switched order lang_ids = dataset.language_id.to(self.device) filepath = datapoint[8] with torch.inference_mode(): wave = self.ap.indexes_to_audio(speech_indexes.int().to(self.device)).detach() mel = self.spec_extractor.audio_to_mel_spec_tensor(wave, explicit_sampling_rate=16000).transpose(0, 1).detach().cpu() gold_speech_sample = mel.clone().to(self.device).unsqueeze(0) utterance_embedding = datapoint[7].unsqueeze(0).to(self.device) try: regression_loss, _, duration_loss, pitch_loss, energy_loss = self.tts(text_tensors=text_tensors, text_lengths=text_lengths, gold_speech=gold_speech_sample, speech_lengths=speech_lengths, gold_durations=gold_durations, gold_pitch=gold_pitch, gold_energy=gold_energy, utterance_embedding=utterance_embedding, lang_ids=lang_ids, return_feats=False, run_glow=False) loss = regression_loss + duration_loss + pitch_loss + energy_loss # we omit the glow loss except TypeError: loss = torch.tensor(torch.nan) if torch.isnan(loss): self.nans.append(filepath) self.nan_indexes.append(index) self.path_to_score[filepath] = loss.cpu().item() self.path_to_id[filepath] = index if len(self.nans) > 0: print("NaNs detected during scoring!") for path in self.nans: print(path) print("\n\n") self.nans_removed = False def show_samples_with_highest_loss(self, n=-1): """ NaN samples will always be shown. To see all samples, pass -1, otherwise n samples will be shown. """ if len(self.nans) > 0: print("The following filepaths had an infinite loss:") for path in self.nans: print(path) print("\n\n") for index, path in enumerate(sorted(self.path_to_score, key=self.path_to_score.get, reverse=True)): if index < n or n == -1: print(f"Loss: {round(self.path_to_score[path], 3)} - Path: {path}") print("\n\n") def remove_samples_with_highest_loss(self, n=10): if self.current_dset is None: print("Please run the scoring first.") else: if self.nans_removed: print("Indexes are no longer accurate. Please re-run the scoring. \n\n" "This function also removes NaNs, so if you want to remove the NaN samples and the n samples " "with the highest loss, only call this function.") else: remove_ids = list() remove_ids.extend(self.nan_indexes) for index, path in enumerate(sorted(self.path_to_score, key=self.path_to_score.get, reverse=True)): if index < n: remove_ids.append(self.path_to_id[path]) self.current_dset.remove_samples(remove_ids) self.nans_removed = True def remove_nans(self): if self.nans_removed: print("NaNs have already been removed!") else: if self.current_dset is None: print("Please run the scoring first to find NaNs.") else: if len(self.nans) > 0: print("The following filepaths had an infinite loss and are being removed from the dataset cache:") for path in self.nans: print(path) self.current_dset.remove_samples(self.nan_indexes) self.nans_removed = True else: print("No NaNs detected in this dataset.")