import os import random import warnings import soundfile as sf import torch from numpy import trim_zeros from speechbrain.pretrained import EncoderClassifier from torch.multiprocessing import Manager from torch.multiprocessing import Process from torch.multiprocessing import set_start_method from torch.utils.data import Dataset from tqdm import tqdm from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend from Preprocessing.AudioPreprocessor import AudioPreprocessor class AlignerDataset(Dataset): def __init__(self, path_to_transcript_dict, cache_dir, lang, loading_processes=30, # careful with the amount of processes if you use silence removal, only as many processes as you have cores min_len_in_seconds=1, max_len_in_seconds=20, cut_silences=False, rebuild_cache=False, verbose=False, device="cpu"): os.makedirs(cache_dir, exist_ok=True) if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: if (device == "cuda" or device == torch.device("cuda")) and cut_silences: try: set_start_method('spawn') # in order to be able to make use of cuda in multiprocessing except RuntimeError: pass elif cut_silences: torch.set_num_threads(1) if cut_silences: torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, onnx=False, verbose=False) # download and cache for it to be loaded and used later torch.set_grad_enabled(True) resource_manager = Manager() self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict) key_list = list(self.path_to_transcript_dict.keys()) with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note: files_used_note.write(str(key_list)) random.shuffle(key_list) # build cache print("... building dataset cache ...") self.datapoints = resource_manager.list() # make processes key_splits = list() process_list = list() for i in range(loading_processes): key_splits.append(key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes]) for key_split in key_splits: process_list.append( Process(target=self.cache_builder_process, args=(key_split, lang, min_len_in_seconds, max_len_in_seconds, cut_silences, verbose, device), daemon=True)) process_list[-1].start() for process in process_list: process.join() self.datapoints = list(self.datapoints) tensored_datapoints = list() # we had to turn all of the tensors to numpy arrays to avoid shared memory # issues. Now that the multi-processing is over, we can convert them back # to tensors to save on conversions in the future. print("Converting into convenient format...") norm_waves = list() for datapoint in tqdm(self.datapoints): tensored_datapoints.append([torch.Tensor(datapoint[0]), torch.LongTensor(datapoint[1]), torch.Tensor(datapoint[2]), torch.LongTensor(datapoint[3])]) norm_waves.append(torch.Tensor(datapoint[-1])) self.datapoints = tensored_datapoints pop_indexes = list() for index, el in enumerate(self.datapoints): try: if len(el[0][0]) != 66: pop_indexes.append(index) except TypeError: pop_indexes.append(index) for pop_index in sorted(pop_indexes, reverse=True): print(f"There seems to be a problem in the transcriptions. Deleting datapoint {pop_index}.") self.datapoints.pop(pop_index) # add speaker embeddings self.speaker_embeddings = list() speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": str(device)}, savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa") with torch.no_grad(): for wave in tqdm(norm_waves): self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) # save to cache torch.save((self.datapoints, norm_waves, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt")) else: # just load the datapoints from cache self.datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu') if len(self.datapoints) == 2: # speaker embeddings are still missing, have to add them here wave_datapoints = self.datapoints[1] self.datapoints = self.datapoints[0] self.speaker_embeddings = list() speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": str(device)}, savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa") with torch.no_grad(): for wave in tqdm(wave_datapoints): self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) torch.save((self.datapoints, wave_datapoints, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt")) else: self.speaker_embeddings = self.datapoints[2] self.datapoints = self.datapoints[0] self.tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=True) print(f"Prepared an Aligner dataset with {len(self.datapoints)} datapoints in {cache_dir}.") def cache_builder_process(self, path_list, lang, min_len, max_len, cut_silences, verbose, device): process_internal_dataset_chunk = list() tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False) _, sr = sf.read(path_list[0]) ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=cut_silences, device=device) for path in tqdm(path_list): if self.path_to_transcript_dict[path].strip() == "": continue wave, sr = sf.read(path) dur_in_seconds = len(wave) / sr if not (min_len <= dur_in_seconds <= max_len): if verbose: print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") continue try: with warnings.catch_warnings(): warnings.simplefilter("ignore") # otherwise we get tons of warnings about an RNN not being in contiguous chunks norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave) except ValueError: continue dur_in_seconds = len(norm_wave) / 16000 if not (min_len <= dur_in_seconds <= max_len): if verbose: print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") continue norm_wave = torch.tensor(trim_zeros(norm_wave.numpy())) # raw audio preprocessing is done transcript = self.path_to_transcript_dict[path] try: cached_text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0).cpu().numpy() except KeyError: tf.string_to_tensor(transcript, handle_missing=True).squeeze(0).cpu().numpy() continue # we skip sentences with unknown symbols try: if len(cached_text[0]) != 66: print(f"There seems to be a problem with the following transcription: {transcript}") continue except TypeError: print(f"There seems to be a problem with the following transcription: {transcript}") continue cached_text_len = torch.LongTensor([len(cached_text)]).numpy() cached_speech = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1).cpu().numpy() cached_speech_len = torch.LongTensor([len(cached_speech)]).numpy() process_internal_dataset_chunk.append([cached_text, cached_text_len, cached_speech, cached_speech_len, norm_wave.cpu().detach().numpy()]) self.datapoints += process_internal_dataset_chunk def __getitem__(self, index): text_vector = self.datapoints[index][0] tokens = list() for vector in text_vector: for phone in self.tf.phone_to_vector: if vector.numpy().tolist() == self.tf.phone_to_vector[phone]: tokens.append(self.tf.phone_to_id[phone]) # this is terribly inefficient, but it's good enough for testing for now. tokens = torch.LongTensor(tokens) return tokens, \ self.datapoints[index][1], \ self.datapoints[index][2], \ self.datapoints[index][3], \ self.speaker_embeddings[index] def __len__(self): return len(self.datapoints)