eriquesouza's picture
app v1
e831f85
raw history blame
No virus
11.1 kB
import os
import random
import warnings
import soundfile as sf
import torch
from numpy import trim_zeros
from speechbrain.pretrained import EncoderClassifier
from torch.multiprocessing import Manager
from torch.multiprocessing import Process
from torch.multiprocessing import set_start_method
from torch.utils.data import Dataset
from tqdm import tqdm
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
from Preprocessing.AudioPreprocessor import AudioPreprocessor
class AlignerDataset(Dataset):
def __init__(self,
path_to_transcript_dict,
cache_dir,
lang,
loading_processes=30, # careful with the amount of processes if you use silence removal, only as many processes as you have cores
min_len_in_seconds=1,
max_len_in_seconds=20,
cut_silences=False,
rebuild_cache=False,
verbose=False,
device="cpu"):
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
if (device == "cuda" or device == torch.device("cuda")) and cut_silences:
try:
set_start_method('spawn') # in order to be able to make use of cuda in multiprocessing
except RuntimeError:
pass
elif cut_silences:
torch.set_num_threads(1)
if cut_silences:
torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False) # download and cache for it to be loaded and used later
torch.set_grad_enabled(True)
resource_manager = Manager()
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
key_list = list(self.path_to_transcript_dict.keys())
with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note:
files_used_note.write(str(key_list))
random.shuffle(key_list)
# build cache
print("... building dataset cache ...")
self.datapoints = resource_manager.list()
# make processes
key_splits = list()
process_list = list()
for i in range(loading_processes):
key_splits.append(key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes])
for key_split in key_splits:
process_list.append(
Process(target=self.cache_builder_process,
args=(key_split,
lang,
min_len_in_seconds,
max_len_in_seconds,
cut_silences,
verbose,
device),
daemon=True))
process_list[-1].start()
for process in process_list:
process.join()
self.datapoints = list(self.datapoints)
tensored_datapoints = list()
# we had to turn all of the tensors to numpy arrays to avoid shared memory
# issues. Now that the multi-processing is over, we can convert them back
# to tensors to save on conversions in the future.
print("Converting into convenient format...")
norm_waves = list()
for datapoint in tqdm(self.datapoints):
tensored_datapoints.append([torch.Tensor(datapoint[0]),
torch.LongTensor(datapoint[1]),
torch.Tensor(datapoint[2]),
torch.LongTensor(datapoint[3])])
norm_waves.append(torch.Tensor(datapoint[-1]))
self.datapoints = tensored_datapoints
pop_indexes = list()
for index, el in enumerate(self.datapoints):
try:
if len(el[0][0]) != 66:
pop_indexes.append(index)
except TypeError:
pop_indexes.append(index)
for pop_index in sorted(pop_indexes, reverse=True):
print(f"There seems to be a problem in the transcriptions. Deleting datapoint {pop_index}.")
self.datapoints.pop(pop_index)
# add speaker embeddings
self.speaker_embeddings = list()
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(device)},
savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
with torch.no_grad():
for wave in tqdm(norm_waves):
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
# save to cache
torch.save((self.datapoints, norm_waves, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt"))
else:
# just load the datapoints from cache
self.datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
if len(self.datapoints) == 2:
# speaker embeddings are still missing, have to add them here
wave_datapoints = self.datapoints[1]
self.datapoints = self.datapoints[0]
self.speaker_embeddings = list()
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(device)},
savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
with torch.no_grad():
for wave in tqdm(wave_datapoints):
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
torch.save((self.datapoints, wave_datapoints, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt"))
else:
self.speaker_embeddings = self.datapoints[2]
self.datapoints = self.datapoints[0]
self.tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=True)
print(f"Prepared an Aligner dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
def cache_builder_process(self,
path_list,
lang,
min_len,
max_len,
cut_silences,
verbose,
device):
process_internal_dataset_chunk = list()
tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False)
_, sr = sf.read(path_list[0])
ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=cut_silences, device=device)
for path in tqdm(path_list):
if self.path_to_transcript_dict[path].strip() == "":
continue
wave, sr = sf.read(path)
dur_in_seconds = len(wave) / sr
if not (min_len <= dur_in_seconds <= max_len):
if verbose:
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
continue
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore") # otherwise we get tons of warnings about an RNN not being in contiguous chunks
norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave)
except ValueError:
continue
dur_in_seconds = len(norm_wave) / 16000
if not (min_len <= dur_in_seconds <= max_len):
if verbose:
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
continue
norm_wave = torch.tensor(trim_zeros(norm_wave.numpy()))
# raw audio preprocessing is done
transcript = self.path_to_transcript_dict[path]
try:
cached_text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0).cpu().numpy()
except KeyError:
tf.string_to_tensor(transcript, handle_missing=True).squeeze(0).cpu().numpy()
continue # we skip sentences with unknown symbols
try:
if len(cached_text[0]) != 66:
print(f"There seems to be a problem with the following transcription: {transcript}")
continue
except TypeError:
print(f"There seems to be a problem with the following transcription: {transcript}")
continue
cached_text_len = torch.LongTensor([len(cached_text)]).numpy()
cached_speech = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1).cpu().numpy()
cached_speech_len = torch.LongTensor([len(cached_speech)]).numpy()
process_internal_dataset_chunk.append([cached_text,
cached_text_len,
cached_speech,
cached_speech_len,
norm_wave.cpu().detach().numpy()])
self.datapoints += process_internal_dataset_chunk
def __getitem__(self, index):
text_vector = self.datapoints[index][0]
tokens = list()
for vector in text_vector:
for phone in self.tf.phone_to_vector:
if vector.numpy().tolist() == self.tf.phone_to_vector[phone]:
tokens.append(self.tf.phone_to_id[phone])
# this is terribly inefficient, but it's good enough for testing for now.
tokens = torch.LongTensor(tokens)
return tokens, \
self.datapoints[index][1], \
self.datapoints[index][2], \
self.datapoints[index][3], \
self.speaker_embeddings[index]
def __len__(self):
return len(self.datapoints)