Florian Lux commited on
Commit
2cb106d
1 Parent(s): f9463cb

implement the cloning demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +16 -0
  2. InferenceInterfaces/InferenceArchitectures/InferenceFastSpeech2.py +256 -0
  3. InferenceInterfaces/InferenceArchitectures/InferenceHiFiGAN.py +91 -0
  4. InferenceInterfaces/InferenceArchitectures/__init__.py +0 -0
  5. InferenceInterfaces/Meta_FastSpeech2.py +75 -0
  6. InferenceInterfaces/__init__.py +0 -0
  7. Layers/Attention.py +324 -0
  8. Layers/Conformer.py +144 -0
  9. Layers/Convolution.py +55 -0
  10. Layers/DurationPredictor.py +139 -0
  11. Layers/EncoderLayer.py +144 -0
  12. Layers/LayerNorm.py +36 -0
  13. Layers/LengthRegulator.py +62 -0
  14. Layers/MultiLayeredConv1d.py +87 -0
  15. Layers/MultiSequential.py +33 -0
  16. Layers/PositionalEncoding.py +166 -0
  17. Layers/PositionwiseFeedForward.py +26 -0
  18. Layers/PostNet.py +74 -0
  19. Layers/ResidualBlock.py +98 -0
  20. Layers/ResidualStack.py +51 -0
  21. Layers/STFT.py +118 -0
  22. Layers/Swish.py +18 -0
  23. Layers/VariancePredictor.py +65 -0
  24. Layers/__init__.py +0 -0
  25. Models/Aligner/__init__.py +0 -0
  26. Models/FastSpeech2_Meta/__init__.py +0 -0
  27. Models/HiFiGAN_combined/__init__.py +0 -0
  28. Preprocessing/ArticulatoryCombinedTextFrontend.py +323 -0
  29. Preprocessing/AudioPreprocessor.py +166 -0
  30. Preprocessing/ProsodicConditionExtractor.py +40 -0
  31. Preprocessing/__init__.py +0 -0
  32. Preprocessing/papercup_features.py +637 -0
  33. README.md +3 -3
  34. TrainingInterfaces/Text_to_Spectrogram/AutoAligner/Aligner.py +287 -0
  35. TrainingInterfaces/Text_to_Spectrogram/AutoAligner/AlignerDataset.py +211 -0
  36. TrainingInterfaces/Text_to_Spectrogram/AutoAligner/TinyTTS.py +36 -0
  37. TrainingInterfaces/Text_to_Spectrogram/AutoAligner/__init__.py +0 -0
  38. TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py +145 -0
  39. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/DurationCalculator.py +31 -0
  40. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/EnergyCalculator.py +86 -0
  41. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2.py +379 -0
  42. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2Loss.py +96 -0
  43. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDatasetLanguageID.py +217 -0
  44. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/PitchCalculator.py +121 -0
  45. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/__init__.py +0 -0
  46. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop.py +201 -0
  47. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop_ctc.py +191 -0
  48. TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py +211 -0
  49. TrainingInterfaces/Text_to_Spectrogram/__init__.py +0 -0
  50. TrainingInterfaces/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea
