File size: 7,007 Bytes
2cb106d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os

import soundfile as sf
import torch
from torch.optim import SGD
from tqdm import tqdm

from InferenceInterfaces.Meta_FastSpeech2 import Meta_FastSpeech2
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio


class UtteranceCloner:

    def __init__(self, device):
        self.tts = Meta_FastSpeech2(device=device)
        self.device = device
        torch.hub._validate_not_a_forked_repo = lambda a, b, c: True  # torch 1.9 has a bug in the hub loading, this is a workaround
        # careful: assumes 16kHz or 8kHz audio
        self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                                  model='silero_vad',
                                                  force_reload=False,
                                                  onnx=False,
                                                  verbose=False)
        (self.get_speech_timestamps, _, _, _, _) = utils
        torch.set_grad_enabled(True)  # finding this issue was very infuriating: silero sets
        # this to false globally during model loading rather than using inference mode or no_grad
        self.silero_model = self.silero_model.to(self.device)

    def extract_prosody(self, transcript, ref_audio_path, lang="de", on_line_fine_tune=False):
        acoustic_model = Aligner()
        acoustic_checkpoint_path = os.path.join("Models", "Aligner", "aligner.pt")
        acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
        acoustic_model = acoustic_model.to(self.device)
        dio = Dio(reduction_factor=1, fs=16000)
        energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
        dc = DurationCalculator(reduction_factor=1)
        wave, sr = sf.read(ref_audio_path)
        tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False)
        ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False)
        try:
            norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave)
        except ValueError:
            print('Something went wrong, the reference wave might be too short.')
            raise RuntimeError

        with torch.inference_mode():
            speech_timestamps = self.get_speech_timestamps(norm_wave, self.silero_model, sampling_rate=16000)
        norm_wave = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]

        norm_wave_length = torch.LongTensor([len(norm_wave)])
        text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0)
        melspec = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1)
        melspec_length = torch.LongTensor([len(melspec)]).numpy()

        if on_line_fine_tune:
            # we fine-tune the aligner for a couple steps using SGD. This makes cloning pretty slow, but the results are greatly improved.
            steps = 10
            tokens = list()  # we need an ID sequence for training rather than a sequence of phonological features
            for vector in text:
                for phone in tf.phone_to_vector:
                    if vector.numpy().tolist() == tf.phone_to_vector[phone]:
                        tokens.append(tf.phone_to_id[phone])
            tokens = torch.LongTensor(tokens)
            tokens = tokens.squeeze().to(self.device)
            tokens_len = torch.LongTensor([len(tokens)]).to(self.device)
            mel = melspec.unsqueeze(0).to(self.device)
            mel.requires_grad = True
            mel_len = torch.LongTensor([len(mel[0])]).to(self.device)
            # actual fine-tuning starts here
            optim_asr = SGD(acoustic_model.parameters(), lr=0.1)
            acoustic_model.train()
            for _ in tqdm(list(range(steps))):
                pred = acoustic_model(mel)
                loss = acoustic_model.ctc_loss(pred.transpose(0, 1).log_softmax(2), tokens, mel_len, tokens_len)
                optim_asr.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(acoustic_model.parameters(), 1.0)
                optim_asr.step()
            acoustic_model.eval()
        torch.save({"asr_model": acoustic_model.state_dict()},
                   os.path.join(os.path.join("Models", "Aligner", "aligner.pt")))

        alignment_path = acoustic_model.inference(mel=melspec.to(self.device),
                                                  tokens=text.to(self.device),
                                                  return_ctc=False)

        duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
        energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
                             input_waves_lengths=norm_wave_length,
                             feats_lengths=melspec_length,
                             durations=duration.unsqueeze(0),
                             durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
        pitch = dio(input_waves=norm_wave.unsqueeze(0),
                    input_waves_lengths=norm_wave_length,
                    feats_lengths=melspec_length,
                    durations=duration.unsqueeze(0),
                    durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()

        return duration, pitch, energy, speech_timestamps[0]['start'], speech_timestamps[-1]['end']

    def clone_utterance(self,
                        path_to_reference_audio,
                        reference_transcription,
                        clone_speaker_identity=True,
                        lang="en"):
        if clone_speaker_identity:
            self.tts.set_utterance_embedding(path_to_reference_audio=path_to_reference_audio)
        duration, pitch, energy, silence_frames_start, silence_frames_end = self.extract_prosody(reference_transcription,
                                                                                                 path_to_reference_audio,
                                                                                                 lang=lang)
        self.tts.set_language(lang)
        start_sil = torch.zeros([silence_frames_start]).to(self.device)
        end_sil = torch.zeros([silence_frames_end]).to(self.device)
        cloned_speech = self.tts(reference_transcription, view=False, durations=duration, pitch=pitch, energy=energy)
        cloned_utt = torch.cat((start_sil, cloned_speech, end_sil), dim=0)
        return cloned_utt.cpu()