SpeechCloning / run_utterance_cloner.py
Florian Lux
implement the cloning demo
2cb106d
raw history blame
No virus
7.01 kB
import os
import soundfile as sf
import torch
from torch.optim import SGD
from tqdm import tqdm
from InferenceInterfaces.Meta_FastSpeech2 import Meta_FastSpeech2
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio
class UtteranceCloner:
def __init__(self, device):
self.tts = Meta_FastSpeech2(device=device)
self.device = device
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(self.get_speech_timestamps, _, _, _, _) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference mode or no_grad
self.silero_model = self.silero_model.to(self.device)
def extract_prosody(self, transcript, ref_audio_path, lang="de", on_line_fine_tune=False):
acoustic_model = Aligner()
acoustic_checkpoint_path = os.path.join("Models", "Aligner", "aligner.pt")
acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
acoustic_model = acoustic_model.to(self.device)
dio = Dio(reduction_factor=1, fs=16000)
energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
dc = DurationCalculator(reduction_factor=1)
wave, sr = sf.read(ref_audio_path)
tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False)
ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False)
try:
norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave)
except ValueError:
print('Something went wrong, the reference wave might be too short.')
raise RuntimeError
with torch.inference_mode():
speech_timestamps = self.get_speech_timestamps(norm_wave, self.silero_model, sampling_rate=16000)
norm_wave = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
norm_wave_length = torch.LongTensor([len(norm_wave)])
text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0)
melspec = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1)
melspec_length = torch.LongTensor([len(melspec)]).numpy()
if on_line_fine_tune:
# we fine-tune the aligner for a couple steps using SGD. This makes cloning pretty slow, but the results are greatly improved.
steps = 10
tokens = list() # we need an ID sequence for training rather than a sequence of phonological features
for vector in text:
for phone in tf.phone_to_vector:
if vector.numpy().tolist() == tf.phone_to_vector[phone]:
tokens.append(tf.phone_to_id[phone])
tokens = torch.LongTensor(tokens)
tokens = tokens.squeeze().to(self.device)
tokens_len = torch.LongTensor([len(tokens)]).to(self.device)
mel = melspec.unsqueeze(0).to(self.device)
mel.requires_grad = True
mel_len = torch.LongTensor([len(mel[0])]).to(self.device)
# actual fine-tuning starts here
optim_asr = SGD(acoustic_model.parameters(), lr=0.1)
acoustic_model.train()
for _ in tqdm(list(range(steps))):
pred = acoustic_model(mel)
loss = acoustic_model.ctc_loss(pred.transpose(0, 1).log_softmax(2), tokens, mel_len, tokens_len)
optim_asr.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(acoustic_model.parameters(), 1.0)
optim_asr.step()
acoustic_model.eval()
torch.save({"asr_model": acoustic_model.state_dict()},
os.path.join(os.path.join("Models", "Aligner", "aligner.pt")))
alignment_path = acoustic_model.inference(mel=melspec.to(self.device),
tokens=text.to(self.device),
return_ctc=False)
duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
pitch = dio(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
return duration, pitch, energy, speech_timestamps[0]['start'], speech_timestamps[-1]['end']
def clone_utterance(self,
path_to_reference_audio,
reference_transcription,
clone_speaker_identity=True,
lang="en"):
if clone_speaker_identity:
self.tts.set_utterance_embedding(path_to_reference_audio=path_to_reference_audio)
duration, pitch, energy, silence_frames_start, silence_frames_end = self.extract_prosody(reference_transcription,
path_to_reference_audio,
lang=lang)
self.tts.set_language(lang)
start_sil = torch.zeros([silence_frames_start]).to(self.device)
end_sil = torch.zeros([silence_frames_end]).to(self.device)
cloned_speech = self.tts(reference_transcription, view=False, durations=duration, pitch=pitch, energy=energy)
cloned_utt = torch.cat((start_sil, cloned_speech, end_sil), dim=0)
return cloned_utt.cpu()