2
+ *.pyc
3
+ *.png
4
+ *.pdf
5
+ *.pt
6
+ tensorboard_logs
7
+ Corpora
8
+ *_graph
9
+ *.out
10
+ *.wav
11
+ *.flac
12
+ audios/
13
+ *playground*
14
+ *.json
15
+ .tmp/
16
+ .vscode/
InferenceInterfaces/InferenceArchitectures/InferenceFastSpeech2.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+
5
+ from Layers.Conformer import Conformer
6
+ from Layers.DurationPredictor import DurationPredictor
7
+ from Layers.LengthRegulator import LengthRegulator
8
+ from Layers.PostNet import PostNet
9
+ from Layers.VariancePredictor import VariancePredictor
10
+ from Utility.utils import make_non_pad_mask
11
+ from Utility.utils import make_pad_mask
12
+
13
+
14
+ class FastSpeech2(torch.nn.Module, ABC):
15
+
16
+ def __init__(self, # network structure related
17
+ weights,
18
+ idim=66,
19
+ odim=80,
20
+ adim=384,
21
+ aheads=4,
22
+ elayers=6,
23
+ eunits=1536,
24
+ dlayers=6,
25
+ dunits=1536,
26
+ postnet_layers=5,
27
+ postnet_chans=256,
28
+ postnet_filts=5,
29
+ positionwise_conv_kernel_size=1,
30
+ use_scaled_pos_enc=True,
31
+ use_batch_norm=True,
32
+ encoder_normalize_before=True,
33
+ decoder_normalize_before=True,
34
+ encoder_concat_after=False,
35
+ decoder_concat_after=False,
36
+ reduction_factor=1,
37
+ # encoder / decoder
38
+ use_macaron_style_in_conformer=True,
39
+ use_cnn_in_conformer=True,
40
+ conformer_enc_kernel_size=7,
41
+ conformer_dec_kernel_size=31,
42
+ # duration predictor
43
+ duration_predictor_layers=2,
44
+ duration_predictor_chans=256,
45
+ duration_predictor_kernel_size=3,
46
+ # energy predictor
47
+ energy_predictor_layers=2,
48
+ energy_predictor_chans=256,
49
+ energy_predictor_kernel_size=3,
50
+ energy_predictor_dropout=0.5,
51
+ energy_embed_kernel_size=1,
52
+ energy_embed_dropout=0.0,
53
+ stop_gradient_from_energy_predictor=True,
54
+ # pitch predictor
55
+ pitch_predictor_layers=5,
56
+ pitch_predictor_chans=256,
57
+ pitch_predictor_kernel_size=5,
58
+ pitch_predictor_dropout=0.5,
59
+ pitch_embed_kernel_size=1,
60
+ pitch_embed_dropout=0.0,
61
+ stop_gradient_from_pitch_predictor=True,
62
+ # training related
63
+ transformer_enc_dropout_rate=0.2,
64
+ transformer_enc_positional_dropout_rate=0.2,
65
+ transformer_enc_attn_dropout_rate=0.2,
66
+ transformer_dec_dropout_rate=0.2,
67
+ transformer_dec_positional_dropout_rate=0.2,
68
+ transformer_dec_attn_dropout_rate=0.2,
69
+ duration_predictor_dropout_rate=0.2,
70
+ postnet_dropout_rate=0.5,
71
+ # additional features
72
+ utt_embed_dim=704,
73
+ connect_utt_emb_at_encoder_out=True,
74
+ lang_embs=100):
75
+ super().__init__()
76
+ self.idim = idim
77
+ self.odim = odim
78
+ self.reduction_factor = reduction_factor
79
+ self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
80
+ self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
81
+ self.use_scaled_pos_enc = use_scaled_pos_enc
82
+ embed = torch.nn.Sequential(torch.nn.Linear(idim, 100),
83
+ torch.nn.Tanh(),
84
+ torch.nn.Linear(100, adim))
85
+ self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers,
86
+ input_layer=embed, dropout_rate=transformer_enc_dropout_rate,
87
+ positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate,
88
+ normalize_before=encoder_normalize_before, concat_after=encoder_concat_after,
89
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer,
90
+ use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False,
91
+ utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs)
92
+ self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers,
93
+ n_chans=duration_predictor_chans,
94
+ kernel_size=duration_predictor_kernel_size,
95
+ dropout_rate=duration_predictor_dropout_rate, )
96
+ self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers,
97
+ n_chans=pitch_predictor_chans,
98
+ kernel_size=pitch_predictor_kernel_size,
99
+ dropout_rate=pitch_predictor_dropout)
100
+ self.pitch_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
101
+ kernel_size=pitch_embed_kernel_size,
102
+ padding=(pitch_embed_kernel_size - 1) // 2),
103
+ torch.nn.Dropout(pitch_embed_dropout))
104
+ self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers,
105
+ n_chans=energy_predictor_chans,
106
+ kernel_size=energy_predictor_kernel_size,
107
+ dropout_rate=energy_predictor_dropout)
108
+ self.energy_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
109
+ kernel_size=energy_embed_kernel_size,
110
+ padding=(energy_embed_kernel_size - 1) // 2),
111
+ torch.nn.Dropout(energy_embed_dropout))
112
+ self.length_regulator = LengthRegulator()
113
+ self.decoder = Conformer(idim=0,
114
+ attention_dim=adim,
115
+ attention_heads=aheads,
116
+ linear_units=dunits,
117
+ num_blocks=dlayers,
118
+ input_layer=None,
119
+ dropout_rate=transformer_dec_dropout_rate,
120
+ positional_dropout_rate=transformer_dec_positional_dropout_rate,
121
+ attention_dropout_rate=transformer_dec_attn_dropout_rate,
122
+ normalize_before=decoder_normalize_before,
123
+ concat_after=decoder_concat_after,
124
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size,
125
+ macaron_style=use_macaron_style_in_conformer,
126
+ use_cnn_module=use_cnn_in_conformer,
127
+ cnn_module_kernel=conformer_dec_kernel_size)
128
+ self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
129
+ self.postnet = PostNet(idim=idim,
130
+ odim=odim,
131
+ n_layers=postnet_layers,
132
+ n_chans=postnet_chans,
133
+ n_filts=postnet_filts,
134
+ use_batch_norm=use_batch_norm,
135
+ dropout_rate=postnet_dropout_rate)
136
+ self.load_state_dict(weights)
137
+
138
+ def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None,
139
+ gold_durations=None, gold_pitch=None, gold_energy=None,
140
+ is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None):
141
+ # forward encoder
142
+ text_masks = self._source_mask(text_lens)
143
+
144
+ encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim)
145
+
146
+ # forward duration predictor and variance predictors
147
+ duration_masks = make_pad_mask(text_lens, device=text_lens.device)
148
+
149
+ if self.stop_gradient_from_pitch_predictor:
150
+ pitch_predictions = self.pitch_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
151
+ else:
152
+ pitch_predictions = self.pitch_predictor(encoded_texts, duration_masks.unsqueeze(-1))
153
+
154
+ if self.stop_gradient_from_energy_predictor:
155
+ energy_predictions = self.energy_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
156
+ else:
157
+ energy_predictions = self.energy_predictor(encoded_texts, duration_masks.unsqueeze(-1))
158
+
159
+ if is_inference:
160
+ if gold_durations is not None:
161
+ duration_predictions = gold_durations
162
+ else:
163
+ duration_predictions = self.duration_predictor.inference(encoded_texts, duration_masks)
164
+ if gold_pitch is not None:
165
+ pitch_predictions = gold_pitch
166
+ if gold_energy is not None:
167
+ energy_predictions = gold_energy
168
+ pitch_embeddings = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
169
+ energy_embeddings = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
170
+ encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
171
+ encoded_texts = self.length_regulator(encoded_texts, duration_predictions, alpha)
172
+ else:
173
+ duration_predictions = self.duration_predictor(encoded_texts, duration_masks)
174
+
175
+ # use groundtruth in training
176
+ pitch_embeddings = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
177
+ energy_embeddings = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
178
+ encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
179
+ encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim)
180
+
181
+ # forward decoder
182
+ if speech_lens is not None and not is_inference:
183
+ if self.reduction_factor > 1:
184
+ olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens])
185
+ else:
186
+ olens_in = speech_lens
187
+ h_masks = self._source_mask(olens_in)
188
+ else:
189
+ h_masks = None
190
+ zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim)
191
+ before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
192
+
193
+ # postnet -> (B, Lmax//r * r, odim)
194
+ after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
195
+
196
+ return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions
197
+
198
+ @torch.no_grad()
199
+ def forward(self,
200
+ text,
201
+ speech=None,
202
+ durations=None,
203
+ pitch=None,
204
+ energy=None,
205
+ utterance_embedding=None,
206
+ return_duration_pitch_energy=False,
207
+ lang_id=None):
208
+ """
209
+ Generate the sequence of features given the sequences of characters.
210
+
211
+ Args:
212
+ text: Input sequence of characters
213
+ speech: Feature sequence to extract style
214
+ durations: Groundtruth of duration
215
+ pitch: Groundtruth of token-averaged pitch
216
+ energy: Groundtruth of token-averaged energy
217
+ return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
218
+ utterance_embedding: embedding of utterance wide parameters
219
+
220
+ Returns:
221
+ Mel Spectrogram
222
+
223
+ """
224
+ self.eval()
225
+ # setup batch axis
226
+ ilens = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device)
227
+ if speech is not None:
228
+ gold_speech = speech.unsqueeze(0)
229
+ else:
230
+ gold_speech = None
231
+ if durations is not None:
232
+ durations = durations.unsqueeze(0)
233
+ if pitch is not None:
234
+ pitch = pitch.unsqueeze(0)
235
+ if energy is not None:
236
+ energy = energy.unsqueeze(0)
237
+ if lang_id is not None:
238
+ lang_id = lang_id.unsqueeze(0)
239
+
240
+ before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(text.unsqueeze(0),
241
+ ilens,
242
+ gold_speech=gold_speech,
243
+ gold_durations=durations,
244
+ is_inference=True,
245
+ gold_pitch=pitch,
246
+ gold_energy=energy,
247
+ utterance_embedding=utterance_embedding.unsqueeze(0),
248
+ lang_ids=lang_id)
249
+ self.train()
250
+ if return_duration_pitch_energy:
251
+ return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0]
252
+ return after_outs[0]
253
+
254
+ def _source_mask(self, ilens):
255
+ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
256
+ return x_masks.unsqueeze(-2)
InferenceInterfaces/InferenceArchitectures/InferenceHiFiGAN.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from Layers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock
4
+
5
+
6
+ class HiFiGANGenerator(torch.nn.Module):
7
+
8
+ def __init__(self,
9
+ path_to_weights,
10
+ in_channels=80,
11
+ out_channels=1,
12
+ channels=512,
13
+ kernel_size=7,
14
+ upsample_scales=(8, 6, 4, 4),
15
+ upsample_kernel_sizes=(16, 12, 8, 8),
16
+ resblock_kernel_sizes=(3, 7, 11),
17
+ resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
18
+ use_additional_convs=True,
19
+ bias=True,
20
+ nonlinear_activation="LeakyReLU",
21
+ nonlinear_activation_params={"negative_slope": 0.1},
22
+ use_weight_norm=True, ):
23
+ super().__init__()
24
+ assert kernel_size % 2 == 1, "Kernal size must be odd number."
25
+ assert len(upsample_scales) == len(upsample_kernel_sizes)
26
+ assert len(resblock_dilations) == len(resblock_kernel_sizes)
27
+ self.num_upsamples = len(upsample_kernel_sizes)
28
+ self.num_blocks = len(resblock_kernel_sizes)
29
+ self.input_conv = torch.nn.Conv1d(in_channels,
30
+ channels,
31
+ kernel_size,
32
+ 1,
33
+ padding=(kernel_size - 1) // 2, )
34
+ self.upsamples = torch.nn.ModuleList()
35
+ self.blocks = torch.nn.ModuleList()
36
+ for i in range(len(upsample_kernel_sizes)):
37
+ self.upsamples += [
38
+ torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
39
+ torch.nn.ConvTranspose1d(channels // (2 ** i),
40
+ channels // (2 ** (i + 1)),
41
+ upsample_kernel_sizes[i],
42
+ upsample_scales[i],
43
+ padding=(upsample_kernel_sizes[i] - upsample_scales[i]) // 2, ), )]
44
+ for j in range(len(resblock_kernel_sizes)):
45
+ self.blocks += [ResidualBlock(kernel_size=resblock_kernel_sizes[j],
46
+ channels=channels // (2 ** (i + 1)),
47
+ dilations=resblock_dilations[j],
48
+ bias=bias,
49
+ use_additional_convs=use_additional_convs,
50
+ nonlinear_activation=nonlinear_activation,
51
+ nonlinear_activation_params=nonlinear_activation_params, )]
52
+ self.output_conv = torch.nn.Sequential(
53
+ torch.nn.LeakyReLU(),
54
+ torch.nn.Conv1d(channels // (2 ** (i + 1)),
55
+ out_channels,
56
+ kernel_size,
57
+ 1,
58
+ padding=(kernel_size - 1) // 2, ),
59
+ torch.nn.Tanh(), )
60
+ if use_weight_norm:
61
+ self.apply_weight_norm()
62
+ self.load_state_dict(torch.load(path_to_weights, map_location='cpu')["generator"])
63
+
64
+ def forward(self, c, normalize_before=False):
65
+ if normalize_before:
66
+ c = (c - self.mean) / self.scale
67
+ c = self.input_conv(c.unsqueeze(0))
68
+ for i in range(self.num_upsamples):
69
+ c = self.upsamples[i](c)
70
+ cs = 0.0 # initialize
71
+ for j in range(self.num_blocks):
72
+ cs = cs + self.blocks[i * self.num_blocks + j](c)
73
+ c = cs / self.num_blocks
74
+ c = self.output_conv(c)
75
+ return c.squeeze(0).squeeze(0)
76
+
77
+ def remove_weight_norm(self):
78
+ def _remove_weight_norm(m):
79
+ try:
80
+ torch.nn.utils.remove_weight_norm(m)
81
+ except ValueError:
82
+ return
83
+
84
+ self.apply(_remove_weight_norm)
85
+
86
+ def apply_weight_norm(self):
87
+ def _apply_weight_norm(m):
88
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
89
+ torch.nn.utils.weight_norm(m)
90
+
91
+ self.apply(_apply_weight_norm)
InferenceInterfaces/InferenceArchitectures/__init__.py ADDED
File without changes
InferenceInterfaces/Meta_FastSpeech2.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa.display as lbd
4
+ import matplotlib.pyplot as plt
5
+ import soundfile
6
+ import torch
7
+
8
+ from InferenceInterfaces.InferenceArchitectures.InferenceFastSpeech2 import FastSpeech2
9
+ from InferenceInterfaces.InferenceArchitectures.InferenceHiFiGAN import HiFiGANGenerator
10
+ from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
11
+ from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
12
+ from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
13
+
14
+
15
+ class Meta_FastSpeech2(torch.nn.Module):
16
+
17
+ def __init__(self, device="cpu"):
18
+ super().__init__()
19
+ model_name = "Meta"
20
+ language = "en"
21
+ self.device = device
22
+ self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
23
+ checkpoint = torch.load(os.path.join("Models", f"FastSpeech2_{model_name}", "best.pt"), map_location='cpu')
24
+ self.phone2mel = FastSpeech2(weights=checkpoint["model"]).to(torch.device(device))
25
+ self.mel2wav = HiFiGANGenerator(path_to_weights=os.path.join("Models", "HiFiGAN_combined", "best.pt")).to(torch.device(device))
26
+ self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
27
+ self.phone2mel.eval()
28
+ self.mel2wav.eval()
29
+ self.lang_id = get_language_id(language)
30
+ self.to(torch.device(device))
31
+
32
+ def set_utterance_embedding(self, path_to_reference_audio):
33
+ wave, sr = soundfile.read(path_to_reference_audio)
34
+ self.default_utterance_embedding = ProsodicConditionExtractor(sr=sr).extract_condition_from_reference_wave(wave).to(self.device)
35
+
36
+ def set_language(self, lang_id):
37
+ """
38
+ The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
39
+ """
40
+ self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, silent=True)
41
+ self.lang_id = get_language_id(lang_id).to(self.device)
42
+
43
+ def forward(self, text, view=False, durations=None, pitch=None, energy=None):
44
+ with torch.no_grad():
45
+ phones = self.text2phone.string_to_tensor(text).to(torch.device(self.device))
46
+ mel, durations, pitch, energy = self.phone2mel(phones,
47
+ return_duration_pitch_energy=True,
48
+ utterance_embedding=self.default_utterance_embedding,
49
+ durations=durations,
50
+ pitch=pitch,
51
+ energy=energy)
52
+ mel = mel.transpose(0, 1)
53
+ wave = self.mel2wav(mel)
54
+ if view:
55
+ from Utility.utils import cumsum_durations
56
+ fig, ax = plt.subplots(nrows=2, ncols=1)
57
+ ax[0].plot(wave.cpu().numpy())
58
+ lbd.specshow(mel.cpu().numpy(),
59
+ ax=ax[1],
60
+ sr=16000,
61
+ cmap='GnBu',
62
+ y_axis='mel',
63
+ x_axis=None,
64
+ hop_length=256)
65
+ ax[0].yaxis.set_visible(False)
66
+ ax[1].yaxis.set_visible(False)
67
+ duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
68
+ ax[1].set_xticks(duration_splits, minor=True)
69
+ ax[1].xaxis.grid(True, which='minor')
70
+ ax[1].set_xticks(label_positions, minor=False)
71
+ ax[1].set_xticklabels(self.text2phone.get_phone_string(text))
72
+ ax[0].set_title(text)
73
+ plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.9, wspace=0.0, hspace=0.0)
74
+ plt.show()
75
+ return wave
InferenceInterfaces/__init__.py ADDED
File without changes
Layers/Attention.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ """Multi-Head Attention layer definition."""
6
+
7
+ import math
8
+
9
+ import numpy
10
+ import torch
11
+ from torch import nn
12
+
13
+ from Utility.utils import make_non_pad_mask
14
+
15
+
16
+ class MultiHeadedAttention(nn.Module):
17
+ """
18
+ Multi-Head Attention layer.
19
+
20
+ Args:
21
+ n_head (int): The number of heads.
22
+ n_feat (int): The number of features.
23
+ dropout_rate (float): Dropout rate.
24
+ """
25
+
26
+ def __init__(self, n_head, n_feat, dropout_rate):
27
+ """
28
+ Construct an MultiHeadedAttention object.
29
+ """
30
+ super(MultiHeadedAttention, self).__init__()
31
+ assert n_feat % n_head == 0
32
+ # We assume d_v always equals d_k
33
+ self.d_k = n_feat // n_head
34
+ self.h = n_head
35
+ self.linear_q = nn.Linear(n_feat, n_feat)
36
+ self.linear_k = nn.Linear(n_feat, n_feat)
37
+ self.linear_v = nn.Linear(n_feat, n_feat)
38
+ self.linear_out = nn.Linear(n_feat, n_feat)
39
+ self.attn = None
40
+ self.dropout = nn.Dropout(p=dropout_rate)
41
+
42
+ def forward_qkv(self, query, key, value):
43
+ """
44
+ Transform query, key and value.
45
+
46
+ Args:
47
+ query (torch.Tensor): Query tensor (#batch, time1, size).
48
+ key (torch.Tensor): Key tensor (#batch, time2, size).
49
+ value (torch.Tensor): Value tensor (#batch, time2, size).
50
+
51
+ Returns:
52
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
53
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
54
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
55
+ """
56
+ n_batch = query.size(0)
57
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
58
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
59
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
60
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
61
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
62
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
63
+
64
+ return q, k, v
65
+
66
+ def forward_attention(self, value, scores, mask):
67
+ """
68
+ Compute attention context vector.
69
+
70
+ Args:
71
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
72
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
73
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
74
+
75
+ Returns:
76
+ torch.Tensor: Transformed value (#batch, time1, d_model)
77
+ weighted by the attention score (#batch, time1, time2).
78
+ """
79
+ n_batch = value.size(0)
80
+ if mask is not None:
81
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
82
+ min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
83
+ scores = scores.masked_fill(mask, min_value)
84
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
85
+ else:
86
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
87
+
88
+ p_attn = self.dropout(self.attn)
89
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
90
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model)
91
+
92
+ return self.linear_out(x) # (batch, time1, d_model)
93
+
94
+ def forward(self, query, key, value, mask):
95
+ """
96
+ Compute scaled dot product attention.
97
+
98
+ Args:
99
+ query (torch.Tensor): Query tensor (#batch, time1, size).
100
+ key (torch.Tensor): Key tensor (#batch, time2, size).
101
+ value (torch.Tensor): Value tensor (#batch, time2, size).
102
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
103
+ (#batch, time1, time2).
104
+
105
+ Returns:
106
+ torch.Tensor: Output tensor (#batch, time1, d_model).
107
+ """
108
+ q, k, v = self.forward_qkv(query, key, value)
109
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
110
+ return self.forward_attention(v, scores, mask)
111
+
112
+
113
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
114
+ """
115
+ Multi-Head Attention layer with relative position encoding.
116
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
117
+ Paper: https://arxiv.org/abs/1901.02860
118
+ Args:
119
+ n_head (int): The number of heads.
120
+ n_feat (int): The number of features.
121
+ dropout_rate (float): Dropout rate.
122
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
123
+ """
124
+
125
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
126
+ """Construct an RelPositionMultiHeadedAttention object."""
127
+ super().__init__(n_head, n_feat, dropout_rate)
128
+ self.zero_triu = zero_triu
129
+ # linear transformation for positional encoding
130
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
131
+ # these two learnable bias are used in matrix c and matrix d
132
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
133
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
134
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
135
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
136
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
137
+
138
+ def rel_shift(self, x):
139
+ """
140
+ Compute relative positional encoding.
141
+ Args:
142
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
143
+ time1 means the length of query vector.
144
+ Returns:
145
+ torch.Tensor: Output tensor.
146
+ """
147
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
148
+ x_padded = torch.cat([zero_pad, x], dim=-1)
149
+
150
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
151
+ x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] # only keep the positions from 0 to time2
152
+
153
+ if self.zero_triu:
154
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
155
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
156
+
157
+ return x
158
+
159
+ def forward(self, query, key, value, pos_emb, mask):
160
+ """
161
+ Compute 'Scaled Dot Product Attention' with rel. positional encoding.
162
+ Args:
163
+ query (torch.Tensor): Query tensor (#batch, time1, size).
164
+ key (torch.Tensor): Key tensor (#batch, time2, size).
165
+ value (torch.Tensor): Value tensor (#batch, time2, size).
166
+ pos_emb (torch.Tensor): Positional embedding tensor
167
+ (#batch, 2*time1-1, size).
168
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
169
+ (#batch, time1, time2).
170
+ Returns:
171
+ torch.Tensor: Output tensor (#batch, time1, d_model).
172
+ """
173
+ q, k, v = self.forward_qkv(query, key, value)
174
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
175
+
176
+ n_batch_pos = pos_emb.size(0)
177
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
178
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
179
+
180
+ # (batch, head, time1, d_k)
181
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
182
+ # (batch, head, time1, d_k)
183
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
184
+
185
+ # compute attention score
186
+ # first compute matrix a and matrix c
187
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
188
+ # (batch, head, time1, time2)
189
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
190
+
191
+ # compute matrix b and matrix d
192
+ # (batch, head, time1, 2*time1-1)
193
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
194
+ matrix_bd = self.rel_shift(matrix_bd)
195
+
196
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
197
+
198
+ return self.forward_attention(v, scores, mask)
199
+
200
+
201
+ class GuidedAttentionLoss(torch.nn.Module):
202
+ """
203
+ Guided attention loss function module.
204
+
205
+ This module calculates the guided attention loss described
206
+ in `Efficiently Trainable Text-to-Speech System Based
207
+ on Deep Convolutional Networks with Guided Attention`_,
208
+ which forces the attention to be diagonal.
209
+
210
+ .. _`Efficiently Trainable Text-to-Speech System
211
+ Based on Deep Convolutional Networks with Guided Attention`:
212
+ https://arxiv.org/abs/1710.08969
213
+ """
214
+
215
+ def __init__(self, sigma=0.4, alpha=1.0):
216
+ """
217
+ Initialize guided attention loss module.
218
+
219
+ Args:
220
+ sigma (float, optional): Standard deviation to control
221
+ how close attention to a diagonal.
222
+ alpha (float, optional): Scaling coefficient (lambda).
223
+ reset_always (bool, optional): Whether to always reset masks.
224
+ """
225
+ super(GuidedAttentionLoss, self).__init__()
226
+ self.sigma = sigma
227
+ self.alpha = alpha
228
+ self.guided_attn_masks = None
229
+ self.masks = None
230
+
231
+ def _reset_masks(self):
232
+ self.guided_attn_masks = None
233
+ self.masks = None
234
+
235
+ def forward(self, att_ws, ilens, olens):
236
+ """
237
+ Calculate forward propagation.
238
+
239
+ Args:
240
+ att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
241
+ ilens (LongTensor): Batch of input lenghts (B,).
242
+ olens (LongTensor): Batch of output lenghts (B,).
243
+
244
+ Returns:
245
+ Tensor: Guided attention loss value.
246
+ """
247
+ self._reset_masks()
248
+ self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device)
249
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device)
250
+ losses = self.guided_attn_masks * att_ws
251
+ loss = torch.mean(losses.masked_select(self.masks))
252
+ self._reset_masks()
253
+ return self.alpha * loss
254
+
255
+ def _make_guided_attention_masks(self, ilens, olens):
256
+ n_batches = len(ilens)
257
+ max_ilen = max(ilens)
258
+ max_olen = max(olens)
259
+ guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=ilens.device)
260
+ for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
261
+ guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma)
262
+ return guided_attn_masks
263
+
264
+ @staticmethod
265
+ def _make_guided_attention_mask(ilen, olen, sigma):
266
+ """
267
+ Make guided attention mask.
268
+ """
269
+ grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device).float(), torch.arange(ilen, device=ilen.device).float())
270
+ return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
271
+
272
+ @staticmethod
273
+ def _make_masks(ilens, olens):
274
+ """
275
+ Make masks indicating non-padded part.
276
+
277
+ Args:
278
+ ilens (LongTensor or List): Batch of lengths (B,).
279
+ olens (LongTensor or List): Batch of lengths (B,).
280
+
281
+ Returns:
282
+ Tensor: Mask tensor indicating non-padded part.
283
+ dtype=torch.uint8 in PyTorch 1.2-
284
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
285
+ """
286
+ in_masks = make_non_pad_mask(ilens, device=ilens.device) # (B, T_in)
287
+ out_masks = make_non_pad_mask(olens, device=olens.device) # (B, T_out)
288
+ return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
289
+
290
+
291
+ class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
292
+ """
293
+ Guided attention loss function module for multi head attention.
294
+
295
+ Args:
296
+ sigma (float, optional): Standard deviation to control
297
+ how close attention to a diagonal.
298
+ alpha (float, optional): Scaling coefficient (lambda).
299
+ reset_always (bool, optional): Whether to always reset masks.
300
+ """
301
+
302
+ def forward(self, att_ws, ilens, olens):
303
+ """
304
+ Calculate forward propagation.
305
+
306
+ Args:
307
+ att_ws (Tensor):
308
+ Batch of multi head attention weights (B, H, T_max_out, T_max_in).
309
+ ilens (LongTensor): Batch of input lenghts (B,).
310
+ olens (LongTensor): Batch of output lenghts (B,).
311
+
312
+ Returns:
313
+ Tensor: Guided attention loss value.
314
+ """
315
+ if self.guided_attn_masks is None:
316
+ self.guided_attn_masks = (self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1))
317
+ if self.masks is None:
318
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
319
+ losses = self.guided_attn_masks * att_ws
320
+ loss = torch.mean(losses.masked_select(self.masks))
321
+ if self.reset_always:
322
+ self._reset_masks()
323
+
324
+ return self.alpha * loss
Layers/Conformer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from Layers.Attention import RelPositionMultiHeadedAttention
9
+ from Layers.Convolution import ConvolutionModule
10
+ from Layers.EncoderLayer import EncoderLayer
11
+ from Layers.LayerNorm import LayerNorm
12
+ from Layers.MultiLayeredConv1d import MultiLayeredConv1d
13
+ from Layers.MultiSequential import repeat
14
+ from Layers.PositionalEncoding import RelPositionalEncoding
15
+ from Layers.Swish import Swish
16
+
17
+
18
+ class Conformer(torch.nn.Module):
19
+ """
20
+ Conformer encoder module.
21
+
22
+ Args:
23
+ idim (int): Input dimension.
24
+ attention_dim (int): Dimension of attention.
25
+ attention_heads (int): The number of heads of multi head attention.
26
+ linear_units (int): The number of units of position-wise feed forward.
27
+ num_blocks (int): The number of decoder blocks.
28
+ dropout_rate (float): Dropout rate.
29
+ positional_dropout_rate (float): Dropout rate after adding positional encoding.
30
+ attention_dropout_rate (float): Dropout rate in attention.
31
+ input_layer (Union[str, torch.nn.Module]): Input layer type.
32
+ normalize_before (bool): Whether to use layer_norm before the first block.
33
+ concat_after (bool): Whether to concat attention layer's input and output.
34
+ if True, additional linear will be applied.
35
+ i.e. x -> x + linear(concat(x, att(x)))
36
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
37
+ positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
38
+ positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
39
+ macaron_style (bool): Whether to use macaron style for positionwise layer.
40
+ pos_enc_layer_type (str): Conformer positional encoding layer type.
41
+ selfattention_layer_type (str): Conformer attention layer type.
42
+ activation_type (str): Conformer activation function type.
43
+ use_cnn_module (bool): Whether to use convolution module.
44
+ cnn_module_kernel (int): Kernerl size of convolution module.
45
+ padding_idx (int): Padding idx for input_layer=embed.
46
+
47
+ """
48
+
49
+ def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1,
50
+ attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1,
51
+ macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, connect_utt_emb_at_encoder_out=True,
52
+ spk_emb_bottleneck_size=128, lang_embs=None):
53
+ super(Conformer, self).__init__()
54
+
55
+ activation = Swish()
56
+ self.conv_subsampling_factor = 1
57
+
58
+ if isinstance(input_layer, torch.nn.Module):
59
+ self.embed = input_layer
60
+ self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate)
61
+ elif input_layer is None:
62
+ self.embed = None
63
+ self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
64
+ else:
65
+ raise ValueError("unknown input_layer: " + input_layer)
66
+
67
+ self.normalize_before = normalize_before
68
+
69
+ self.connect_utt_emb_at_encoder_out = connect_utt_emb_at_encoder_out
70
+ if utt_embed is not None:
71
+ self.hs_emb_projection = torch.nn.Linear(attention_dim + spk_emb_bottleneck_size, attention_dim)
72
+ # embedding projection derived from https://arxiv.org/pdf/1705.08947.pdf
73
+ self.embedding_projection = torch.nn.Sequential(torch.nn.Linear(utt_embed, spk_emb_bottleneck_size),
74
+ torch.nn.Softsign())
75
+ if lang_embs is not None:
76
+ self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=attention_dim)
77
+
78
+ # self-attention module definition
79
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
80
+ encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
81
+
82
+ # feed-forward module definition
83
+ positionwise_layer = MultiLayeredConv1d
84
+ positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,)
85
+
86
+ # convolution module definition
87
+ convolution_layer = ConvolutionModule
88
+ convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
89
+
90
+ self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args),
91
+ positionwise_layer(*positionwise_layer_args),
92
+ positionwise_layer(*positionwise_layer_args) if macaron_style else None,
93
+ convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate,
94
+ normalize_before, concat_after))
95
+ if self.normalize_before:
96
+ self.after_norm = LayerNorm(attention_dim)
97
+
98
+ def forward(self, xs, masks, utterance_embedding=None, lang_ids=None):
99
+ """
100
+ Encode input sequence.
101
+
102
+ Args:
103
+ utterance_embedding: embedding containing lots of conditioning signals
104
+ step: indicator for when to start updating the embedding function
105
+ xs (torch.Tensor): Input tensor (#batch, time, idim).
106
+ masks (torch.Tensor): Mask tensor (#batch, time).
107
+
108
+ Returns:
109
+ torch.Tensor: Output tensor (#batch, time, attention_dim).
110
+ torch.Tensor: Mask tensor (#batch, time).
111
+
112
+ """
113
+
114
+ if self.embed is not None:
115
+ xs = self.embed(xs)
116
+
117
+ if lang_ids is not None:
118
+ lang_embs = self.language_embedding(lang_ids)
119
+ xs = xs + lang_embs # offset the phoneme distribution of a language
120
+
121
+ if utterance_embedding is not None and not self.connect_utt_emb_at_encoder_out:
122
+ xs = self._integrate_with_utt_embed(xs, utterance_embedding)
123
+
124
+ xs = self.pos_enc(xs)
125
+
126
+ xs, masks = self.encoders(xs, masks)
127
+ if isinstance(xs, tuple):
128
+ xs = xs[0]
129
+
130
+ if self.normalize_before:
131
+ xs = self.after_norm(xs)
132
+
133
+ if utterance_embedding is not None and self.connect_utt_emb_at_encoder_out:
134
+ xs = self._integrate_with_utt_embed(xs, utterance_embedding)
135
+
136
+ return xs, masks
137
+
138
+ def _integrate_with_utt_embed(self, hs, utt_embeddings):
139
+ # project embedding into smaller space
140
+ speaker_embeddings_projected = self.embedding_projection(utt_embeddings)
141
+ # concat hidden states with spk embeds and then apply projection
142
+ speaker_embeddings_expanded = F.normalize(speaker_embeddings_projected).unsqueeze(1).expand(-1, hs.size(1), -1)
143
+ hs = self.hs_emb_projection(torch.cat([hs, speaker_embeddings_expanded], dim=-1))
144
+ return hs
Layers/Convolution.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+
7
+ from torch import nn
8
+
9
+
10
+ class ConvolutionModule(nn.Module):
11
+ """
12
+ ConvolutionModule in Conformer model.
13
+
14
+ Args:
15
+ channels (int): The number of channels of conv layers.
16
+ kernel_size (int): Kernel size of conv layers.
17
+
18
+ """
19
+
20
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
21
+ super(ConvolutionModule, self).__init__()
22
+ # kernel_size should be an odd number for 'SAME' padding
23
+ assert (kernel_size - 1) % 2 == 0
24
+
25
+ self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
26
+ self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
27
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=channels)
28
+ self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
29
+ self.activation = activation
30
+
31
+ def forward(self, x):
32
+ """
33
+ Compute convolution module.
34
+
35
+ Args:
36
+ x (torch.Tensor): Input tensor (#batch, time, channels).
37
+
38
+ Returns:
39
+ torch.Tensor: Output tensor (#batch, time, channels).
40
+
41
+ """
42
+ # exchange the temporal dimension and the feature dimension
43
+ x = x.transpose(1, 2)
44
+
45
+ # GLU mechanism
46
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
47
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
48
+
49
+ # 1D Depthwise Conv
50
+ x = self.depthwise_conv(x)
51
+ x = self.activation(self.norm(x))
52
+
53
+ x = self.pointwise_conv2(x)
54
+
55
+ return x.transpose(1, 2)
Layers/DurationPredictor.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+
6
+ import torch
7
+
8
+ from Layers.LayerNorm import LayerNorm
9
+
10
+
11
+ class DurationPredictor(torch.nn.Module):
12
+ """
13
+ Duration predictor module.
14
+
15
+ This is a module of duration predictor described
16
+ in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
17
+ The duration predictor predicts a duration of each frame in log domain
18
+ from the hidden embeddings of encoder.
19
+
20
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
21
+ https://arxiv.org/pdf/1905.09263.pdf
22
+
23
+ Note:
24
+ The calculation domain of outputs is different
25
+ between in `forward` and in `inference`. In `forward`,
26
+ the outputs are calculated in log domain but in `inference`,
27
+ those are calculated in linear domain.
28
+
29
+ """
30
+
31
+ def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
32
+ """
33
+ Initialize duration predictor module.
34
+
35
+ Args:
36
+ idim (int): Input dimension.
37
+ n_layers (int, optional): Number of convolutional layers.
38
+ n_chans (int, optional): Number of channels of convolutional layers.
39
+ kernel_size (int, optional): Kernel size of convolutional layers.
40
+ dropout_rate (float, optional): Dropout rate.
41
+ offset (float, optional): Offset value to avoid nan in log domain.
42
+
43
+ """
44
+ super(DurationPredictor, self).__init__()
45
+ self.offset = offset
46
+ self.conv = torch.nn.ModuleList()
47
+ for idx in range(n_layers):
48
+ in_chans = idim if idx == 0 else n_chans
49
+ self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), torch.nn.ReLU(),
50
+ LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), )]
51
+ self.linear = torch.nn.Linear(n_chans, 1)
52
+
53
+ def _forward(self, xs, x_masks=None, is_inference=False):
54
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
55
+ for f in self.conv:
56
+ xs = f(xs) # (B, C, Tmax)
57
+
58
+ # NOTE: calculate in log domain
59
+ xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
60
+
61
+ if is_inference:
62
+ # NOTE: calculate in linear domain
63
+ xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
64
+
65
+ if x_masks is not None:
66
+ xs = xs.masked_fill(x_masks, 0.0)
67
+
68
+ return xs
69
+
70
+ def forward(self, xs, x_masks=None):
71
+ """
72
+ Calculate forward propagation.
73
+
74
+ Args:
75
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
76
+ x_masks (ByteTensor, optional):
77
+ Batch of masks indicating padded part (B, Tmax).
78
+
79
+ Returns:
80
+ Tensor: Batch of predicted durations in log domain (B, Tmax).
81
+
82
+ """
83
+ return self._forward(xs, x_masks, False)
84
+
85
+ def inference(self, xs, x_masks=None):
86
+ """
87
+ Inference duration.
88
+
89
+ Args:
90
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
91
+ x_masks (ByteTensor, optional):
92
+ Batch of masks indicating padded part (B, Tmax).
93
+
94
+ Returns:
95
+ LongTensor: Batch of predicted durations in linear domain (B, Tmax).
96
+
97
+ """
98
+ return self._forward(xs, x_masks, True)
99
+
100
+
101
+ class DurationPredictorLoss(torch.nn.Module):
102
+ """
103
+ Loss function module for duration predictor.
104
+
105
+ The loss value is Calculated in log domain to make it Gaussian.
106
+
107
+ """
108
+
109
+ def __init__(self, offset=1.0, reduction="mean"):
110
+ """
111
+ Args:
112
+ offset (float, optional): Offset value to avoid nan in log domain.
113
+ reduction (str): Reduction type in loss calculation.
114
+
115
+ """
116
+ super(DurationPredictorLoss, self).__init__()
117
+ self.criterion = torch.nn.MSELoss(reduction=reduction)
118
+ self.offset = offset
119
+
120
+ def forward(self, outputs, targets):
121
+ """
122
+ Calculate forward propagation.
123
+
124
+ Args:
125
+ outputs (Tensor): Batch of prediction durations in log domain (B, T)
126
+ targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
127
+
128
+ Returns:
129
+ Tensor: Mean squared error loss value.
130
+
131
+ Note:
132
+ `outputs` is in log domain but `targets` is in linear domain.
133
+
134
+ """
135
+ # NOTE: outputs is in log domain while targets in linear
136
+ targets = torch.log(targets.float() + self.offset)
137
+ loss = self.criterion(outputs, targets)
138
+
139
+ return loss
Layers/EncoderLayer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from Layers.LayerNorm import LayerNorm
11
+
12
+
13
+ class EncoderLayer(nn.Module):
14
+ """
15
+ Encoder layer module.
16
+
17
+ Args:
18
+ size (int): Input dimension.
19
+ self_attn (torch.nn.Module): Self-attention module instance.
20
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
21
+ can be used as the argument.
22
+ feed_forward (torch.nn.Module): Feed-forward module instance.
23
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
24
+ can be used as the argument.
25
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
26
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
27
+ can be used as the argument.
28
+ conv_module (torch.nn.Module): Convolution module instance.
29
+ `ConvlutionModule` instance can be used as the argument.
30
+ dropout_rate (float): Dropout rate.
31
+ normalize_before (bool): Whether to use layer_norm before the first block.
32
+ concat_after (bool): Whether to concat attention layer's input and output.
33
+ if True, additional linear will be applied.
34
+ i.e. x -> x + linear(concat(x, att(x)))
35
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
36
+
37
+ """
38
+
39
+ def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ):
40
+ super(EncoderLayer, self).__init__()
41
+ self.self_attn = self_attn
42
+ self.feed_forward = feed_forward
43
+ self.feed_forward_macaron = feed_forward_macaron
44
+ self.conv_module = conv_module
45
+ self.norm_ff = LayerNorm(size) # for the FNN module
46
+ self.norm_mha = LayerNorm(size) # for the MHA module
47
+ if feed_forward_macaron is not None:
48
+ self.norm_ff_macaron = LayerNorm(size)
49
+ self.ff_scale = 0.5
50
+ else:
51
+ self.ff_scale = 1.0
52
+ if self.conv_module is not None:
53
+ self.norm_conv = LayerNorm(size) # for the CNN module
54
+ self.norm_final = LayerNorm(size) # for the final output of the block
55
+ self.dropout = nn.Dropout(dropout_rate)
56
+ self.size = size
57
+ self.normalize_before = normalize_before
58
+ self.concat_after = concat_after
59
+ if self.concat_after:
60
+ self.concat_linear = nn.Linear(size + size, size)
61
+
62
+ def forward(self, x_input, mask, cache=None):
63
+ """
64
+ Compute encoded features.
65
+
66
+ Args:
67
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
68
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
69
+ - w/o pos emb: Tensor (#batch, time, size).
70
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
71
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
72
+
73
+ Returns:
74
+ torch.Tensor: Output tensor (#batch, time, size).
75
+ torch.Tensor: Mask tensor (#batch, time).
76
+
77
+ """
78
+ if isinstance(x_input, tuple):
79
+ x, pos_emb = x_input[0], x_input[1]
80
+ else:
81
+ x, pos_emb = x_input, None
82
+
83
+ # whether to use macaron style
84
+ if self.feed_forward_macaron is not None:
85
+ residual = x
86
+ if self.normalize_before:
87
+ x = self.norm_ff_macaron(x)
88
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
89
+ if not self.normalize_before:
90
+ x = self.norm_ff_macaron(x)
91
+
92
+ # multi-headed self-attention module
93
+ residual = x
94
+ if self.normalize_before:
95
+ x = self.norm_mha(x)
96
+
97
+ if cache is None:
98
+ x_q = x
99
+ else:
100
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
101
+ x_q = x[:, -1:, :]
102
+ residual = residual[:, -1:, :]
103
+ mask = None if mask is None else mask[:, -1:, :]
104
+
105
+ if pos_emb is not None:
106
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
107
+ else:
108
+ x_att = self.self_attn(x_q, x, x, mask)
109
+
110
+ if self.concat_after:
111
+ x_concat = torch.cat((x, x_att), dim=-1)
112
+ x = residual + self.concat_linear(x_concat)
113
+ else:
114
+ x = residual + self.dropout(x_att)
115
+ if not self.normalize_before:
116
+ x = self.norm_mha(x)
117
+
118
+ # convolution module
119
+ if self.conv_module is not None:
120
+ residual = x
121
+ if self.normalize_before:
122
+ x = self.norm_conv(x)
123
+ x = residual + self.dropout(self.conv_module(x))
124
+ if not self.normalize_before:
125
+ x = self.norm_conv(x)
126
+
127
+ # feed forward module
128
+ residual = x
129
+ if self.normalize_before:
130
+ x = self.norm_ff(x)
131
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
132
+ if not self.normalize_before:
133
+ x = self.norm_ff(x)
134
+
135
+ if self.conv_module is not None:
136
+ x = self.norm_final(x)
137
+
138
+ if cache is not None:
139
+ x = torch.cat([cache, x], dim=1)
140
+
141
+ if pos_emb is not None:
142
+ return (x, pos_emb), mask
143
+
144
+ return x, mask
Layers/LayerNorm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ import torch
6
+
7
+
8
+ class LayerNorm(torch.nn.LayerNorm):
9
+ """
10
+ Layer normalization module.
11
+
12
+ Args:
13
+ nout (int): Output dim size.
14
+ dim (int): Dimension to be normalized.
15
+ """
16
+
17
+ def __init__(self, nout, dim=-1):
18
+ """
19
+ Construct an LayerNorm object.
20
+ """
21
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
22
+ self.dim = dim
23
+
24
+ def forward(self, x):
25
+ """
26
+ Apply layer normalization.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: Normalized tensor.
33
+ """
34
+ if self.dim == -1:
35
+ return super(LayerNorm, self).forward(x)
36
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
Layers/LengthRegulator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ from abc import ABC
6
+
7
+ import torch
8
+
9
+ from Utility.utils import pad_list
10
+
11
+
12
+ class LengthRegulator(torch.nn.Module, ABC):
13
+ """
14
+ Length regulator module for feed-forward Transformer.
15
+
16
+ This is a module of length regulator described in
17
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
18
+ The length regulator expands char or
19
+ phoneme-level embedding features to frame-level by repeating each
20
+ feature based on the corresponding predicted durations.
21
+
22
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
23
+ https://arxiv.org/pdf/1905.09263.pdf
24
+
25
+ """
26
+
27
+ def __init__(self, pad_value=0.0):
28
+ """
29
+ Initialize length regulator module.
30
+
31
+ Args:
32
+ pad_value (float, optional): Value used for padding.
33
+ """
34
+ super(LengthRegulator, self).__init__()
35
+ self.pad_value = pad_value
36
+
37
+ def forward(self, xs, ds, alpha=1.0):
38
+ """
39
+ Calculate forward propagation.
40
+
41
+ Args:
42
+ xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
43
+ ds (LongTensor): Batch of durations of each frame (B, T).
44
+ alpha (float, optional): Alpha value to control speed of speech.
45
+
46
+ Returns:
47
+ Tensor: replicated input tensor based on durations (B, T*, D).
48
+ """
49
+ if alpha != 1.0:
50
+ assert alpha > 0
51
+ ds = torch.round(ds.float() * alpha).long()
52
+
53
+ if ds.sum() == 0:
54
+ ds[ds.sum(dim=1).eq(0)] = 1
55
+
56
+ return pad_list([self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)], self.pad_value)
57
+
58
+ def _repeat_one_sequence(self, x, d):
59
+ """
60
+ Repeat each frame according to duration
61
+ """
62
+ return torch.repeat_interleave(x, d, dim=0)
Layers/MultiLayeredConv1d.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ """
6
+ Layer modules for FFT block in FastSpeech (Feed-forward Transformer).
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class MultiLayeredConv1d(torch.nn.Module):
13
+ """
14
+ Multi-layered conv1d for Transformer block.
15
+
16
+ This is a module of multi-layered conv1d designed
17
+ to replace positionwise feed-forward network
18
+ in Transformer block, which is introduced in
19
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
20
+
21
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
22
+ https://arxiv.org/pdf/1905.09263.pdf
23
+ """
24
+
25
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
26
+ """
27
+ Initialize MultiLayeredConv1d module.
28
+
29
+ Args:
30
+ in_chans (int): Number of input channels.
31
+ hidden_chans (int): Number of hidden channels.
32
+ kernel_size (int): Kernel size of conv1d.
33
+ dropout_rate (float): Dropout rate.
34
+ """
35
+ super(MultiLayeredConv1d, self).__init__()
36
+ self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
37
+ self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
38
+ self.dropout = torch.nn.Dropout(dropout_rate)
39
+
40
+ def forward(self, x):
41
+ """
42
+ Calculate forward propagation.
43
+
44
+ Args:
45
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
46
+
47
+ Returns:
48
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
49
+ """
50
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
51
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
52
+
53
+
54
+ class Conv1dLinear(torch.nn.Module):
55
+ """
56
+ Conv1D + Linear for Transformer block.
57
+
58
+ A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
59
+ """
60
+
61
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
62
+ """
63
+ Initialize Conv1dLinear module.
64
+
65
+ Args:
66
+ in_chans (int): Number of input channels.
67
+ hidden_chans (int): Number of hidden channels.
68
+ kernel_size (int): Kernel size of conv1d.
69
+ dropout_rate (float): Dropout rate.
70
+ """
71
+ super(Conv1dLinear, self).__init__()
72
+ self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
73
+ self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
74
+ self.dropout = torch.nn.Dropout(dropout_rate)
75
+
76
+ def forward(self, x):
77
+ """
78
+ Calculate forward propagation.
79
+
80
+ Args:
81
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
82
+
83
+ Returns:
84
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
85
+ """
86
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
87
+ return self.w_2(self.dropout(x))
Layers/MultiSequential.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ import torch
6
+
7
+
8
+ class MultiSequential(torch.nn.Sequential):
9
+ """
10
+ Multi-input multi-output torch.nn.Sequential.
11
+ """
12
+
13
+ def forward(self, *args):
14
+ """
15
+ Repeat.
16
+ """
17
+ for m in self:
18
+ args = m(*args)
19
+ return args
20
+
21
+
22
+ def repeat(N, fn):
23
+ """
24
+ Repeat module N times.
25
+
26
+ Args:
27
+ N (int): Number of repeat time.
28
+ fn (Callable): Function to generate module.
29
+
30
+ Returns:
31
+ MultiSequential: Repeated model instance.
32
+ """
33
+ return MultiSequential(*[fn(n) for n in range(N)])
Layers/PositionalEncoding.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+
10
+ class PositionalEncoding(torch.nn.Module):
11
+ """
12
+ Positional encoding.
13
+
14
+ Args:
15
+ d_model (int): Embedding dimension.
16
+ dropout_rate (float): Dropout rate.
17
+ max_len (int): Maximum input length.
18
+ reverse (bool): Whether to reverse the input position.
19
+ """
20
+
21
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
22
+ """
23
+ Construct an PositionalEncoding object.
24
+ """
25
+ super(PositionalEncoding, self).__init__()
26
+ self.d_model = d_model
27
+ self.reverse = reverse
28
+ self.xscale = math.sqrt(self.d_model)
29
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
30
+ self.pe = None
31
+ self.extend_pe(torch.tensor(0.0, device=d_model.device).expand(1, max_len))
32
+
33
+ def extend_pe(self, x):
34
+ """
35
+ Reset the positional encodings.
36
+ """
37
+ if self.pe is not None:
38
+ if self.pe.size(1) >= x.size(1):
39
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
40
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
41
+ return
42
+ pe = torch.zeros(x.size(1), self.d_model)
43
+ if self.reverse:
44
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
45
+ else:
46
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
47
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model))
48
+ pe[:, 0::2] = torch.sin(position * div_term)
49
+ pe[:, 1::2] = torch.cos(position * div_term)
50
+ pe = pe.unsqueeze(0)
51
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
52
+
53
+ def forward(self, x):
54
+ """
55
+ Add positional encoding.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor (batch, time, `*`).
59
+
60
+ Returns:
61
+ torch.Tensor: Encoded tensor (batch, time, `*`).
62
+ """
63
+ self.extend_pe(x)
64
+ x = x * self.xscale + self.pe[:, : x.size(1)]
65
+ return self.dropout(x)
66
+
67
+
68
+ class RelPositionalEncoding(torch.nn.Module):
69
+ """
70
+ Relative positional encoding module (new implementation).
71
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
72
+ See : Appendix B in https://arxiv.org/abs/1901.02860
73
+ Args:
74
+ d_model (int): Embedding dimension.
75
+ dropout_rate (float): Dropout rate.
76
+ max_len (int): Maximum input length.
77
+ """
78
+
79
+ def __init__(self, d_model, dropout_rate, max_len=5000):
80
+ """
81
+ Construct an PositionalEncoding object.
82
+ """
83
+ super(RelPositionalEncoding, self).__init__()
84
+ self.d_model = d_model
85
+ self.xscale = math.sqrt(self.d_model)
86
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
87
+ self.pe = None
88
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
89
+
90
+ def extend_pe(self, x):
91
+ """Reset the positional encodings."""
92
+ if self.pe is not None:
93
+ # self.pe contains both positive and negative parts
94
+ # the length of self.pe is 2 * input_len - 1
95
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
96
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
97
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
98
+ return
99
+ # Suppose `i` means to the position of query vecotr and `j` means the
100
+ # position of key vector. We use position relative positions when keys
101
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
102
+ pe_positive = torch.zeros(x.size(1), self.d_model, device=x.device)
103
+ pe_negative = torch.zeros(x.size(1), self.d_model, device=x.device)
104
+ position = torch.arange(0, x.size(1), dtype=torch.float32, device=x.device).unsqueeze(1)
105
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32, device=x.device) * -(math.log(10000.0) / self.d_model))
106
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
107
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
108
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
109
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
110
+
111
+ # Reserve the order of positive indices and concat both positive and
112
+ # negative indices. This is used to support the shifting trick
113
+ # as in https://arxiv.org/abs/1901.02860
114
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
115
+ pe_negative = pe_negative[1:].unsqueeze(0)
116
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
117
+ self.pe = pe.to(dtype=x.dtype)
118
+
119
+ def forward(self, x):
120
+ """
121
+ Add positional encoding.
122
+ Args:
123
+ x (torch.Tensor): Input tensor (batch, time, `*`).
124
+ Returns:
125
+ torch.Tensor: Encoded tensor (batch, time, `*`).
126
+ """
127
+ self.extend_pe(x)
128
+ x = x * self.xscale
129
+ pos_emb = self.pe[:, self.pe.size(1) // 2 - x.size(1) + 1: self.pe.size(1) // 2 + x.size(1), ]
130
+ return self.dropout(x), self.dropout(pos_emb)
131
+
132
+
133
+ class ScaledPositionalEncoding(PositionalEncoding):
134
+ """
135
+ Scaled positional encoding module.
136
+
137
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
138
+
139
+ Args:
140
+ d_model (int): Embedding dimension.
141
+ dropout_rate (float): Dropout rate.
142
+ max_len (int): Maximum input length.
143
+
144
+ """
145
+
146
+ def __init__(self, d_model, dropout_rate, max_len=5000):
147
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
148
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
149
+
150
+ def reset_parameters(self):
151
+ self.alpha.data = torch.tensor(1.0)
152
+
153
+ def forward(self, x):
154
+ """
155
+ Add positional encoding.
156
+
157
+ Args:
158
+ x (torch.Tensor): Input tensor (batch, time, `*`).
159
+
160
+ Returns:
161
+ torch.Tensor: Encoded tensor (batch, time, `*`).
162
+
163
+ """
164
+ self.extend_pe(x)
165
+ x = x + self.alpha * self.pe[:, : x.size(1)]
166
+ return self.dropout(x)
Layers/PositionwiseFeedForward.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+
6
+ import torch
7
+
8
+
9
+ class PositionwiseFeedForward(torch.nn.Module):
10
+ """
11
+ Args:
12
+ idim (int): Input dimenstion.
13
+ hidden_units (int): The number of hidden units.
14
+ dropout_rate (float): Dropout rate.
15
+
16
+ """
17
+
18
+ def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
19
+ super(PositionwiseFeedForward, self).__init__()
20
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
21
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
22
+ self.dropout = torch.nn.Dropout(dropout_rate)
23
+ self.activation = activation
24
+
25
+ def forward(self, x):
26
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))
Layers/PostNet.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ class PostNet(torch.nn.Module):
9
+ """
10
+ From Tacotron2
11
+
12
+ Postnet module for Spectrogram prediction network.
13
+
14
+ This is a module of Postnet in Spectrogram prediction network,
15
+ which described in `Natural TTS Synthesis by
16
+ Conditioning WaveNet on Mel Spectrogram Predictions`_.
17
+ The Postnet refines the predicted
18
+ Mel-filterbank of the decoder,
19
+ which helps to compensate the detail sturcture of spectrogram.
20
+
21
+ .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
22
+ https://arxiv.org/abs/1712.05884
23
+ """
24
+
25
+ def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True):
26
+ """
27
+ Initialize postnet module.
28
+
29
+ Args:
30
+ idim (int): Dimension of the inputs.
31
+ odim (int): Dimension of the outputs.
32
+ n_layers (int, optional): The number of layers.
33
+ n_filts (int, optional): The number of filter size.
34
+ n_units (int, optional): The number of filter channels.
35
+ use_batch_norm (bool, optional): Whether to use batch normalization..
36
+ dropout_rate (float, optional): Dropout rate..
37
+ """
38
+ super(PostNet, self).__init__()
39
+ self.postnet = torch.nn.ModuleList()
40
+ for layer in range(n_layers - 1):
41
+ ichans = odim if layer == 0 else n_chans
42
+ ochans = odim if layer == n_layers - 1 else n_chans
43
+ if use_batch_norm:
44
+ self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
45
+ torch.nn.GroupNorm(num_groups=32, num_channels=ochans), torch.nn.Tanh(),
46
+ torch.nn.Dropout(dropout_rate), )]
47
+
48
+ else:
49
+ self.postnet += [
50
+ torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(),
51
+ torch.nn.Dropout(dropout_rate), )]
52
+ ichans = n_chans if n_layers != 1 else odim
53
+ if use_batch_norm:
54
+ self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
55
+ torch.nn.GroupNorm(num_groups=20, num_channels=odim),
56
+ torch.nn.Dropout(dropout_rate), )]
57
+
58
+ else:
59
+ self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
60
+ torch.nn.Dropout(dropout_rate), )]
61
+
62
+ def forward(self, xs):
63
+ """
64
+ Calculate forward propagation.
65
+
66
+ Args:
67
+ xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
68
+
69
+ Returns:
70
+ Tensor: Batch of padded output tensor. (B, odim, Tmax).
71
+ """
72
+ for i in range(len(self.postnet)):
73
+ xs = self.postnet[i](xs)
74
+ return xs
Layers/ResidualBlock.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ References:
5
+ - https://github.com/jik876/hifi-gan
6
+ - https://github.com/kan-bayashi/ParallelWaveGAN
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class Conv1d(torch.nn.Conv1d):
13
+ """
14
+ Conv1d module with customized initialization.
15
+ """
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super(Conv1d, self).__init__(*args, **kwargs)
19
+
20
+ def reset_parameters(self):
21
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
22
+ if self.bias is not None:
23
+ torch.nn.init.constant_(self.bias, 0.0)
24
+
25
+
26
+ class Conv1d1x1(Conv1d):
27
+ """
28
+ 1x1 Conv1d with customized initialization.
29
+ """
30
+
31
+ def __init__(self, in_channels, out_channels, bias):
32
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias)
33
+
34
+
35
+ class HiFiGANResidualBlock(torch.nn.Module):
36
+ """Residual block module in HiFiGAN."""
37
+
38
+ def __init__(self,
39
+ kernel_size=3,
40
+ channels=512,
41
+ dilations=(1, 3, 5),
42
+ bias=True,
43
+ use_additional_convs=True,
44
+ nonlinear_activation="LeakyReLU",
45
+ nonlinear_activation_params={"negative_slope": 0.1}, ):
46
+ """
47
+ Initialize HiFiGANResidualBlock module.
48
+
49
+ Args:
50
+ kernel_size (int): Kernel size of dilation convolution layer.
51
+ channels (int): Number of channels for convolution layer.
52
+ dilations (List[int]): List of dilation factors.
53
+ use_additional_convs (bool): Whether to use additional convolution layers.
54
+ bias (bool): Whether to add bias parameter in convolution layers.
55
+ nonlinear_activation (str): Activation function module name.
56
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
57
+ """
58
+ super().__init__()
59
+ self.use_additional_convs = use_additional_convs
60
+ self.convs1 = torch.nn.ModuleList()
61
+ if use_additional_convs:
62
+ self.convs2 = torch.nn.ModuleList()
63
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
64
+ for dilation in dilations:
65
+ self.convs1 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
66
+ torch.nn.Conv1d(channels,
67
+ channels,
68
+ kernel_size,
69
+ 1,
70
+ dilation=dilation,
71
+ bias=bias,
72
+ padding=(kernel_size - 1) // 2 * dilation, ), )]
73
+ if use_additional_convs:
74
+ self.convs2 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
75
+ torch.nn.Conv1d(channels,
76
+ channels,
77
+ kernel_size,
78
+ 1,
79
+ dilation=1,
80
+ bias=bias,
81
+ padding=(kernel_size - 1) // 2, ), )]
82
+
83
+ def forward(self, x):
84
+ """
85
+ Calculate forward propagation.
86
+
87
+ Args:
88
+ x (Tensor): Input tensor (B, channels, T).
89
+
90
+ Returns:
91
+ Tensor: Output tensor (B, channels, T).
92
+ """
93
+ for idx in range(len(self.convs1)):
94
+ xt = self.convs1[idx](x)
95
+ if self.use_additional_convs:
96
+ xt = self.convs2[idx](xt)
97
+ x = xt + x
98
+ return x
Layers/ResidualStack.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+
6
+ import torch
7
+
8
+
9
+ class ResidualStack(torch.nn.Module):
10
+
11
+ def __init__(self, kernel_size=3, channels=32, dilation=1, bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2},
12
+ pad="ReflectionPad1d", pad_params={}, ):
13
+ """
14
+ Initialize ResidualStack module.
15
+
16
+ Args:
17
+ kernel_size (int): Kernel size of dilation convolution layer.
18
+ channels (int): Number of channels of convolution layers.
19
+ dilation (int): Dilation factor.
20
+ bias (bool): Whether to add bias parameter in convolution layers.
21
+ nonlinear_activation (str): Activation function module name.
22
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
23
+ pad (str): Padding function module name before dilated convolution layer.
24
+ pad_params (dict): Hyperparameters for padding function.
25
+
26
+ """
27
+ super(ResidualStack, self).__init__()
28
+
29
+ # defile residual stack part
30
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
31
+ self.stack = torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
32
+ getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
33
+ torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
34
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
35
+ torch.nn.Conv1d(channels, channels, 1, bias=bias), )
36
+
37
+ # defile extra layer for skip connection
38
+ self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
39
+
40
+ def forward(self, c):
41
+ """
42
+ Calculate forward propagation.
43
+
44
+ Args:
45
+ c (Tensor): Input tensor (B, channels, T).
46
+
47
+ Returns:
48
+ Tensor: Output tensor (B, chennels, T).
49
+
50
+ """
51
+ return self.stack(c) + self.skip_layer(c)
Layers/STFT.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import torch
6
+ from torch.functional import stft as torch_stft
7
+ from torch_complex.tensor import ComplexTensor
8
+
9
+ from Utility.utils import make_pad_mask
10
+
11
+
12
+ class STFT(torch.nn.Module):
13
+
14
+ def __init__(self, n_fft=512, win_length=None, hop_length=128, window="hann", center=True, normalized=False,
15
+ onesided=True):
16
+ super().__init__()
17
+ self.n_fft = n_fft
18
+ if win_length is None:
19
+ self.win_length = n_fft
20
+ else:
21
+ self.win_length = win_length
22
+ self.hop_length = hop_length
23
+ self.center = center
24
+ self.normalized = normalized
25
+ self.onesided = onesided
26
+ self.window = window
27
+
28
+ def extra_repr(self):
29
+ return (f"n_fft={self.n_fft}, "
30
+ f"win_length={self.win_length}, "
31
+ f"hop_length={self.hop_length}, "
32
+ f"center={self.center}, "
33
+ f"normalized={self.normalized}, "
34
+ f"onesided={self.onesided}")
35
+
36
+ def forward(self, input_wave, ilens=None):
37
+ """
38
+ STFT forward function.
39
+ Args:
40
+ input_wave: (Batch, Nsamples) or (Batch, Nsample, Channels)
41
+ ilens: (Batch)
42
+ Returns:
43
+ output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
44
+ """
45
+ bs = input_wave.size(0)
46
+
47
+ if input_wave.dim() == 3:
48
+ multi_channel = True
49
+ # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
50
+ input_wave = input_wave.transpose(1, 2).reshape(-1, input_wave.size(1))
51
+ else:
52
+ multi_channel = False
53
+
54
+ # output: (Batch, Freq, Frames, 2=real_imag)
55
+ # or (Batch, Channel, Freq, Frames, 2=real_imag)
56
+ if self.window is not None:
57
+ window_func = getattr(torch, f"{self.window}_window")
58
+ window = window_func(self.win_length, dtype=input_wave.dtype, device=input_wave.device)
59
+ else:
60
+ window = None
61
+
62
+ complex_output = torch_stft(input=input_wave,
63
+ n_fft=self.n_fft,
64
+ win_length=self.win_length,
65
+ hop_length=self.hop_length,
66
+ center=self.center,
67
+ window=window,
68
+ normalized=self.normalized,
69
+ onesided=self.onesided,
70
+ return_complex=True)
71
+ output = torch.view_as_real(complex_output)
72
+ # output: (Batch, Freq, Frames, 2=real_imag)
73
+ # -> (Batch, Frames, Freq, 2=real_imag)
74
+ output = output.transpose(1, 2)
75
+ if multi_channel:
76
+ # output: (Batch * Channel, Frames, Freq, 2=real_imag)
77
+ # -> (Batch, Frame, Channel, Freq, 2=real_imag)
78
+ output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2)
79
+
80
+ if ilens is not None:
81
+ if self.center:
82
+ pad = self.win_length // 2
83
+ ilens = ilens + 2 * pad
84
+
85
+ olens = torch.div((ilens - self.win_length), self.hop_length, rounding_mode="trunc") + 1
86
+ output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
87
+ else:
88
+ olens = None
89
+
90
+ return output, olens
91
+
92
+ def inverse(self, input, ilens=None):
93
+ """
94
+ Inverse STFT.
95
+ Args:
96
+ input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
97
+ ilens: (batch,)
98
+ Returns:
99
+ wavs: (batch, samples)
100
+ ilens: (batch,)
101
+ """
102
+ istft = torch.functional.istft
103
+
104
+ if self.window is not None:
105
+ window_func = getattr(torch, f"{self.window}_window")
106
+ window = window_func(self.win_length, dtype=input.dtype, device=input.device)
107
+ else:
108
+ window = None
109
+
110
+ if isinstance(input, ComplexTensor):
111
+ input = torch.stack([input.real, input.imag], dim=-1)
112
+ assert input.shape[-1] == 2
113
+ input = input.transpose(1, 2)
114
+
115
+ wavs = istft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=window, center=self.center,
116
+ normalized=self.normalized, onesided=self.onesided, length=ilens.max() if ilens is not None else ilens)
117
+
118
+ return wavs, ilens
Layers/Swish.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+ import torch
7
+
8
+
9
+ class Swish(torch.nn.Module):
10
+ """
11
+ Construct an Swish activation function for Conformer.
12
+ """
13
+
14
+ def forward(self, x):
15
+ """
16
+ Return Swish activation function.
17
+ """
18
+ return x * torch.sigmoid(x)
Layers/VariancePredictor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ from abc import ABC
6
+
7
+ import torch
8
+
9
+ from Layers.LayerNorm import LayerNorm
10
+
11
+
12
+ class VariancePredictor(torch.nn.Module, ABC):
13
+ """
14
+ Variance predictor module.
15
+
16
+ This is a module of variance predictor described in `FastSpeech 2:
17
+ Fast and High-Quality End-to-End Text to Speech`_.
18
+
19
+ .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`:
20
+ https://arxiv.org/abs/2006.04558
21
+
22
+ """
23
+
24
+ def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, bias=True, dropout_rate=0.5, ):
25
+ """
26
+ Initilize duration predictor module.
27
+
28
+ Args:
29
+ idim (int): Input dimension.
30
+ n_layers (int, optional): Number of convolutional layers.
31
+ n_chans (int, optional): Number of channels of convolutional layers.
32
+ kernel_size (int, optional): Kernel size of convolutional layers.
33
+ dropout_rate (float, optional): Dropout rate.
34
+ """
35
+ super().__init__()
36
+ self.conv = torch.nn.ModuleList()
37
+ for idx in range(n_layers):
38
+ in_chans = idim if idx == 0 else n_chans
39
+ self.conv += [
40
+ torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias, ), torch.nn.ReLU(),
41
+ LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), )]
42
+ self.linear = torch.nn.Linear(n_chans, 1)
43
+
44
+ def forward(self, xs, x_masks=None):
45
+ """
46
+ Calculate forward propagation.
47
+
48
+ Args:
49
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
50
+ x_masks (ByteTensor, optional):
51
+ Batch of masks indicating padded part (B, Tmax).
52
+
53
+ Returns:
54
+ Tensor: Batch of predicted sequences (B, Tmax, 1).
55
+ """
56
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
57
+ for f in self.conv:
58
+ xs = f(xs) # (B, C, Tmax)
59
+
60
+ xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1)
61
+
62
+ if x_masks is not None:
63
+ xs = xs.masked_fill(x_masks, 0.0)
64
+
65
+ return xs
Layers/__init__.py ADDED
File without changes
Models/Aligner/__init__.py ADDED
File without changes
Models/FastSpeech2_Meta/__init__.py ADDED
File without changes
Models/HiFiGAN_combined/__init__.py ADDED
File without changes
Preprocessing/ArticulatoryCombinedTextFrontend.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+
4
+ import panphon
5
+ import phonemizer
6
+ import torch
7
+
8
+ from Preprocessing.papercup_features import generate_feature_table
9
+
10
+
11
+ class ArticulatoryCombinedTextFrontend:
12
+
13
+ def __init__(self,
14
+ language,
15
+ use_word_boundaries=False, # goes together well with
16
+ # parallel models and an aligner. Doesn't go together
17
+ # well with autoregressive models.
18
+ use_explicit_eos=True,
19
+ use_prosody=False, # unfortunately the non-segmental
20
+ # nature of prosodic markers mixed with the sequential
21
+ # phonemes hurts the performance of end-to-end models a
22
+ # lot, even though one might think enriching the input
23
+ # with such information would help.
24
+ use_lexical_stress=False,
25
+ silent=True,
26
+ allow_unknown=False,
27
+ add_silence_to_end=True,
28
+ strip_silence=True):
29
+ """
30
+ Mostly preparing ID lookups
31
+ """
32
+ self.strip_silence = strip_silence
33
+ self.use_word_boundaries = use_word_boundaries
34
+ self.allow_unknown = allow_unknown
35
+ self.use_explicit_eos = use_explicit_eos
36
+ self.use_prosody = use_prosody
37
+ self.use_stress = use_lexical_stress
38
+ self.add_silence_to_end = add_silence_to_end
39
+ self.feature_table = panphon.FeatureTable()
40
+
41
+ if language == "en":
42
+ self.g2p_lang = "en-us"
43
+ self.expand_abbreviations = english_text_expansion
44
+ if not silent:
45
+ print("Created an English Text-Frontend")
46
+
47
+ elif language == "de":
48
+ self.g2p_lang = "de"
49
+ self.expand_abbreviations = lambda x: x
50
+ if not silent:
51
+ print("Created a German Text-Frontend")
52
+
53
+ elif language == "el":
54
+ self.g2p_lang = "el"
55
+ self.expand_abbreviations = lambda x: x
56
+ if not silent:
57
+ print("Created a Greek Text-Frontend")
58
+
59
+ elif language == "es":
60
+ self.g2p_lang = "es"
61
+ self.expand_abbreviations = lambda x: x
62
+ if not silent:
63
+ print("Created a Spanish Text-Frontend")
64
+
65
+ elif language == "fi":
66
+ self.g2p_lang = "fi"
67
+ self.expand_abbreviations = lambda x: x
68
+ if not silent:
69
+ print("Created a Finnish Text-Frontend")
70
+
71
+ elif language == "ru":
72
+ self.g2p_lang = "ru"
73
+ self.expand_abbreviations = lambda x: x
74
+ if not silent:
75
+ print("Created a Russian Text-Frontend")
76
+
77
+ elif language == "hu":
78
+ self.g2p_lang = "hu"
79
+ self.expand_abbreviations = lambda x: x
80
+ if not silent:
81
+ print("Created a Hungarian Text-Frontend")
82
+
83
+ elif language == "nl":
84
+ self.g2p_lang = "nl"
85
+ self.expand_abbreviations = lambda x: x
86
+ if not silent:
87
+ print("Created a Dutch Text-Frontend")
88
+
89
+ elif language == "fr":
90
+ self.g2p_lang = "fr-fr"
91
+ self.expand_abbreviations = lambda x: x
92
+ if not silent:
93
+ print("Created a French Text-Frontend")
94
+
95
+ elif language == "it":
96
+ self.g2p_lang = "it"
97
+ self.expand_abbreviations = lambda x: x
98
+ if not silent:
99
+ print("Created a Italian Text-Frontend")
100
+
101
+ elif language == "pt":
102
+ self.g2p_lang = "pt"
103
+ self.expand_abbreviations = lambda x: x
104
+ if not silent:
105
+ print("Created a Portuguese Text-Frontend")
106
+
107
+ elif language == "pl":
108
+ self.g2p_lang = "pl"
109
+ self.expand_abbreviations = lambda x: x
110
+ if not silent:
111
+ print("Created a Polish Text-Frontend")
112
+
113
+ # remember to also update get_language_id() when adding something here
114
+
115
+ else:
116
+ print("Language not supported yet")
117
+ sys.exit()
118
+
119
+ self.phone_to_vector_papercup = generate_feature_table()
120
+
121
+ self.phone_to_vector = dict()
122
+ for phone in self.phone_to_vector_papercup:
123
+ panphon_features = self.feature_table.word_to_vector_list(phone, numeric=True)
124
+ if panphon_features == []:
125
+ panphon_features = [[0] * 24]
126
+ papercup_features = self.phone_to_vector_papercup[phone]
127
+ self.phone_to_vector[phone] = papercup_features + panphon_features[0]
128
+
129
+ self.phone_to_id = { # this lookup must be updated manually, because the only
130
+ # other way would be extracting them from a set, which can be non-deterministic
131
+ '~': 0,
132
+ '#': 1,
133
+ '?': 2,
134
+ '!': 3,
135
+ '.': 4,
136
+ 'ɜ': 5,
137
+ 'ɫ': 6,
138
+ 'ə': 7,
139
+ 'ɚ': 8,
140
+ 'a': 9,
141
+ 'ð': 10,
142
+ 'ɛ': 11,
143
+ 'ɪ': 12,
144
+ 'ᵻ': 13,
145
+ 'ŋ': 14,
146
+ 'ɔ': 15,
147
+ 'ɒ': 16,
148
+ 'ɾ': 17,
149
+ 'ʃ': 18,
150
+ 'θ': 19,
151
+ 'ʊ': 20,
152
+ 'ʌ': 21,
153
+ 'ʒ': 22,
154
+ 'æ': 23,
155
+ 'b': 24,
156
+ 'ʔ': 25,
157
+ 'd': 26,
158
+ 'e': 27,
159
+ 'f': 28,
160
+ 'g': 29,
161
+ 'h': 30,
162
+ 'i': 31,
163
+ 'j': 32,
164
+ 'k': 33,
165
+ 'l': 34,
166
+ 'm': 35,
167
+ 'n': 36,
168
+ 'ɳ': 37,
169
+ 'o': 38,
170
+ 'p': 39,
171
+ 'ɡ': 40,
172
+ 'ɹ': 41,
173
+ 'r': 42,
174
+ 's': 43,
175
+ 't': 44,
176
+ 'u': 45,
177
+ 'v': 46,
178
+ 'w': 47,
179
+ 'x': 48,
180
+ 'z': 49,
181
+ 'ʀ': 50,
182
+ 'ø': 51,
183
+ 'ç': 52,
184
+ 'ɐ': 53,
185
+ 'œ': 54,
186
+ 'y': 55,
187
+ 'ʏ': 56,
188
+ 'ɑ': 57,
189
+ 'c': 58,
190
+ 'ɲ': 59,
191
+ 'ɣ': 60,
192
+ 'ʎ': 61,
193
+ 'β': 62,
194
+ 'ʝ': 63,
195
+ 'ɟ': 64,
196
+ 'q': 65,
197
+ 'ɕ': 66,
198
+ 'ʲ': 67,
199
+ 'ɭ': 68,
200
+ 'ɵ': 69,
201
+ 'ʑ': 70,
202
+ 'ʋ': 71,
203
+ 'ʁ': 72,
204
+ 'ɨ': 73,
205
+ 'ʂ': 74,
206
+ 'ɬ': 75,
207
+ } # for the states of the ctc loss and dijkstra/mas in the aligner
208
+
209
+ self.id_to_phone = {v: k for k, v in self.phone_to_id.items()}
210
+
211
+ def string_to_tensor(self, text, view=False, device="cpu", handle_missing=True, input_phonemes=False):
212
+ """
213
+ Fixes unicode errors, expands some abbreviations,
214
+ turns graphemes into phonemes and then vectorizes
215
+ the sequence as articulatory features
216
+ """
217
+ if input_phonemes:
218
+ phones = text
219
+ else:
220
+ phones = self.get_phone_string(text=text, include_eos_symbol=True)
221
+ if view:
222
+ print("Phonemes: \n{}\n".format(phones))
223
+ phones_vector = list()
224
+ # turn into numeric vectors
225
+ for char in phones:
226
+ if handle_missing:
227
+ try:
228
+ phones_vector.append(self.phone_to_vector[char])
229
+ except KeyError:
230
+ print("unknown phoneme: {}".format(char))
231
+ else:
232
+ phones_vector.append(self.phone_to_vector[char]) # leave error handling to elsewhere
233
+
234
+ return torch.Tensor(phones_vector, device=device)
235
+
236
+ def get_phone_string(self, text, include_eos_symbol=True):
237
+ # expand abbreviations
238
+ utt = self.expand_abbreviations(text)
239
+ # phonemize
240
+ phones = phonemizer.phonemize(utt,
241
+ language_switch='remove-flags',
242
+ backend="espeak",
243
+ language=self.g2p_lang,
244
+ preserve_punctuation=True,
245
+ strip=True,
246
+ punctuation_marks=';:,.!?¡¿—…"«»“”~/',
247
+ with_stress=self.use_stress).replace(";", ",").replace("/", " ").replace("—", "") \
248
+ .replace(":", ",").replace('"', ",").replace("-", ",").replace("...", ",").replace("-", ",").replace("\n", " ") \
249
+ .replace("\t", " ").replace("¡", "").replace("¿", "").replace(",", "~").replace(" ̃", "").replace('̩', "").replace("̃", "").replace("̪", "")
250
+ # less than 1 wide characters hidden here
251
+ phones = re.sub("~+", "~", phones)
252
+ if not self.use_prosody:
253
+ # retain ~ as heuristic pause marker, even though all other symbols are removed with this option.
254
+ # also retain . ? and ! since they can be indicators for the stop token
255
+ phones = phones.replace("ˌ", "").replace("ː", "").replace("ˑ", "") \
256
+ .replace("˘", "").replace("|", "").replace("‖", "")
257
+ if not self.use_word_boundaries:
258
+ phones = phones.replace(" ", "")
259
+ else:
260
+ phones = re.sub(r"\s+", " ", phones)
261
+ phones = re.sub(" ", "~", phones)
262
+ if self.strip_silence:
263
+ phones = phones.lstrip("~").rstrip("~")
264
+ if self.add_silence_to_end:
265
+ phones += "~" # adding a silence in the end during add_silence_to_end produces more natural sounding prosody
266
+ if include_eos_symbol:
267
+ phones += "#"
268
+
269
+ phones = "~" + phones
270
+ phones = re.sub("~+", "~", phones)
271
+
272
+ return phones
273
+
274
+
275
+ def english_text_expansion(text):
276
+ """
277
+ Apply as small part of the tacotron style text cleaning pipeline, suitable for e.g. LJSpeech.
278
+ See https://github.com/keithito/tacotron/
279
+ Careful: Only apply to english datasets. Different languages need different cleaners.
280
+ """
281
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in
282
+ [('Mrs.', 'misess'), ('Mr.', 'mister'), ('Dr.', 'doctor'), ('St.', 'saint'), ('Co.', 'company'), ('Jr.', 'junior'), ('Maj.', 'major'),
283
+ ('Gen.', 'general'), ('Drs.', 'doctors'), ('Rev.', 'reverend'), ('Lt.', 'lieutenant'), ('Hon.', 'honorable'), ('Sgt.', 'sergeant'),
284
+ ('Capt.', 'captain'), ('Esq.', 'esquire'), ('Ltd.', 'limited'), ('Col.', 'colonel'), ('Ft.', 'fort')]]
285
+ for regex, replacement in _abbreviations:
286
+ text = re.sub(regex, replacement, text)
287
+ return text
288
+
289
+
290
+ def get_language_id(language):
291
+ if language == "en":
292
+ return torch.LongTensor([0])
293
+ elif language == "de":
294
+ return torch.LongTensor([1])
295
+ elif language == "el":
296
+ return torch.LongTensor([2])
297
+ elif language == "es":
298
+ return torch.LongTensor([3])
299
+ elif language == "fi":
300
+ return torch.LongTensor([4])
301
+ elif language == "ru":
302
+ return torch.LongTensor([5])
303
+ elif language == "hu":
304
+ return torch.LongTensor([6])
305
+ elif language == "nl":
306
+ return torch.LongTensor([7])
307
+ elif language == "fr":
308
+ return torch.LongTensor([8])
309
+ elif language == "pt":
310
+ return torch.LongTensor([9])
311
+ elif language == "pl":
312
+ return torch.LongTensor([10])
313
+ elif language == "it":
314
+ return torch.LongTensor([11])
315
+
316
+
317
+ if __name__ == '__main__':
318
+ # test an English utterance
319
+ tfr_en = ArticulatoryCombinedTextFrontend(language="en")
320
+ print(tfr_en.string_to_tensor("This is a complex sentence, it even has a pause! But can it do this? Nice.", view=True))
321
+
322
+ tfr_en = ArticulatoryCombinedTextFrontend(language="de")
323
+ print(tfr_en.string_to_tensor("Alles klar, jetzt testen wir einen deutschen Satz. Ich hoffe es gibt nicht mehr viele unspezifizierte Phoneme.", view=True))
Preprocessing/AudioPreprocessor.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.core as lb
3
+ import librosa.display as lbd
4
+ import matplotlib.pyplot as plt
5
+ import numpy
6
+ import numpy as np
7
+ import pyloudnorm as pyln
8
+ import torch
9
+ from torchaudio.transforms import Resample
10
+
11
+
12
+ class AudioPreprocessor:
13
+
14
+ def __init__(self, input_sr, output_sr=None, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False, device="cpu"):
15
+ """
16
+ The parameters are by default set up to do well
17
+ on a 16kHz signal. A different sampling rate may
18
+ require different hop_length and n_fft (e.g.
19
+ doubling frequency --> doubling hop_length and
20
+ doubling n_fft)
21
+ """
22
+ self.cut_silence = cut_silence
23
+ self.device = device
24
+ self.sr = input_sr
25
+ self.new_sr = output_sr
26
+ self.hop_length = hop_length
27
+ self.n_fft = n_fft
28
+ self.mel_buckets = melspec_buckets
29
+ self.meter = pyln.Meter(input_sr)
30
+ self.final_sr = input_sr
31
+ if cut_silence:
32
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
33
+ # careful: assumes 16kHz or 8kHz audio
34
+ self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
35
+ model='silero_vad',
36
+ force_reload=False,
37
+ onnx=False,
38
+ verbose=False)
39
+ (self.get_speech_timestamps,
40
+ self.save_audio,
41
+ self.read_audio,
42
+ self.VADIterator,
43
+ self.collect_chunks) = utils
44
+ self.silero_model = self.silero_model.to(self.device)
45
+ if output_sr is not None and output_sr != input_sr:
46
+ self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
47
+ self.final_sr = output_sr
48
+ else:
49
+ self.resample = lambda x: x
50
+
51
+ def cut_silence_from_audio(self, audio):
52
+ """
53
+ https://github.com/snakers4/silero-vad
54
+ """
55
+ return self.collect_chunks(self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr), audio)
56
+
57
+ def to_mono(self, x):
58
+ """
59
+ make sure we deal with a 1D array
60
+ """
61
+ if len(x.shape) == 2:
62
+ return lb.to_mono(numpy.transpose(x))
63
+ else:
64
+ return x
65
+
66
+ def normalize_loudness(self, audio):
67
+ """
68
+ normalize the amplitudes according to
69
+ their decibels, so this should turn any
70
+ signal with different magnitudes into
71
+ the same magnitude by analysing loudness
72
+ """
73
+ loudness = self.meter.integrated_loudness(audio)
74
+ loud_normed = pyln.normalize.loudness(audio, loudness, -30.0)
75
+ peak = numpy.amax(numpy.abs(loud_normed))
76
+ peak_normed = numpy.divide(loud_normed, peak)
77
+ return peak_normed
78
+
79
+ def logmelfilterbank(self, audio, sampling_rate, fmin=40, fmax=8000, eps=1e-10):
80
+ """
81
+ Compute log-Mel filterbank
82
+
83
+ one day this could be replaced by torchaudio's internal log10(melspec(audio)), but
84
+ for some reason it gives slightly different results, so in order not to break backwards
85
+ compatibility, this is kept for now. If there is ever a reason to completely re-train
86
+ all models, this would be a good opportunity to make the switch.
87
+ """
88
+ if isinstance(audio, torch.Tensor):
89
+ audio = audio.numpy()
90
+ # get amplitude spectrogram
91
+ x_stft = librosa.stft(audio, n_fft=self.n_fft, hop_length=self.hop_length, win_length=None, window="hann", pad_mode="reflect")
92
+ spc = np.abs(x_stft).T
93
+ # get mel basis
94
+ fmin = 0 if fmin is None else fmin
95
+ fmax = sampling_rate / 2 if fmax is None else fmax
96
+ mel_basis = librosa.filters.mel(sampling_rate, self.n_fft, self.mel_buckets, fmin, fmax)
97
+ # apply log and return
98
+ return torch.Tensor(np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))).transpose(0, 1)
99
+
100
+ def normalize_audio(self, audio):
101
+ """
102
+ one function to apply them all in an
103
+ order that makes sense.
104
+ """
105
+ audio = self.to_mono(audio)
106
+ audio = self.normalize_loudness(audio)
107
+ audio = torch.Tensor(audio).to(self.device)
108
+ audio = self.resample(audio)
109
+ if self.cut_silence:
110
+ audio = self.cut_silence_from_audio(audio)
111
+ return audio.to("cpu")
112
+
113
+ def visualize_cleaning(self, unclean_audio):
114
+ """
115
+ displays Mel Spectrogram of unclean audio
116
+ and then displays Mel Spectrogram of the
117
+ cleaned version.
118
+ """
119
+ fig, ax = plt.subplots(nrows=2, ncols=1)
120
+ unclean_audio_mono = self.to_mono(unclean_audio)
121
+ unclean_spec = self.audio_to_mel_spec_tensor(unclean_audio_mono, normalize=False).numpy()
122
+ clean_spec = self.audio_to_mel_spec_tensor(unclean_audio_mono, normalize=True).numpy()
123
+ lbd.specshow(unclean_spec, sr=self.sr, cmap='GnBu', y_axis='mel', ax=ax[0], x_axis='time')
124
+ ax[0].set(title='Uncleaned Audio')
125
+ ax[0].label_outer()
126
+ if self.new_sr is not None:
127
+ lbd.specshow(clean_spec, sr=self.new_sr, cmap='GnBu', y_axis='mel', ax=ax[1], x_axis='time')
128
+ else:
129
+ lbd.specshow(clean_spec, sr=self.sr, cmap='GnBu', y_axis='mel', ax=ax[1], x_axis='time')
130
+ ax[1].set(title='Cleaned Audio')
131
+ ax[1].label_outer()
132
+ plt.show()
133
+
134
+ def audio_to_wave_tensor(self, audio, normalize=True):
135
+ if normalize:
136
+ return self.normalize_audio(audio)
137
+ else:
138
+ if isinstance(audio, torch.Tensor):
139
+ return audio
140
+ else:
141
+ return torch.Tensor(audio)
142
+
143
+ def audio_to_mel_spec_tensor(self, audio, normalize=True, explicit_sampling_rate=None):
144
+ """
145
+ explicit_sampling_rate is for when
146
+ normalization has already been applied
147
+ and that included resampling. No way
148
+ to detect the current sr of the incoming
149
+ audio
150
+ """
151
+ if explicit_sampling_rate is None:
152
+ if normalize:
153
+ audio = self.normalize_audio(audio)
154
+ return self.logmelfilterbank(audio=audio, sampling_rate=self.final_sr)
155
+ return self.logmelfilterbank(audio=audio, sampling_rate=self.sr)
156
+ if normalize:
157
+ audio = self.normalize_audio(audio)
158
+ return self.logmelfilterbank(audio=audio, sampling_rate=explicit_sampling_rate)
159
+
160
+
161
+ if __name__ == '__main__':
162
+ import soundfile
163
+
164
+ wav, sr = soundfile.read("../audios/test.wav")
165
+ ap = AudioPreprocessor(input_sr=sr, output_sr=16000)
166
+ ap.visualize_cleaning(wav)
Preprocessing/ProsodicConditionExtractor.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import torch.multiprocessing
4
+ import torch.multiprocessing
5
+ from numpy import trim_zeros
6
+ from speechbrain.pretrained import EncoderClassifier
7
+
8
+ from Preprocessing.AudioPreprocessor import AudioPreprocessor
9
+
10
+
11
+ class ProsodicConditionExtractor:
12
+
13
+ def __init__(self, sr, device=torch.device("cpu")):
14
+ self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False)
15
+ # https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb
16
+ self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
17
+ run_opts={"device": str(device)},
18
+ savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
19
+ # https://huggingface.co/speechbrain/spkrec-xvect-voxceleb
20
+ self.speaker_embedding_func_xvector = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb",
21
+ run_opts={"device": str(device)},
22
+ savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_xvector")
23
+
24
+ def extract_condition_from_reference_wave(self, wave, already_normalized=False):
25
+ if already_normalized:
26
+ norm_wave = wave
27
+ else:
28
+ norm_wave = self.ap.audio_to_wave_tensor(normalize=True, audio=wave)
29
+ norm_wave = torch.tensor(trim_zeros(norm_wave.numpy()))
30
+ spk_emb_ecapa = self.speaker_embedding_func_ecapa.encode_batch(wavs=norm_wave.unsqueeze(0)).squeeze()
31
+ spk_emb_xvector = self.speaker_embedding_func_xvector.encode_batch(wavs=norm_wave.unsqueeze(0)).squeeze()
32
+ combined_utt_condition = torch.cat([spk_emb_ecapa.cpu(),
33
+ spk_emb_xvector.cpu()], dim=0)
34
+ return combined_utt_condition
35
+
36
+
37
+ if __name__ == '__main__':
38
+ wave, sr = sf.read("../audios/1.wav")
39
+ ext = ProsodicConditionExtractor(sr=sr)
40
+ print(ext.extract_condition_from_reference_wave(wave=wave).shape)
Preprocessing/__init__.py ADDED
File without changes
Preprocessing/papercup_features.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Derived from an open-source resource provided by Papercup Technologies Limited
2
+ # Resource-Author: Marlene Staib
3
+ # Modified by Florian Lux, 2021
4
+
5
+ def generate_feature_lookup():
6
+ return {
7
+ '~': {'symbol_type': 'silence'},
8
+ '#': {'symbol_type': 'end of sentence'},
9
+ '?': {'symbol_type': 'questionmark'},
10
+ '!': {'symbol_type': 'exclamationmark'},
11
+ '.': {'symbol_type': 'fullstop'},
12
+ 'ɜ': {
13
+ 'symbol_type' : 'phoneme',
14
+ 'vowel_consonant' : 'vowel',
15
+ 'VUV' : 'voiced',
16
+ 'vowel_frontness' : 'central',
17
+ 'vowel_openness' : 'open-mid',
18
+ 'vowel_roundedness': 'unrounded',
19
+ },
20
+ 'ɫ': {
21
+ 'symbol_type' : 'phoneme',
22
+ 'vowel_consonant' : 'consonant',
23
+ 'VUV' : 'voiced',
24
+ 'consonant_place' : 'alveolar',
25
+ 'consonant_manner': 'lateral-approximant',
26
+ },
27
+ 'ə': {
28
+ 'symbol_type' : 'phoneme',
29
+ 'vowel_consonant' : 'vowel',
30
+ 'VUV' : 'voiced',
31
+ 'vowel_frontness' : 'central',
32
+ 'vowel_openness' : 'mid',
33
+ 'vowel_roundedness': 'unrounded',
34
+ },
35
+ 'ɚ': {
36
+ 'symbol_type' : 'phoneme',
37
+ 'vowel_consonant' : 'vowel',
38
+ 'VUV' : 'voiced',
39
+ 'vowel_frontness' : 'central',
40
+ 'vowel_openness' : 'mid',
41
+ 'vowel_roundedness': 'unrounded',
42
+ },
43
+ 'a': {
44
+ 'symbol_type' : 'phoneme',
45
+ 'vowel_consonant' : 'vowel',
46
+ 'VUV' : 'voiced',
47
+ 'vowel_frontness' : 'front',
48
+ 'vowel_openness' : 'open',
49
+ 'vowel_roundedness': 'unrounded',
50
+ },
51
+ 'ð': {
52
+ 'symbol_type' : 'phoneme',
53
+ 'vowel_consonant' : 'consonant',
54
+ 'VUV' : 'voiced',
55
+ 'consonant_place' : 'dental',
56
+ 'consonant_manner': 'fricative'
57
+ },
58
+ 'ɛ': {
59
+ 'symbol_type' : 'phoneme',
60
+ 'vowel_consonant' : 'vowel',
61
+ 'VUV' : 'voiced',
62
+ 'vowel_frontness' : 'front',
63
+ 'vowel_openness' : 'open-mid',
64
+ 'vowel_roundedness': 'unrounded',
65
+ },
66
+ 'ɪ': {
67
+ 'symbol_type' : 'phoneme',
68
+ 'vowel_consonant' : 'vowel',
69
+ 'VUV' : 'voiced',
70
+ 'vowel_frontness' : 'front_central',
71
+ 'vowel_openness' : 'close_close-mid',
72
+ 'vowel_roundedness': 'unrounded',
73
+ },
74
+ 'ᵻ': {
75
+ 'symbol_type' : 'phoneme',
76
+ 'vowel_consonant' : 'vowel',
77
+ 'VUV' : 'voiced',
78
+ 'vowel_frontness' : 'central',
79
+ 'vowel_openness' : 'close',
80
+ 'vowel_roundedness': 'unrounded',
81
+ },
82
+ 'ŋ': {
83
+ 'symbol_type' : 'phoneme',
84
+ 'vowel_consonant' : 'consonant',
85
+ 'VUV' : 'voiced',
86
+ 'consonant_place' : 'velar',
87
+ 'consonant_manner': 'nasal'
88
+ },
89
+ 'ɔ': {
90
+ 'symbol_type' : 'phoneme',
91
+ 'vowel_consonant' : 'vowel',
92
+ 'VUV' : 'voiced',
93
+ 'vowel_frontness' : 'back',
94
+ 'vowel_openness' : 'open-mid',
95
+ 'vowel_roundedness': 'rounded',
96
+ },
97
+ 'ɒ': {
98
+ 'symbol_type' : 'phoneme',
99
+ 'vowel_consonant' : 'vowel',
100
+ 'VUV' : 'voiced',
101
+ 'vowel_frontness' : 'back',
102
+ 'vowel_openness' : 'open',
103
+ 'vowel_roundedness': 'rounded',
104
+ },
105
+ 'ɾ': {
106
+ 'symbol_type' : 'phoneme',
107
+ 'vowel_consonant' : 'consonant',
108
+ 'VUV' : 'voiced',
109
+ 'consonant_place' : 'alveolar',
110
+ 'consonant_manner': 'tap'
111
+ },
112
+ 'ʃ': {
113
+ 'symbol_type' : 'phoneme',
114
+ 'vowel_consonant' : 'consonant',
115
+ 'VUV' : 'unvoiced',
116
+ 'consonant_place' : 'postalveolar',
117
+ 'consonant_manner': 'fricative'
118
+ },
119
+ 'θ': {
120
+ 'symbol_type' : 'phoneme',
121
+ 'vowel_consonant' : 'consonant',
122
+ 'VUV' : 'unvoiced',
123
+ 'consonant_place' : 'dental',
124
+ 'consonant_manner': 'fricative'
125
+ },
126
+ 'ʊ': {
127
+ 'symbol_type' : 'phoneme',
128
+ 'vowel_consonant' : 'vowel',
129
+ 'VUV' : 'voiced',
130
+ 'vowel_frontness' : 'central_back',
131
+ 'vowel_openness' : 'close_close-mid',
132
+ 'vowel_roundedness': 'unrounded'
133
+ },
134
+ 'ʌ': {
135
+ 'symbol_type' : 'phoneme',
136
+ 'vowel_consonant' : 'vowel',
137
+ 'VUV' : 'voiced',
138
+ 'vowel_frontness' : 'back',
139
+ 'vowel_openness' : 'open-mid',
140
+ 'vowel_roundedness': 'unrounded'
141
+ },
142
+ 'ʒ': {
143
+ 'symbol_type' : 'phoneme',
144
+ 'vowel_consonant' : 'consonant',
145
+ 'VUV' : 'voiced',
146
+ 'consonant_place' : 'postalveolar',
147
+ 'consonant_manner': 'fricative'
148
+ },
149
+ 'æ': {
150
+ 'symbol_type' : 'phoneme',
151
+ 'vowel_consonant' : 'vowel',
152
+ 'VUV' : 'voiced',
153
+ 'vowel_frontness' : 'front',
154
+ 'vowel_openness' : 'open-mid_open',
155
+ 'vowel_roundedness': 'unrounded'
156
+ },
157
+ 'b': {
158
+ 'symbol_type' : 'phoneme',
159
+ 'vowel_consonant' : 'consonant',
160
+ 'VUV' : 'voiced',
161
+ 'consonant_place' : 'bilabial',
162
+ 'consonant_manner': 'stop'
163
+ },
164
+ 'ʔ': {
165
+ 'symbol_type' : 'phoneme',
166
+ 'vowel_consonant' : 'consonant',
167
+ 'VUV' : 'unvoiced',
168
+ 'consonant_place' : 'glottal',
169
+ 'consonant_manner': 'stop'
170
+ },
171
+ 'd': {
172
+ 'symbol_type' : 'phoneme',
173
+ 'vowel_consonant' : 'consonant',
174
+ 'VUV' : 'voiced',
175
+ 'consonant_place' : 'alveolar',
176
+ 'consonant_manner': 'stop'
177
+ },
178
+ 'e': {
179
+ 'symbol_type' : 'phoneme',
180
+ 'vowel_consonant' : 'vowel',
181
+ 'VUV' : 'voiced',
182
+ 'vowel_frontness' : 'front',
183
+ 'vowel_openness' : 'close-mid',
184
+ 'vowel_roundedness': 'unrounded'
185
+ },
186
+ 'f': {
187
+ 'symbol_type' : 'phoneme',
188
+ 'vowel_consonant' : 'consonant',
189
+ 'VUV' : 'unvoiced',
190
+ 'consonant_place' : 'labiodental',
191
+ 'consonant_manner': 'fricative'
192
+ },
193
+ 'g': {
194
+ 'symbol_type' : 'phoneme',
195
+ 'vowel_consonant' : 'consonant',
196
+ 'VUV' : 'voiced',
197
+ 'consonant_place' : 'velar',
198
+ 'consonant_manner': 'stop'
199
+ },
200
+ 'h': {
201
+ 'symbol_type' : 'phoneme',
202
+ 'vowel_consonant' : 'consonant',
203
+ 'VUV' : 'unvoiced',
204
+ 'consonant_place' : 'glottal',
205
+ 'consonant_manner': 'fricative'
206
+ },
207
+ 'i': {
208
+ 'symbol_type' : 'phoneme',
209
+ 'vowel_consonant' : 'vowel',
210
+ 'VUV' : 'voiced',
211
+ 'vowel_frontness' : 'front',
212
+ 'vowel_openness' : 'close',
213
+ 'vowel_roundedness': 'unrounded'
214
+ },
215
+ 'j': {
216
+ 'symbol_type' : 'phoneme',
217
+ 'vowel_consonant' : 'consonant',
218
+ 'VUV' : 'voiced',
219
+ 'consonant_place' : 'palatal',
220
+ 'consonant_manner': 'approximant'
221
+ },
222
+ 'k': {
223
+ 'symbol_type' : 'phoneme',
224
+ 'vowel_consonant' : 'consonant',
225
+ 'VUV' : 'unvoiced',
226
+ 'consonant_place' : 'velar',
227
+ 'consonant_manner': 'stop'
228
+ },
229
+ 'l': {
230
+ 'symbol_type' : 'phoneme',
231
+ 'vowel_consonant' : 'consonant',
232
+ 'VUV' : 'voiced',
233
+ 'consonant_place' : 'alveolar',
234
+ 'consonant_manner': 'lateral-approximant'
235
+ },
236
+ 'm': {
237
+ 'symbol_type' : 'phoneme',
238
+ 'vowel_consonant' : 'consonant',
239
+ 'VUV' : 'voiced',
240
+ 'consonant_place' : 'bilabial',
241
+ 'consonant_manner': 'nasal'
242
+ },
243
+ 'n': {
244
+ 'symbol_type' : 'phoneme',
245
+ 'vowel_consonant' : 'consonant',
246
+ 'VUV' : 'voiced',
247
+ 'consonant_place' : 'alveolar',
248
+ 'consonant_manner': 'nasal'
249
+ },
250
+ 'ɳ': {
251
+ 'symbol_type' : 'phoneme',
252
+ 'vowel_consonant' : 'consonant',
253
+ 'VUV' : 'voiced',
254
+ 'consonant_place' : 'palatal',
255
+ 'consonant_manner': 'nasal'
256
+ },
257
+ 'o': {
258
+ 'symbol_type' : 'phoneme',
259
+ 'vowel_consonant' : 'vowel',
260
+ 'VUV' : 'voiced',
261
+ 'vowel_frontness' : 'back',
262
+ 'vowel_openness' : 'close-mid',
263
+ 'vowel_roundedness': 'rounded'
264
+ },
265
+ 'p': {
266
+ 'symbol_type' : 'phoneme',
267
+ 'vowel_consonant' : 'consonant',
268
+ 'VUV' : 'unvoiced',
269
+ 'consonant_place' : 'bilabial',
270
+ 'consonant_manner': 'stop'
271
+ },
272
+ 'ɡ': {
273
+ 'symbol_type' : 'phoneme',
274
+ 'vowel_consonant' : 'consonant',
275
+ 'VUV' : 'voiced',
276
+ 'consonant_place' : 'velar',
277
+ 'consonant_manner': 'stop'
278
+ },
279
+ 'ɹ': {
280
+ 'symbol_type' : 'phoneme',
281
+ 'vowel_consonant' : 'consonant',
282
+ 'VUV' : 'voiced',
283
+ 'consonant_place' : 'alveolar',
284
+ 'consonant_manner': 'approximant'
285
+ },
286
+ 'r': {
287
+ 'symbol_type' : 'phoneme',
288
+ 'vowel_consonant' : 'consonant',
289
+ 'VUV' : 'voiced',
290
+ 'consonant_place' : 'alveolar',
291
+ 'consonant_manner': 'trill'
292
+ },
293
+ 's': {
294
+ 'symbol_type' : 'phoneme',
295
+ 'vowel_consonant' : 'consonant',
296
+ 'VUV' : 'unvoiced',
297
+ 'consonant_place' : 'alveolar',
298
+ 'consonant_manner': 'fricative'
299
+ },
300
+ 't': {
301
+ 'symbol_type' : 'phoneme',
302
+ 'vowel_consonant' : 'consonant',
303
+ 'VUV' : 'unvoiced',
304
+ 'consonant_place' : 'alveolar',
305
+ 'consonant_manner': 'stop'
306
+ },
307
+ 'u': {
308
+ 'symbol_type' : 'phoneme',
309
+ 'vowel_consonant' : 'vowel',
310
+ 'VUV' : 'voiced',
311
+ 'vowel_frontness' : 'back',
312
+ 'vowel_openness' : 'close',
313
+ 'vowel_roundedness': 'rounded',
314
+ },
315
+ 'v': {
316
+ 'symbol_type' : 'phoneme',
317
+ 'vowel_consonant' : 'consonant',
318
+ 'VUV' : 'voiced',
319
+ 'consonant_place' : 'labiodental',
320
+ 'consonant_manner': 'fricative'
321
+ },
322
+ 'w': {
323
+ 'symbol_type' : 'phoneme',
324
+ 'vowel_consonant' : 'consonant',
325
+ 'VUV' : 'voiced',
326
+ 'consonant_place' : 'labial-velar',
327
+ 'consonant_manner': 'approximant'
328
+ },
329
+ 'x': {
330
+ 'symbol_type' : 'phoneme',
331
+ 'vowel_consonant' : 'consonant',
332
+ 'VUV' : 'unvoiced',
333
+ 'consonant_place' : 'velar',
334
+ 'consonant_manner': 'fricative'
335
+ },
336
+ 'z': {
337
+ 'symbol_type' : 'phoneme',
338
+ 'vowel_consonant' : 'consonant',
339
+ 'VUV' : 'voiced',
340
+ 'consonant_place' : 'alveolar',
341
+ 'consonant_manner': 'fricative'
342
+ },
343
+ 'ʀ': {
344
+ 'symbol_type' : 'phoneme',
345
+ 'vowel_consonant' : 'consonant',
346
+ 'VUV' : 'voiced',
347
+ 'consonant_place' : 'uvular',
348
+ 'consonant_manner': 'trill'
349
+ },
350
+ 'ø': {
351
+ 'symbol_type' : 'phoneme',
352
+ 'vowel_consonant' : 'vowel',
353
+ 'VUV' : 'voiced',
354
+ 'vowel_frontness' : 'front',
355
+ 'vowel_openness' : 'close-mid',
356
+ 'vowel_roundedness': 'rounded'
357
+ },
358
+ 'ç': {
359
+ 'symbol_type' : 'phoneme',
360
+ 'vowel_consonant' : 'consonant',
361
+ 'VUV' : 'unvoiced',
362
+ 'consonant_place' : 'palatal',
363
+ 'consonant_manner': 'fricative'
364
+ },
365
+ 'ɐ': {
366
+ 'symbol_type' : 'phoneme',
367
+ 'vowel_consonant' : 'vowel',
368
+ 'VUV' : 'voiced',
369
+ 'vowel_frontness' : 'central',
370
+ 'vowel_openness' : 'open',
371
+ 'vowel_roundedness': 'unrounded'
372
+ },
373
+ 'œ': {
374
+ 'symbol_type' : 'phoneme',
375
+ 'vowel_consonant' : 'vowel',
376
+ 'VUV' : 'voiced',
377
+ 'vowel_frontness' : 'front',
378
+ 'vowel_openness' : 'open-mid',
379
+ 'vowel_roundedness': 'rounded'
380
+ },
381
+ 'y': {
382
+ 'symbol_type' : 'phoneme',
383
+ 'vowel_consonant' : 'vowel',
384
+ 'VUV' : 'voiced',
385
+ 'vowel_frontness' : 'front',
386
+ 'vowel_openness' : 'close',
387
+ 'vowel_roundedness': 'rounded'
388
+ },
389
+ 'ʏ': {
390
+ 'symbol_type' : 'phoneme',
391
+ 'vowel_consonant' : 'vowel',
392
+ 'VUV' : 'voiced',
393
+ 'vowel_frontness' : 'front_central',
394
+ 'vowel_openness' : 'close_close-mid',
395
+ 'vowel_roundedness': 'rounded'
396
+ },
397
+ 'ɑ': {
398
+ 'symbol_type' : 'phoneme',
399
+ 'vowel_consonant' : 'vowel',
400
+ 'VUV' : 'voiced',
401
+ 'vowel_frontness' : 'back',
402
+ 'vowel_openness' : 'open',
403
+ 'vowel_roundedness': 'unrounded'
404
+ },
405
+ 'c': {
406
+ 'symbol_type' : 'phoneme',
407
+ 'vowel_consonant' : 'consonant',
408
+ 'VUV' : 'unvoiced',
409
+ 'consonant_place' : 'palatal',
410
+ 'consonant_manner': 'stop'
411
+ },
412
+ 'ɲ': {
413
+ 'symbol_type' : 'phoneme',
414
+ 'vowel_consonant' : 'consonant',
415
+ 'VUV' : 'voiced',
416
+ 'consonant_place' : 'palatal',
417
+ 'consonant_manner': 'nasal'
418
+ },
419
+ 'ɣ': {
420
+ 'symbol_type' : 'phoneme',
421
+ 'vowel_consonant' : 'consonant',
422
+ 'VUV' : 'voiced',
423
+ 'consonant_place' : 'velar',
424
+ 'consonant_manner': 'fricative'
425
+ },
426
+ 'ʎ': {
427
+ 'symbol_type' : 'phoneme',
428
+ 'vowel_consonant' : 'consonant',
429
+ 'VUV' : 'voiced',
430
+ 'consonant_place' : 'palatal',
431
+ 'consonant_manner': 'lateral-approximant'
432
+ },
433
+ 'β': {
434
+ 'symbol_type' : 'phoneme',
435
+ 'vowel_consonant' : 'consonant',
436
+ 'VUV' : 'voiced',
437
+ 'consonant_place' : 'bilabial',
438
+ 'consonant_manner': 'fricative'
439
+ },
440
+ 'ʝ': {
441
+ 'symbol_type' : 'phoneme',
442
+ 'vowel_consonant' : 'consonant',
443
+ 'VUV' : 'voiced',
444
+ 'consonant_place' : 'palatal',
445
+ 'consonant_manner': 'fricative'
446
+ },
447
+ 'ɟ': {
448
+ 'symbol_type' : 'phoneme',
449
+ 'vowel_consonant' : 'consonant',
450
+ 'VUV' : 'voiced',
451
+ 'consonant_place' : 'palatal',
452
+ 'consonant_manner': 'stop'
453
+ },
454
+ 'q': {
455
+ 'symbol_type' : 'phoneme',
456
+ 'vowel_consonant' : 'consonant',
457
+ 'VUV' : 'unvoiced',
458
+ 'consonant_place' : 'uvular',
459
+ 'consonant_manner': 'stop'
460
+ },
461
+ 'ɕ': {
462
+ 'symbol_type' : 'phoneme',
463
+ 'vowel_consonant' : 'consonant',
464
+ 'VUV' : 'unvoiced',
465
+ 'consonant_place' : 'alveolopalatal',
466
+ 'consonant_manner': 'fricative'
467
+ },
468
+ 'ʲ': {
469
+ 'symbol_type' : 'phoneme',
470
+ 'vowel_consonant' : 'consonant',
471
+ 'VUV' : 'voiced',
472
+ 'consonant_place' : 'palatal',
473
+ 'consonant_manner': 'approximant'
474
+ },
475
+ 'ɭ': {
476
+ 'symbol_type' : 'phoneme',
477
+ 'vowel_consonant' : 'consonant',
478
+ 'VUV' : 'voiced',
479
+ 'consonant_place' : 'palatal', # should be retroflex, but palatal should be close enough
480
+ 'consonant_manner': 'lateral-approximant'
481
+ },
482
+ 'ɵ': {
483
+ 'symbol_type' : 'phoneme',
484
+ 'vowel_consonant' : 'vowel',
485
+ 'VUV' : 'voiced',
486
+ 'vowel_frontness' : 'central',
487
+ 'vowel_openness' : 'open-mid',
488
+ 'vowel_roundedness': 'rounded'
489
+ },
490
+ 'ʑ': {
491
+ 'symbol_type' : 'phoneme',
492
+ 'vowel_consonant' : 'consonant',
493
+ 'VUV' : 'voiced',
494
+ 'consonant_place' : 'alveolopalatal',
495
+ 'consonant_manner': 'fricative'
496
+ },
497
+ 'ʋ': {
498
+ 'symbol_type' : 'phoneme',
499
+ 'vowel_consonant' : 'consonant',
500
+ 'VUV' : 'voiced',
501
+ 'consonant_place' : 'labiodental',
502
+ 'consonant_manner': 'approximant'
503
+ },
504
+ 'ʁ': {
505
+ 'symbol_type' : 'phoneme',
506
+ 'vowel_consonant' : 'consonant',
507
+ 'VUV' : 'voiced',
508
+ 'consonant_place' : 'uvular',
509
+ 'consonant_manner': 'fricative'
510
+ },
511
+ 'ɨ': {
512
+ 'symbol_type' : 'phoneme',
513
+ 'vowel_consonant' : 'vowel',
514
+ 'VUV' : 'voiced',
515
+ 'vowel_frontness' : 'central',
516
+ 'vowel_openness' : 'close',
517
+ 'vowel_roundedness': 'unrounded'
518
+ },
519
+ 'ʂ': {
520
+ 'symbol_type' : 'phoneme',
521
+ 'vowel_consonant' : 'consonant',
522
+ 'VUV' : 'unvoiced',
523
+ 'consonant_place' : 'palatal', # should be retroflex, but palatal should be close enough
524
+ 'consonant_manner': 'fricative'
525
+ },
526
+ 'ɬ': {
527
+ 'symbol_type' : 'phoneme',
528
+ 'vowel_consonant' : 'consonant',
529
+ 'VUV' : 'unvoiced',
530
+ 'consonant_place' : 'alveolar', # should be noted it's also lateral, but should be close enough
531
+ 'consonant_manner': 'fricative'
532
+ },
533
+ } # REMEMBER to also add the phonemes added here to the ID lookup table in the TextFrontend as the new highest ID
534
+
535
+
536
+ def generate_feature_table():
537
+ ipa_to_phonemefeats = generate_feature_lookup()
538
+
539
+ feat_types = set()
540
+ for ipa in ipa_to_phonemefeats:
541
+ if len(ipa) == 1:
542
+ [feat_types.add(feat) for feat in ipa_to_phonemefeats[ipa].keys()]
543
+
544
+ feat_to_val_set = dict()
545
+ for feat in feat_types:
546
+ feat_to_val_set[feat] = set()
547
+ for ipa in ipa_to_phonemefeats:
548
+ if len(ipa) == 1:
549
+ for feat in ipa_to_phonemefeats[ipa]:
550
+ feat_to_val_set[feat].add(ipa_to_phonemefeats[ipa][feat])
551
+
552
+ # print(feat_to_val_set)
553
+
554
+ value_list = set()
555
+ for val_set in [feat_to_val_set[feat] for feat in feat_to_val_set]:
556
+ for value in val_set:
557
+ value_list.add(value)
558
+ # print("{")
559
+ # for index, value in enumerate(list(value_list)):
560
+ # print('"{}":{},'.format(value,index))
561
+ # print("}")
562
+
563
+ value_to_index = {
564
+ "dental" : 0,
565
+ "postalveolar" : 1,
566
+ "mid" : 2,
567
+ "close-mid" : 3,
568
+ "vowel" : 4,
569
+ "silence" : 5,
570
+ "consonant" : 6,
571
+ "close" : 7,
572
+ "velar" : 8,
573
+ "stop" : 9,
574
+ "palatal" : 10,
575
+ "nasal" : 11,
576
+ "glottal" : 12,
577
+ "central" : 13,
578
+ "back" : 14,
579
+ "approximant" : 15,
580
+ "uvular" : 16,
581
+ "open-mid" : 17,
582
+ "front_central" : 18,
583
+ "front" : 19,
584
+ "end of sentence" : 20,
585
+ "labiodental" : 21,
586
+ "close_close-mid" : 22,
587
+ "labial-velar" : 23,
588
+ "unvoiced" : 24,
589
+ "central_back" : 25,
590
+ "trill" : 26,
591
+ "rounded" : 27,
592
+ "open-mid_open" : 28,
593
+ "tap" : 29,
594
+ "alveolar" : 30,
595
+ "bilabial" : 31,
596
+ "phoneme" : 32,
597
+ "open" : 33,
598
+ "fricative" : 34,
599
+ "unrounded" : 35,
600
+ "lateral-approximant": 36,
601
+ "voiced" : 37,
602
+ "questionmark" : 38,
603
+ "exclamationmark" : 39,
604
+ "fullstop" : 40,
605
+ "alveolopalatal" : 41
606
+ }
607
+
608
+ phone_to_vector = dict()
609
+ for ipa in ipa_to_phonemefeats:
610
+ if len(ipa) == 1:
611
+ phone_to_vector[ipa] = [0] * sum([len(values) for values in [feat_to_val_set[feat] for feat in feat_to_val_set]])
612
+ for feat in ipa_to_phonemefeats[ipa]:
613
+ if ipa_to_phonemefeats[ipa][feat] in value_to_index:
614
+ phone_to_vector[ipa][value_to_index[ipa_to_phonemefeats[ipa][feat]]] = 1
615
+
616
+ for feat in feat_to_val_set:
617
+ for value in feat_to_val_set[feat]:
618
+ if value not in value_to_index:
619
+ print(f"Unknown feature value in featureset! {value}")
620
+
621
+ # print(f"{sum([len(values) for values in [feat_to_val_set[feat] for feat in feat_to_val_set]])} should be 42")
622
+
623
+ return phone_to_vector
624
+
625
+
626
+ def generate_phone_to_id_lookup():
627
+ ipa_to_phonemefeats = generate_feature_lookup()
628
+ count = 0
629
+ phone_to_id = dict()
630
+ for key in sorted(list(ipa_to_phonemefeats)): # careful: non-deterministic
631
+ phone_to_id[key] = count
632
+ count += 1
633
+ return phone_to_id
634
+
635
+
636
+ if __name__ == '__main__':
637
+ print(generate_phone_to_id_lookup())
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: SpeechCloning
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
1
  ---
2
  title: SpeechCloning
3
+ emoji: 🦜
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/Aligner.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken and adapted from https://github.com/as-ideas/DeepForcedAligner
3
+ """
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.multiprocessing
8
+ import torch.nn as nn
9
+ from scipy.sparse import coo_matrix
10
+ from scipy.sparse.csgraph import dijkstra
11
+ from torch.nn import CTCLoss
12
+ from torch.nn.utils.rnn import pack_padded_sequence
13
+ from torch.nn.utils.rnn import pad_packed_sequence
14
+
15
+ from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
16
+
17
+
18
+ class BatchNormConv(nn.Module):
19
+
20
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
21
+ super().__init__()
22
+ self.conv = nn.Conv1d(
23
+ in_channels, out_channels, kernel_size,
24
+ stride=1, padding=kernel_size // 2, bias=False)
25
+ self.bnorm = nn.BatchNorm1d(out_channels)
26
+ self.relu = nn.ReLU()
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, 2)
30
+ x = self.conv(x)
31
+ x = self.relu(x)
32
+ x = self.bnorm(x)
33
+ x = x.transpose(1, 2)
34
+ return x
35
+
36
+
37
+ class Aligner(torch.nn.Module):
38
+
39
+ def __init__(self,
40
+ n_mels=80,
41
+ num_symbols=145,
42
+ lstm_dim=512,
43
+ conv_dim=512):
44
+ super().__init__()
45
+ self.convs = nn.ModuleList([
46
+ BatchNormConv(n_mels, conv_dim, 3),
47
+ nn.Dropout(p=0.5),
48
+ BatchNormConv(conv_dim, conv_dim, 3),
49
+ nn.Dropout(p=0.5),
50
+ BatchNormConv(conv_dim, conv_dim, 3),
51
+ nn.Dropout(p=0.5),
52
+ BatchNormConv(conv_dim, conv_dim, 3),
53
+ nn.Dropout(p=0.5),
54
+ BatchNormConv(conv_dim, conv_dim, 3),
55
+ nn.Dropout(p=0.5),
56
+ ])
57
+ self.rnn = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
58
+ self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
59
+ self.tf = ArticulatoryCombinedTextFrontend(language="en")
60
+ self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
61
+ self.vector_to_id = dict()
62
+ for phone in self.tf.phone_to_vector:
63
+ self.vector_to_id[tuple(self.tf.phone_to_vector[phone])] = self.tf.phone_to_id[phone]
64
+
65
+ def forward(self, x, lens=None):
66
+ for conv in self.convs:
67
+ x = conv(x)
68
+
69
+ if lens is not None:
70
+ x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
71
+ x, _ = self.rnn(x)
72
+ if lens is not None:
73
+ x, _ = pad_packed_sequence(x, batch_first=True)
74
+
75
+ x = self.proj(x)
76
+
77
+ return x
78
+
79
+ @torch.no_grad()
80
+ def label_speech(self, speech):
81
+ # theoretically possible, but doesn't work well at all. Would probably require a beamsearch
82
+ probabilities_of_phones_over_frames = self(speech.unsqueeze(0)).squeeze()[:, :73]
83
+ smoothed_phone_probs_over_frames = list()
84
+ for index, _ in enumerate(probabilities_of_phones_over_frames):
85
+ access_safe_prev_index = max(0, index - 1)
86
+ access_safe_next_index = min(index + 1, len(probabilities_of_phones_over_frames) - 1)
87
+ smoothed_probs = (probabilities_of_phones_over_frames[access_safe_prev_index] +
88
+ probabilities_of_phones_over_frames[access_safe_next_index] +
89
+ probabilities_of_phones_over_frames[index]) / 3
90
+ smoothed_phone_probs_over_frames.append(smoothed_probs.unsqueeze(0))
91
+ print(torch.cat(smoothed_phone_probs_over_frames))
92
+ _, phone_ids_over_frames = torch.max(torch.cat(smoothed_phone_probs_over_frames), dim=1)
93
+ phone_ids = torch.unique_consecutive(phone_ids_over_frames)
94
+ phones = list()
95
+ for id_of_phone in phone_ids:
96
+ phones.append(self.tf.id_to_phone[int(id_of_phone)])
97
+ return "".join(phones)
98
+
99
+ @torch.inference_mode()
100
+ def inference(self, mel, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False):
101
+ if not train:
102
+ tokens_indexed = list() # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi
103
+ for vector in tokens:
104
+ tokens_indexed.append(self.vector_to_id[tuple(vector.cpu().detach().numpy().tolist())])
105
+ tokens = np.asarray(tokens_indexed)
106
+ else:
107
+ tokens = tokens.cpu().detach().numpy()
108
+
109
+ pred = self(mel.unsqueeze(0))
110
+ if return_ctc:
111
+ ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]),
112
+ torch.LongTensor([len(tokens)])).item()
113
+ pred = pred.squeeze().cpu().detach().numpy()
114
+ pred_max = pred[:, tokens]
115
+ path_probs = 1. - pred_max
116
+ adj_matrix = to_adj_matrix(path_probs)
117
+
118
+ if pathfinding == "MAS":
119
+
120
+ alignment_matrix = binarize_alignment(pred_max)
121
+
122
+ if save_img_for_debug is not None:
123
+ phones = list()
124
+ for index in tokens:
125
+ for phone in self.tf.phone_to_id:
126
+ if self.tf.phone_to_id[phone] == index:
127
+ phones.append(phone)
128
+ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 4))
129
+
130
+ ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
131
+
132
+ ax.set_ylabel("Mel-Frames")
133
+
134
+ ax.set_xticks(range(len(pred_max[0])))
135
+ ax.set_xticklabels(labels=phones)
136
+
137
+ ax.set_title("MAS Path")
138
+
139
+ plt.tight_layout()
140
+ fig.savefig(save_img_for_debug)
141
+ fig.clf()
142
+ plt.close()
143
+
144
+ if return_ctc:
145
+ return alignment_matrix, ctc_loss
146
+ return alignment_matrix
147
+
148
+ elif pathfinding == "dijkstra":
149
+
150
+ dist_matrix, predecessors, *_ = dijkstra(csgraph=adj_matrix,
151
+ directed=True,
152
+ indices=0,
153
+ return_predecessors=True)
154
+ path = []
155
+ pr_index = predecessors[-1]
156
+ while pr_index != 0:
157
+ path.append(pr_index)
158
+ pr_index = predecessors[pr_index]
159
+ path.reverse()
160
+
161
+ # append first and last node
162
+ path = [0] + path + [dist_matrix.size - 1]
163
+ cols = path_probs.shape[1]
164
+ mel_text = {}
165
+
166
+ # collect indices (mel, text) along the path
167
+ for node_index in path:
168
+ i, j = from_node_index(node_index, cols)
169
+ mel_text[i] = j
170
+
171
+ path_plot = np.zeros_like(pred_max)
172
+ for i in mel_text:
173
+ path_plot[i][mel_text[i]] = 1.0
174
+
175
+ if save_img_for_debug is not None:
176
+
177
+ phones = list()
178
+ for index in tokens:
179
+ for phone in self.tf.phone_to_id:
180
+ if self.tf.phone_to_id[phone] == index:
181
+ phones.append(phone)
182
+ fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(10, 9))
183
+
184
+ ax[0].imshow(pred_max, interpolation='nearest', aspect='auto', origin="lower")
185
+ ax[1].imshow(path_plot, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
186
+
187
+ ax[0].set_ylabel("Mel-Frames")
188
+ ax[1].set_ylabel("Mel-Frames")
189
+
190
+ ax[0].set_xticks(range(len(pred_max[0])))
191
+ ax[0].set_xticklabels(labels=phones)
192
+
193
+ ax[1].set_xticks(range(len(pred_max[0])))
194
+ ax[1].set_xticklabels(labels=phones)
195
+
196
+ ax[0].set_title("Path Probabilities")
197
+ ax[1].set_title("Dijkstra Path")
198
+
199
+ plt.tight_layout()
200
+ fig.savefig(save_img_for_debug)
201
+ fig.clf()
202
+ plt.close()
203
+
204
+ if return_ctc:
205
+ return path_plot, ctc_loss
206
+ return path_plot
207
+
208
+
209
+ def binarize_alignment(alignment_prob):
210
+ """
211
+ # Implementation by:
212
+ # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py
213
+ # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py
214
+
215
+ Binarizes alignment with MAS.
216
+ """
217
+ # assumes mel x text
218
+ opt = np.zeros_like(alignment_prob)
219
+ alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later
220
+ alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1]
221
+ attn_map = np.log(alignment_prob)
222
+ attn_map[0, 1:] = -np.inf
223
+ log_p = np.zeros_like(attn_map)
224
+ log_p[0, :] = attn_map[0, :]
225
+ prev_ind = np.zeros_like(attn_map, dtype=np.int64)
226
+ for i in range(1, attn_map.shape[0]):
227
+ for j in range(attn_map.shape[1]): # for each text dim
228
+ prev_log = log_p[i - 1, j]
229
+ prev_j = j
230
+ if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
231
+ prev_log = log_p[i - 1, j - 1]
232
+ prev_j = j - 1
233
+ log_p[i, j] = attn_map[i, j] + prev_log
234
+ prev_ind[i, j] = prev_j
235
+ # now backtrack
236
+ curr_text_idx = attn_map.shape[1] - 1
237
+ for i in range(attn_map.shape[0] - 1, -1, -1):
238
+ opt[i, curr_text_idx] = 1
239
+ curr_text_idx = prev_ind[i, curr_text_idx]
240
+ opt[0, curr_text_idx] = 1
241
+ return opt
242
+
243
+
244
+ def to_node_index(i, j, cols):
245
+ return cols * i + j
246
+
247
+
248
+ def from_node_index(node_index, cols):
249
+ return node_index // cols, node_index % cols
250
+
251
+
252
+ def to_adj_matrix(mat):
253
+ rows = mat.shape[0]
254
+ cols = mat.shape[1]
255
+
256
+ row_ind = []
257
+ col_ind = []
258
+ data = []
259
+
260
+ for i in range(rows):
261
+ for j in range(cols):
262
+
263
+ node = to_node_index(i, j, cols)
264
+
265
+ if j < cols - 1:
266
+ right_node = to_node_index(i, j + 1, cols)
267
+ weight_right = mat[i, j + 1]
268
+ row_ind.append(node)
269
+ col_ind.append(right_node)
270
+ data.append(weight_right)
271
+
272
+ if i < rows - 1 and j < cols:
273
+ bottom_node = to_node_index(i + 1, j, cols)
274
+ weight_bottom = mat[i + 1, j]
275
+ row_ind.append(node)
276
+ col_ind.append(bottom_node)
277
+ data.append(weight_bottom)
278
+
279
+ if i < rows - 1 and j < cols - 1:
280
+ bottom_right_node = to_node_index(i + 1, j + 1, cols)
281
+ weight_bottom_right = mat[i + 1, j + 1]
282
+ row_ind.append(node)
283
+ col_ind.append(bottom_right_node)
284
+ data.append(weight_bottom_right)
285
+
286
+ adj_mat = coo_matrix((data, (row_ind, col_ind)), shape=(rows * cols, rows * cols))
287
+ return adj_mat.tocsr()
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/AlignerDataset.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import warnings
4
+
5
+ import soundfile as sf
6
+ import torch
7
+ from numpy import trim_zeros
8
+ from speechbrain.pretrained import EncoderClassifier
9
+ from torch.multiprocessing import Manager
10
+ from torch.multiprocessing import Process
11
+ from torch.multiprocessing import set_start_method
12
+ from torch.utils.data import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
16
+ from Preprocessing.AudioPreprocessor import AudioPreprocessor
17
+
18
+
19
+ class AlignerDataset(Dataset):
20
+
21
+ def __init__(self,
22
+ path_to_transcript_dict,
23
+ cache_dir,
24
+ lang,
25
+ loading_processes=30, # careful with the amount of processes if you use silence removal, only as many processes as you have cores
26
+ min_len_in_seconds=1,
27
+ max_len_in_seconds=20,
28
+ cut_silences=False,
29
+ rebuild_cache=False,
30
+ verbose=False,
31
+ device="cpu"):
32
+ os.makedirs(cache_dir, exist_ok=True)
33
+ if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
34
+ if (device == "cuda" or device == torch.device("cuda")) and cut_silences:
35
+ try:
36
+ set_start_method('spawn') # in order to be able to make use of cuda in multiprocessing
37
+ except RuntimeError:
38
+ pass
39
+ elif cut_silences:
40
+ torch.set_num_threads(1)
41
+ if cut_silences:
42
+ torch.hub.load(repo_or_dir='snakers4/silero-vad',
43
+ model='silero_vad',
44
+ force_reload=False,
45
+ onnx=False,
46
+ verbose=False) # download and cache for it to be loaded and used later
47
+ torch.set_grad_enabled(True)
48
+ resource_manager = Manager()
49
+ self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
50
+ key_list = list(self.path_to_transcript_dict.keys())
51
+ with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note:
52
+ files_used_note.write(str(key_list))
53
+ random.shuffle(key_list)
54
+ # build cache
55
+ print("... building dataset cache ...")
56
+ self.datapoints = resource_manager.list()
57
+ # make processes
58
+ key_splits = list()
59
+ process_list = list()
60
+ for i in range(loading_processes):
61
+ key_splits.append(key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes])
62
+ for key_split in key_splits:
63
+ process_list.append(
64
+ Process(target=self.cache_builder_process,
65
+ args=(key_split,
66
+ lang,
67
+ min_len_in_seconds,
68
+ max_len_in_seconds,
69
+ cut_silences,
70
+ verbose,
71
+ device),
72
+ daemon=True))
73
+ process_list[-1].start()
74
+ for process in process_list:
75
+ process.join()
76
+ self.datapoints = list(self.datapoints)
77
+ tensored_datapoints = list()
78
+ # we had to turn all of the tensors to numpy arrays to avoid shared memory
79
+ # issues. Now that the multi-processing is over, we can convert them back
80
+ # to tensors to save on conversions in the future.
81
+ print("Converting into convenient format...")
82
+ norm_waves = list()
83
+ for datapoint in tqdm(self.datapoints):
84
+ tensored_datapoints.append([torch.Tensor(datapoint[0]),
85
+ torch.LongTensor(datapoint[1]),
86
+ torch.Tensor(datapoint[2]),
87
+ torch.LongTensor(datapoint[3])])
88
+ norm_waves.append(torch.Tensor(datapoint[-1]))
89
+
90
+ self.datapoints = tensored_datapoints
91
+
92
+ pop_indexes = list()
93
+ for index, el in enumerate(self.datapoints):
94
+ try:
95
+ if len(el[0][0]) != 66:
96
+ pop_indexes.append(index)
97
+ except TypeError:
98
+ pop_indexes.append(index)
99
+ for pop_index in sorted(pop_indexes, reverse=True):
100
+ print(f"There seems to be a problem in the transcriptions. Deleting datapoint {pop_index}.")
101
+ self.datapoints.pop(pop_index)
102
+
103
+ # add speaker embeddings
104
+ self.speaker_embeddings = list()
105
+ speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
106
+ run_opts={"device": str(device)},
107
+ savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
108
+ with torch.no_grad():
109
+ for wave in tqdm(norm_waves):
110
+ self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
111
+
112
+ # save to cache
113
+ torch.save((self.datapoints, norm_waves, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt"))
114
+ else:
115
+ # just load the datapoints from cache
116
+ self.datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
117
+ if len(self.datapoints) == 2:
118
+ # speaker embeddings are still missing, have to add them here
119
+ wave_datapoints = self.datapoints[1]
120
+ self.datapoints = self.datapoints[0]
121
+ self.speaker_embeddings = list()
122
+ speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
123
+ run_opts={"device": str(device)},
124
+ savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
125
+ with torch.no_grad():
126
+ for wave in tqdm(wave_datapoints):
127
+ self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
128
+ torch.save((self.datapoints, wave_datapoints, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt"))
129
+ else:
130
+ self.speaker_embeddings = self.datapoints[2]
131
+ self.datapoints = self.datapoints[0]
132
+
133
+ self.tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=True)
134
+ print(f"Prepared an Aligner dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
135
+
136
+ def cache_builder_process(self,
137
+ path_list,
138
+ lang,
139
+ min_len,
140
+ max_len,
141
+ cut_silences,
142
+ verbose,
143
+ device):
144
+ process_internal_dataset_chunk = list()
145
+ tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False)
146
+ _, sr = sf.read(path_list[0])
147
+ ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=cut_silences, device=device)
148
+
149
+ for path in tqdm(path_list):
150
+ if self.path_to_transcript_dict[path].strip() == "":
151
+ continue
152
+
153
+ wave, sr = sf.read(path)
154
+ dur_in_seconds = len(wave) / sr
155
+ if not (min_len <= dur_in_seconds <= max_len):
156
+ if verbose:
157
+ print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
158
+ continue
159
+ try:
160
+ with warnings.catch_warnings():
161
+ warnings.simplefilter("ignore") # otherwise we get tons of warnings about an RNN not being in contiguous chunks
162
+ norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave)
163
+ except ValueError:
164
+ continue
165
+ dur_in_seconds = len(norm_wave) / 16000
166
+ if not (min_len <= dur_in_seconds <= max_len):
167
+ if verbose:
168
+ print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
169
+ continue
170
+ norm_wave = torch.tensor(trim_zeros(norm_wave.numpy()))
171
+ # raw audio preprocessing is done
172
+ transcript = self.path_to_transcript_dict[path]
173
+ try:
174
+ cached_text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0).cpu().numpy()
175
+ except KeyError:
176
+ tf.string_to_tensor(transcript, handle_missing=True).squeeze(0).cpu().numpy()
177
+ continue # we skip sentences with unknown symbols
178
+ try:
179
+ if len(cached_text[0]) != 66:
180
+ print(f"There seems to be a problem with the following transcription: {transcript}")
181
+ continue
182
+ except TypeError:
183
+ print(f"There seems to be a problem with the following transcription: {transcript}")
184
+ continue
185
+ cached_text_len = torch.LongTensor([len(cached_text)]).numpy()
186
+ cached_speech = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1).cpu().numpy()
187
+ cached_speech_len = torch.LongTensor([len(cached_speech)]).numpy()
188
+ process_internal_dataset_chunk.append([cached_text,
189
+ cached_text_len,
190
+ cached_speech,
191
+ cached_speech_len,
192
+ norm_wave.cpu().detach().numpy()])
193
+ self.datapoints += process_internal_dataset_chunk
194
+
195
+ def __getitem__(self, index):
196
+ text_vector = self.datapoints[index][0]
197
+ tokens = list()
198
+ for vector in text_vector:
199
+ for phone in self.tf.phone_to_vector:
200
+ if vector.numpy().tolist() == self.tf.phone_to_vector[phone]:
201
+ tokens.append(self.tf.phone_to_id[phone])
202
+ # this is terribly inefficient, but it's good enough for testing for now.
203
+ tokens = torch.LongTensor(tokens)
204
+ return tokens, \
205
+ self.datapoints[index][1], \
206
+ self.datapoints[index][2], \
207
+ self.datapoints[index][3], \
208
+ self.speaker_embeddings[index]
209
+
210
+ def __len__(self):
211
+ return len(self.datapoints)
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/TinyTTS.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.multiprocessing
3
+ from torch.nn.utils.rnn import pack_padded_sequence
4
+ from torch.nn.utils.rnn import pad_packed_sequence
5
+
6
+ from Utility.utils import make_non_pad_mask
7
+
8
+
9
+ class TinyTTS(torch.nn.Module):
10
+
11
+ def __init__(self,
12
+ n_mels=80,
13
+ num_symbols=145,
14
+ speaker_embedding_dim=192,
15
+ lstm_dim=512):
16
+ super().__init__()
17
+ self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, lstm_dim)
18
+ self.rnn1 = torch.nn.LSTM(lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
19
+ self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
20
+ self.out_proj = torch.nn.Linear(2 * lstm_dim, n_mels)
21
+ self.l1_criterion = torch.nn.L1Loss(reduction="none")
22
+ self.l2_criterion = torch.nn.MSELoss(reduction="none")
23
+
24
+ def forward(self, x, lens, ys):
25
+ x = self.in_proj(x)
26
+ x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
27
+ x, _ = self.rnn1(x)
28
+ x, _ = self.rnn2(x)
29
+ x, _ = pad_packed_sequence(x, batch_first=True)
30
+ x = self.out_proj(x)
31
+ out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
32
+ out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
33
+ out_weights /= ys.size(0) * ys.size(2)
34
+ l1_loss = self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
35
+ l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
36
+ return l1_loss + l2_loss
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/__init__.py ADDED
File without changes
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import torch
5
+ import torch.multiprocessing
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from torch.optim import RAdam
8
+ from torch.utils.data.dataloader import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
12
+ from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.TinyTTS import TinyTTS
13
+
14
+
15
+ def collate_and_pad(batch):
16
+ # text, text_len, speech, speech_len
17
+ return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
18
+ torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
19
+ pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
20
+ torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
21
+ torch.stack([datapoint[4] for datapoint in batch]).squeeze())
22
+
23
+
24
+ def train_loop(train_dataset,
25
+ device,
26
+ save_directory,
27
+ batch_size,
28
+ steps,
29
+ path_to_checkpoint=None,
30
+ fine_tune=False,
31
+ resume=False,
32
+ debug_img_path=None,
33
+ use_reconstruction=True):
34
+ """
35
+ Args:
36
+ resume: whether to resume from the most recent checkpoint
37
+ steps: How many steps to train
38
+ path_to_checkpoint: reloads a checkpoint to continue training from there
39
+ fine_tune: whether to load everything from a checkpoint, or only the model parameters
40
+ train_dataset: Pytorch Dataset Object for train data
41
+ device: Device to put the loaded tensors on
42
+ save_directory: Where to save the checkpoints
43
+ batch_size: How many elements should be loaded at once
44
+ """
45
+ os.makedirs(save_directory, exist_ok=True)
46
+ train_loader = DataLoader(batch_size=batch_size,
47
+ dataset=train_dataset,
48
+ drop_last=True,
49
+ num_workers=8,
50
+ pin_memory=False,
51
+ shuffle=True,
52
+ prefetch_factor=16,
53
+ collate_fn=collate_and_pad,
54
+ persistent_workers=True)
55
+
56
+ asr_model = Aligner().to(device)
57
+ optim_asr = RAdam(asr_model.parameters(), lr=0.0001)
58
+
59
+ tiny_tts = TinyTTS().to(device)
60
+ optim_tts = RAdam(tiny_tts.parameters(), lr=0.0001)
61
+
62
+ step_counter = 0
63
+ if resume:
64
+ previous_checkpoint = os.path.join(save_directory, "aligner.pt")
65
+ path_to_checkpoint = previous_checkpoint
66
+ fine_tune = False
67
+
68
+ if path_to_checkpoint is not None:
69
+ check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
70
+ asr_model.load_state_dict(check_dict["asr_model"])
71
+ tiny_tts.load_state_dict(check_dict["tts_model"])
72
+ if not fine_tune:
73
+ optim_asr.load_state_dict(check_dict["optimizer"])
74
+ optim_tts.load_state_dict(check_dict["tts_optimizer"])
75
+ step_counter = check_dict["step_counter"]
76
+ if step_counter > steps:
77
+ print("Desired steps already reached in loaded checkpoint.")
78
+ return
79
+ start_time = time.time()
80
+
81
+ while True:
82
+ loss_sum = list()
83
+
84
+ asr_model.train()
85
+ tiny_tts.train()
86
+ for batch in tqdm(train_loader):
87
+ tokens = batch[0].to(device)
88
+ tokens_len = batch[1].to(device)
89
+ mel = batch[2].to(device)
90
+ mel_len = batch[3].to(device)
91
+ speaker_embeddings = batch[4].to(device)
92
+
93
+ pred = asr_model(mel, mel_len)
94
+
95
+ ctc_loss = asr_model.ctc_loss(pred.transpose(0, 1).log_softmax(2),
96
+ tokens,
97
+ mel_len,
98
+ tokens_len)
99
+
100
+ if use_reconstruction:
101
+ speaker_embeddings_expanded = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(1).expand(-1, pred.size(1), -1)
102
+ tts_lambda = min([5, step_counter / 2000]) # super simple schedule
103
+ reconstruction_loss = tiny_tts(x=torch.cat([pred, speaker_embeddings_expanded], dim=-1),
104
+ # combine ASR prediction with speaker embeddings to allow for reconstruction loss on multiple speakers
105
+ lens=mel_len,
106
+ ys=mel) * tts_lambda # reconstruction loss to make the states more distinct
107
+ loss = ctc_loss + reconstruction_loss
108
+ else:
109
+ loss = ctc_loss
110
+
111
+ optim_asr.zero_grad()
112
+ if use_reconstruction:
113
+ optim_tts.zero_grad()
114
+ loss.backward()
115
+ torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
116
+ if use_reconstruction:
117
+ torch.nn.utils.clip_grad_norm_(tiny_tts.parameters(), 1.0)
118
+ optim_asr.step()
119
+ if use_reconstruction:
120
+ optim_tts.step()
121
+
122
+ step_counter += 1
123
+
124
+ loss_sum.append(loss.item())
125
+
126
+ asr_model.eval()
127
+ loss_this_epoch = sum(loss_sum) / len(loss_sum)
128
+ torch.save({
129
+ "asr_model" : asr_model.state_dict(),
130
+ "optimizer" : optim_asr.state_dict(),
131
+ "tts_model" : tiny_tts.state_dict(),
132
+ "tts_optimizer": optim_tts.state_dict(),
133
+ "step_counter" : step_counter,
134
+ },
135
+ os.path.join(save_directory, "aligner.pt"))
136
+ print("Total Loss: {}".format(round(loss_this_epoch, 3)))
137
+ print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
138
+ print("Steps: {}".format(step_counter))
139
+ if debug_img_path is not None:
140
+ asr_model.inference(mel=mel[0][:mel_len[0]],
141
+ tokens=tokens[0][:tokens_len[0]],
142
+ save_img_for_debug=debug_img_path + f"/{step_counter}.png",
143
+ train=True) # for testing
144
+ if step_counter > steps:
145
+ return
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/DurationCalculator.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Nagoya University (Tomoki Hayashi)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ import matplotlib.pyplot as plt
6
+
7
+ import torch
8
+
9
+
10
+ class DurationCalculator(torch.nn.Module):
11
+
12
+ def __init__(self, reduction_factor):
13
+ self.reduction_factor = reduction_factor
14
+ super().__init__()
15
+
16
+ @torch.no_grad()
17
+ def forward(self, att_ws, vis=None):
18
+ """
19
+ Convert alignment matrix to durations.
20
+ """
21
+ if vis is not None:
22
+ plt.figure(figsize=(8, 4))
23
+ plt.imshow(att_ws.cpu().numpy(), interpolation='nearest', aspect='auto', origin="lower")
24
+ plt.xlabel("Inputs")
25
+ plt.ylabel("Outputs")
26
+ plt.tight_layout()
27
+ plt.savefig(vis)
28
+ plt.close()
29
+ # calculate duration from 2d alignment matrix
30
+ durations = torch.stack([att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])])
31
+ return durations.view(-1) * self.reduction_factor
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/EnergyCalculator.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Nagoya University (Tomoki Hayashi)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from Layers.STFT import STFT
9
+ from Utility.utils import pad_list
10
+
11
+
12
+ class EnergyCalculator(torch.nn.Module):
13
+
14
+ def __init__(self, fs=16000, n_fft=1024, win_length=None, hop_length=256, window="hann", center=True,
15
+ normalized=False, onesided=True, use_token_averaged_energy=True, reduction_factor=1):
16
+ super().__init__()
17
+
18
+ self.fs = fs
19
+ self.n_fft = n_fft
20
+ self.hop_length = hop_length
21
+ self.win_length = win_length
22
+ self.window = window
23
+ self.use_token_averaged_energy = use_token_averaged_energy
24
+ if use_token_averaged_energy:
25
+ assert reduction_factor >= 1
26
+ self.reduction_factor = reduction_factor
27
+
28
+ self.stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided)
29
+
30
+ def output_size(self):
31
+ return 1
32
+
33
+ def get_parameters(self):
34
+ return dict(fs=self.fs, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, win_length=self.win_length, center=self.stft.center,
35
+ normalized=self.stft.normalized, use_token_averaged_energy=self.use_token_averaged_energy, reduction_factor=self.reduction_factor)
36
+
37
+ def forward(self, input_waves, input_waves_lengths=None, feats_lengths=None, durations=None,
38
+ durations_lengths=None, norm_by_average=True):
39
+ # If not provided, we assume that the inputs have the same length
40
+ if input_waves_lengths is None:
41
+ input_waves_lengths = (input_waves.new_ones(input_waves.shape[0], dtype=torch.long) * input_waves.shape[1])
42
+
43
+ # Domain-conversion: e.g. Stft: time -> time-freq
44
+ input_stft, energy_lengths = self.stft(input_waves, input_waves_lengths)
45
+
46
+ assert input_stft.dim() >= 4, input_stft.shape
47
+ assert input_stft.shape[-1] == 2, input_stft.shape
48
+
49
+ # input_stft: (..., F, 2) -> (..., F)
50
+ input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
51
+ # sum over frequency (B, N, F) -> (B, N)
52
+ energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10))
53
+
54
+ # (Optional): Adjust length to match with the mel-spectrogram
55
+ if feats_lengths is not None:
56
+ energy = [self._adjust_num_frames(e[:el].view(-1), fl) for e, el, fl in zip(energy, energy_lengths, feats_lengths)]
57
+ energy_lengths = feats_lengths
58
+
59
+ # (Optional): Average by duration to calculate token-wise energy
60
+ if self.use_token_averaged_energy:
61
+ energy = [self._average_by_duration(e[:el].view(-1), d) for e, el, d in zip(energy, energy_lengths, durations)]
62
+ energy_lengths = durations_lengths
63
+
64
+ # Padding
65
+ if isinstance(energy, list):
66
+ energy = pad_list(energy, 0.0)
67
+
68
+ # Return with the shape (B, T, 1)
69
+ if norm_by_average:
70
+ average = energy[0][energy[0] != 0.0].mean()
71
+ energy = energy / average
72
+ return energy.unsqueeze(-1), energy_lengths
73
+
74
+ def _average_by_duration(self, x, d):
75
+ assert 0 <= len(x) - d.sum() < self.reduction_factor
76
+ d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
77
+ x_avg = [x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0) for start, end in zip(d_cumsum[:-1], d_cumsum[1:])]
78
+ return torch.stack(x_avg)
79
+
80
+ @staticmethod
81
+ def _adjust_num_frames(x, num_frames):
82
+ if num_frames > len(x):
83
+ x = F.pad(x, (0, num_frames - len(x)))
84
+ elif num_frames < len(x):
85
+ x = x[:num_frames]
86
+ return x
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ from abc import ABC
6
+
7
+ import torch
8
+
9
+ from Layers.Conformer import Conformer
10
+ from Layers.DurationPredictor import DurationPredictor
11
+ from Layers.LengthRegulator import LengthRegulator
12
+ from Layers.PostNet import PostNet
13
+ from Layers.VariancePredictor import VariancePredictor
14
+ from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.FastSpeech2Loss import FastSpeech2Loss
15
+ from Utility.SoftDTW.sdtw_cuda_loss import SoftDTW
16
+ from Utility.utils import initialize
17
+ from Utility.utils import make_non_pad_mask
18
+ from Utility.utils import make_pad_mask
19
+
20
+
21
+ class FastSpeech2(torch.nn.Module, ABC):
22
+ """
23
+ FastSpeech 2 module.
24
+
25
+ This is a module of FastSpeech 2 described in FastSpeech 2: Fast and
26
+ High-Quality End-to-End Text to Speech. Instead of quantized pitch and
27
+ energy, we use token-averaged value introduced in FastPitch: Parallel
28
+ Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers
29
+ instead of regular Transformers.
30
+
31
+ https://arxiv.org/abs/2006.04558
32
+ https://arxiv.org/abs/2006.06873
33
+ https://arxiv.org/pdf/2005.08100
34
+
35
+ """
36
+
37
+ def __init__(self,
38
+ # network structure related
39
+ idim=66,
40
+ odim=80,
41
+ adim=384,
42
+ aheads=4,
43
+ elayers=6,
44
+ eunits=1536,
45
+ dlayers=6,
46
+ dunits=1536,
47
+ postnet_layers=5,
48
+ postnet_chans=256,
49
+ postnet_filts=5,
50
+ positionwise_layer_type="conv1d",
51
+ positionwise_conv_kernel_size=1,
52
+ use_scaled_pos_enc=True,
53
+ use_batch_norm=True,
54
+ encoder_normalize_before=True,
55
+ decoder_normalize_before=True,
56
+ encoder_concat_after=False,
57
+ decoder_concat_after=False,
58
+ reduction_factor=1,
59
+ # encoder / decoder
60
+ use_macaron_style_in_conformer=True,
61
+ use_cnn_in_conformer=True,
62
+ conformer_enc_kernel_size=7,
63
+ conformer_dec_kernel_size=31,
64
+ # duration predictor
65
+ duration_predictor_layers=2,
66
+ duration_predictor_chans=256,
67
+ duration_predictor_kernel_size=3,
68
+ # energy predictor
69
+ energy_predictor_layers=2,
70
+ energy_predictor_chans=256,
71
+ energy_predictor_kernel_size=3,
72
+ energy_predictor_dropout=0.5,
73
+ energy_embed_kernel_size=1,
74
+ energy_embed_dropout=0.0,
75
+ stop_gradient_from_energy_predictor=False,
76
+ # pitch predictor
77
+ pitch_predictor_layers=5,
78
+ pitch_predictor_chans=256,
79
+ pitch_predictor_kernel_size=5,
80
+ pitch_predictor_dropout=0.5,
81
+ pitch_embed_kernel_size=1,
82
+ pitch_embed_dropout=0.0,
83
+ stop_gradient_from_pitch_predictor=True,
84
+ # training related
85
+ transformer_enc_dropout_rate=0.2,
86
+ transformer_enc_positional_dropout_rate=0.2,
87
+ transformer_enc_attn_dropout_rate=0.2,
88
+ transformer_dec_dropout_rate=0.2,
89
+ transformer_dec_positional_dropout_rate=0.2,
90
+ transformer_dec_attn_dropout_rate=0.2,
91
+ duration_predictor_dropout_rate=0.2,
92
+ postnet_dropout_rate=0.5,
93
+ init_type="xavier_uniform",
94
+ init_enc_alpha=1.0,
95
+ init_dec_alpha=1.0,
96
+ use_masking=False,
97
+ use_weighted_masking=True,
98
+ # additional features
99
+ use_dtw_loss=False,
100
+ utt_embed_dim=704,
101
+ connect_utt_emb_at_encoder_out=True,
102
+ lang_embs=100):
103
+ super().__init__()
104
+
105
+ # store hyperparameters
106
+ self.idim = idim
107
+ self.odim = odim
108
+ self.use_dtw_loss = use_dtw_loss
109
+ self.eos = 1
110
+ self.reduction_factor = reduction_factor
111
+ self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
112
+ self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
113
+ self.use_scaled_pos_enc = use_scaled_pos_enc
114
+ self.multilingual_model = lang_embs is not None
115
+ self.multispeaker_model = utt_embed_dim is not None
116
+
117
+ # define encoder
118
+ embed = torch.nn.Sequential(torch.nn.Linear(idim, 100),
119
+ torch.nn.Tanh(),
120
+ torch.nn.Linear(100, adim))
121
+ self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers,
122
+ input_layer=embed, dropout_rate=transformer_enc_dropout_rate,
123
+ positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate,
124
+ normalize_before=encoder_normalize_before, concat_after=encoder_concat_after,
125
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer,
126
+ use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False,
127
+ utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs)
128
+
129
+ # define duration predictor
130
+ self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans,
131
+ kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, )
132
+
133
+ # define pitch predictor
134
+ self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans,
135
+ kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout)
136
+ # continuous pitch + FastPitch style avg
137
+ self.pitch_embed = torch.nn.Sequential(
138
+ torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2),
139
+ torch.nn.Dropout(pitch_embed_dropout))
140
+
141
+ # define energy predictor
142
+ self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans,
143
+ kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout)
144
+ # continuous energy + FastPitch style avg
145
+ self.energy_embed = torch.nn.Sequential(
146
+ torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2),
147
+ torch.nn.Dropout(energy_embed_dropout))
148
+
149
+ # define length regulator
150
+ self.length_regulator = LengthRegulator()
151
+
152
+ self.decoder = Conformer(idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None,
153
+ dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate,
154
+ attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before,
155
+ concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size,
156
+ macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size)
157
+
158
+ # define final projection
159
+ self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
160
+
161
+ # define postnet
162
+ self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm,
163
+ dropout_rate=postnet_dropout_rate)
164
+
165
+ # initialize parameters
166
+ self._reset_parameters(init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha)
167
+
168
+ # define criterions
169
+ self.criterion = FastSpeech2Loss(use_masking=use_masking, use_weighted_masking=use_weighted_masking)
170
+ self.dtw_criterion = SoftDTW(use_cuda=True, gamma=0.1)
171
+
172
+ def forward(self,
173
+ text_tensors,
174
+ text_lengths,
175
+ gold_speech,
176
+ speech_lengths,
177
+ gold_durations,
178
+ gold_pitch,
179
+ gold_energy,
180
+ utterance_embedding,
181
+ return_mels=False,
182
+ lang_ids=None):
183
+ """
184
+ Calculate forward propagation.
185
+
186
+ Args:
187
+ return_mels: whether to return the predicted spectrogram
188
+ text_tensors (LongTensor): Batch of padded text vectors (B, Tmax).
189
+ text_lengths (LongTensor): Batch of lengths of each input (B,).
190
+ gold_speech (Tensor): Batch of padded target features (B, Lmax, odim).
191
+ speech_lengths (LongTensor): Batch of the lengths of each target (B,).
192
+ gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1).
193
+ gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1).
194
+ gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1).
195
+
196
+ Returns:
197
+ Tensor: Loss scalar value.
198
+ Dict: Statistics to be monitored.
199
+ Tensor: Weight value.
200
+ """
201
+ # Texts include EOS token from the teacher model already in this version
202
+
203
+ # forward propagation
204
+ before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(text_tensors, text_lengths, gold_speech, speech_lengths,
205
+ gold_durations, gold_pitch, gold_energy, utterance_embedding=utterance_embedding,
206
+ is_inference=False, lang_ids=lang_ids)
207
+
208
+ # modify mod part of groundtruth (speaking pace)
209
+ if self.reduction_factor > 1:
210
+ speech_lengths = speech_lengths.new([olen - olen % self.reduction_factor for olen in speech_lengths])
211
+
212
+ # calculate loss
213
+ l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs,
214
+ e_outs=e_outs, ys=gold_speech, ds=gold_durations, ps=gold_pitch, es=gold_energy,
215
+ ilens=text_lengths, olens=speech_lengths)
216
+ loss = l1_loss + duration_loss + pitch_loss + energy_loss
217
+
218
+ if self.use_dtw_loss:
219
+ # print("Regular Loss: {}".format(loss))
220
+ dtw_loss = self.dtw_criterion(after_outs, gold_speech).mean() / 2000.0 # division to balance orders of magnitude
221
+ # print("DTW Loss: {}".format(dtw_loss))
222
+ loss = loss + dtw_loss
223
+
224
+ if return_mels:
225
+ return loss, after_outs
226
+ return loss
227
+
228
+ def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None,
229
+ gold_durations=None, gold_pitch=None, gold_energy=None,
230
+ is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None):
231
+
232
+ if not self.multilingual_model:
233
+ lang_ids = None
234
+
235
+ if not self.multispeaker_model:
236
+ utterance_embedding = None
237
+
238
+ # forward encoder
239
+ text_masks = self._source_mask(text_lens)
240
+
241
+ encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim)
242
+
243
+ # forward duration predictor and variance predictors
244
+ d_masks = make_pad_mask(text_lens, device=text_lens.device)
245
+
246
+ if self.stop_gradient_from_pitch_predictor:
247
+ pitch_predictions = self.pitch_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1))
248
+ else:
249
+ pitch_predictions = self.pitch_predictor(encoded_texts, d_masks.unsqueeze(-1))
250
+
251
+ if self.stop_gradient_from_energy_predictor:
252
+ energy_predictions = self.energy_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1))
253
+ else:
254
+ energy_predictions = self.energy_predictor(encoded_texts, d_masks.unsqueeze(-1))
255
+
256
+ if is_inference:
257
+ d_outs = self.duration_predictor.inference(encoded_texts, d_masks) # (B, Tmax)
258
+ # use prediction in inference
259
+ p_embs = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
260
+ e_embs = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
261
+ encoded_texts = encoded_texts + e_embs + p_embs
262
+ encoded_texts = self.length_regulator(encoded_texts, d_outs, alpha) # (B, Lmax, adim)
263
+ else:
264
+ d_outs = self.duration_predictor(encoded_texts, d_masks)
265
+
266
+ # use groundtruth in training
267
+ p_embs = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
268
+ e_embs = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
269
+ encoded_texts = encoded_texts + e_embs + p_embs
270
+ encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim)
271
+
272
+ # forward decoder
273
+ if speech_lens is not None and not is_inference:
274
+ if self.reduction_factor > 1:
275
+ olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens])
276
+ else:
277
+ olens_in = speech_lens
278
+ h_masks = self._source_mask(olens_in)
279
+ else:
280
+ h_masks = None
281
+ zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim)
282
+ before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
283
+
284
+ # postnet -> (B, Lmax//r * r, odim)
285
+ after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
286
+
287
+ return before_outs, after_outs, d_outs, pitch_predictions, energy_predictions
288
+
289
+ def batch_inference(self, texts, text_lens, utt_emb):
290
+ _, after_outs, d_outs, _, _ = self._forward(texts,
291
+ text_lens,
292
+ None,
293
+ is_inference=True,
294
+ alpha=1.0)
295
+ return after_outs, d_outs
296
+
297
+ def inference(self,
298
+ text,
299
+ speech=None,
300
+ durations=None,
301
+ pitch=None,
302
+ energy=None,
303
+ alpha=1.0,
304
+ use_teacher_forcing=False,
305
+ utterance_embedding=None,
306
+ return_duration_pitch_energy=False,
307
+ lang_id=None):
308
+ """
309
+ Generate the sequence of features given the sequences of characters.
310
+
311
+ Args:
312
+ text (LongTensor): Input sequence of characters (T,).
313
+ speech (Tensor, optional): Feature sequence to extract style (N, idim).
314
+ durations (LongTensor, optional): Groundtruth of duration (T + 1,).
315
+ pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1).
316
+ energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1).
317
+ alpha (float, optional): Alpha to control the speed.
318
+ use_teacher_forcing (bool, optional): Whether to use teacher forcing.
319
+ If true, groundtruth of duration, pitch and energy will be used.
320
+ return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
321
+
322
+ Returns:
323
+ Tensor: Output sequence of features (L, odim).
324
+
325
+ """
326
+ self.eval()
327
+ x, y = text, speech
328
+ d, p, e = durations, pitch, energy
329
+
330
+ # setup batch axis
331
+ ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
332
+ xs, ys = x.unsqueeze(0), None
333
+ if y is not None:
334
+ ys = y.unsqueeze(0)
335
+ if lang_id is not None:
336
+ lang_id = lang_id.unsqueeze(0)
337
+
338
+ if use_teacher_forcing:
339
+ # use groundtruth of duration, pitch, and energy
340
+ ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0)
341
+ before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs,
342
+ ilens,
343
+ ys,
344
+ gold_durations=ds,
345
+ gold_pitch=ps,
346
+ gold_energy=es,
347
+ utterance_embedding=utterance_embedding.unsqueeze(0),
348
+ lang_ids=lang_id) # (1, L, odim)
349
+ else:
350
+ before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs,
351
+ ilens,
352
+ ys,
353
+ is_inference=True,
354
+ alpha=alpha,
355
+ utterance_embedding=utterance_embedding.unsqueeze(0),
356
+ lang_ids=lang_id) # (1, L, odim)
357
+ self.train()
358
+ if return_duration_pitch_energy:
359
+ return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0]
360
+ return after_outs[0]
361
+
362
+ def _source_mask(self, ilens):
363
+ """
364
+ Make masks for self-attention.
365
+
366
+ Args:
367
+ ilens (LongTensor): Batch of lengths (B,).
368
+
369
+ Returns:
370
+ Tensor: Mask tensor for self-attention.
371
+
372
+ """
373
+ x_masks = make_non_pad_mask(ilens, device=ilens.device)
374
+ return x_masks.unsqueeze(-2)
375
+
376
+ def _reset_parameters(self, init_type, init_enc_alpha, init_dec_alpha):
377
+ # initialize parameters
378
+ if init_type != "pytorch":
379
+ initialize(self, init_type)
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2Loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import torch
6
+
7
+ from Layers.DurationPredictor import DurationPredictorLoss
8
+ from Utility.utils import make_non_pad_mask
9
+
10
+
11
+ class FastSpeech2Loss(torch.nn.Module):
12
+
13
+ def __init__(self, use_masking=True, use_weighted_masking=False):
14
+ """
15
+ use_masking (bool):
16
+ Whether to apply masking for padded part in loss calculation.
17
+ use_weighted_masking (bool):
18
+ Whether to weighted masking in loss calculation.
19
+ """
20
+ super().__init__()
21
+
22
+ assert (use_masking != use_weighted_masking) or not use_masking
23
+ self.use_masking = use_masking
24
+ self.use_weighted_masking = use_weighted_masking
25
+
26
+ # define criterions
27
+ reduction = "none" if self.use_weighted_masking else "mean"
28
+ self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
29
+ self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
30
+ self.duration_criterion = DurationPredictorLoss(reduction=reduction)
31
+
32
+ def forward(self, after_outs, before_outs, d_outs, p_outs, e_outs, ys,
33
+ ds, ps, es, ilens, olens, ):
34
+ """
35
+ Args:
36
+ after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
37
+ before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
38
+ d_outs (LongTensor): Batch of outputs of duration predictor (B, Tmax).
39
+ p_outs (Tensor): Batch of outputs of pitch predictor (B, Tmax, 1).
40
+ e_outs (Tensor): Batch of outputs of energy predictor (B, Tmax, 1).
41
+ ys (Tensor): Batch of target features (B, Lmax, odim).
42
+ ds (LongTensor): Batch of durations (B, Tmax).
43
+ ps (Tensor): Batch of target token-averaged pitch (B, Tmax, 1).
44
+ es (Tensor): Batch of target token-averaged energy (B, Tmax, 1).
45
+ ilens (LongTensor): Batch of the lengths of each input (B,).
46
+ olens (LongTensor): Batch of the lengths of each target (B,).
47
+
48
+ Returns:
49
+ Tensor: L1 loss value.
50
+ Tensor: Duration predictor loss value.
51
+ Tensor: Pitch predictor loss value.
52
+ Tensor: Energy predictor loss value.
53
+
54
+ """
55
+ # apply mask to remove padded part
56
+ if self.use_masking:
57
+ out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
58
+ before_outs = before_outs.masked_select(out_masks)
59
+ if after_outs is not None:
60
+ after_outs = after_outs.masked_select(out_masks)
61
+ ys = ys.masked_select(out_masks)
62
+ duration_masks = make_non_pad_mask(ilens).to(ys.device)
63
+ d_outs = d_outs.masked_select(duration_masks)
64
+ ds = ds.masked_select(duration_masks)
65
+ pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device)
66
+ p_outs = p_outs.masked_select(pitch_masks)
67
+ e_outs = e_outs.masked_select(pitch_masks)
68
+ ps = ps.masked_select(pitch_masks)
69
+ es = es.masked_select(pitch_masks)
70
+
71
+ # calculate loss
72
+ l1_loss = self.l1_criterion(before_outs, ys)
73
+ if after_outs is not None:
74
+ l1_loss += self.l1_criterion(after_outs, ys)
75
+ duration_loss = self.duration_criterion(d_outs, ds)
76
+ pitch_loss = self.mse_criterion(p_outs, ps)
77
+ energy_loss = self.mse_criterion(e_outs, es)
78
+
79
+ # make weighted mask and apply it
80
+ if self.use_weighted_masking:
81
+ out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
82
+ out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
83
+ out_weights /= ys.size(0) * ys.size(2)
84
+ duration_masks = make_non_pad_mask(ilens).to(ys.device)
85
+ duration_weights = (duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float())
86
+ duration_weights /= ds.size(0)
87
+
88
+ # apply weight
89
+ l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
90
+ duration_loss = (duration_loss.mul(duration_weights).masked_select(duration_masks).sum())
91
+ pitch_masks = duration_masks.unsqueeze(-1)
92
+ pitch_weights = duration_weights.unsqueeze(-1)
93
+ pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum()
94
+ energy_loss = (energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum())
95
+
96
+ return l1_loss, duration_loss, pitch_loss, energy_loss
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDatasetLanguageID.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import statistics
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from tqdm import tqdm
7
+
8
+ from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
9
+ from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
10
+ from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
11
+ from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset
12
+ from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
13
+ from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
14
+ from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio
15
+
16
+
17
+ class FastSpeechDataset(Dataset):
18
+
19
+ def __init__(self,
20
+ path_to_transcript_dict,
21
+ acoustic_checkpoint_path,
22
+ cache_dir,
23
+ lang,
24
+ loading_processes=40,
25
+ min_len_in_seconds=1,
26
+ max_len_in_seconds=20,
27
+ cut_silence=False,
28
+ reduction_factor=1,
29
+ device=torch.device("cpu"),
30
+ rebuild_cache=False,
31
+ ctc_selection=True,
32
+ save_imgs=False):
33
+ self.cache_dir = cache_dir
34
+ os.makedirs(cache_dir, exist_ok=True)
35
+ if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache:
36
+ if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
37
+ AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
38
+ cache_dir=cache_dir,
39
+ lang=lang,
40
+ loading_processes=loading_processes,
41
+ min_len_in_seconds=min_len_in_seconds,
42
+ max_len_in_seconds=max_len_in_seconds,
43
+ cut_silences=cut_silence,
44
+ rebuild_cache=rebuild_cache,
45
+ device=device)
46
+ datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
47
+ # we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech.
48
+ if not isinstance(datapoints, tuple): # check for backwards compatibility
49
+ print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.")
50
+ AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
51
+ cache_dir=cache_dir,
52
+ lang=lang,
53
+ loading_processes=loading_processes,
54
+ min_len_in_seconds=min_len_in_seconds,
55
+ max_len_in_seconds=max_len_in_seconds,
56
+ cut_silences=cut_silence,
57
+ rebuild_cache=True)
58
+ datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
59
+ dataset = datapoints[0]
60
+ norm_waves = datapoints[1]
61
+
62
+ # build cache
63
+ print("... building dataset cache ...")
64
+ self.datapoints = list()
65
+ self.ctc_losses = list()
66
+
67
+ acoustic_model = Aligner()
68
+ acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
69
+
70
+ # ==========================================
71
+ # actual creation of datapoints starts here
72
+ # ==========================================
73
+
74
+ acoustic_model = acoustic_model.to(device)
75
+ dio = Dio(reduction_factor=reduction_factor, fs=16000)
76
+ energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000)
77
+ dc = DurationCalculator(reduction_factor=reduction_factor)
78
+ vis_dir = os.path.join(cache_dir, "duration_vis")
79
+ os.makedirs(vis_dir, exist_ok=True)
80
+ pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device)
81
+
82
+ for index in tqdm(range(len(dataset))):
83
+ norm_wave = norm_waves[index]
84
+ norm_wave_length = torch.LongTensor([len(norm_wave)])
85
+
86
+ if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection:
87
+ continue
88
+
89
+ text = dataset[index][0]
90
+ melspec = dataset[index][2]
91
+ melspec_length = dataset[index][3]
92
+
93
+ alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device),
94
+ tokens=text.to(device),
95
+ save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None,
96
+ return_ctc=True)
97
+
98
+ cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
99
+
100
+ last_vec = None
101
+ for phoneme_index, vec in enumerate(text):
102
+ if last_vec is not None:
103
+ if last_vec.numpy().tolist() == vec.numpy().tolist():
104
+ # we found a case of repeating phonemes!
105
+ # now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
106
+ dur_1 = cached_duration[phoneme_index - 1]
107
+ dur_2 = cached_duration[phoneme_index]
108
+ total_dur = dur_1 + dur_2
109
+ new_dur_1 = int((total_dur / 5) * 3)
110
+ new_dur_2 = total_dur - new_dur_1
111
+ cached_duration[phoneme_index - 1] = new_dur_1
112
+ cached_duration[phoneme_index] = new_dur_2
113
+ last_vec = vec
114
+
115
+ cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
116
+ input_waves_lengths=norm_wave_length,
117
+ feats_lengths=melspec_length,
118
+ durations=cached_duration.unsqueeze(0),
119
+ durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
120
+
121
+ cached_pitch = dio(input_waves=norm_wave.unsqueeze(0),
122
+ input_waves_lengths=norm_wave_length,
123
+ feats_lengths=melspec_length,
124
+ durations=cached_duration.unsqueeze(0),
125
+ durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
126
+
127
+ try:
128
+ prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu()
129
+ except RuntimeError:
130
+ # if there is an audio without any voiced segments whatsoever we have to skip it.
131
+ continue
132
+
133
+ self.datapoints.append([dataset[index][0],
134
+ dataset[index][1],
135
+ dataset[index][2],
136
+ dataset[index][3],
137
+ cached_duration.cpu(),
138
+ cached_energy,
139
+ cached_pitch,
140
+ prosodic_condition])
141
+ self.ctc_losses.append(ctc_loss)
142
+
143
+ # =============================
144
+ # done with datapoint creation
145
+ # =============================
146
+
147
+ if ctc_selection:
148
+ # now we can filter out some bad datapoints based on the CTC scores we collected
149
+ mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses)
150
+ std_dev = statistics.stdev(self.ctc_losses)
151
+ threshold = mean_ctc + std_dev
152
+ for index in range(len(self.ctc_losses), 0, -1):
153
+ if self.ctc_losses[index - 1] > threshold:
154
+ self.datapoints.pop(index - 1)
155
+ print(
156
+ f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}")
157
+
158
+ # save to cache
159
+ if len(self.datapoints) > 0:
160
+ torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt"))
161
+ else:
162
+ import sys
163
+ print("No datapoints were prepared! Exiting...")
164
+ sys.exit()
165
+ else:
166
+ # just load the datapoints from cache
167
+ self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu')
168
+
169
+ self.cache_dir = cache_dir
170
+ self.language_id = get_language_id(lang)
171
+ print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
172
+
173
+ def __getitem__(self, index):
174
+ return self.datapoints[index][0], \
175
+ self.datapoints[index][1], \
176
+ self.datapoints[index][2], \
177
+ self.datapoints[index][3], \
178
+ self.datapoints[index][4], \
179
+ self.datapoints[index][5], \
180
+ self.datapoints[index][6], \
181
+ self.datapoints[index][7], \
182
+ self.language_id
183
+
184
+ def __len__(self):
185
+ return len(self.datapoints)
186
+
187
+ def remove_samples(self, list_of_samples_to_remove):
188
+ for remove_id in sorted(list_of_samples_to_remove, reverse=True):
189
+ self.datapoints.pop(remove_id)
190
+ torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
191
+ print("Dataset updated!")
192
+
193
+ def fix_repeating_phones(self):
194
+ """
195
+ The viterbi decoding of the durations cannot
196
+ handle repetitions. This is now solved heuristically,
197
+ but if you have a cache from before March 2022,
198
+ use this method to postprocess those cases.
199
+ """
200
+ for datapoint_index in tqdm(list(range(len(self.datapoints)))):
201
+ last_vec = None
202
+ for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]):
203
+ if last_vec is not None:
204
+ if last_vec.numpy().tolist() == vec.numpy().tolist():
205
+ # we found a case of repeating phonemes!
206
+ # now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
207
+ dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1]
208
+ dur_2 = self.datapoints[datapoint_index][4][phoneme_index]
209
+ total_dur = dur_1 + dur_2
210
+ new_dur_1 = int((total_dur / 5) * 3)
211
+ new_dur_2 = total_dur - new_dur_1
212
+ self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1
213
+ self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2
214
+ print("fix applied")
215
+ last_vec = vec
216
+ torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
217
+ print("Dataset updated!")
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/PitchCalculator.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Nagoya University (Tomoki Hayashi)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ import numpy as np
6
+ import pyworld
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from scipy.interpolate import interp1d
10
+
11
+ from Utility.utils import pad_list
12
+
13
+
14
+ class Dio(torch.nn.Module):
15
+ """
16
+ F0 estimation with dio + stonemask algortihm.
17
+ This is f0 extractor based on dio + stonemask algorithm
18
+ introduced in https://doi.org/10.1587/transinf.2015EDP7457
19
+ """
20
+
21
+ def __init__(self, fs=16000, n_fft=1024, hop_length=256, f0min=40, f0max=400, use_token_averaged_f0=True,
22
+ use_continuous_f0=True, use_log_f0=True, reduction_factor=1):
23
+ super().__init__()
24
+ self.fs = fs
25
+ self.n_fft = n_fft
26
+ self.hop_length = hop_length
27
+ self.frame_period = 1000 * hop_length / fs
28
+ self.f0min = f0min
29
+ self.f0max = f0max
30
+ self.use_token_averaged_f0 = use_token_averaged_f0
31
+ self.use_continuous_f0 = use_continuous_f0
32
+ self.use_log_f0 = use_log_f0
33
+ if use_token_averaged_f0:
34
+ assert reduction_factor >= 1
35
+ self.reduction_factor = reduction_factor
36
+
37
+ def output_size(self):
38
+ return 1
39
+
40
+ def get_parameters(self):
41
+ return dict(fs=self.fs, n_fft=self.n_fft, hop_length=self.hop_length, f0min=self.f0min, f0max=self.f0max,
42
+ use_token_averaged_f0=self.use_token_averaged_f0, use_continuous_f0=self.use_continuous_f0, use_log_f0=self.use_log_f0,
43
+ reduction_factor=self.reduction_factor)
44
+
45
+ def forward(self, input_waves, input_waves_lengths=None, feats_lengths=None, durations=None,
46
+ durations_lengths=None, norm_by_average=True):
47
+ # If not provided, we assume that the inputs have the same length
48
+ if input_waves_lengths is None:
49
+ input_waves_lengths = (input_waves.new_ones(input_waves.shape[0], dtype=torch.long) * input_waves.shape[1])
50
+
51
+ # F0 extraction
52
+ pitch = [self._calculate_f0(x[:xl]) for x, xl in zip(input_waves, input_waves_lengths)]
53
+
54
+ # (Optional): Adjust length to match with the mel-spectrogram
55
+ if feats_lengths is not None:
56
+ pitch = [self._adjust_num_frames(p, fl).view(-1) for p, fl in zip(pitch, feats_lengths)]
57
+
58
+ # (Optional): Average by duration to calculate token-wise f0
59
+ if self.use_token_averaged_f0:
60
+ pitch = [self._average_by_duration(p, d).view(-1) for p, d in zip(pitch, durations)]
61
+ pitch_lengths = durations_lengths
62
+ else:
63
+ pitch_lengths = input_waves.new_tensor([len(p) for p in pitch], dtype=torch.long)
64
+
65
+ # Padding
66
+ pitch = pad_list(pitch, 0.0)
67
+
68
+ # Return with the shape (B, T, 1)
69
+ if norm_by_average:
70
+ average = pitch[0][pitch[0] != 0.0].mean()
71
+ pitch = pitch / average
72
+ return pitch.unsqueeze(-1), pitch_lengths
73
+
74
+ def _calculate_f0(self, input):
75
+ x = input.cpu().numpy().astype(np.double)
76
+ f0, timeaxis = pyworld.dio(x, self.fs, f0_floor=self.f0min, f0_ceil=self.f0max, frame_period=self.frame_period)
77
+ f0 = pyworld.stonemask(x, f0, timeaxis, self.fs)
78
+ if self.use_continuous_f0:
79
+ f0 = self._convert_to_continuous_f0(f0)
80
+ if self.use_log_f0:
81
+ nonzero_idxs = np.where(f0 != 0)[0]
82
+ f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
83
+ return input.new_tensor(f0.reshape(-1), dtype=torch.float)
84
+
85
+ @staticmethod
86
+ def _adjust_num_frames(x, num_frames):
87
+ if num_frames > len(x):
88
+ x = F.pad(x, (0, num_frames - len(x)))
89
+ elif num_frames < len(x):
90
+ x = x[:num_frames]
91
+ return x
92
+
93
+ @staticmethod
94
+ def _convert_to_continuous_f0(f0: np.array):
95
+ if (f0 == 0).all():
96
+ return f0
97
+
98
+ # padding start and end of f0 sequence
99
+ start_f0 = f0[f0 != 0][0]
100
+ end_f0 = f0[f0 != 0][-1]
101
+ start_idx = np.where(f0 == start_f0)[0][0]
102
+ end_idx = np.where(f0 == end_f0)[0][-1]
103
+ f0[:start_idx] = start_f0
104
+ f0[end_idx:] = end_f0
105
+
106
+ # get non-zero frame index
107
+ nonzero_idxs = np.where(f0 != 0)[0]
108
+
109
+ # perform linear interpolation
110
+ interp_fn = interp1d(nonzero_idxs, f0[nonzero_idxs])
111
+ f0 = interp_fn(np.arange(0, f0.shape[0]))
112
+
113
+ return f0
114
+
115
+ def _average_by_duration(self, x, d):
116
+ assert 0 <= len(x) - d.sum() < self.reduction_factor
117
+ d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
118
+ x_avg = [
119
+ x[start:end].masked_select(x[start:end].gt(0.0)).mean(dim=0) if len(x[start:end].masked_select(x[start:end].gt(0.0))) != 0 else x.new_tensor(0.0)
120
+ for start, end in zip(d_cumsum[:-1], d_cumsum[1:])]
121
+ return torch.stack(x_avg)
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/__init__.py ADDED
File without changes
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import librosa.display as lbd
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torch.multiprocessing
8
+ import torch.multiprocessing
9
+ from torch.cuda.amp import GradScaler
10
+ from torch.cuda.amp import autocast
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from torch.utils.data.dataloader import DataLoader
13
+ from tqdm import tqdm
14
+
15
+ from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
16
+ from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
17
+ from Utility.WarmupScheduler import WarmupScheduler
18
+ from Utility.utils import cumsum_durations
19
+ from Utility.utils import delete_old_checkpoints
20
+ from Utility.utils import get_most_recent_checkpoint
21
+
22
+
23
+ @torch.no_grad()
24
+ def plot_progress_spec(net, device, save_dir, step, lang, default_emb):
25
+ tf = ArticulatoryCombinedTextFrontend(language=lang)
26
+ sentence = ""
27
+ if lang == "en":
28
+ sentence = "This is a complex sentence, it even has a pause!"
29
+ elif lang == "de":
30
+ sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
31
+ elif lang == "el":
32
+ sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
33
+ elif lang == "es":
34
+ sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
35
+ elif lang == "fi":
36
+ sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
37
+ elif lang == "ru":
38
+ sentence = "Это сложное предложение, в нем даже есть пауза!"
39
+ elif lang == "hu":
40
+ sentence = "Ez egy összetett mondat, még szünet is van benne!"
41
+ elif lang == "nl":
42
+ sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
43
+ elif lang == "fr":
44
+ sentence = "C'est une phrase complexe, elle a même une pause !"
45
+ phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
46
+ spec, durations, *_ = net.inference(text=phoneme_vector,
47
+ return_duration_pitch_energy=True,
48
+ utterance_embedding=default_emb,
49
+ lang_id=get_language_id(lang).to(device))
50
+ spec = spec.transpose(0, 1).to("cpu").numpy()
51
+ duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
52
+ if not os.path.exists(os.path.join(save_dir, "spec")):
53
+ os.makedirs(os.path.join(save_dir, "spec"))
54
+ fig, ax = plt.subplots(nrows=1, ncols=1)
55
+ lbd.specshow(spec,
56
+ ax=ax,
57
+ sr=16000,
58
+ cmap='GnBu',
59
+ y_axis='mel',
60
+ x_axis=None,
61
+ hop_length=256)
62
+ ax.yaxis.set_visible(False)
63
+ ax.set_xticks(duration_splits, minor=True)
64
+ ax.xaxis.grid(True, which='minor')
65
+ ax.set_xticks(label_positions, minor=False)
66
+ ax.set_xticklabels(tf.get_phone_string(sentence))
67
+ ax.set_title(sentence)
68
+ plt.savefig(os.path.join(os.path.join(save_dir, "spec"), str(step) + ".png"))
69
+ plt.clf()
70
+ plt.close()
71
+
72
+
73
+ def collate_and_pad(batch):
74
+ # text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id
75
+ return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
76
+ torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
77
+ pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
78
+ torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
79
+ pad_sequence([datapoint[4] for datapoint in batch], batch_first=True),
80
+ pad_sequence([datapoint[5] for datapoint in batch], batch_first=True),
81
+ pad_sequence([datapoint[6] for datapoint in batch], batch_first=True),
82
+ torch.stack([datapoint[7] for datapoint in batch]).squeeze(),
83
+ torch.stack([datapoint[8] for datapoint in batch]))
84
+
85
+
86
+ def train_loop(net,
87
+ train_dataset,
88
+ device,
89
+ save_directory,
90
+ batch_size=32,
91
+ steps=300000,
92
+ epochs_per_save=1,
93
+ lang="en",
94
+ lr=0.0001,
95
+ warmup_steps=4000,
96
+ path_to_checkpoint=None,
97
+ fine_tune=False,
98
+ resume=False):
99
+ """
100
+ Args:
101
+ resume: whether to resume from the most recent checkpoint
102
+ warmup_steps: how long the learning rate should increase before it reaches the specified value
103
+ steps: How many steps to train
104
+ lr: The initial learning rate for the optimiser
105
+ path_to_checkpoint: reloads a checkpoint to continue training from there
106
+ fine_tune: whether to load everything from a checkpoint, or only the model parameters
107
+ lang: language of the synthesis
108
+ net: Model to train
109
+ train_dataset: Pytorch Dataset Object for train data
110
+ device: Device to put the loaded tensors on
111
+ save_directory: Where to save the checkpoints
112
+ batch_size: How many elements should be loaded at once
113
+ epochs_per_save: how many epochs to train in between checkpoints
114
+
115
+ """
116
+ net = net.to(device)
117
+
118
+ torch.multiprocessing.set_sharing_strategy('file_system')
119
+ train_loader = DataLoader(batch_size=batch_size,
120
+ dataset=train_dataset,
121
+ drop_last=True,
122
+ num_workers=8,
123
+ pin_memory=True,
124
+ shuffle=True,
125
+ prefetch_factor=8,
126
+ collate_fn=collate_and_pad,
127
+ persistent_workers=True)
128
+ default_embedding = None
129
+ for index in range(20): # slicing is not implemented for datasets, so this detour is needed.
130
+ if default_embedding is None:
131
+ default_embedding = train_dataset[index][7].squeeze()
132
+ else:
133
+ default_embedding = default_embedding + train_dataset[index][7].squeeze()
134
+ default_embedding = (default_embedding / len(train_dataset)).to(device)
135
+ # default speaker embedding for inference is the average of the first 20 speaker embeddings. So if you use multiple datasets combined,
136
+ # put a single speaker one with the nicest voice first into the concat dataset.
137
+ step_counter = 0
138
+ optimizer = torch.optim.Adam(net.parameters(), lr=lr)
139
+ scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
140
+ scaler = GradScaler()
141
+ epoch = 0
142
+ if resume:
143
+ path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
144
+ if path_to_checkpoint is not None:
145
+ check_dict = torch.load(path_to_checkpoint, map_location=device)
146
+ net.load_state_dict(check_dict["model"])
147
+ if not fine_tune:
148
+ optimizer.load_state_dict(check_dict["optimizer"])
149
+ scheduler.load_state_dict(check_dict["scheduler"])
150
+ step_counter = check_dict["step_counter"]
151
+ scaler.load_state_dict(check_dict["scaler"])
152
+ start_time = time.time()
153
+ while True:
154
+ net.train()
155
+ epoch += 1
156
+ optimizer.zero_grad()
157
+ train_losses_this_epoch = list()
158
+ for batch in tqdm(train_loader):
159
+ with autocast():
160
+ train_loss = net(text_tensors=batch[0].to(device),
161
+ text_lengths=batch[1].to(device),
162
+ gold_speech=batch[2].to(device),
163
+ speech_lengths=batch[3].to(device),
164
+ gold_durations=batch[4].to(device),
165
+ gold_pitch=batch[6].to(device), # mind the switched order
166
+ gold_energy=batch[5].to(device), # mind the switched order
167
+ utterance_embedding=batch[7].to(device),
168
+ lang_ids=batch[8].to(device),
169
+ return_mels=False)
170
+ train_losses_this_epoch.append(train_loss.item())
171
+
172
+ optimizer.zero_grad()
173
+ scaler.scale(train_loss).backward()
174
+ del train_loss
175
+ step_counter += 1
176
+ scaler.unscale_(optimizer)
177
+ torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
178
+ scaler.step(optimizer)
179
+ scaler.update()
180
+ scheduler.step()
181
+
182
+ net.eval()
183
+ if epoch % epochs_per_save == 0:
184
+ torch.save({
185
+ "model" : net.state_dict(),
186
+ "optimizer" : optimizer.state_dict(),
187
+ "step_counter": step_counter,
188
+ "scaler" : scaler.state_dict(),
189
+ "scheduler" : scheduler.state_dict(),
190
+ "default_emb" : default_embedding,
191
+ }, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter)))
192
+ delete_old_checkpoints(save_directory, keep=5)
193
+ plot_progress_spec(net, device, save_dir=save_directory, step=step_counter, lang=lang, default_emb=default_embedding)
194
+ if step_counter > steps:
195
+ # DONE
196
+ return
197
+ print("Epoch: {}".format(epoch))
198
+ print("Train Loss: {}".format(sum(train_losses_this_epoch) / len(train_losses_this_epoch)))
199
+ print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
200
+ print("Steps: {}".format(step_counter))
201
+ net.train()
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop_ctc.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+
5
+ import librosa.display as lbd
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import torch.multiprocessing
9
+ import torch.multiprocessing
10
+ from torch.cuda.amp import GradScaler
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from tqdm import tqdm
13
+
14
+ from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
15
+ from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
16
+ from Utility.WarmupScheduler import WarmupScheduler
17
+ from Utility.utils import cumsum_durations
18
+ from Utility.utils import delete_old_checkpoints
19
+ from Utility.utils import get_most_recent_checkpoint
20
+
21
+
22
+ def plot_progress_spec(net, device, save_dir, step, lang):
23
+ tf = ArticulatoryCombinedTextFrontend(language=lang)
24
+ sentence = ""
25
+ if lang == "en":
26
+ sentence = "This is a complex sentence, it even has a pause!"
27
+ elif lang == "de":
28
+ sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
29
+ elif lang == "el":
30
+ sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
31
+ elif lang == "es":
32
+ sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
33
+ elif lang == "fi":
34
+ sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
35
+ elif lang == "ru":
36
+ sentence = "Это сложное предложение, в нем даже есть пауза!"
37
+ elif lang == "hu":
38
+ sentence = "Ez egy összetett mondat, még szünet is van benne!"
39
+ elif lang == "nl":
40
+ sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
41
+ elif lang == "fr":
42
+ sentence = "C'est une phrase complexe, elle a même une pause !"
43
+ phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
44
+ spec, durations, *_ = net.inference(text=phoneme_vector, return_duration_pitch_energy=True)
45
+ spec = spec.transpose(0, 1).to("cpu").numpy()
46
+ duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
47
+ if not os.path.exists(os.path.join(save_dir, "spec")):
48
+ os.makedirs(os.path.join(save_dir, "spec"))
49
+ fig, ax = plt.subplots(nrows=1, ncols=1)
50
+ lbd.specshow(spec,
51
+ ax=ax,
52
+ sr=16000,
53
+ cmap='GnBu',
54
+ y_axis='mel',
55
+ x_axis=None,
56
+ hop_length=256)
57
+ ax.yaxis.set_visible(False)
58
+ ax.set_xticks(duration_splits, minor=True)
59
+ ax.xaxis.grid(True, which='minor')
60
+ ax.set_xticks(label_positions, minor=False)
61
+ ax.set_xticklabels(tf.get_phone_string(sentence))
62
+ ax.set_title(sentence)
63
+ plt.savefig(os.path.join(os.path.join(save_dir, "spec"), str(step) + ".png"))
64
+ plt.clf()
65
+ plt.close()
66
+
67
+
68
+ def train_loop(net,
69
+ train_sentences,
70
+ device,
71
+ save_directory,
72
+ aligner_checkpoint,
73
+ batch_size=32,
74
+ steps=300000,
75
+ epochs_per_save=5,
76
+ lang="en",
77
+ lr=0.0001,
78
+ warmup_steps=4000,
79
+ path_to_checkpoint=None,
80
+ fine_tune=False,
81
+ resume=False):
82
+ """
83
+ Args:
84
+ resume: whether to resume from the most recent checkpoint
85
+ warmup_steps: how long the learning rate should increase before it reaches the specified value
86
+ steps: How many steps to train
87
+ lr: The initial learning rate for the optimiser
88
+ path_to_checkpoint: reloads a checkpoint to continue training from there
89
+ fine_tune: whether to load everything from a checkpoint, or only the model parameters
90
+ lang: language of the synthesis and of the train sentences
91
+ net: Model to train
92
+ train_sentences: list of (string) sentences the CTC objective should be learned on
93
+ device: Device to put the loaded tensors on
94
+ save_directory: Where to save the checkpoints
95
+ batch_size: How many elements should be loaded at once
96
+ epochs_per_save: how many epochs to train in between checkpoints
97
+
98
+ """
99
+ net = net.to(device)
100
+
101
+ torch.multiprocessing.set_sharing_strategy('file_system')
102
+ text_to_art_vec = ArticulatoryCombinedTextFrontend(language=lang)
103
+ asr_aligner = Aligner().to(device)
104
+ check_dict = torch.load(os.path.join(aligner_checkpoint), map_location=device)
105
+ asr_aligner.load_state_dict(check_dict["asr_model"])
106
+ net.stop_gradient_from_energy_predictor = False
107
+ net.stop_gradient_from_pitch_predictor = False
108
+ step_counter = 0
109
+ optimizer = torch.optim.Adam(net.parameters(), lr=lr)
110
+ scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
111
+ scaler = GradScaler()
112
+ epoch = 0
113
+ if resume:
114
+ path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
115
+ if path_to_checkpoint is not None:
116
+ check_dict = torch.load(path_to_checkpoint, map_location=device)
117
+ net.load_state_dict(check_dict["model"])
118
+ if not fine_tune:
119
+ optimizer.load_state_dict(check_dict["optimizer"])
120
+ scheduler.load_state_dict(check_dict["scheduler"])
121
+ step_counter = check_dict["step_counter"]
122
+ scaler.load_state_dict(check_dict["scaler"])
123
+ start_time = time.time()
124
+ while True:
125
+ net.train()
126
+ epoch += 1
127
+ optimizer.zero_grad()
128
+ train_losses_this_epoch = list()
129
+ random.shuffle(train_sentences)
130
+ batch_of_text_vecs = list()
131
+ batch_of_tokens = list()
132
+
133
+ for sentence in tqdm(train_sentences):
134
+ if sentence.strip() == "":
135
+ continue
136
+
137
+ phonemes = text_to_art_vec.get_phone_string(sentence)
138
+ # collect batch of texts
139
+ batch_of_text_vecs.append(text_to_art_vec.string_to_tensor(phonemes, input_phonemes=True).squeeze(0).to(device))
140
+
141
+ # collect batch of tokens
142
+ tokens = list()
143
+ for phone in phonemes:
144
+ tokens.append(text_to_art_vec.phone_to_id[phone])
145
+ tokens = torch.LongTensor(tokens).to(device)
146
+ batch_of_tokens.append(tokens)
147
+
148
+ if len(batch_of_tokens) == batch_size:
149
+ token_batch = pad_sequence(batch_of_tokens, batch_first=True)
150
+ token_lens = torch.LongTensor([len(x) for x in batch_of_tokens]).to(device)
151
+ text_batch = pad_sequence(batch_of_text_vecs, batch_first=True)
152
+ spec_batch, d_outs = net.batch_inference(texts=text_batch, text_lens=token_lens)
153
+ spec_lens = torch.LongTensor([sum(x) for x in d_outs]).to(device)
154
+
155
+ asr_pred = asr_aligner(spec_batch, spec_lens)
156
+ train_loss = asr_aligner.ctc_loss(asr_pred.transpose(0, 1).log_softmax(2), token_batch, spec_lens, token_lens)
157
+ train_losses_this_epoch.append(train_loss.item())
158
+
159
+ optimizer.zero_grad()
160
+ asr_aligner.zero_grad()
161
+ scaler.scale(train_loss).backward()
162
+ del train_loss
163
+ step_counter += 1
164
+ scaler.unscale_(optimizer)
165
+ torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
166
+ scaler.step(optimizer)
167
+ scaler.update()
168
+ scheduler.step()
169
+ batch_of_tokens = list()
170
+ batch_of_text_vecs = list()
171
+
172
+ net.eval()
173
+ if epoch % epochs_per_save == 0:
174
+ torch.save({
175
+ "model" : net.state_dict(),
176
+ "optimizer" : optimizer.state_dict(),
177
+ "step_counter": step_counter,
178
+ "scaler" : scaler.state_dict(),
179
+ "scheduler" : scheduler.state_dict(),
180
+ }, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter)))
181
+ delete_old_checkpoints(save_directory, keep=5)
182
+ with torch.no_grad():
183
+ plot_progress_spec(net, device, save_dir=save_directory, step=step_counter, lang=lang)
184
+ if step_counter > steps:
185
+ # DONE
186
+ return
187
+ print("Epoch: {}".format(epoch))
188
+ print("Train Loss: {}".format(sum(train_losses_this_epoch) / len(train_losses_this_epoch)))
189
+ print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
190
+ print("Steps: {}".format(step_counter))
191
+ net.train()
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa.display as lbd
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+ import torch.multiprocessing
5
+ from torch.cuda.amp import GradScaler
6
+ from torch.cuda.amp import autocast
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from torch.utils.data.dataloader import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
12
+ from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
13
+ from Utility.WarmupScheduler import WarmupScheduler
14
+ from Utility.path_to_transcript_dicts import *
15
+ from Utility.utils import cumsum_durations
16
+ from Utility.utils import delete_old_checkpoints
17
+ from Utility.utils import get_most_recent_checkpoint
18
+
19
+
20
+ def train_loop(net,
21
+ datasets,
22
+ device,
23
+ save_directory,
24
+ batch_size,
25
+ steps,
26
+ steps_per_checkpoint,
27
+ lr,
28
+ path_to_checkpoint,
29
+ resume=False,
30
+ warmup_steps=4000):
31
+ # ============
32
+ # Preparations
33
+ # ============
34
+ net = net.to(device)
35
+ torch.multiprocessing.set_sharing_strategy('file_system')
36
+ train_loaders = list()
37
+ train_iters = list()
38
+ for dataset in datasets:
39
+ train_loaders.append(DataLoader(batch_size=batch_size,
40
+ dataset=dataset,
41
+ drop_last=True,
42
+ num_workers=2,
43
+ pin_memory=True,
44
+ shuffle=True,
45
+ prefetch_factor=5,
46
+ collate_fn=collate_and_pad,
47
+ persistent_workers=True))
48
+ train_iters.append(iter(train_loaders[-1]))
49
+ default_embeddings = {"en": None, "de": None, "el": None, "es": None, "fi": None, "ru": None, "hu": None, "nl": None, "fr": None}
50
+ for index, lang in enumerate(["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]):
51
+ default_embedding = None
52
+ for datapoint in datasets[index]:
53
+ if default_embedding is None:
54
+ default_embedding = datapoint[7].squeeze()
55
+ else:
56
+ default_embedding = default_embedding + datapoint[7].squeeze()
57
+ default_embeddings[lang] = (default_embedding / len(datasets[index])).to(device)
58
+ optimizer = torch.optim.RAdam(net.parameters(), lr=lr, eps=1.0e-06, weight_decay=0.0)
59
+ grad_scaler = GradScaler()
60
+ scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
61
+ if resume:
62
+ previous_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
63
+ if previous_checkpoint is not None:
64
+ path_to_checkpoint = previous_checkpoint
65
+ else:
66
+ raise RuntimeError(f"No checkpoint found that can be resumed from in {save_directory}")
67
+ step_counter = 0
68
+ train_losses_total = list()
69
+ if path_to_checkpoint is not None:
70
+ check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
71
+ net.load_state_dict(check_dict["model"])
72
+ if resume:
73
+ optimizer.load_state_dict(check_dict["optimizer"])
74
+ step_counter = check_dict["step_counter"]
75
+ grad_scaler.load_state_dict(check_dict["scaler"])
76
+ scheduler.load_state_dict(check_dict["scheduler"])
77
+ if step_counter > steps:
78
+ print("Desired steps already reached in loaded checkpoint.")
79
+ return
80
+
81
+ net.train()
82
+ # =============================
83
+ # Actual train loop starts here
84
+ # =============================
85
+ for step in tqdm(range(step_counter, steps)):
86
+ batches = []
87
+ for index in range(len(datasets)):
88
+ # we get one batch for each task (i.e. language in this case)
89
+ try:
90
+ batch = next(train_iters[index])
91
+ batches.append(batch)
92
+ except StopIteration:
93
+ train_iters[index] = iter(train_loaders[index])
94
+ batch = next(train_iters[index])
95
+ batches.append(batch)
96
+ train_loss = 0.0
97
+ for batch in batches:
98
+ with autocast():
99
+ # we sum the loss for each task, as we would do for the
100
+ # second order regular MAML, but we do it only over one
101
+ # step (i.e. iterations of inner loop = 1)
102
+ train_loss = train_loss + net(text_tensors=batch[0].to(device),
103
+ text_lengths=batch[1].to(device),
104
+ gold_speech=batch[2].to(device),
105
+ speech_lengths=batch[3].to(device),
106
+ gold_durations=batch[4].to(device),
107
+ gold_pitch=batch[6].to(device), # mind the switched order
108
+ gold_energy=batch[5].to(device), # mind the switched order
109
+ utterance_embedding=batch[7].to(device),
110
+ lang_ids=batch[8].to(device),
111
+ return_mels=False)
112
+ # then we directly update our meta-parameters without
113
+ # the need for any task specific parameters
114
+ train_losses_total.append(train_loss.item())
115
+ optimizer.zero_grad()
116
+ grad_scaler.scale(train_loss).backward()
117
+ grad_scaler.unscale_(optimizer)
118
+ torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
119
+ grad_scaler.step(optimizer)
120
+ grad_scaler.update()
121
+ scheduler.step()
122
+
123
+ if step % steps_per_checkpoint == 0:
124
+ # ==============================
125
+ # Enough steps for some insights
126
+ # ==============================
127
+ net.eval()
128
+ print(f"Total Loss: {round(sum(train_losses_total) / len(train_losses_total), 3)}")
129
+ train_losses_total = list()
130
+ torch.save({
131
+ "model" : net.state_dict(),
132
+ "optimizer" : optimizer.state_dict(),
133
+ "scaler" : grad_scaler.state_dict(),
134
+ "scheduler" : scheduler.state_dict(),
135
+ "step_counter": step,
136
+ "default_emb" : default_embeddings["en"]
137
+ },
138
+ os.path.join(save_directory, "checkpoint_{}.pt".format(step)))
139
+ delete_old_checkpoints(save_directory, keep=5)
140
+ for lang in ["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]:
141
+ plot_progress_spec(net=net,
142
+ device=device,
143
+ lang=lang,
144
+ save_dir=save_directory,
145
+ step=step,
146
+ utt_embeds=default_embeddings)
147
+ net.train()
148
+
149
+
150
+ @torch.inference_mode()
151
+ def plot_progress_spec(net, device, save_dir, step, lang, utt_embeds):
152
+ tf = ArticulatoryCombinedTextFrontend(language=lang)
153
+ sentence = ""
154
+ default_embed = utt_embeds[lang]
155
+ if lang == "en":
156
+ sentence = "This is a complex sentence, it even has a pause!"
157
+ elif lang == "de":
158
+ sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
159
+ elif lang == "el":
160
+ sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
161
+ elif lang == "es":
162
+ sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
163
+ elif lang == "fi":
164
+ sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
165
+ elif lang == "ru":
166
+ sentence = "Это сложное предложение, в нем даже есть пауза!"
167
+ elif lang == "hu":
168
+ sentence = "Ez egy összetett mondat, még szünet is van benne!"
169
+ elif lang == "nl":
170
+ sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
171
+ elif lang == "fr":
172
+ sentence = "C'est une phrase complexe, elle a même une pause !"
173
+ phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
174
+ spec, durations, *_ = net.inference(text=phoneme_vector,
175
+ return_duration_pitch_energy=True,
176
+ utterance_embedding=default_embed,
177
+ lang_id=get_language_id(lang).to(device))
178
+ spec = spec.transpose(0, 1).to("cpu").numpy()
179
+ duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
180
+ if not os.path.exists(os.path.join(save_dir, "spec")):
181
+ os.makedirs(os.path.join(save_dir, "spec"))
182
+ fig, ax = plt.subplots(nrows=1, ncols=1)
183
+ lbd.specshow(spec,
184
+ ax=ax,
185
+ sr=16000,
186
+ cmap='GnBu',
187
+ y_axis='mel',
188
+ x_axis=None,
189
+ hop_length=256)
190
+ ax.yaxis.set_visible(False)
191
+ ax.set_xticks(duration_splits, minor=True)
192
+ ax.xaxis.grid(True, which='minor')
193
+ ax.set_xticks(label_positions, minor=False)
194
+ ax.set_xticklabels(tf.get_phone_string(sentence))
195
+ ax.set_title(sentence)
196
+ plt.savefig(os.path.join(os.path.join(save_dir, "spec"), f"{step}_{lang}.png"))
197
+ plt.clf()
198
+ plt.close()
199
+
200
+
201
+ def collate_and_pad(batch):
202
+ # text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id
203
+ return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
204
+ torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
205
+ pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
206
+ torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
207
+ pad_sequence([datapoint[4] for datapoint in batch], batch_first=True),
208
+ pad_sequence([datapoint[5] for datapoint in batch], batch_first=True),
209
+ pad_sequence([datapoint[6] for datapoint in batch], batch_first=True),
210
+ torch.stack([datapoint[7] for datapoint in batch]).squeeze(),
211
+ torch.stack([datapoint[8] for datapoint in batch]))
TrainingInterfaces/Text_to_Spectrogram/__init__.py ADDED
File without changes
TrainingInterfaces/__init__.py ADDED
File without changes