Flux9665's picture
update to the current version
70399da
raw
history blame
No virus
9.2 kB
import os
import numpy
import soundfile as sf
import torch
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
from Modules.Aligner.Aligner import Aligner
from Modules.ToucanTTS.DurationCalculator import DurationCalculator
from Modules.ToucanTTS.EnergyCalculator import EnergyCalculator
from Modules.ToucanTTS.PitchCalculator import Parselmouth
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Preprocessing.articulatory_features import get_feature_to_index_lookup
from Utility.storage_config import MODELS_DIR
from Utility.utils import float2pcm
class UtteranceCloner:
"""
Clone the prosody of an utterance, but exchange the speaker (or don't)
Useful for Privacy Applications
"""
def __init__(self, model_id, device, language="eng"):
self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
self.tf = ArticulatoryCombinedTextFrontend(language=language, device=device)
self.device = device
acoustic_checkpoint_path = os.path.join(MODELS_DIR, "Aligner", "aligner.pt")
self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"]
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(self.get_speech_timestamps, _, _, _, _) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference_mode or no_grad
self.acoustic_model = Aligner()
self.acoustic_model = self.acoustic_model.to(self.device)
self.acoustic_model.load_state_dict(self.aligner_weights)
self.acoustic_model.eval()
self.parsel = Parselmouth(reduction_factor=1, fs=16000)
self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
self.dc = DurationCalculator(reduction_factor=1)
def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True):
if on_line_fine_tune:
self.acoustic_model.load_state_dict(self.aligner_weights)
self.acoustic_model.eval()
wave, sr = sf.read(ref_audio_path)
if self.tf.language != lang:
self.tf = ArticulatoryCombinedTextFrontend(language=lang, device=self.device)
if self.ap.input_sr != sr:
self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False)
try:
norm_wave = self.ap.normalize_audio(audio=wave)
except ValueError:
print('Something went wrong, the reference wave might be too short.')
raise RuntimeError
with torch.inference_mode():
speech_timestamps = self.get_speech_timestamps(norm_wave, self.silero_model, sampling_rate=16000)
if len(speech_timestamps) == 0:
speech_timestamps = [{'start': 0, 'end': len(norm_wave)}]
start_silence = speech_timestamps[0]['start']
end_silence = len(norm_wave) - speech_timestamps[-1]['end']
norm_wave = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
norm_wave_length = torch.LongTensor([len(norm_wave)])
text = self.tf.string_to_tensor(transcript, handle_missing=False).squeeze(0)
features = self.ap.audio_to_mel_spec_tensor(audio=norm_wave, explicit_sampling_rate=16000).transpose(0, 1)
feature_length = torch.LongTensor([len(features)]).numpy()
if on_line_fine_tune:
# we fine-tune the aligner for a couple steps using SGD. This makes cloning pretty slow, but the results are greatly improved.
steps = 4
tokens = self.tf.text_vectors_to_id_sequence(text_vector=text) # we need an ID sequence for training rather than a sequence of phonological features
tokens = torch.LongTensor(tokens).squeeze().to(self.device)
tokens_len = torch.LongTensor([len(tokens)]).to(self.device)
mel = features.unsqueeze(0).to(self.device)
mel_len = torch.LongTensor([len(mel[0])]).to(self.device)
# actual fine-tuning starts here
optim_asr = torch.optim.Adam(self.acoustic_model.parameters(), lr=0.00001)
self.acoustic_model.train()
for _ in range(steps):
pred = self.acoustic_model(mel.clone())
loss = self.acoustic_model.ctc_loss(pred.transpose(0, 1).log_softmax(2), tokens, mel_len, tokens_len)
print(loss.item())
optim_asr.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.acoustic_model.parameters(), 1.0)
optim_asr.step()
self.acoustic_model.eval()
# We deal with the word boundaries by having 2 versions of text: with and without word boundaries.
# We note the index of word boundaries and insert durations of 0 afterwards
text_without_word_boundaries = list()
indexes_of_word_boundaries = list()
for phoneme_index, vector in enumerate(text):
if vector[get_feature_to_index_lookup()["word-boundary"]] == 0:
text_without_word_boundaries.append(vector.numpy().tolist())
else:
indexes_of_word_boundaries.append(phoneme_index)
matrix_without_word_boundaries = torch.Tensor(text_without_word_boundaries)
alignment_path = self.acoustic_model.inference(features=features.to(self.device),
tokens=matrix_without_word_boundaries.to(self.device),
return_ctc=False)
duration = self.dc(torch.LongTensor(alignment_path), vis=None).cpu()
for index_of_word_boundary in indexes_of_word_boundaries:
duration = torch.cat([duration[:index_of_word_boundary],
torch.LongTensor([0]), # insert a 0 duration wherever there is a word boundary
duration[index_of_word_boundary:]])
energy = self.energy_calc(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=feature_length,
text=text,
durations=duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
pitch = self.parsel(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=feature_length,
text=text,
durations=duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
return duration, pitch, energy, start_silence, end_silence
def clone_utterance(self,
path_to_reference_audio_for_intonation,
path_to_reference_audio_for_voice,
transcription_of_intonation_reference,
filename_of_result=None,
lang="eng"):
"""
What is said in path_to_reference_audio_for_intonation has to match the text in the reference_transcription exactly!
"""
self.tts.set_utterance_embedding(path_to_reference_audio=path_to_reference_audio_for_voice)
duration, pitch, energy, silence_frames_start, silence_frames_end = self.extract_prosody(transcription_of_intonation_reference,
path_to_reference_audio_for_intonation,
lang=lang)
self.tts.set_language(lang)
start_sil = numpy.zeros([int(silence_frames_start * 1.5)]) # timestamps are from 16kHz, but now we're using 24000Hz, so upsampling required
end_sil = numpy.zeros([int(silence_frames_end * 1.5)]) # timestamps are from 16kHz, but now we're using 24000Hz, so upsampling required
cloned_speech, sr = self.tts(transcription_of_intonation_reference, view=False, durations=duration, pitch=pitch, energy=energy)
cloned_utt = numpy.concatenate([start_sil, cloned_speech, end_sil], axis=0)
if filename_of_result is not None:
sf.write(file=filename_of_result, data=float2pcm(cloned_utt), samplerate=sr, subtype="PCM_16")
return cloned_utt, sr