TheComputerMan commited on
Commit
dcc9625
·
1 Parent(s): 77e60d7

Upload Meta_FastSpeech2.py

Browse files
Files changed (1) hide show
  1. Meta_FastSpeech2.py +81 -0
Meta_FastSpeech2.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa.display as lbd
4
+ import matplotlib.pyplot as plt
5
+ import soundfile
6
+ import torch
7
+
8
+ from InferenceInterfaces.InferenceArchitectures.InferenceFastSpeech2 import FastSpeech2
9
+ from InferenceInterfaces.InferenceArchitectures.InferenceHiFiGAN import HiFiGANGenerator
10
+ from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
11
+ from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
12
+ from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
13
+
14
+
15
+ class Meta_FastSpeech2(torch.nn.Module):
16
+
17
+ def __init__(self, device="cpu"):
18
+ super().__init__()
19
+ model_name = "Meta"
20
+ language = "en"
21
+ self.device = device
22
+ self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
23
+ checkpoint = torch.load(os.path.join("Models", f"FastSpeech2_{model_name}", "best.pt"), map_location='cpu')
24
+ self.phone2mel = FastSpeech2(weights=checkpoint["model"]).to(torch.device(device))
25
+ self.mel2wav = HiFiGANGenerator(path_to_weights=os.path.join("Models", "HiFiGAN_combined", "best.pt")).to(torch.device(device))
26
+ self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
27
+ self.phone2mel.eval()
28
+ self.mel2wav.eval()
29
+ self.lang_id = get_language_id(language)
30
+ self.to(torch.device(device))
31
+
32
+ def set_utterance_embedding(self, path_to_reference_audio):
33
+ wave, sr = soundfile.read(path_to_reference_audio)
34
+ self.default_utterance_embedding = ProsodicConditionExtractor(sr=sr).extract_condition_from_reference_wave(wave).to(self.device)
35
+
36
+ def set_phonemizer_language(self, lang_id):
37
+ """
38
+ The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
39
+ """
40
+ self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, silent=False)
41
+
42
+ def set_accent_language(self, lang_id):
43
+ """
44
+ The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
45
+ """
46
+ self.lang_id = get_language_id(lang_id).to(self.device)
47
+
48
+ def forward(self, text, view=False, durations=None, pitch=None, energy=None):
49
+ with torch.no_grad():
50
+ phones = self.text2phone.string_to_tensor(text, input_phonemes=True).to(torch.device(self.device))
51
+ mel, durations, pitch, energy = self.phone2mel(phones,
52
+ return_duration_pitch_energy=True,
53
+ utterance_embedding=self.default_utterance_embedding,
54
+ durations=durations,
55
+ pitch=pitch,
56
+ energy=energy,
57
+ lang_id=self.lang_id)
58
+ mel = mel.transpose(0, 1)
59
+ wave = self.mel2wav(mel)
60
+ if view:
61
+ from Utility.utils import cumsum_durations
62
+ fig, ax = plt.subplots(nrows=2, ncols=1)
63
+ ax[0].plot(wave.cpu().numpy())
64
+ lbd.specshow(mel.cpu().numpy(),
65
+ ax=ax[1],
66
+ sr=16000,
67
+ cmap='GnBu',
68
+ y_axis='mel',
69
+ x_axis=None,
70
+ hop_length=256)
71
+ ax[0].yaxis.set_visible(False)
72
+ ax[1].yaxis.set_visible(False)
73
+ duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
74
+ ax[1].set_xticks(duration_splits, minor=True)
75
+ ax[1].xaxis.grid(True, which='minor')
76
+ ax[1].set_xticks(label_positions, minor=False)
77
+ ax[1].set_xticklabels(self.text2phone.get_phone_string(text))
78
+ ax[0].set_title(text)
79
+ plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.9, wspace=0.0, hspace=0.0)
80
+ plt.show()
81
+ return wave