Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
No virus
7.05 kB
"""
This module is meant to find potentially problematic samples
in the data you are using. There are two types: The alignment
scorer and the TTS scorer. The alignment scorer can help you
find mispronunciations or errors in the labels. The TTS scorer
can help you find outliers in the audio part of text-audio pairs.
"""
import torch
import torch.multiprocessing
from tqdm import tqdm
from Architectures.ToucanTTS.ToucanTTS import ToucanTTS
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
from Utility.corpus_preparation import prepare_tts_corpus
class TTSScorer:
def __init__(self,
path_to_model,
device,
):
self.device = device
self.path_to_score = dict()
self.path_to_id = dict()
self.nans = list()
self.nan_indexes = list()
self.tts = ToucanTTS()
checkpoint = torch.load(path_to_model, map_location='cpu')
weights = checkpoint["model"]
self.tts.load_state_dict(weights)
self.tts.to(self.device)
self.nans_removed = False
self.current_dset = None
self.ap = CodecAudioPreprocessor(input_sr=-1, device=device)
self.spec_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device)
def score(self, path_to_toucantts_dataset, lang_id):
"""
call this to update the path_to_score dict with scores for this dataset
"""
dataset = prepare_tts_corpus(dict(), path_to_toucantts_dataset, lang_id)
self.current_dset = dataset
self.nans = list()
self.nan_indexes = list()
self.path_to_score = dict()
self.path_to_id = dict()
_ = dataset[0]
for index in tqdm(range(len(dataset.datapoints))):
datapoint = dataset.datapoints[index]
text_tensors = datapoint[0].to(self.device).unsqueeze(0).float()
text_lengths = datapoint[1].squeeze().to(self.device).unsqueeze(0)
speech_indexes = datapoint[2]
speech_lengths = datapoint[3].squeeze().to(self.device).unsqueeze(0)
gold_durations = datapoint[4].to(self.device).unsqueeze(0)
gold_pitch = datapoint[6].to(self.device).unsqueeze(0) # mind the switched order
gold_energy = datapoint[5].to(self.device).unsqueeze(0) # mind the switched order
lang_ids = dataset.language_id.to(self.device)
filepath = datapoint[8]
with torch.inference_mode():
wave = self.ap.indexes_to_audio(speech_indexes.int().to(self.device)).detach()
mel = self.spec_extractor.audio_to_mel_spec_tensor(wave, explicit_sampling_rate=16000).transpose(0, 1).detach().cpu()
gold_speech_sample = mel.clone().to(self.device).unsqueeze(0)
utterance_embedding = datapoint[7].unsqueeze(0).to(self.device)
try:
regression_loss, _, duration_loss, pitch_loss, energy_loss = self.tts(text_tensors=text_tensors,
text_lengths=text_lengths,
gold_speech=gold_speech_sample,
speech_lengths=speech_lengths,
gold_durations=gold_durations,
gold_pitch=gold_pitch,
gold_energy=gold_energy,
utterance_embedding=utterance_embedding,
lang_ids=lang_ids,
return_feats=False,
run_glow=False)
loss = regression_loss + duration_loss + pitch_loss + energy_loss # we omit the glow loss
except TypeError:
loss = torch.tensor(torch.nan)
if torch.isnan(loss):
self.nans.append(filepath)
self.nan_indexes.append(index)
self.path_to_score[filepath] = loss.cpu().item()
self.path_to_id[filepath] = index
if len(self.nans) > 0:
print("NaNs detected during scoring!")
for path in self.nans:
print(path)
print("\n\n")
self.nans_removed = False
def show_samples_with_highest_loss(self, n=-1):
"""
NaN samples will always be shown.
To see all samples, pass -1, otherwise n samples will be shown.
"""
if len(self.nans) > 0:
print("The following filepaths had an infinite loss:")
for path in self.nans:
print(path)
print("\n\n")
for index, path in enumerate(sorted(self.path_to_score, key=self.path_to_score.get, reverse=True)):
if index < n or n == -1:
print(f"Loss: {round(self.path_to_score[path], 3)} - Path: {path}")
print("\n\n")
def remove_samples_with_highest_loss(self, n=10):
if self.current_dset is None:
print("Please run the scoring first.")
else:
if self.nans_removed:
print("Indexes are no longer accurate. Please re-run the scoring. \n\n"
"This function also removes NaNs, so if you want to remove the NaN samples and the n samples "
"with the highest loss, only call this function.")
else:
remove_ids = list()
remove_ids.extend(self.nan_indexes)
for index, path in enumerate(sorted(self.path_to_score, key=self.path_to_score.get, reverse=True)):
if index < n:
remove_ids.append(self.path_to_id[path])
self.current_dset.remove_samples(remove_ids)
self.nans_removed = True
def remove_nans(self):
if self.nans_removed:
print("NaNs have already been removed!")
else:
if self.current_dset is None:
print("Please run the scoring first to find NaNs.")
else:
if len(self.nans) > 0:
print("The following filepaths had an infinite loss and are being removed from the dataset cache:")
for path in self.nans:
print(path)
self.current_dset.remove_samples(self.nan_indexes)
self.nans_removed = True
else:
print("No NaNs detected in this dataset.")