Flux9665 commited on
Commit
49572e7
1 Parent(s): 046fea7

Delete run_utterance_cloner.py

Browse files
Files changed (1) hide show
  1. run_utterance_cloner.py +0 -121
run_utterance_cloner.py DELETED
@@ -1,121 +0,0 @@
1
- import os
2
-
3
- import soundfile as sf
4
- import torch
5
- from torch.optim import SGD
6
- from tqdm import tqdm
7
-
8
- from InferenceInterfaces.Meta_FastSpeech2 import Meta_FastSpeech2
9
- from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
10
- from Preprocessing.AudioPreprocessor import AudioPreprocessor
11
- from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
12
- from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
13
- from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
14
- from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio
15
-
16
-
17
- class UtteranceCloner:
18
-
19
- def __init__(self, device):
20
- self.tts = Meta_FastSpeech2(device=device)
21
- self.device = device
22
- torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
23
- # careful: assumes 16kHz or 8kHz audio
24
- self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
25
- model='silero_vad',
26
- force_reload=False,
27
- onnx=False,
28
- verbose=False)
29
- (self.get_speech_timestamps, _, _, _, _) = utils
30
- torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
31
- # this to false globally during model loading rather than using inference mode or no_grad
32
- self.silero_model = self.silero_model.to(self.device)
33
-
34
- def extract_prosody(self, transcript, ref_audio_path, lang="de", on_line_fine_tune=False):
35
- acoustic_model = Aligner()
36
- acoustic_checkpoint_path = os.path.join("Models", "Aligner", "aligner.pt")
37
- acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
38
- acoustic_model = acoustic_model.to(self.device)
39
- dio = Dio(reduction_factor=1, fs=16000)
40
- energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
41
- dc = DurationCalculator(reduction_factor=1)
42
- wave, sr = sf.read(ref_audio_path)
43
- tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False)
44
- ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False)
45
- try:
46
- norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave)
47
- except ValueError:
48
- print('Something went wrong, the reference wave might be too short.')
49
- raise RuntimeError
50
-
51
- with torch.inference_mode():
52
- speech_timestamps = self.get_speech_timestamps(norm_wave, self.silero_model, sampling_rate=16000)
53
- norm_wave = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
54
-
55
- norm_wave_length = torch.LongTensor([len(norm_wave)])
56
- text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0)
57
- melspec = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1)
58
- melspec_length = torch.LongTensor([len(melspec)]).numpy()
59
-
60
- if on_line_fine_tune:
61
- # we fine-tune the aligner for a couple steps using SGD. This makes cloning pretty slow, but the results are greatly improved.
62
- steps = 10
63
- tokens = list() # we need an ID sequence for training rather than a sequence of phonological features
64
- for vector in text:
65
- for phone in tf.phone_to_vector:
66
- if vector.numpy().tolist() == tf.phone_to_vector[phone]:
67
- tokens.append(tf.phone_to_id[phone])
68
- tokens = torch.LongTensor(tokens)
69
- tokens = tokens.squeeze().to(self.device)
70
- tokens_len = torch.LongTensor([len(tokens)]).to(self.device)
71
- mel = melspec.unsqueeze(0).to(self.device)
72
- mel.requires_grad = True
73
- mel_len = torch.LongTensor([len(mel[0])]).to(self.device)
74
- # actual fine-tuning starts here
75
- optim_asr = SGD(acoustic_model.parameters(), lr=0.1)
76
- acoustic_model.train()
77
- for _ in tqdm(list(range(steps))):
78
- pred = acoustic_model(mel)
79
- loss = acoustic_model.ctc_loss(pred.transpose(0, 1).log_softmax(2), tokens, mel_len, tokens_len)
80
- optim_asr.zero_grad()
81
- loss.backward()
82
- torch.nn.utils.clip_grad_norm_(acoustic_model.parameters(), 1.0)
83
- optim_asr.step()
84
- acoustic_model.eval()
85
- torch.save({"asr_model": acoustic_model.state_dict()},
86
- os.path.join(os.path.join("Models", "Aligner", "aligner.pt")))
87
-
88
- alignment_path = acoustic_model.inference(mel=melspec.to(self.device),
89
- tokens=text.to(self.device),
90
- return_ctc=False)
91
-
92
- duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
93
- energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
94
- input_waves_lengths=norm_wave_length,
95
- feats_lengths=melspec_length,
96
- durations=duration.unsqueeze(0),
97
- durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
98
- pitch = dio(input_waves=norm_wave.unsqueeze(0),
99
- input_waves_lengths=norm_wave_length,
100
- feats_lengths=melspec_length,
101
- durations=duration.unsqueeze(0),
102
- durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
103
-
104
- return duration, pitch, energy, speech_timestamps[0]['start'], speech_timestamps[-1]['end']
105
-
106
- def clone_utterance(self,
107
- path_to_reference_audio,
108
- reference_transcription,
109
- clone_speaker_identity=True,
110
- lang="en"):
111
- if clone_speaker_identity:
112
- self.tts.set_utterance_embedding(path_to_reference_audio=path_to_reference_audio)
113
- duration, pitch, energy, silence_frames_start, silence_frames_end = self.extract_prosody(reference_transcription,
114
- path_to_reference_audio,
115
- lang=lang)
116
- self.tts.set_language(lang)
117
- start_sil = torch.zeros([silence_frames_start]).to(self.device)
118
- end_sil = torch.zeros([silence_frames_end]).to(self.device)
119
- cloned_speech = self.tts(reference_transcription, view=False, durations=duration, pitch=pitch, energy=energy)
120
- cloned_utt = torch.cat((start_sil, cloned_speech, end_sil), dim=0)
121
- return cloned_utt.cpu()