Flux9665 commited on
Commit
70399da
β€’
1 Parent(s): e208c87

update to the current version

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. Architectures/ToucanTTS/StochasticToucanTTS/README.md +0 -1
  2. Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTS.py +0 -493
  3. Architectures/ToucanTTS/StochasticToucanTTS/StochasticVariancePredictor.py +0 -440
  4. Architectures/__init__.py +0 -0
  5. InferenceInterfaces/ControllableInterface.py +25 -18
  6. InferenceInterfaces/ToucanTTSInterface.py +73 -72
  7. InferenceInterfaces/UtteranceCloner.py +8 -6
  8. InferenceInterfaces/audioseal_wm_16bits.yaml +0 -39
  9. {Architectures β†’ Modules}/Aligner/Aligner.py +27 -31
  10. {Architectures β†’ Modules}/Aligner/CodecAlignerDataset.py +57 -14
  11. {Architectures β†’ Modules}/Aligner/README.md +0 -0
  12. {Architectures β†’ Modules}/Aligner/Reconstructor.py +8 -15
  13. {Architectures β†’ Modules}/Aligner/__init__.py +0 -0
  14. {Architectures β†’ Modules}/Aligner/autoaligner_train_loop.py +4 -2
  15. {Architectures β†’ Modules}/ControllabilityGAN/GAN.py +23 -10
  16. {Architectures β†’ Modules}/ControllabilityGAN/__init__.py +0 -0
  17. {Architectures β†’ Modules}/ControllabilityGAN/dataset/__init__.py +0 -0
  18. {Architectures β†’ Modules}/ControllabilityGAN/dataset/speaker_embeddings_dataset.py +0 -0
  19. {Architectures β†’ Modules}/ControllabilityGAN/wgan/__init__.py +0 -0
  20. {Architectures β†’ Modules}/ControllabilityGAN/wgan/init_weights.py +0 -0
  21. {Architectures β†’ Modules}/ControllabilityGAN/wgan/init_wgan.py +2 -2
  22. {Architectures β†’ Modules}/ControllabilityGAN/wgan/resnet_1.py +2 -2
  23. {Architectures β†’ Modules}/ControllabilityGAN/wgan/resnet_init.py +4 -4
  24. {Architectures β†’ Modules}/ControllabilityGAN/wgan/wgan_qc.py +6 -11
  25. {Architectures β†’ Modules}/EmbeddingModel/GST.py +1 -1
  26. {Architectures β†’ Modules}/EmbeddingModel/README.md +0 -0
  27. {Architectures β†’ Modules}/EmbeddingModel/StyleEmbedding.py +2 -2
  28. {Architectures β†’ Modules}/EmbeddingModel/StyleTTSEncoder.py +0 -0
  29. {Architectures β†’ Modules}/EmbeddingModel/__init__.py +0 -0
  30. {Architectures β†’ Modules}/GeneralLayers/Attention.py +0 -0
  31. {Architectures β†’ Modules}/GeneralLayers/ConditionalLayerNorm.py +0 -0
  32. {Architectures β†’ Modules}/GeneralLayers/Conformer.py +29 -18
  33. {Architectures β†’ Modules}/GeneralLayers/Convolution.py +1 -1
  34. {Architectures β†’ Modules}/GeneralLayers/DurationPredictor.py +3 -3
  35. {Architectures β†’ Modules}/GeneralLayers/EncoderLayer.py +1 -1
  36. {Architectures β†’ Modules}/GeneralLayers/LayerNorm.py +0 -0
  37. {Architectures β†’ Modules}/GeneralLayers/LengthRegulator.py +0 -0
  38. {Architectures β†’ Modules}/GeneralLayers/MultiLayeredConv1d.py +0 -0
  39. {Architectures β†’ Modules}/GeneralLayers/MultiSequential.py +0 -0
  40. {Architectures β†’ Modules}/GeneralLayers/PositionalEncoding.py +0 -0
  41. {Architectures β†’ Modules}/GeneralLayers/PositionwiseFeedForward.py +0 -0
  42. {Architectures β†’ Modules}/GeneralLayers/README.md +0 -0
  43. {Architectures β†’ Modules}/GeneralLayers/ResidualBlock.py +0 -0
  44. {Architectures β†’ Modules}/GeneralLayers/ResidualStack.py +0 -0
  45. {Architectures β†’ Modules}/GeneralLayers/STFT.py +0 -0
  46. {Architectures β†’ Modules}/GeneralLayers/Swish.py +0 -0
  47. {Architectures β†’ Modules}/GeneralLayers/VariancePredictor.py +3 -3
  48. {Architectures β†’ Modules}/GeneralLayers/__init__.py +0 -0
  49. {Architectures β†’ Modules}/README.md +0 -0
  50. {Architectures β†’ Modules}/ToucanTTS/CodecDiscriminator.py +0 -0
Architectures/ToucanTTS/StochasticToucanTTS/README.md DELETED
@@ -1 +0,0 @@
1
- This is an experimental version of the TTS that uses normalizing flows to predict the prosody explicitly, so that we can still have the controllability of the explicit prosody predictors, however a much better naturalness and livelyness than what we get from a deterministic predictor.
 
 
Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTS.py DELETED
@@ -1,493 +0,0 @@
1
- import torch
2
- from torch.nn import Linear
3
- from torch.nn import Sequential
4
- from torch.nn import Tanh
5
-
6
- from Architectures.GeneralLayers.Conformer import Conformer
7
- from Architectures.GeneralLayers.LengthRegulator import LengthRegulator
8
- from Architectures.ToucanTTS.Glow import Glow
9
- from Architectures.ToucanTTS.StochasticToucanTTS.StochasticToucanTTSLoss import StochasticToucanTTSLoss
10
- from Architectures.ToucanTTS.StochasticToucanTTS.StochasticVariancePredictor import StochasticVariancePredictor
11
- from Preprocessing.articulatory_features import get_feature_to_index_lookup
12
- from Utility.utils import initialize
13
- from Utility.utils import make_non_pad_mask
14
- from Utility.utils import make_pad_mask
15
-
16
-
17
- class StochasticToucanTTS(torch.nn.Module):
18
- """
19
- StochasticToucanTTS module, which is mostly just a FastSpeech 2 module,
20
- but with lots of designs from different architectures accumulated
21
- and some major components added to put a large focus on multilinguality.
22
-
23
- Original contributions:
24
- - Inputs are configurations of the articulatory tract
25
- - Word boundaries are modeled explicitly in the encoder end removed before the decoder
26
- - Speaker embedding conditioning is derived from GST and Adaspeech 4
27
- - Responsiveness of variance predictors to utterance embedding is increased through conditional layer norm
28
- - The final output receives a GAN discriminator feedback signal
29
- - Stochastic Duration Prediction through a normalizing flow
30
- - Stochastic Pitch Prediction through a normalizing flow
31
- - Stochastic Energy prediction through a normalizing flow
32
-
33
- Contributions inspired from elsewhere:
34
- - The PostNet is also a normalizing flow, like in PortaSpeech
35
- - Pitch and energy values are averaged per-phone, as in FastPitch to enable great controllability
36
- - The encoder and decoder are Conformers
37
-
38
- """
39
-
40
- def __init__(self,
41
- # network structure related
42
- input_feature_dimensions=62,
43
- output_spectrogram_channels=80,
44
- attention_dimension=192,
45
- attention_heads=4,
46
- positionwise_conv_kernel_size=1,
47
- use_scaled_positional_encoding=True,
48
- init_type="xavier_uniform",
49
- use_macaron_style_in_conformer=True,
50
- use_cnn_in_conformer=True,
51
-
52
- # encoder
53
- encoder_layers=6,
54
- encoder_units=1536,
55
- encoder_normalize_before=True,
56
- encoder_concat_after=False,
57
- conformer_encoder_kernel_size=7,
58
- transformer_enc_dropout_rate=0.2,
59
- transformer_enc_positional_dropout_rate=0.2,
60
- transformer_enc_attn_dropout_rate=0.2,
61
-
62
- # decoder
63
- decoder_layers=6,
64
- decoder_units=1536,
65
- decoder_concat_after=False,
66
- conformer_decoder_kernel_size=31,
67
- decoder_normalize_before=True,
68
- transformer_dec_dropout_rate=0.2,
69
- transformer_dec_positional_dropout_rate=0.2,
70
- transformer_dec_attn_dropout_rate=0.2,
71
-
72
- # duration predictor
73
- duration_predictor_layers=3,
74
- duration_predictor_chans=256,
75
- duration_predictor_kernel_size=3,
76
- duration_predictor_dropout_rate=0.2,
77
-
78
- # pitch predictor
79
- pitch_embed_kernel_size=1,
80
- pitch_embed_dropout=0.0,
81
-
82
- # energy predictor
83
- energy_embed_kernel_size=1,
84
- energy_embed_dropout=0.0,
85
-
86
- # additional features
87
- utt_embed_dim=192,
88
- lang_embs=8000):
89
- super().__init__()
90
-
91
- self.input_feature_dimensions = input_feature_dimensions
92
- self.output_spectrogram_channels = output_spectrogram_channels
93
- self.attention_dimension = attention_dimension
94
- self.use_scaled_pos_enc = use_scaled_positional_encoding
95
- self.multilingual_model = lang_embs is not None
96
- self.multispeaker_model = utt_embed_dim is not None
97
-
98
- articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension))
99
- self.encoder = Conformer(conformer_type="encoder",
100
- attention_dim=attention_dimension,
101
- attention_heads=attention_heads,
102
- linear_units=encoder_units,
103
- num_blocks=encoder_layers,
104
- input_layer=articulatory_feature_embedding,
105
- dropout_rate=transformer_enc_dropout_rate,
106
- positional_dropout_rate=transformer_enc_positional_dropout_rate,
107
- attention_dropout_rate=transformer_enc_attn_dropout_rate,
108
- normalize_before=encoder_normalize_before,
109
- concat_after=encoder_concat_after,
110
- positionwise_conv_kernel_size=positionwise_conv_kernel_size,
111
- macaron_style=use_macaron_style_in_conformer,
112
- use_cnn_module=use_cnn_in_conformer,
113
- cnn_module_kernel=conformer_encoder_kernel_size,
114
- zero_triu=False,
115
- utt_embed=utt_embed_dim,
116
- lang_embs=lang_embs,
117
- use_output_norm=True)
118
-
119
- self.duration_flow = StochasticVariancePredictor(in_channels=attention_dimension,
120
- kernel_size=3,
121
- p_dropout=0.5,
122
- n_flows=5,
123
- conditioning_signal_channels=utt_embed_dim)
124
-
125
- self.pitch_flow = StochasticVariancePredictor(in_channels=attention_dimension,
126
- kernel_size=5,
127
- p_dropout=0.5,
128
- n_flows=6,
129
- conditioning_signal_channels=utt_embed_dim)
130
-
131
- self.energy_flow = StochasticVariancePredictor(in_channels=attention_dimension,
132
- kernel_size=3,
133
- p_dropout=0.5,
134
- n_flows=3,
135
- conditioning_signal_channels=utt_embed_dim)
136
-
137
- self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1,
138
- out_channels=attention_dimension,
139
- kernel_size=pitch_embed_kernel_size,
140
- padding=(pitch_embed_kernel_size - 1) // 2),
141
- torch.nn.Dropout(pitch_embed_dropout))
142
-
143
- self.energy_embed = Sequential(torch.nn.Conv1d(in_channels=1, out_channels=attention_dimension, kernel_size=energy_embed_kernel_size,
144
- padding=(energy_embed_kernel_size - 1) // 2),
145
- torch.nn.Dropout(energy_embed_dropout))
146
-
147
- self.length_regulator = LengthRegulator()
148
-
149
- self.decoder = Conformer(conformer_type="decoder",
150
- attention_dim=attention_dimension,
151
- attention_heads=attention_heads,
152
- linear_units=decoder_units,
153
- num_blocks=decoder_layers,
154
- input_layer=None,
155
- dropout_rate=transformer_dec_dropout_rate,
156
- positional_dropout_rate=transformer_dec_positional_dropout_rate,
157
- attention_dropout_rate=transformer_dec_attn_dropout_rate,
158
- normalize_before=decoder_normalize_before,
159
- concat_after=decoder_concat_after,
160
- positionwise_conv_kernel_size=positionwise_conv_kernel_size,
161
- macaron_style=use_macaron_style_in_conformer,
162
- use_cnn_module=use_cnn_in_conformer,
163
- cnn_module_kernel=conformer_decoder_kernel_size,
164
- use_output_norm=False,
165
- utt_embed=utt_embed_dim)
166
-
167
- self.feat_out = Linear(attention_dimension, output_spectrogram_channels)
168
-
169
- self.post_flow = Glow(
170
- in_channels=output_spectrogram_channels,
171
- hidden_channels=192, # post_glow_hidden
172
- kernel_size=3, # post_glow_kernel_size
173
- dilation_rate=1,
174
- n_blocks=12, # post_glow_n_blocks (original 12 in paper)
175
- n_layers=3, # post_glow_n_block_layers (original 3 in paper)
176
- n_split=4,
177
- n_sqz=2,
178
- text_condition_channels=attention_dimension,
179
- share_cond_layers=False, # post_share_cond_layers
180
- share_wn_layers=4,
181
- sigmoid_scale=False,
182
- condition_integration_projection=torch.nn.Conv1d(output_spectrogram_channels + attention_dimension, attention_dimension, 5, padding=2)
183
- )
184
-
185
- # initialize parameters
186
- self._reset_parameters(init_type=init_type)
187
- if lang_embs is not None:
188
- torch.nn.init.normal_(self.encoder.language_embedding.weight, mean=0, std=attention_dimension ** -0.5)
189
-
190
- self.criterion = StochasticToucanTTSLoss()
191
-
192
- def forward(self,
193
- text_tensors,
194
- text_lengths,
195
- gold_speech,
196
- speech_lengths,
197
- gold_durations,
198
- gold_pitch,
199
- gold_energy,
200
- utterance_embedding,
201
- return_feats=False,
202
- lang_ids=None,
203
- run_glow=True
204
- ):
205
- """
206
- Args:
207
- return_feats (Boolean): whether to return the predicted spectrogram
208
- text_tensors (LongTensor): Batch of padded text vectors (B, Tmax).
209
- text_lengths (LongTensor): Batch of lengths of each input (B,).
210
- gold_speech (Tensor): Batch of padded target features (B, Lmax, odim).
211
- speech_lengths (LongTensor): Batch of the lengths of each target (B,).
212
- gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1).
213
- gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1).
214
- gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1).
215
- run_glow (Boolean): Whether to run the PostNet. There should be a warmup phase in the beginning.
216
- lang_ids (LongTensor): The language IDs used to access the language embedding table, if the model is multilingual
217
- utterance_embedding (Tensor): Batch of embeddings to condition the TTS on, if the model is multispeaker
218
- """
219
- before_outs, \
220
- after_outs, \
221
- duration_loss, \
222
- pitch_loss, \
223
- energy_loss, \
224
- glow_loss = self._forward(text_tensors=text_tensors,
225
- text_lengths=text_lengths,
226
- gold_speech=gold_speech,
227
- speech_lengths=speech_lengths,
228
- gold_durations=gold_durations,
229
- gold_pitch=gold_pitch,
230
- gold_energy=gold_energy,
231
- utterance_embedding=utterance_embedding,
232
- is_inference=False,
233
- lang_ids=lang_ids,
234
- run_glow=run_glow)
235
-
236
- # calculate loss
237
- l1_loss = self.criterion(after_outs=after_outs,
238
- before_outs=before_outs,
239
- gold_spectrograms=gold_speech,
240
- spectrogram_lengths=speech_lengths,
241
- text_lengths=text_lengths)
242
-
243
- if return_feats:
244
- if after_outs is None:
245
- after_outs = before_outs
246
- return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss, after_outs
247
- return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss
248
-
249
- def _forward(self,
250
- text_tensors,
251
- text_lengths,
252
- gold_speech=None,
253
- speech_lengths=None,
254
- gold_durations=None,
255
- gold_pitch=None,
256
- gold_energy=None,
257
- is_inference=False,
258
- utterance_embedding=None,
259
- lang_ids=None,
260
- run_glow=True):
261
-
262
- if not self.multilingual_model:
263
- lang_ids = None
264
-
265
- if not self.multispeaker_model:
266
- utterance_embedding = None
267
-
268
- # encoding the texts
269
- text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2)
270
- padding_masks = make_pad_mask(text_lengths, device=text_lengths.device)
271
- encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
272
-
273
- if is_inference:
274
- variance_mask = torch.ones(size=[text_tensors.size(1)], device=text_tensors.device)
275
-
276
- # predicting pitch
277
- pitch_predictions = self.pitch_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2)
278
- for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
279
- if phoneme_vector[get_feature_to_index_lookup()["voiced"]] == 0:
280
- pitch_predictions[0][phoneme_index] = 0.0
281
- embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
282
- encoded_texts = encoded_texts + embedded_pitch_curve
283
-
284
- # predicting energy
285
- energy_predictions = self.energy_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2)
286
- embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
287
- encoded_texts = encoded_texts + embedded_energy_curve
288
-
289
- # predicting durations
290
- predicted_durations = self.duration_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2).squeeze(-1)
291
- predicted_durations = torch.ceil(torch.exp(predicted_durations)).long()
292
- for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
293
- if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
294
- predicted_durations[0][phoneme_index] = 0
295
-
296
- # predicting durations for text and upsampling accordingly
297
- upsampled_enriched_encoded_texts = self.length_regulator(encoded_texts, predicted_durations)
298
-
299
- else:
300
- # learning to predict pitch
301
- idx = gold_pitch != 0
302
- pitch_mask = torch.logical_and(text_masks, idx.transpose(1, 2))
303
- scaled_pitch_targets = gold_pitch.detach().clone()
304
- scaled_pitch_targets[idx] = torch.exp(gold_pitch[idx]) # we scale up, so that the log in the flow can handle the value ranges better.
305
- pitch_flow_loss = torch.sum(self.pitch_flow(encoded_texts.transpose(1, 2).detach(), pitch_mask, w=scaled_pitch_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False))
306
- pitch_flow_loss = torch.sum(pitch_flow_loss / torch.sum(pitch_mask)) # weighted masking
307
- embedded_pitch_curve = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
308
- encoded_texts = encoded_texts + embedded_pitch_curve
309
-
310
- # learning to predict energy
311
- idx = gold_energy != 0
312
- energy_mask = torch.logical_and(text_masks, idx.transpose(1, 2))
313
- scaled_energy_targets = gold_energy.detach().clone()
314
- scaled_energy_targets[idx] = torch.exp(gold_energy[idx]) # we scale up, so that the log in the flow can handle the value ranges better.
315
- energy_flow_loss = torch.sum(self.energy_flow(encoded_texts.transpose(1, 2).detach(), energy_mask, w=scaled_energy_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False))
316
- energy_flow_loss = torch.sum(energy_flow_loss / torch.sum(energy_mask)) # weighted masking
317
- embedded_energy_curve = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
318
- encoded_texts = encoded_texts + embedded_energy_curve
319
-
320
- # learning to predict durations
321
- idx = gold_durations.unsqueeze(-1) != 0
322
- duration_mask = torch.logical_and(text_masks, idx.transpose(1, 2))
323
- duration_targets = gold_durations.unsqueeze(-1).detach().clone().float()
324
- duration_flow_loss = torch.sum(self.duration_flow(encoded_texts.transpose(1, 2).detach(), duration_mask, w=duration_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False))
325
- duration_flow_loss = torch.sum(duration_flow_loss / torch.sum(duration_mask)) # weighted masking
326
-
327
- upsampled_enriched_encoded_texts = self.length_regulator(encoded_texts, gold_durations)
328
-
329
- # decoding spectrogram
330
- decoder_masks = make_non_pad_mask(speech_lengths, device=speech_lengths.device).unsqueeze(-2) if speech_lengths is not None and not is_inference else None
331
- decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, decoder_masks, utterance_embedding=utterance_embedding)
332
- decoded_spectrogram = self.feat_out(decoded_speech).view(decoded_speech.size(0), -1, self.output_spectrogram_channels)
333
-
334
- # refine spectrogram further with a normalizing flow (requires warmup, so it's not always on)
335
- glow_loss = None
336
- if run_glow:
337
- if is_inference:
338
- refined_spectrogram = self.post_flow(tgt_mels=None,
339
- infer=is_inference,
340
- mel_out=decoded_spectrogram,
341
- encoded_texts=upsampled_enriched_encoded_texts,
342
- tgt_nonpadding=None).squeeze()
343
- else:
344
- glow_loss = self.post_flow(tgt_mels=gold_speech,
345
- infer=is_inference,
346
- mel_out=decoded_spectrogram.detach().clone(),
347
- encoded_texts=upsampled_enriched_encoded_texts.detach().clone(),
348
- tgt_nonpadding=decoder_masks)
349
- if is_inference:
350
- return decoded_spectrogram.squeeze(), \
351
- refined_spectrogram.squeeze(), \
352
- predicted_durations.squeeze(), \
353
- pitch_predictions.squeeze(), \
354
- energy_predictions.squeeze()
355
- else:
356
- return decoded_spectrogram, \
357
- None, \
358
- duration_flow_loss, \
359
- pitch_flow_loss, \
360
- energy_flow_loss, \
361
- glow_loss
362
-
363
- @torch.inference_mode()
364
- def inference(self,
365
- text,
366
- speech=None,
367
- utterance_embedding=None,
368
- return_duration_pitch_energy=False,
369
- lang_id=None,
370
- run_postflow=True):
371
- """
372
- Args:
373
- text (LongTensor): Input sequence of characters (T,).
374
- speech (Tensor, optional): Feature sequence to extract style (N, idim).
375
- return_duration_pitch_energy (Boolean): whether to return the list of predicted durations for nicer plotting
376
- run_postflow (Boolean): Whether to run the PostNet. There should be a warmup phase in the beginning.
377
- lang_id (LongTensor): The language ID used to access the language embedding table, if the model is multilingual
378
- utterance_embedding (Tensor): Embedding to condition the TTS on, if the model is multispeaker
379
- """
380
- self.eval()
381
- x, y = text, speech
382
-
383
- # setup batch axis
384
- ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
385
- xs, ys = x.unsqueeze(0), None
386
- if y is not None:
387
- ys = y.unsqueeze(0)
388
- if lang_id is not None:
389
- lang_id = lang_id.unsqueeze(0)
390
- utterance_embeddings = utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None
391
-
392
- before_outs, \
393
- after_outs, \
394
- duration_predictions, \
395
- pitch_predictions, \
396
- energy_predictions = self._forward(xs,
397
- ilens,
398
- ys,
399
- is_inference=True,
400
- utterance_embedding=utterance_embeddings,
401
- lang_ids=lang_id,
402
- run_glow=run_postflow) # (1, L, odim)
403
- self.train()
404
- if after_outs is None:
405
- after_outs = before_outs
406
- if return_duration_pitch_energy:
407
- return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions
408
- return after_outs
409
-
410
- def _reset_parameters(self, init_type):
411
- # initialize parameters
412
- if init_type != "pytorch":
413
- initialize(self, init_type)
414
-
415
-
416
- if __name__ == '__main__':
417
- print(sum(p.numel() for p in StochasticToucanTTS().parameters() if p.requires_grad))
418
-
419
- print(" TESTING TRAINING ")
420
-
421
- print(" batchsize 3 ")
422
- dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone]
423
- dummy_text_lens = torch.LongTensor([2, 3, 3])
424
-
425
- dummy_speech_batch = torch.randn([3, 30, 80]) # [Batch, Sequence Length, Spectrogram Buckets]
426
- dummy_speech_lens = torch.LongTensor([10, 30, 20])
427
-
428
- dummy_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5], [5, 5, 10]])
429
- dummy_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]], [[1.1], [1.2], [0.8]]])
430
- dummy_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]], [[1.1], [1.2], [0.8]]])
431
-
432
- dummy_utterance_embed = torch.randn([3, 192]) # [Batch, Dimensions of Speaker Embedding]
433
- dummy_language_id = torch.LongTensor([5, 3, 2]).unsqueeze(1)
434
-
435
- model = StochasticToucanTTS()
436
- l1, dl, pl, el, gl = model(dummy_text_batch,
437
- dummy_text_lens,
438
- dummy_speech_batch,
439
- dummy_speech_lens,
440
- dummy_durations,
441
- dummy_pitch,
442
- dummy_energy,
443
- utterance_embedding=dummy_utterance_embed,
444
- lang_ids=dummy_language_id)
445
-
446
- loss = l1 + gl + dl + pl + el
447
- print(loss)
448
- loss.backward()
449
-
450
- # from Utility.utils import plot_grad_flow
451
-
452
- # plot_grad_flow(model.encoder.named_parameters())
453
- # plot_grad_flow(model.decoder.named_parameters())
454
- # plot_grad_flow(model.pitch_predictor.named_parameters())
455
- # plot_grad_flow(model.duration_predictor.named_parameters())
456
- # plot_grad_flow(model.post_flow.named_parameters())
457
-
458
- print(" batchsize 2 ")
459
- dummy_text_batch = torch.randint(low=0, high=2, size=[2, 3, 62]).float() # [Batch, Sequence Length, Features per Phone]
460
- dummy_text_lens = torch.LongTensor([2, 3])
461
-
462
- dummy_speech_batch = torch.randn([2, 30, 80]) # [Batch, Sequence Length, Spectrogram Buckets]
463
- dummy_speech_lens = torch.LongTensor([10, 30])
464
-
465
- dummy_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5]])
466
- dummy_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]]])
467
- dummy_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]]])
468
-
469
- dummy_utterance_embed = torch.randn([2, 192]) # [Batch, Dimensions of Speaker Embedding]
470
- dummy_language_id = torch.LongTensor([5, 3]).unsqueeze(1)
471
-
472
- model = StochasticToucanTTS()
473
- l1, dl, pl, el, gl = model(dummy_text_batch,
474
- dummy_text_lens,
475
- dummy_speech_batch,
476
- dummy_speech_lens,
477
- dummy_durations,
478
- dummy_pitch,
479
- dummy_energy,
480
- utterance_embedding=dummy_utterance_embed,
481
- lang_ids=dummy_language_id)
482
-
483
- loss = l1 + gl + dl + el + pl
484
- print(loss)
485
- loss.backward()
486
-
487
- print(" TESTING INFERENCE ")
488
- dummy_text_batch = torch.randint(low=0, high=2, size=[12, 62]).float() # [Sequence Length, Features per Phone]
489
- dummy_utterance_embed = torch.randn([192]) # [Dimensions of Speaker Embedding]
490
- dummy_language_id = torch.LongTensor([2])
491
- print(StochasticToucanTTS().inference(dummy_text_batch,
492
- utterance_embedding=dummy_utterance_embed,
493
- lang_id=dummy_language_id).shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Architectures/ToucanTTS/StochasticToucanTTS/StochasticVariancePredictor.py DELETED
@@ -1,440 +0,0 @@
1
- """
2
- Code taken and adapted from https://github.com/jaywalnut310/vits
3
-
4
- MIT License
5
-
6
- Copyright (c) 2021 Jaehyeon Kim
7
-
8
- Permission is hereby granted, free of charge, to any person obtaining a copy
9
- of this software and associated documentation files (the "Software"), to deal
10
- in the Software without restriction, including without limitation the rights
11
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- copies of the Software, and to permit persons to whom the Software is
13
- furnished to do so, subject to the following conditions:
14
-
15
- The above copyright notice and this permission notice shall be included in all
16
- copies or substantial portions of the Software.
17
-
18
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- SOFTWARE.
25
- """
26
-
27
- import math
28
-
29
- import numpy as np
30
- import torch
31
- from torch import nn
32
- from torch.nn import functional as F
33
-
34
- DEFAULT_MIN_BIN_WIDTH = 1e-3
35
- DEFAULT_MIN_BIN_HEIGHT = 1e-3
36
- DEFAULT_MIN_DERIVATIVE = 1e-3
37
-
38
-
39
- class StochasticVariancePredictor(nn.Module):
40
- def __init__(self, in_channels, kernel_size, p_dropout, n_flows=4, conditioning_signal_channels=0):
41
- super().__init__()
42
- self.in_channels = in_channels
43
- self.filter_channels = in_channels
44
- self.kernel_size = kernel_size
45
- self.p_dropout = p_dropout
46
- self.n_flows = n_flows
47
- self.gin_channels = conditioning_signal_channels if conditioning_signal_channels is not None else 0
48
-
49
- self.log_flow = Log()
50
- self.flows = nn.ModuleList()
51
- self.flows.append(ElementwiseAffine(2))
52
- for i in range(n_flows):
53
- self.flows.append(ConvFlow(2, in_channels, kernel_size, n_layers=3))
54
- self.flows.append(Flip())
55
-
56
- self.post_pre = nn.Conv1d(1, in_channels, 1)
57
- self.post_proj = nn.Conv1d(in_channels, in_channels, 1)
58
- self.post_convs = DDSConv(in_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
59
- self.post_flows = nn.ModuleList()
60
- self.post_flows.append(ElementwiseAffine(2))
61
- for i in range(4):
62
- self.post_flows.append(ConvFlow(2, in_channels, kernel_size, n_layers=3))
63
- self.post_flows.append(Flip())
64
-
65
- self.pre = nn.Conv1d(in_channels, in_channels, 1)
66
- self.proj = nn.Conv1d(in_channels, in_channels, 1)
67
- self.convs = DDSConv(in_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
68
- if self.gin_channels != 0:
69
- self.cond = nn.Conv1d(self.gin_channels, in_channels, 1)
70
-
71
- def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=0.3):
72
- x = self.pre(x)
73
- if g is not None:
74
- g = torch.detach(g)
75
- x = x + self.cond(g)
76
- x = self.convs(x, x_mask)
77
- x = self.proj(x) * x_mask
78
-
79
- if not reverse:
80
- flows = self.flows
81
- assert w is not None
82
-
83
- logdet_tot_q = 0
84
- h_w = self.post_pre(w)
85
- h_w = self.post_convs(h_w, x_mask)
86
- h_w = self.post_proj(h_w) * x_mask
87
- e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
88
- z_q = e_q
89
- for flow in self.post_flows:
90
- z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
91
- logdet_tot_q += logdet_q
92
- z_u, z1 = torch.split(z_q, [1, 1], 1)
93
- u = torch.sigmoid(z_u) * x_mask
94
- z0 = (w - u) * x_mask
95
- logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
96
- logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
97
-
98
- logdet_tot = 0
99
- z0, logdet = self.log_flow(z0, x_mask)
100
- logdet_tot += logdet
101
- z = torch.cat([z0, z1], 1)
102
- for flow in flows:
103
- z, logdet = flow(z, x_mask, g=x, reverse=reverse)
104
- logdet_tot = logdet_tot + logdet
105
- nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
106
- return nll + logq # [b]
107
- else:
108
- flows = list(reversed(self.flows))
109
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
110
- z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
111
- # noise scale 0.8 derived from coqui implementation, but dropped to 0.3 during testing. Might not be ideal yet.
112
- for flow in flows:
113
- z = flow(z, x_mask, g=x, reverse=reverse)
114
- z0, z1 = torch.split(z, [1, 1], 1)
115
- logw = z0
116
- return logw
117
-
118
-
119
- class Log(nn.Module):
120
- def forward(self, x, x_mask, reverse=False, **kwargs):
121
- if not reverse:
122
- y = torch.log(torch.clamp_min(x, 1e-6)) * x_mask
123
- logdet = torch.sum(-y, [1, 2])
124
- return y, logdet
125
- else:
126
- x = torch.exp(x) * x_mask
127
- return x
128
-
129
-
130
- class Flip(nn.Module):
131
- def forward(self, x, *args, reverse=False, **kwargs):
132
- x = torch.flip(x, [1])
133
- if not reverse:
134
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
135
- return x, logdet
136
- else:
137
- return x
138
-
139
-
140
- class DDSConv(nn.Module):
141
- """
142
- Dialted and Depth-Separable Convolution
143
- """
144
-
145
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
146
- super().__init__()
147
- self.channels = channels
148
- self.kernel_size = kernel_size
149
- self.n_layers = n_layers
150
- self.p_dropout = p_dropout
151
-
152
- self.drop = nn.Dropout(p_dropout)
153
- self.convs_sep = nn.ModuleList()
154
- self.convs_1x1 = nn.ModuleList()
155
- self.norms_1 = nn.ModuleList()
156
- self.norms_2 = nn.ModuleList()
157
- for i in range(n_layers):
158
- dilation = kernel_size ** i
159
- padding = (kernel_size * dilation - dilation) // 2
160
- self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
161
- groups=channels, dilation=dilation, padding=padding
162
- ))
163
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
164
- self.norms_1.append(LayerNorm(channels))
165
- self.norms_2.append(LayerNorm(channels))
166
-
167
- def forward(self, x, x_mask, g=None):
168
- if g is not None:
169
- x = x + g
170
- for i in range(self.n_layers):
171
- y = self.convs_sep[i](x * x_mask)
172
- y = self.norms_1[i](y)
173
- y = F.gelu(y)
174
- y = self.convs_1x1[i](y)
175
- y = self.norms_2[i](y)
176
- y = F.gelu(y)
177
- y = self.drop(y)
178
- x = x + y
179
- return x * x_mask
180
-
181
-
182
- class ConvFlow(nn.Module):
183
- def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
184
- super().__init__()
185
- self.in_channels = in_channels
186
- self.filter_channels = filter_channels
187
- self.kernel_size = kernel_size
188
- self.n_layers = n_layers
189
- self.num_bins = num_bins
190
- self.tail_bound = tail_bound
191
- self.half_channels = in_channels // 2
192
-
193
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
194
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
195
- self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
196
- self.proj.weight.data.zero_()
197
- self.proj.bias.data.zero_()
198
-
199
- def forward(self, x, x_mask, g=None, reverse=False):
200
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
201
- h = self.pre(x0)
202
- h = self.convs(h, x_mask, g=g)
203
- h = self.proj(h) * x_mask
204
-
205
- b, c, t = x0.shape
206
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
207
-
208
- unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
209
- unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels)
210
- unnormalized_derivatives = h[..., 2 * self.num_bins:]
211
-
212
- x1, logabsdet = piecewise_rational_quadratic_transform(x1,
213
- unnormalized_widths,
214
- unnormalized_heights,
215
- unnormalized_derivatives,
216
- inverse=reverse,
217
- tails='linear',
218
- tail_bound=self.tail_bound
219
- )
220
-
221
- x = torch.cat([x0, x1], 1) * x_mask
222
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
223
- if not reverse:
224
- return x, logdet
225
- else:
226
- return x
227
-
228
-
229
- class ElementwiseAffine(nn.Module):
230
- def __init__(self, channels):
231
- super().__init__()
232
- self.channels = channels
233
- self.m = nn.Parameter(torch.zeros(channels, 1))
234
- self.logs = nn.Parameter(torch.zeros(channels, 1))
235
-
236
- def forward(self, x, x_mask, reverse=False, **kwargs):
237
- if not reverse:
238
- y = self.m + torch.exp(self.logs) * x
239
- y = y * x_mask
240
- logdet = torch.sum(self.logs * x_mask, [1, 2])
241
- return y, logdet
242
- else:
243
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
244
- return x
245
-
246
-
247
- class LayerNorm(nn.Module):
248
- def __init__(self, channels, eps=1e-5):
249
- super().__init__()
250
- self.channels = channels
251
- self.eps = eps
252
-
253
- self.gamma = nn.Parameter(torch.ones(channels))
254
- self.beta = nn.Parameter(torch.zeros(channels))
255
-
256
- def forward(self, x):
257
- x = x.transpose(1, -1)
258
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
259
- return x.transpose(1, -1)
260
-
261
-
262
- def piecewise_rational_quadratic_transform(inputs,
263
- unnormalized_widths,
264
- unnormalized_heights,
265
- unnormalized_derivatives,
266
- inverse=False,
267
- tails=None,
268
- tail_bound=1.,
269
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
270
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
271
- min_derivative=DEFAULT_MIN_DERIVATIVE):
272
- if tails is None:
273
- spline_fn = rational_quadratic_spline
274
- spline_kwargs = {}
275
- else:
276
- spline_fn = unconstrained_rational_quadratic_spline
277
- spline_kwargs = {
278
- 'tails' : tails,
279
- 'tail_bound': tail_bound
280
- }
281
-
282
- outputs, logabsdet = spline_fn(
283
- inputs=inputs,
284
- unnormalized_widths=unnormalized_widths,
285
- unnormalized_heights=unnormalized_heights,
286
- unnormalized_derivatives=unnormalized_derivatives,
287
- inverse=inverse,
288
- min_bin_width=min_bin_width,
289
- min_bin_height=min_bin_height,
290
- min_derivative=min_derivative,
291
- **spline_kwargs
292
- )
293
- return outputs, logabsdet
294
-
295
-
296
- def rational_quadratic_spline(inputs,
297
- unnormalized_widths,
298
- unnormalized_heights,
299
- unnormalized_derivatives,
300
- inverse=False,
301
- left=0., right=1., bottom=0., top=1.,
302
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
303
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
304
- min_derivative=DEFAULT_MIN_DERIVATIVE):
305
- if torch.min(inputs) < left or torch.max(inputs) > right:
306
- raise ValueError('Input to a transform is not within its domain')
307
-
308
- num_bins = unnormalized_widths.shape[-1]
309
-
310
- if min_bin_width * num_bins > 1.0:
311
- raise ValueError('Minimal bin width too large for the number of bins')
312
- if min_bin_height * num_bins > 1.0:
313
- raise ValueError('Minimal bin height too large for the number of bins')
314
-
315
- widths = F.softmax(unnormalized_widths, dim=-1)
316
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
317
- cumwidths = torch.cumsum(widths, dim=-1)
318
- cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
319
- cumwidths = (right - left) * cumwidths + left
320
- cumwidths[..., 0] = left
321
- cumwidths[..., -1] = right
322
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
323
-
324
- derivatives = min_derivative + F.softplus(unnormalized_derivatives)
325
-
326
- heights = F.softmax(unnormalized_heights, dim=-1)
327
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
328
- cumheights = torch.cumsum(heights, dim=-1)
329
- cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
330
- cumheights = (top - bottom) * cumheights + bottom
331
- cumheights[..., 0] = bottom
332
- cumheights[..., -1] = top
333
- heights = cumheights[..., 1:] - cumheights[..., :-1]
334
-
335
- if inverse:
336
- bin_idx = searchsorted(cumheights, inputs)[..., None]
337
- else:
338
- bin_idx = searchsorted(cumwidths, inputs)[..., None]
339
-
340
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
341
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
342
-
343
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
344
- delta = heights / widths
345
- input_delta = delta.gather(-1, bin_idx)[..., 0]
346
-
347
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
348
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
349
-
350
- input_heights = heights.gather(-1, bin_idx)[..., 0]
351
-
352
- if inverse:
353
- a = (((inputs - input_cumheights) * (input_derivatives
354
- + input_derivatives_plus_one
355
- - 2 * input_delta)
356
- + input_heights * (input_delta - input_derivatives)))
357
- b = (input_heights * input_derivatives
358
- - (inputs - input_cumheights) * (input_derivatives
359
- + input_derivatives_plus_one
360
- - 2 * input_delta))
361
- c = - input_delta * (inputs - input_cumheights)
362
-
363
- discriminant = b.pow(2) - 4 * a * c
364
- assert (discriminant >= 0).all()
365
-
366
- root = (2 * c) / (-b - torch.sqrt(discriminant))
367
- outputs = root * input_bin_widths + input_cumwidths
368
-
369
- theta_one_minus_theta = root * (1 - root)
370
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
371
- * theta_one_minus_theta)
372
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
373
- + 2 * input_delta * theta_one_minus_theta
374
- + input_derivatives * (1 - root).pow(2))
375
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
376
-
377
- return outputs, -logabsdet
378
- else:
379
- theta = (inputs - input_cumwidths) / input_bin_widths
380
- theta_one_minus_theta = theta * (1 - theta)
381
-
382
- numerator = input_heights * (input_delta * theta.pow(2)
383
- + input_derivatives * theta_one_minus_theta)
384
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
385
- * theta_one_minus_theta)
386
- outputs = input_cumheights + numerator / denominator
387
-
388
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
389
- + 2 * input_delta * theta_one_minus_theta
390
- + input_derivatives * (1 - theta).pow(2))
391
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
392
-
393
- return outputs, logabsdet
394
-
395
-
396
- def searchsorted(bin_locations, inputs, eps=1e-6):
397
- bin_locations[..., -1] += eps
398
- return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
399
-
400
-
401
- def unconstrained_rational_quadratic_spline(inputs,
402
- unnormalized_widths,
403
- unnormalized_heights,
404
- unnormalized_derivatives,
405
- inverse=False,
406
- tails='linear',
407
- tail_bound=1.,
408
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
409
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
410
- min_derivative=DEFAULT_MIN_DERIVATIVE):
411
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
412
- outside_interval_mask = ~inside_interval_mask
413
-
414
- outputs = torch.zeros_like(inputs)
415
- logabsdet = torch.zeros_like(inputs)
416
-
417
- if tails == 'linear':
418
- unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
419
- constant = np.log(np.exp(1 - min_derivative) - 1)
420
- unnormalized_derivatives[..., 0] = constant
421
- unnormalized_derivatives[..., -1] = constant
422
-
423
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
424
- logabsdet[outside_interval_mask] = 0
425
- else:
426
- raise RuntimeError('{} tails are not implemented.'.format(tails))
427
-
428
- outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
429
- inputs=inputs[inside_interval_mask],
430
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
431
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
432
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
433
- inverse=inverse,
434
- left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
435
- min_bin_width=min_bin_width,
436
- min_bin_height=min_bin_height,
437
- min_derivative=min_derivative
438
- )
439
-
440
- return outputs, logabsdet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Architectures/__init__.py DELETED
File without changes
InferenceInterfaces/ControllableInterface.py CHANGED
@@ -2,8 +2,8 @@ import os
2
 
3
  import torch
4
 
5
- from Architectures.ControllabilityGAN.GAN import GanWrapper
6
  from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
 
7
  from Utility.storage_config import MODELS_DIR
8
 
9
 
@@ -15,7 +15,7 @@ class ControllableInterface:
15
  else:
16
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
17
  os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
18
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
  self.model = ToucanTTSInterface(device=self.device, tts_model_path="Meta")
20
  self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device=self.device)
21
  self.generated_speaker_embeds = list()
@@ -25,9 +25,11 @@ class ControllableInterface:
25
 
26
  def read(self,
27
  prompt,
 
28
  language,
29
  accent,
30
  voice_seed,
 
31
  duration_scaling_factor,
32
  pause_duration_scaling_factor,
33
  pitch_variance_scale,
@@ -37,24 +39,29 @@ class ControllableInterface:
37
  emb_slider_3,
38
  emb_slider_4,
39
  emb_slider_5,
40
- emb_slider_6
 
41
  ):
42
  if self.current_language != language:
43
  self.model.set_phonemizer_language(language)
 
44
  self.current_language = language
45
  if self.current_accent != accent:
46
  self.model.set_accent_language(accent)
 
47
  self.current_accent = accent
48
-
49
- self.wgan.set_latent(voice_seed)
50
- controllability_vector = torch.tensor([emb_slider_1,
51
- emb_slider_2,
52
- emb_slider_3,
53
- emb_slider_4,
54
- emb_slider_5,
55
- emb_slider_6], dtype=torch.float32)
56
- embedding = self.wgan.modify_embed(controllability_vector)
57
- self.model.set_utterance_embedding(embedding=embedding)
 
 
58
 
59
  phones = self.model.text2phone.get_phone_string(prompt)
60
  if len(phones) > 1800:
@@ -92,15 +99,15 @@ class ControllableInterface:
92
  if self.current_accent != "eng":
93
  self.model.set_accent_language("eng")
94
  self.current_accent = "eng"
95
- print("\n\n")
96
- print(prompt)
97
- print(language)
98
- print("\n\n")
99
  wav, sr, fig = self.model(prompt,
100
  input_is_phones=False,
101
  duration_scaling_factor=duration_scaling_factor,
102
  pitch_variance_scale=pitch_variance_scale,
103
  energy_variance_scale=energy_variance_scale,
104
  pause_duration_scaling_factor=pause_duration_scaling_factor,
105
- return_plot_as_filepath=True)
 
 
106
  return sr, wav, fig
 
2
 
3
  import torch
4
 
 
5
  from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
6
+ from Modules.ControllabilityGAN.GAN import GanWrapper
7
  from Utility.storage_config import MODELS_DIR
8
 
9
 
 
15
  else:
16
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
17
  os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
18
+ self.device = "cuda" if gpu_id != "cpu" else "cpu"
19
  self.model = ToucanTTSInterface(device=self.device, tts_model_path="Meta")
20
  self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device=self.device)
21
  self.generated_speaker_embeds = list()
 
25
 
26
  def read(self,
27
  prompt,
28
+ reference_audio,
29
  language,
30
  accent,
31
  voice_seed,
32
+ prosody_creativity,
33
  duration_scaling_factor,
34
  pause_duration_scaling_factor,
35
  pitch_variance_scale,
 
39
  emb_slider_3,
40
  emb_slider_4,
41
  emb_slider_5,
42
+ emb_slider_6,
43
+ loudness_in_db
44
  ):
45
  if self.current_language != language:
46
  self.model.set_phonemizer_language(language)
47
+ print(f"switched phonemizer language to {language}")
48
  self.current_language = language
49
  if self.current_accent != accent:
50
  self.model.set_accent_language(accent)
51
+ print(f"switched accent language to {accent}")
52
  self.current_accent = accent
53
+ if reference_audio is None:
54
+ self.wgan.set_latent(voice_seed)
55
+ controllability_vector = torch.tensor([emb_slider_1,
56
+ emb_slider_2,
57
+ emb_slider_3,
58
+ emb_slider_4,
59
+ emb_slider_5,
60
+ emb_slider_6], dtype=torch.float32)
61
+ embedding = self.wgan.modify_embed(controllability_vector)
62
+ self.model.set_utterance_embedding(embedding=embedding)
63
+ else:
64
+ self.model.set_utterance_embedding(reference_audio)
65
 
66
  phones = self.model.text2phone.get_phone_string(prompt)
67
  if len(phones) > 1800:
 
99
  if self.current_accent != "eng":
100
  self.model.set_accent_language("eng")
101
  self.current_accent = "eng"
102
+
103
+ print(prompt + "\n\n")
 
 
104
  wav, sr, fig = self.model(prompt,
105
  input_is_phones=False,
106
  duration_scaling_factor=duration_scaling_factor,
107
  pitch_variance_scale=pitch_variance_scale,
108
  energy_variance_scale=energy_variance_scale,
109
  pause_duration_scaling_factor=pause_duration_scaling_factor,
110
+ return_plot_as_filepath=True,
111
+ prosody_creativity=prosody_creativity,
112
+ loudness_in_db=loudness_in_db)
113
  return sr, wav, fig
InferenceInterfaces/ToucanTTSInterface.py CHANGED
@@ -1,19 +1,17 @@
1
  import itertools
2
  import os
3
- import warnings
4
 
 
5
  import matplotlib.pyplot as plt
6
  import pyloudnorm
7
  import sounddevice
8
  import soundfile
9
  import torch
10
- with warnings.catch_warnings():
11
- warnings.simplefilter("ignore")
12
- from speechbrain.pretrained import EncoderClassifier
13
- from torchaudio.transforms import Resample
14
 
15
- from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS
16
- from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN
17
  from Preprocessing.AudioPreprocessor import AudioPreprocessor
18
  from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
19
  from Preprocessing.TextFrontend import get_language_id
@@ -29,7 +27,6 @@ class ToucanTTSInterface(torch.nn.Module):
29
  tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
30
  vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
31
  language="eng", # initial language of the model, can be changed later with the setter methods
32
- enhance=None # legacy argument
33
  ):
34
  super().__init__()
35
  self.device = device
@@ -40,7 +37,7 @@ class ToucanTTSInterface(torch.nn.Module):
40
  ################################
41
  # build text to phone #
42
  ################################
43
- self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
44
 
45
  #####################################
46
  # load phone to features model #
@@ -92,8 +89,12 @@ class ToucanTTSInterface(torch.nn.Module):
92
  speaker_embs = list()
93
  for path in path_to_reference_audio:
94
  wave, sr = soundfile.read(path)
 
 
 
 
95
  wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32))
96
- speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).unsqueeze(0)).squeeze()
97
  speaker_embs.append(speaker_embedding)
98
  self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs)
99
 
@@ -105,10 +106,10 @@ class ToucanTTSInterface(torch.nn.Module):
105
  self.set_accent_language(lang_id=lang_id)
106
 
107
  def set_phonemizer_language(self, lang_id):
108
- self.text2phone.change_lang(language=lang_id, add_silence_to_end=True)
109
 
110
  def set_accent_language(self, lang_id):
111
- if lang_id in ['ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so']:
112
  if lang_id == 'vi-so' or lang_id == 'vi-ctr':
113
  lang_id = 'vie'
114
  elif lang_id == 'spa-lat':
@@ -120,7 +121,7 @@ class ToucanTTSInterface(torch.nn.Module):
120
  elif lang_id == 'en-sc' or lang_id == 'en-us':
121
  lang_id = 'eng'
122
  else:
123
- # no clue where these others are even coming from, they are not in ISO 639-2
124
  lang_id = 'eng'
125
 
126
  self.lang_id = get_language_id(lang_id).to(self.device)
@@ -138,7 +139,7 @@ class ToucanTTSInterface(torch.nn.Module):
138
  input_is_phones=False,
139
  return_plot_as_filepath=False,
140
  loudness_in_db=-24.0,
141
- glow_sampling_temperature=0.2):
142
  """
143
  duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
144
  1.0 means no scaling happens, higher values increase durations for the whole
@@ -154,16 +155,16 @@ class ToucanTTSInterface(torch.nn.Module):
154
  phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
155
  mel, durations, pitch, energy = self.phone2mel(phones,
156
  return_duration_pitch_energy=True,
157
- utterance_embedding=self.default_utterance_embedding.to(self.device),
158
  durations=durations,
159
  pitch=pitch,
160
  energy=energy,
161
- lang_id=self.lang_id.to(self.device),
162
  duration_scaling_factor=duration_scaling_factor,
163
  pitch_variance_scale=pitch_variance_scale,
164
  energy_variance_scale=energy_variance_scale,
165
  pause_duration_scaling_factor=pause_duration_scaling_factor,
166
- glow_sampling_temperature=glow_sampling_temperature)
167
 
168
  wave, _, _ = self.vocoder(mel.unsqueeze(0))
169
  wave = wave.squeeze().cpu()
@@ -177,63 +178,56 @@ class ToucanTTSInterface(torch.nn.Module):
177
  pass
178
 
179
  if view or return_plot_as_filepath:
180
- try:
181
- fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
182
 
183
- ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu')
184
- ax.yaxis.set_visible(False)
185
- duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
186
- ax.xaxis.grid(True, which='minor')
187
- ax.set_xticks(label_positions, minor=False)
188
- if input_is_phones:
189
- phones = text.replace(" ", "|")
190
- else:
191
- phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
192
- try:
193
- ax.set_xticklabels(phones)
194
- except IndexError:
195
- pass
196
- word_boundaries = list()
197
- for label_index, phone in enumerate(phones):
198
- if phone == "|":
199
- word_boundaries.append(label_positions[label_index])
 
 
200
 
201
- try:
202
- prev_word_boundary = 0
203
- word_label_positions = list()
204
- for word_boundary in word_boundaries:
205
- word_label_positions.append((word_boundary + prev_word_boundary) / 2)
206
- prev_word_boundary = word_boundary
207
- word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2)
208
 
209
- secondary_ax = ax.secondary_xaxis('bottom')
210
- secondary_ax.tick_params(axis="x", direction="out", pad=24)
211
- secondary_ax.set_xticks(word_label_positions, minor=False)
212
- secondary_ax.set_xticklabels(text.split())
213
- secondary_ax.tick_params(axis='x', colors='orange')
214
- secondary_ax.xaxis.label.set_color('orange')
215
- except ValueError:
216
- ax.set_title(text)
217
- except IndexError:
218
- ax.set_title(text)
219
- except RuntimeError:
220
- ax.set_title(text)
221
 
222
- ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5)
223
- ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0)
224
- plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0)
225
- ax.set_aspect("auto")
226
- except:
227
- pass
228
 
229
  if return_plot_as_filepath:
230
- try:
231
- plt.savefig("tmp.png")
232
- plt.close()
233
- except:
234
- pass
235
  return wave, sr, "tmp.png"
236
-
237
  return wave, sr
238
 
239
  def read_to_file(self,
@@ -247,7 +241,7 @@ class ToucanTTSInterface(torch.nn.Module):
247
  dur_list=None,
248
  pitch_list=None,
249
  energy_list=None,
250
- glow_sampling_temperature=0.2):
251
  """
252
  Args:
253
  silent: Whether to be verbose about the process
@@ -259,12 +253,19 @@ class ToucanTTSInterface(torch.nn.Module):
259
  duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
260
  1.0 means no scaling happens, higher values increase durations for the whole
261
  utterance, lower values decrease durations for the whole utterance.
 
 
 
262
  pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
263
  1.0 means no scaling happens, higher values increase variance of the pitch curve,
264
  lower values decrease variance of the pitch curve.
265
  energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
266
  1.0 means no scaling happens, higher values increase variance of the energy curve,
267
  lower values decrease variance of the energy curve.
 
 
 
 
268
  """
269
  if not dur_list:
270
  dur_list = []
@@ -272,7 +273,7 @@ class ToucanTTSInterface(torch.nn.Module):
272
  pitch_list = []
273
  if not energy_list:
274
  energy_list = []
275
- silence = torch.zeros([14300])
276
  wav = silence.clone()
277
  for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list):
278
  if text.strip() != "":
@@ -286,7 +287,7 @@ class ToucanTTSInterface(torch.nn.Module):
286
  pitch_variance_scale=pitch_variance_scale,
287
  energy_variance_scale=energy_variance_scale,
288
  pause_duration_scaling_factor=pause_duration_scaling_factor,
289
- glow_sampling_temperature=glow_sampling_temperature)
290
  spoken_sentence = torch.tensor(spoken_sentence).cpu()
291
  wav = torch.cat((wav, spoken_sentence, silence), 0)
292
  soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16")
@@ -298,7 +299,7 @@ class ToucanTTSInterface(torch.nn.Module):
298
  pitch_variance_scale=1.0,
299
  energy_variance_scale=1.0,
300
  blocking=False,
301
- glow_sampling_temperature=0.2):
302
  if text.strip() == "":
303
  return
304
  wav, sr = self(text,
@@ -306,7 +307,7 @@ class ToucanTTSInterface(torch.nn.Module):
306
  duration_scaling_factor=duration_scaling_factor,
307
  pitch_variance_scale=pitch_variance_scale,
308
  energy_variance_scale=energy_variance_scale,
309
- glow_sampling_temperature=glow_sampling_temperature)
310
  silence = torch.zeros([sr // 2])
311
  wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy()
312
  sounddevice.play(float2pcm(wav), samplerate=sr)
 
1
  import itertools
2
  import os
 
3
 
4
+ import librosa
5
  import matplotlib.pyplot as plt
6
  import pyloudnorm
7
  import sounddevice
8
  import soundfile
9
  import torch
10
+ from speechbrain.pretrained import EncoderClassifier
11
+ from torchaudio.transforms import Resample
 
 
12
 
13
+ from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS
14
+ from Modules.Vocoder.HiFiGAN_Generator import HiFiGAN
15
  from Preprocessing.AudioPreprocessor import AudioPreprocessor
16
  from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
17
  from Preprocessing.TextFrontend import get_language_id
 
27
  tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
28
  vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
29
  language="eng", # initial language of the model, can be changed later with the setter methods
 
30
  ):
31
  super().__init__()
32
  self.device = device
 
37
  ################################
38
  # build text to phone #
39
  ################################
40
+ self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True, device=device)
41
 
42
  #####################################
43
  # load phone to features model #
 
89
  speaker_embs = list()
90
  for path in path_to_reference_audio:
91
  wave, sr = soundfile.read(path)
92
+ if len(wave.shape) > 1: # oh no, we found a stereo audio!
93
+ if len(wave[0]) == 2: # let's figure out whether we need to switch the axes
94
+ wave = wave.transpose() # if yes, we switch the axes.
95
+ wave = librosa.to_mono(wave)
96
  wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32))
97
+ speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).squeeze().unsqueeze(0)).squeeze()
98
  speaker_embs.append(speaker_embedding)
99
  self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs)
100
 
 
106
  self.set_accent_language(lang_id=lang_id)
107
 
108
  def set_phonemizer_language(self, lang_id):
109
+ self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, device=self.device)
110
 
111
  def set_accent_language(self, lang_id):
112
+ if lang_id in {'ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so'}:
113
  if lang_id == 'vi-so' or lang_id == 'vi-ctr':
114
  lang_id = 'vie'
115
  elif lang_id == 'spa-lat':
 
121
  elif lang_id == 'en-sc' or lang_id == 'en-us':
122
  lang_id = 'eng'
123
  else:
124
+ # no clue where these others are even coming from, they are not in ISO 639-3
125
  lang_id = 'eng'
126
 
127
  self.lang_id = get_language_id(lang_id).to(self.device)
 
139
  input_is_phones=False,
140
  return_plot_as_filepath=False,
141
  loudness_in_db=-24.0,
142
+ prosody_creativity=0.1):
143
  """
144
  duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
145
  1.0 means no scaling happens, higher values increase durations for the whole
 
155
  phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
156
  mel, durations, pitch, energy = self.phone2mel(phones,
157
  return_duration_pitch_energy=True,
158
+ utterance_embedding=self.default_utterance_embedding,
159
  durations=durations,
160
  pitch=pitch,
161
  energy=energy,
162
+ lang_id=self.lang_id,
163
  duration_scaling_factor=duration_scaling_factor,
164
  pitch_variance_scale=pitch_variance_scale,
165
  energy_variance_scale=energy_variance_scale,
166
  pause_duration_scaling_factor=pause_duration_scaling_factor,
167
+ prosody_creativity=prosody_creativity)
168
 
169
  wave, _, _ = self.vocoder(mel.unsqueeze(0))
170
  wave = wave.squeeze().cpu()
 
178
  pass
179
 
180
  if view or return_plot_as_filepath:
181
+ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
 
182
 
183
+ ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu')
184
+ ax.yaxis.set_visible(False)
185
+ duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
186
+ ax.xaxis.grid(True, which='minor')
187
+ ax.set_xticks(label_positions, minor=False)
188
+ if input_is_phones:
189
+ phones = text.replace(" ", "|")
190
+ else:
191
+ phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
192
+ try:
193
+ ax.set_xticklabels(phones)
194
+ except IndexError:
195
+ pass
196
+ except ValueError:
197
+ pass
198
+ word_boundaries = list()
199
+ for label_index, phone in enumerate(phones):
200
+ if phone == "|":
201
+ word_boundaries.append(label_positions[label_index])
202
 
203
+ try:
204
+ prev_word_boundary = 0
205
+ word_label_positions = list()
206
+ for word_boundary in word_boundaries:
207
+ word_label_positions.append((word_boundary + prev_word_boundary) / 2)
208
+ prev_word_boundary = word_boundary
209
+ word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2)
210
 
211
+ secondary_ax = ax.secondary_xaxis('bottom')
212
+ secondary_ax.tick_params(axis="x", direction="out", pad=24)
213
+ secondary_ax.set_xticks(word_label_positions, minor=False)
214
+ secondary_ax.set_xticklabels(text.split())
215
+ secondary_ax.tick_params(axis='x', colors='orange')
216
+ secondary_ax.xaxis.label.set_color('orange')
217
+ except ValueError:
218
+ ax.set_title(text)
219
+ except IndexError:
220
+ ax.set_title(text)
 
 
221
 
222
+ ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5)
223
+ ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0)
224
+ plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0)
225
+ ax.set_aspect("auto")
 
 
226
 
227
  if return_plot_as_filepath:
228
+ plt.savefig("tmp.png")
229
+ plt.close()
 
 
 
230
  return wave, sr, "tmp.png"
 
231
  return wave, sr
232
 
233
  def read_to_file(self,
 
241
  dur_list=None,
242
  pitch_list=None,
243
  energy_list=None,
244
+ prosody_creativity=0.1):
245
  """
246
  Args:
247
  silent: Whether to be verbose about the process
 
253
  duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
254
  1.0 means no scaling happens, higher values increase durations for the whole
255
  utterance, lower values decrease durations for the whole utterance.
256
+ pause_duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
257
+ 1.0 means no scaling happens, higher values increase durations for the pauses,
258
+ lower values decrease durations for the whole utterance.
259
  pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
260
  1.0 means no scaling happens, higher values increase variance of the pitch curve,
261
  lower values decrease variance of the pitch curve.
262
  energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
263
  1.0 means no scaling happens, higher values increase variance of the energy curve,
264
  lower values decrease variance of the energy curve.
265
+ prosody_creativity: sampling temperature of the generative model that comes up with the pitch, energy and
266
+ durations. Higher values mena more variance, lower temperature means less variance across
267
+ generations. reasonable values are between 0.0 and 1.2, anything higher makes the voice
268
+ sound very weird.
269
  """
270
  if not dur_list:
271
  dur_list = []
 
273
  pitch_list = []
274
  if not energy_list:
275
  energy_list = []
276
+ silence = torch.zeros([400])
277
  wav = silence.clone()
278
  for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list):
279
  if text.strip() != "":
 
287
  pitch_variance_scale=pitch_variance_scale,
288
  energy_variance_scale=energy_variance_scale,
289
  pause_duration_scaling_factor=pause_duration_scaling_factor,
290
+ prosody_creativity=prosody_creativity)
291
  spoken_sentence = torch.tensor(spoken_sentence).cpu()
292
  wav = torch.cat((wav, spoken_sentence, silence), 0)
293
  soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16")
 
299
  pitch_variance_scale=1.0,
300
  energy_variance_scale=1.0,
301
  blocking=False,
302
+ prosody_creativity=0.1):
303
  if text.strip() == "":
304
  return
305
  wav, sr = self(text,
 
307
  duration_scaling_factor=duration_scaling_factor,
308
  pitch_variance_scale=pitch_variance_scale,
309
  energy_variance_scale=energy_variance_scale,
310
+ prosody_creativity=prosody_creativity)
311
  silence = torch.zeros([sr // 2])
312
  wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy()
313
  sounddevice.play(float2pcm(wav), samplerate=sr)
InferenceInterfaces/UtteranceCloner.py CHANGED
@@ -4,11 +4,11 @@ import numpy
4
  import soundfile as sf
5
  import torch
6
 
7
- from Architectures.Aligner.Aligner import Aligner
8
- from Architectures.ToucanTTS.DurationCalculator import DurationCalculator
9
- from Architectures.ToucanTTS.EnergyCalculator import EnergyCalculator
10
- from Architectures.ToucanTTS.PitchCalculator import Parselmouth
11
  from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
 
 
 
 
12
  from Preprocessing.AudioPreprocessor import AudioPreprocessor
13
  from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
14
  from Preprocessing.articulatory_features import get_feature_to_index_lookup
@@ -26,7 +26,7 @@ class UtteranceCloner:
26
  def __init__(self, model_id, device, language="eng"):
27
  self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
28
  self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
29
- self.tf = ArticulatoryCombinedTextFrontend(language=language)
30
  self.device = device
31
  acoustic_checkpoint_path = os.path.join(MODELS_DIR, "Aligner", "aligner.pt")
32
  self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"]
@@ -43,6 +43,7 @@ class UtteranceCloner:
43
  self.acoustic_model = Aligner()
44
  self.acoustic_model = self.acoustic_model.to(self.device)
45
  self.acoustic_model.load_state_dict(self.aligner_weights)
 
46
  self.parsel = Parselmouth(reduction_factor=1, fs=16000)
47
  self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
48
  self.dc = DurationCalculator(reduction_factor=1)
@@ -50,10 +51,11 @@ class UtteranceCloner:
50
  def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True):
51
  if on_line_fine_tune:
52
  self.acoustic_model.load_state_dict(self.aligner_weights)
 
53
 
54
  wave, sr = sf.read(ref_audio_path)
55
  if self.tf.language != lang:
56
- self.tf = ArticulatoryCombinedTextFrontend(language=lang)
57
  if self.ap.input_sr != sr:
58
  self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False)
59
  try:
 
4
  import soundfile as sf
5
  import torch
6
 
 
 
 
 
7
  from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
8
+ from Modules.Aligner.Aligner import Aligner
9
+ from Modules.ToucanTTS.DurationCalculator import DurationCalculator
10
+ from Modules.ToucanTTS.EnergyCalculator import EnergyCalculator
11
+ from Modules.ToucanTTS.PitchCalculator import Parselmouth
12
  from Preprocessing.AudioPreprocessor import AudioPreprocessor
13
  from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
14
  from Preprocessing.articulatory_features import get_feature_to_index_lookup
 
26
  def __init__(self, model_id, device, language="eng"):
27
  self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
28
  self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
29
+ self.tf = ArticulatoryCombinedTextFrontend(language=language, device=device)
30
  self.device = device
31
  acoustic_checkpoint_path = os.path.join(MODELS_DIR, "Aligner", "aligner.pt")
32
  self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"]
 
43
  self.acoustic_model = Aligner()
44
  self.acoustic_model = self.acoustic_model.to(self.device)
45
  self.acoustic_model.load_state_dict(self.aligner_weights)
46
+ self.acoustic_model.eval()
47
  self.parsel = Parselmouth(reduction_factor=1, fs=16000)
48
  self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
49
  self.dc = DurationCalculator(reduction_factor=1)
 
51
  def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True):
52
  if on_line_fine_tune:
53
  self.acoustic_model.load_state_dict(self.aligner_weights)
54
+ self.acoustic_model.eval()
55
 
56
  wave, sr = sf.read(ref_audio_path)
57
  if self.tf.language != lang:
58
+ self.tf = ArticulatoryCombinedTextFrontend(language=lang, device=self.device)
59
  if self.ap.input_sr != sr:
60
  self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False)
61
  try:
InferenceInterfaces/audioseal_wm_16bits.yaml DELETED
@@ -1,39 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- name: audioseal_wm_16bits
8
- model_type: seanet
9
- checkpoint: "https://dl.fbaipublicfiles.com/audioseal/6edcf62f/generator.pth"
10
- nbits: 16
11
- seanet:
12
- activation: ELU
13
- activation_params:
14
- alpha: 1.0
15
- causal: false
16
- channels: 1
17
- compress: 2
18
- dilation_base: 2
19
- dimension: 128
20
- disable_norm_outer_blocks: 0
21
- kernel_size: 7
22
- last_kernel_size: 7
23
- lstm: 2
24
- n_filters: 32
25
- n_residual_layers: 1
26
- norm: weight_norm
27
- norm_params: { }
28
- pad_mode: constant
29
- ratios:
30
- - 8
31
- - 5
32
- - 4
33
- - 2
34
- residual_kernel_size: 3
35
- true_skip: true
36
- decoder:
37
- final_activation: null
38
- final_activation_params: null
39
- trim_right_ratio: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{Architectures β†’ Modules}/Aligner/Aligner.py RENAMED
@@ -1,27 +1,31 @@
1
  """
2
  taken and adapted from https://github.com/as-ideas/DeepForcedAligner
 
 
 
 
3
  """
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import torch
7
  import torch.multiprocessing
8
- import torch.nn as nn
9
  from torch.nn import CTCLoss
10
  from torch.nn.utils.rnn import pack_padded_sequence
11
  from torch.nn.utils.rnn import pad_packed_sequence
12
 
13
  from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
 
14
 
15
 
16
- class BatchNormConv(nn.Module):
17
 
18
  def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
19
  super().__init__()
20
- self.conv = nn.Conv1d(
21
  in_channels, out_channels, kernel_size,
22
  stride=1, padding=kernel_size // 2, bias=False)
23
- self.bnorm = nn.BatchNorm1d(out_channels)
24
- self.relu = nn.ReLU()
25
 
26
  def forward(self, x):
27
  x = x.transpose(1, 2)
@@ -37,22 +41,23 @@ class Aligner(torch.nn.Module):
37
  def __init__(self,
38
  n_features=128,
39
  num_symbols=145,
40
- lstm_dim=512,
41
- conv_dim=512):
42
  super().__init__()
43
- self.convs = nn.ModuleList([
44
  BatchNormConv(n_features, conv_dim, 3),
45
- nn.Dropout(p=0.5),
46
  BatchNormConv(conv_dim, conv_dim, 3),
47
- nn.Dropout(p=0.5),
48
  BatchNormConv(conv_dim, conv_dim, 3),
49
- nn.Dropout(p=0.5),
50
  BatchNormConv(conv_dim, conv_dim, 3),
51
- nn.Dropout(p=0.5),
52
  BatchNormConv(conv_dim, conv_dim, 3),
53
- nn.Dropout(p=0.5),
54
  ])
55
- self.rnn = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
 
56
  self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
57
  self.tf = ArticulatoryCombinedTextFrontend(language="eng")
58
  self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
@@ -61,14 +66,17 @@ class Aligner(torch.nn.Module):
61
  def forward(self, x, lens=None):
62
  for conv in self.convs:
63
  x = conv(x)
64
-
65
  if lens is not None:
66
  x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
67
- x, _ = self.rnn(x)
 
68
  if lens is not None:
69
  x, _ = pad_packed_sequence(x, batch_first=True)
70
 
71
  x = self.proj(x)
 
 
 
72
 
73
  return x
74
 
@@ -88,15 +96,12 @@ class Aligner(torch.nn.Module):
88
  pred_max = pred[:, tokens]
89
 
90
  # run monotonic alignment search
91
-
92
  alignment_matrix = binarize_alignment(pred_max)
93
 
94
  if save_img_for_debug is not None:
95
  phones = list()
96
  for index in tokens:
97
- for phone in self.tf.phone_to_id:
98
- if self.tf.phone_to_id[phone] == index:
99
- phones.append(phone)
100
  fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
101
 
102
  ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
@@ -115,7 +120,6 @@ class Aligner(torch.nn.Module):
115
  return alignment_matrix
116
 
117
 
118
-
119
  def binarize_alignment(alignment_prob):
120
  """
121
  # Implementation by:
@@ -152,13 +156,5 @@ def binarize_alignment(alignment_prob):
152
 
153
 
154
  if __name__ == '__main__':
155
- tf = ArticulatoryCombinedTextFrontend(language="eng")
156
- from Preprocessing.HiFiCodecAudioPreprocessor import CodecAudioPreprocessor
157
-
158
- cap = CodecAudioPreprocessor(input_sr=-1)
159
- dummy_codebook_indexes = torch.randint(low=0, high=1023, size=[9, 20])
160
- codebook_frames = cap.indexes_to_codec_frames(dummy_codebook_indexes)
161
- alignment = Aligner().inference(codebook_frames.transpose(0, 1), tokens=tf.string_to_tensor("Hello world"))
162
- print(alignment.shape)
163
- plt.imshow(alignment, origin="lower", cmap="GnBu")
164
- plt.show()
 
1
  """
2
  taken and adapted from https://github.com/as-ideas/DeepForcedAligner
3
+
4
+ refined with insights from https://www.audiolabs-erlangen.de/resources/NLUI/2023-ICASSP-eval-alignment-tts
5
+ EVALUATING SPEECH–PHONEME ALIGNMENT AND ITS IMPACT ON NEURAL TEXT-TO-SPEECH SYNTHESIS
6
+ by Frank Zalkow, Prachi Govalkar, Meinard Muller, Emanuel A. P. Habets, Christian Dittmar
7
  """
8
  import matplotlib.pyplot as plt
9
  import numpy as np
10
  import torch
11
  import torch.multiprocessing
 
12
  from torch.nn import CTCLoss
13
  from torch.nn.utils.rnn import pack_padded_sequence
14
  from torch.nn.utils.rnn import pad_packed_sequence
15
 
16
  from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
17
+ from Utility.utils import make_non_pad_mask
18
 
19
 
20
+ class BatchNormConv(torch.nn.Module):
21
 
22
  def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
23
  super().__init__()
24
+ self.conv = torch.nn.Conv1d(
25
  in_channels, out_channels, kernel_size,
26
  stride=1, padding=kernel_size // 2, bias=False)
27
+ self.bnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm(torch.nn.BatchNorm1d(out_channels))
28
+ self.relu = torch.nn.ReLU()
29
 
30
  def forward(self, x):
31
  x = x.transpose(1, 2)
 
41
  def __init__(self,
42
  n_features=128,
43
  num_symbols=145,
44
+ conv_dim=512,
45
+ lstm_dim=512):
46
  super().__init__()
47
+ self.convs = torch.nn.ModuleList([
48
  BatchNormConv(n_features, conv_dim, 3),
49
+ torch.nn.Dropout(p=0.5),
50
  BatchNormConv(conv_dim, conv_dim, 3),
51
+ torch.nn.Dropout(p=0.5),
52
  BatchNormConv(conv_dim, conv_dim, 3),
53
+ torch.nn.Dropout(p=0.5),
54
  BatchNormConv(conv_dim, conv_dim, 3),
55
+ torch.nn.Dropout(p=0.5),
56
  BatchNormConv(conv_dim, conv_dim, 3),
57
+ torch.nn.Dropout(p=0.5),
58
  ])
59
+ self.rnn1 = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
60
+ self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
61
  self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
62
  self.tf = ArticulatoryCombinedTextFrontend(language="eng")
63
  self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
 
66
  def forward(self, x, lens=None):
67
  for conv in self.convs:
68
  x = conv(x)
 
69
  if lens is not None:
70
  x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
71
+ x, _ = self.rnn1(x)
72
+ x, _ = self.rnn2(x)
73
  if lens is not None:
74
  x, _ = pad_packed_sequence(x, batch_first=True)
75
 
76
  x = self.proj(x)
77
+ if lens is not None:
78
+ out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(x.device)
79
+ x = x * out_masks.float()
80
 
81
  return x
82
 
 
96
  pred_max = pred[:, tokens]
97
 
98
  # run monotonic alignment search
 
99
  alignment_matrix = binarize_alignment(pred_max)
100
 
101
  if save_img_for_debug is not None:
102
  phones = list()
103
  for index in tokens:
104
+ phones.append(self.tf.id_to_phone[index])
 
 
105
  fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
106
 
107
  ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
 
120
  return alignment_matrix
121
 
122
 
 
123
  def binarize_alignment(alignment_prob):
124
  """
125
  # Implementation by:
 
156
 
157
 
158
  if __name__ == '__main__':
159
+ print(sum(p.numel() for p in Aligner().parameters() if p.requires_grad))
160
+ print(Aligner()(x=torch.randn(size=[3, 30, 128]), lens=torch.LongTensor([20, 30, 10])).shape)
 
 
 
 
 
 
 
 
{Architectures β†’ Modules}/Aligner/CodecAlignerDataset.py RENAMED
@@ -32,6 +32,7 @@ class CodecAlignerDataset(Dataset):
32
  allow_unknown_symbols=False,
33
  gpu_count=1,
34
  rank=0):
 
35
  self.gpu_count = gpu_count
36
  self.rank = rank
37
  if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
@@ -50,9 +51,10 @@ class CodecAlignerDataset(Dataset):
50
  self.lang = lang
51
  self.device = device
52
  self.cache_dir = cache_dir
53
- self.tf = ArticulatoryCombinedTextFrontend(language=self.lang)
54
  cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
55
  self.speaker_embeddings = cache[2]
 
56
  self.datapoints = cache[0]
57
  if self.gpu_count > 1:
58
  # we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
@@ -85,6 +87,7 @@ class CodecAlignerDataset(Dataset):
85
  if type(path_to_transcript_dict) != dict:
86
  path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
87
  torch.multiprocessing.set_start_method('spawn', force=True)
 
88
  resource_manager = Manager()
89
  self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
90
  key_list = list(self.path_to_transcript_dict.keys())
@@ -93,6 +96,13 @@ class CodecAlignerDataset(Dataset):
93
  fisher_yates_shuffle(key_list)
94
  # build cache
95
  print("... building dataset cache ...")
 
 
 
 
 
 
 
96
  self.result_pool = resource_manager.list()
97
  # make processes
98
  key_splits = list()
@@ -176,8 +186,8 @@ class CodecAlignerDataset(Dataset):
176
  torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
177
  # this to false globally during model loading rather than using inference mode or no_grad
178
  silero_model = silero_model.to(device)
179
- silence = torch.zeros([16000 // 4], device=device)
180
- tf = ArticulatoryCombinedTextFrontend(language=lang)
181
  _, sr = sf.read(path_list[0])
182
  assumed_sr = sr
183
  ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
@@ -186,13 +196,15 @@ class CodecAlignerDataset(Dataset):
186
  for path in tqdm(path_list):
187
  if self.path_to_transcript_dict[path].strip() == "":
188
  continue
189
-
190
  try:
191
  wave, sr = sf.read(path)
192
  except:
193
  print(f"Problem with an audio file: {path}")
194
  continue
195
 
 
 
 
196
  wave = librosa.to_mono(wave)
197
 
198
  if sr != assumed_sr:
@@ -210,16 +222,19 @@ class CodecAlignerDataset(Dataset):
210
  if verbose:
211
  print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
212
  continue
213
-
214
- # remove silences from front and back, then add constant 1/4th second silences back to front and back
215
- with torch.no_grad():
216
  speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
217
  try:
 
 
 
 
 
218
  result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
219
  except IndexError:
220
  print("Audio might be too short to cut silences from front and back.")
221
  continue
222
- wave = torch.cat([silence, result, silence])
223
 
224
  # raw audio preprocessing is done
225
  transcript = self.path_to_transcript_dict[path]
@@ -238,10 +253,10 @@ class CodecAlignerDataset(Dataset):
238
  # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
239
  continue
240
 
241
- cached_speech = ap.audio_to_codebook_indexes(audio=wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy()
242
  process_internal_dataset_chunk.append([cached_text,
243
  cached_speech,
244
- result.cpu().detach().numpy(),
245
  path])
246
  self.result_pool.append(process_internal_dataset_chunk)
247
 
@@ -256,16 +271,44 @@ class CodecAlignerDataset(Dataset):
256
  codes = codes.transpose(0, 1)
257
 
258
  return tokens, \
259
- token_len, \
260
- codes, \
261
- None, \
262
- self.speaker_embeddings[index]
263
 
264
  def __len__(self):
265
  return len(self.datapoints)
266
 
 
 
 
 
 
 
 
 
 
267
 
268
  def fisher_yates_shuffle(lst):
269
  for i in range(len(lst) - 1, 0, -1):
270
  j = random.randint(0, i)
271
  lst[i], lst[j] = lst[j], lst[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  allow_unknown_symbols=False,
33
  gpu_count=1,
34
  rank=0):
35
+
36
  self.gpu_count = gpu_count
37
  self.rank = rank
38
  if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
 
51
  self.lang = lang
52
  self.device = device
53
  self.cache_dir = cache_dir
54
+ self.tf = ArticulatoryCombinedTextFrontend(language=self.lang, device=device)
55
  cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
56
  self.speaker_embeddings = cache[2]
57
+ self.filepaths = cache[3]
58
  self.datapoints = cache[0]
59
  if self.gpu_count > 1:
60
  # we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
 
87
  if type(path_to_transcript_dict) != dict:
88
  path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
89
  torch.multiprocessing.set_start_method('spawn', force=True)
90
+ torch.multiprocessing.set_sharing_strategy('file_system')
91
  resource_manager = Manager()
92
  self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
93
  key_list = list(self.path_to_transcript_dict.keys())
 
96
  fisher_yates_shuffle(key_list)
97
  # build cache
98
  print("... building dataset cache ...")
99
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
100
+ # careful: assumes 16kHz or 8kHz audio
101
+ _, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', # make sure it gets downloaded during single-processing first, if it's not already downloaded
102
+ model='silero_vad',
103
+ force_reload=False,
104
+ onnx=False,
105
+ verbose=False)
106
  self.result_pool = resource_manager.list()
107
  # make processes
108
  key_splits = list()
 
186
  torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
187
  # this to false globally during model loading rather than using inference mode or no_grad
188
  silero_model = silero_model.to(device)
189
+ silence = torch.zeros([16000 // 8]).to(device)
190
+ tf = ArticulatoryCombinedTextFrontend(language=lang, device=device)
191
  _, sr = sf.read(path_list[0])
192
  assumed_sr = sr
193
  ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
 
196
  for path in tqdm(path_list):
197
  if self.path_to_transcript_dict[path].strip() == "":
198
  continue
 
199
  try:
200
  wave, sr = sf.read(path)
201
  except:
202
  print(f"Problem with an audio file: {path}")
203
  continue
204
 
205
+ if len(wave.shape) > 1: # oh no, we found a stereo audio!
206
+ if len(wave[0]) == 2: # let's figure out whether we need to switch the axes
207
+ wave = wave.transpose() # if yes, we switch the axes.
208
  wave = librosa.to_mono(wave)
209
 
210
  if sr != assumed_sr:
 
222
  if verbose:
223
  print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
224
  continue
225
+ with torch.inference_mode():
 
 
226
  speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
227
  try:
228
+ silence_timestamps = invert_segments(speech_timestamps, len(norm_wave))
229
+ for silence_timestamp in silence_timestamps:
230
+ begin = silence_timestamp['start']
231
+ end = silence_timestamp['end']
232
+ norm_wave = torch.cat([norm_wave[:begin], torch.zeros([end - begin], device=device), norm_wave[end:]])
233
  result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
234
  except IndexError:
235
  print("Audio might be too short to cut silences from front and back.")
236
  continue
237
+ norm_wave = torch.cat([silence, result, silence])
238
 
239
  # raw audio preprocessing is done
240
  transcript = self.path_to_transcript_dict[path]
 
253
  # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
254
  continue
255
 
256
+ cached_speech = ap.audio_to_codebook_indexes(audio=norm_wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy()
257
  process_internal_dataset_chunk.append([cached_text,
258
  cached_speech,
259
+ norm_wave.cpu().detach().numpy(),
260
  path])
261
  self.result_pool.append(process_internal_dataset_chunk)
262
 
 
271
  codes = codes.transpose(0, 1)
272
 
273
  return tokens, \
274
+ token_len, \
275
+ codes, \
276
+ None, \
277
+ self.speaker_embeddings[index]
278
 
279
  def __len__(self):
280
  return len(self.datapoints)
281
 
282
+ def remove_samples(self, list_of_samples_to_remove):
283
+ for remove_id in sorted(list_of_samples_to_remove, reverse=True):
284
+ self.datapoints.pop(remove_id)
285
+ self.speaker_embeddings.pop(remove_id)
286
+ self.filepaths.pop(remove_id)
287
+ torch.save((self.datapoints, None, self.speaker_embeddings, self.filepaths),
288
+ os.path.join(self.cache_dir, "aligner_train_cache.pt"))
289
+ print("Dataset updated!")
290
+
291
 
292
  def fisher_yates_shuffle(lst):
293
  for i in range(len(lst) - 1, 0, -1):
294
  j = random.randint(0, i)
295
  lst[i], lst[j] = lst[j], lst[i]
296
+
297
+
298
+ def invert_segments(segments, total_duration):
299
+ if not segments:
300
+ return [{'start': 0, 'end': total_duration}]
301
+
302
+ inverted_segments = []
303
+ previous_end = 0
304
+
305
+ for segment in segments:
306
+ start = segment['start']
307
+ if previous_end < start:
308
+ inverted_segments.append({'start': previous_end, 'end': start})
309
+ previous_end = segment['end']
310
+
311
+ if previous_end < total_duration:
312
+ inverted_segments.append({'start': previous_end, 'end': total_duration})
313
+
314
+ return inverted_segments
{Architectures β†’ Modules}/Aligner/README.md RENAMED
File without changes
{Architectures β†’ Modules}/Aligner/Reconstructor.py RENAMED
@@ -1,7 +1,5 @@
1
  import torch
2
  import torch.multiprocessing
3
- from torch.nn.utils.rnn import pack_padded_sequence
4
- from torch.nn.utils.rnn import pad_packed_sequence
5
 
6
  from Utility.utils import make_non_pad_mask
7
 
@@ -12,28 +10,23 @@ class Reconstructor(torch.nn.Module):
12
  n_features=128,
13
  num_symbols=145,
14
  speaker_embedding_dim=192,
15
- lstm_dim=256):
16
  super().__init__()
17
- self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, lstm_dim)
18
- self.rnn1 = torch.nn.LSTM(lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
19
- self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
20
- self.out_proj = torch.nn.Linear(2 * lstm_dim, n_features)
21
  self.l1_criterion = torch.nn.L1Loss(reduction="none")
22
- self.l2_criterion = torch.nn.MSELoss(reduction="none")
23
 
24
  def forward(self, x, lens, ys):
25
  x = self.in_proj(x)
26
- x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
27
- x, _ = self.rnn1(x)
28
- x, _ = self.rnn2(x)
29
- x, _ = pad_packed_sequence(x, batch_first=True)
30
  x = self.out_proj(x)
31
  out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
32
  out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
33
  out_weights /= ys.size(0) * ys.size(2)
34
- l1_loss = self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
35
- l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
36
- return l1_loss + l2_loss
37
 
38
 
39
  if __name__ == '__main__':
 
1
  import torch
2
  import torch.multiprocessing
 
 
3
 
4
  from Utility.utils import make_non_pad_mask
5
 
 
10
  n_features=128,
11
  num_symbols=145,
12
  speaker_embedding_dim=192,
13
+ hidden_dim=256):
14
  super().__init__()
15
+ self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, hidden_dim)
16
+ self.hidden_proj = torch.nn.Linear(hidden_dim, hidden_dim)
17
+ self.out_proj = torch.nn.Linear(hidden_dim, n_features)
 
18
  self.l1_criterion = torch.nn.L1Loss(reduction="none")
 
19
 
20
  def forward(self, x, lens, ys):
21
  x = self.in_proj(x)
22
+ x = torch.nn.functional.leaky_relu(x)
23
+ x = self.hidden_proj(x)
24
+ x = torch.nn.functional.leaky_relu(x)
 
25
  x = self.out_proj(x)
26
  out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
27
  out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
28
  out_weights /= ys.size(0) * ys.size(2)
29
+ return self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
 
 
30
 
31
 
32
  if __name__ == '__main__':
{Architectures β†’ Modules}/Aligner/__init__.py RENAMED
File without changes
{Architectures β†’ Modules}/Aligner/autoaligner_train_loop.py RENAMED
@@ -8,8 +8,8 @@ from torch.optim import RAdam
8
  from torch.utils.data.dataloader import DataLoader
9
  from tqdm import tqdm
10
 
11
- from Architectures.Aligner.Aligner import Aligner
12
- from Architectures.Aligner.Reconstructor import Reconstructor
13
  from Preprocessing.AudioPreprocessor import AudioPreprocessor
14
  from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
15
 
@@ -152,6 +152,8 @@ def train_loop(train_dataset,
152
  optim_asr.zero_grad()
153
  if use_reconstruction:
154
  optim_tts.zero_grad()
 
 
155
  loss.backward()
156
  torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
157
  if use_reconstruction:
 
8
  from torch.utils.data.dataloader import DataLoader
9
  from tqdm import tqdm
10
 
11
+ from Modules.Aligner.Aligner import Aligner
12
+ from Modules.Aligner.Reconstructor import Reconstructor
13
  from Preprocessing.AudioPreprocessor import AudioPreprocessor
14
  from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
15
 
 
152
  optim_asr.zero_grad()
153
  if use_reconstruction:
154
  optim_tts.zero_grad()
155
+ if gpu_count > 1:
156
+ torch.distributed.barrier()
157
  loss.backward()
158
  torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
159
  if use_reconstruction:
{Architectures β†’ Modules}/ControllabilityGAN/GAN.py RENAMED
@@ -1,12 +1,11 @@
1
  import torch
2
 
3
- from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan
4
 
5
 
6
- class GanWrapper(torch.nn.Module):
7
 
8
- def __init__(self, path_wgan, device, *args, **kwargs):
9
- super().__init__(*args, **kwargs)
10
  self.device = device
11
  self.path_wgan = path_wgan
12
 
@@ -20,27 +19,41 @@ class GanWrapper(torch.nn.Module):
20
  self.U = self.compute_controllability()
21
 
22
  self.z_list = list()
 
23
  for _ in range(1100):
24
- self.z_list.append(self.wgan.G.module.sample_latent(1, 32).to("cpu"))
25
  self.z = self.z_list[0]
26
 
27
  def set_latent(self, seed):
28
  self.z = self.z = self.z_list[seed]
29
 
30
  def reset_default_latent(self):
31
- self.z = self.wgan.G.module.sample_latent(1, 32).to("cpu")
32
 
33
  def load_model(self, path):
34
  gan_checkpoint = torch.load(path, map_location="cpu")
35
 
36
  self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
37
- self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict'])
38
- self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  self.mean = gan_checkpoint["dataset_mean"]
41
  self.std = gan_checkpoint["dataset_std"]
42
 
43
- def compute_controllability(self, n_samples=50000):
44
  _, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
45
  intermediate = intermediate.cpu()
46
  z = z.cpu()
@@ -69,7 +82,7 @@ class GanWrapper(torch.nn.Module):
69
  def modify_embed(self, x):
70
  self.wgan.G.eval()
71
  z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
72
- embed_modified = self.wgan.G.module.forward(z_new.unsqueeze(0).to(self.device))
73
  if self.normalize:
74
  embed_modified = inverse_normalize(
75
  embed_modified.cpu(),
 
1
  import torch
2
 
3
+ from Modules.ControllabilityGAN.wgan.init_wgan import create_wgan
4
 
5
 
6
+ class GanWrapper:
7
 
8
+ def __init__(self, path_wgan, device):
 
9
  self.device = device
10
  self.path_wgan = path_wgan
11
 
 
19
  self.U = self.compute_controllability()
20
 
21
  self.z_list = list()
22
+
23
  for _ in range(1100):
24
+ self.z_list.append(self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8))
25
  self.z = self.z_list[0]
26
 
27
  def set_latent(self, seed):
28
  self.z = self.z = self.z_list[seed]
29
 
30
  def reset_default_latent(self):
31
+ self.z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
32
 
33
  def load_model(self, path):
34
  gan_checkpoint = torch.load(path, map_location="cpu")
35
 
36
  self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
37
+ # Create a new state dict without 'module.' prefix
38
+ new_state_dict_G = {}
39
+ for key, value in gan_checkpoint['generator_state_dict'].items():
40
+ # Remove 'module.' prefix
41
+ new_key = key.replace('module.', '')
42
+ new_state_dict_G[new_key] = value
43
+
44
+ new_state_dict_D = {}
45
+ for key, value in gan_checkpoint['critic_state_dict'].items():
46
+ # Remove 'module.' prefix
47
+ new_key = key.replace('module.', '')
48
+ new_state_dict_D[new_key] = value
49
+
50
+ self.wgan.G.load_state_dict(new_state_dict_G)
51
+ self.wgan.D.load_state_dict(new_state_dict_D)
52
 
53
  self.mean = gan_checkpoint["dataset_mean"]
54
  self.std = gan_checkpoint["dataset_std"]
55
 
56
+ def compute_controllability(self, n_samples=100000):
57
  _, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
58
  intermediate = intermediate.cpu()
59
  z = z.cpu()
 
82
  def modify_embed(self, x):
83
  self.wgan.G.eval()
84
  z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
85
+ embed_modified = self.wgan.G.forward(z_new.unsqueeze(0).to(self.device))
86
  if self.normalize:
87
  embed_modified = inverse_normalize(
88
  embed_modified.cpu(),
{Architectures β†’ Modules}/ControllabilityGAN/__init__.py RENAMED
File without changes
{Architectures β†’ Modules}/ControllabilityGAN/dataset/__init__.py RENAMED
File without changes
{Architectures β†’ Modules}/ControllabilityGAN/dataset/speaker_embeddings_dataset.py RENAMED
File without changes
{Architectures β†’ Modules}/ControllabilityGAN/wgan/__init__.py RENAMED
File without changes
{Architectures β†’ Modules}/ControllabilityGAN/wgan/init_weights.py RENAMED
File without changes
{Architectures β†’ Modules}/ControllabilityGAN/wgan/init_wgan.py RENAMED
@@ -1,7 +1,7 @@
1
  import torch
2
 
3
- from Architectures.ControllabilityGAN.wgan.resnet_init import init_resnet
4
- from Architectures.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost
5
 
6
 
7
  def create_wgan(parameters, device, optimizer='adam'):
 
1
  import torch
2
 
3
+ from Modules.ControllabilityGAN.wgan.resnet_init import init_resnet
4
+ from Modules.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost
5
 
6
 
7
  def create_wgan(parameters, device, optimizer='adam'):
{Architectures β†’ Modules}/ControllabilityGAN/wgan/resnet_1.py RENAMED
@@ -76,8 +76,8 @@ class ResNet_G(nn.Module):
76
  return out, l_1
77
  return out
78
 
79
- def sample_latent(self, n_samples, z_size):
80
- return torch.randn((n_samples, z_size))
81
 
82
 
83
  class ResNet_D(nn.Module):
 
76
  return out, l_1
77
  return out
78
 
79
+ def sample_latent(self, n_samples, z_size, temperature=0.7):
80
+ return torch.randn((n_samples, z_size)) * temperature
81
 
82
 
83
  class ResNet_D(nn.Module):
{Architectures β†’ Modules}/ControllabilityGAN/wgan/resnet_init.py RENAMED
@@ -1,7 +1,7 @@
1
- from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_D
2
- from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_G
3
- from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_D
4
- from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_G
5
 
6
 
7
  def init_resnet(parameters):
 
1
+ from Modules.ControllabilityGAN.wgan.init_weights import weights_init_D
2
+ from Modules.ControllabilityGAN.wgan.init_weights import weights_init_G
3
+ from Modules.ControllabilityGAN.wgan.resnet_1 import ResNet_D
4
+ from Modules.ControllabilityGAN.wgan.resnet_1 import ResNet_G
5
 
6
 
7
  def init_resnet(parameters):
{Architectures β†’ Modules}/ControllabilityGAN/wgan/wgan_qc.py RENAMED
@@ -3,7 +3,6 @@ import time
3
 
4
  import numpy as np
5
  import torch
6
- import torch.nn as nn
7
  import torch.optim as optim
8
  from cvxopt import matrix
9
  from cvxopt import solvers
@@ -11,13 +10,12 @@ from cvxopt import sparse
11
  from cvxopt import spmatrix
12
  from torch.autograd import grad as torch_grad
13
  from tqdm import tqdm
14
- import spaces
15
 
16
 
17
- class WassersteinGanQuadraticCost(torch.nn.Module):
18
 
19
- def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations, data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0, *args, **kwargs):
20
- super().__init__(*args, **kwargs)
21
  self.G = generator
22
  self.G_opt = gen_optimizer
23
  self.D = discriminator
@@ -46,8 +44,8 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
46
  self.Kr = np.sqrt(self.K)
47
  self.LAMBDA = 2 * self.Kr * gamma * 2
48
 
49
- self.G = nn.DataParallel(self.G.to(self.device))
50
- self.D = nn.DataParallel(self.D.to(self.device))
51
 
52
  self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal)
53
  self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal)
@@ -245,10 +243,7 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
245
  latent_samples = latent_samples.to(self.device)
246
  if nograd:
247
  with torch.no_grad():
248
- if isinstance(self.G, torch.nn.parallel.DataParallel):
249
- generated_data = self.G.module(latent_samples, return_intermediate=return_intermediate)
250
- else:
251
- generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
252
  else:
253
  generated_data = self.G(latent_samples)
254
  self.G.train()
 
3
 
4
  import numpy as np
5
  import torch
 
6
  import torch.optim as optim
7
  from cvxopt import matrix
8
  from cvxopt import solvers
 
10
  from cvxopt import spmatrix
11
  from torch.autograd import grad as torch_grad
12
  from tqdm import tqdm
 
13
 
14
 
15
+ class WassersteinGanQuadraticCost:
16
 
17
+ def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations,
18
+ data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0):
19
  self.G = generator
20
  self.G_opt = gen_optimizer
21
  self.D = discriminator
 
44
  self.Kr = np.sqrt(self.K)
45
  self.LAMBDA = 2 * self.Kr * gamma * 2
46
 
47
+ self.G = self.G.to(self.device)
48
+ self.D = self.D.to(self.device)
49
 
50
  self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal)
51
  self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal)
 
243
  latent_samples = latent_samples.to(self.device)
244
  if nograd:
245
  with torch.no_grad():
246
+ generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
 
 
 
247
  else:
248
  generated_data = self.G(latent_samples)
249
  self.G.train()
{Architectures β†’ Modules}/EmbeddingModel/GST.py RENAMED
@@ -3,7 +3,7 @@
3
 
4
  import torch
5
 
6
- from Architectures.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention
7
 
8
 
9
  class GSTStyleEncoder(torch.nn.Module):
 
3
 
4
  import torch
5
 
6
+ from Modules.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention
7
 
8
 
9
  class GSTStyleEncoder(torch.nn.Module):
{Architectures β†’ Modules}/EmbeddingModel/README.md RENAMED
File without changes
{Architectures β†’ Modules}/EmbeddingModel/StyleEmbedding.py RENAMED
@@ -1,7 +1,7 @@
1
  import torch
2
 
3
- from Architectures.EmbeddingModel.GST import GSTStyleEncoder
4
- from Architectures.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder
5
 
6
 
7
  class StyleEmbedding(torch.nn.Module):
 
1
  import torch
2
 
3
+ from Modules.EmbeddingModel.GST import GSTStyleEncoder
4
+ from Modules.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder
5
 
6
 
7
  class StyleEmbedding(torch.nn.Module):
{Architectures β†’ Modules}/EmbeddingModel/StyleTTSEncoder.py RENAMED
File without changes
{Architectures β†’ Modules}/EmbeddingModel/__init__.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/Attention.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/ConditionalLayerNorm.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/Conformer.py RENAMED
@@ -4,16 +4,16 @@ Taken from ESPNet, but heavily modified
4
 
5
  import torch
6
 
7
- from Architectures.GeneralLayers.Attention import RelPositionMultiHeadedAttention
8
- from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
9
- from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
10
- from Architectures.GeneralLayers.Convolution import ConvolutionModule
11
- from Architectures.GeneralLayers.EncoderLayer import EncoderLayer
12
- from Architectures.GeneralLayers.LayerNorm import LayerNorm
13
- from Architectures.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d
14
- from Architectures.GeneralLayers.MultiSequential import repeat
15
- from Architectures.GeneralLayers.PositionalEncoding import RelPositionalEncoding
16
- from Architectures.GeneralLayers.Swish import Swish
17
  from Utility.utils import integrate_with_utt_embed
18
 
19
 
@@ -84,8 +84,12 @@ class Conformer(torch.nn.Module):
84
  self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim))
85
  if lang_embs is not None:
86
  self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size)
87
- self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim)
 
 
 
88
  self.language_emb_norm = LayerNorm(attention_dim)
 
89
  # self-attention module definition
90
  encoder_selfattn_layer = RelPositionMultiHeadedAttention
91
  encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
@@ -138,21 +142,28 @@ class Conformer(torch.nn.Module):
138
  if isinstance(xs, tuple):
139
  x, pos_emb = xs[0], xs[1]
140
  if self.conformer_type != "encoder":
141
- x = integrate_with_utt_embed(hs=x, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration)
 
 
 
142
  xs = (x, pos_emb)
143
  else:
144
  if self.conformer_type != "encoder":
145
- xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration)
 
 
 
146
  xs, masks = encoder(xs, masks)
147
 
148
  if isinstance(xs, tuple):
149
  xs = xs[0]
150
 
151
- if self.use_output_norm and not (self.utt_embed and self.conformer_type == "encoder"):
152
- xs = self.output_norm(xs)
153
-
154
  if self.utt_embed and self.conformer_type == "encoder":
155
- xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding,
156
- projection=self.encoder_embedding_projection, embedding_training=self.use_conditional_layernorm_embedding_integration)
 
 
 
 
157
 
158
  return xs, masks
 
4
 
5
  import torch
6
 
7
+ from Modules.GeneralLayers.Attention import RelPositionMultiHeadedAttention
8
+ from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d
9
+ from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
10
+ from Modules.GeneralLayers.Convolution import ConvolutionModule
11
+ from Modules.GeneralLayers.EncoderLayer import EncoderLayer
12
+ from Modules.GeneralLayers.LayerNorm import LayerNorm
13
+ from Modules.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d
14
+ from Modules.GeneralLayers.MultiSequential import repeat
15
+ from Modules.GeneralLayers.PositionalEncoding import RelPositionalEncoding
16
+ from Modules.GeneralLayers.Swish import Swish
17
  from Utility.utils import integrate_with_utt_embed
18
 
19
 
 
84
  self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim))
85
  if lang_embs is not None:
86
  self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size)
87
+ if lang_emb_size == attention_dim:
88
+ self.language_embedding_projection = lambda x: x
89
+ else:
90
+ self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim)
91
  self.language_emb_norm = LayerNorm(attention_dim)
92
+
93
  # self-attention module definition
94
  encoder_selfattn_layer = RelPositionMultiHeadedAttention
95
  encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
 
142
  if isinstance(xs, tuple):
143
  x, pos_emb = xs[0], xs[1]
144
  if self.conformer_type != "encoder":
145
+ x = integrate_with_utt_embed(hs=x,
146
+ utt_embeddings=utterance_embedding,
147
+ projection=self.decoder_embedding_projections[encoder_index],
148
+ embedding_training=self.use_conditional_layernorm_embedding_integration)
149
  xs = (x, pos_emb)
150
  else:
151
  if self.conformer_type != "encoder":
152
+ xs = integrate_with_utt_embed(hs=xs,
153
+ utt_embeddings=utterance_embedding,
154
+ projection=self.decoder_embedding_projections[encoder_index],
155
+ embedding_training=self.use_conditional_layernorm_embedding_integration)
156
  xs, masks = encoder(xs, masks)
157
 
158
  if isinstance(xs, tuple):
159
  xs = xs[0]
160
 
 
 
 
161
  if self.utt_embed and self.conformer_type == "encoder":
162
+ xs = integrate_with_utt_embed(hs=xs,
163
+ utt_embeddings=utterance_embedding,
164
+ projection=self.encoder_embedding_projection,
165
+ embedding_training=self.use_conditional_layernorm_embedding_integration)
166
+ elif self.use_output_norm:
167
+ xs = self.output_norm(xs)
168
 
169
  return xs, masks
{Architectures β†’ Modules}/GeneralLayers/Convolution.py RENAMED
@@ -24,7 +24,7 @@ class ConvolutionModule(nn.Module):
24
 
25
  self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
26
  self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
27
- self.norm = nn.BatchNorm1d(channels)
28
  self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
29
  self.activation = activation
30
 
 
24
 
25
  self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
26
  self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
27
+ self.norm = nn.SyncBatchNorm.convert_sync_batchnorm(nn.BatchNorm1d(channels))
28
  self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
29
  self.activation = activation
30
 
{Architectures β†’ Modules}/GeneralLayers/DurationPredictor.py RENAMED
@@ -5,9 +5,9 @@
5
 
6
  import torch
7
 
8
- from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
9
- from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
10
- from Architectures.GeneralLayers.LayerNorm import LayerNorm
11
  from Utility.utils import integrate_with_utt_embed
12
 
13
 
 
5
 
6
  import torch
7
 
8
+ from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d
9
+ from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
10
+ from Modules.GeneralLayers.LayerNorm import LayerNorm
11
  from Utility.utils import integrate_with_utt_embed
12
 
13
 
{Architectures β†’ Modules}/GeneralLayers/EncoderLayer.py RENAMED
@@ -7,7 +7,7 @@
7
  import torch
8
  from torch import nn
9
 
10
- from Architectures.GeneralLayers.LayerNorm import LayerNorm
11
 
12
 
13
  class EncoderLayer(nn.Module):
 
7
  import torch
8
  from torch import nn
9
 
10
+ from Modules.GeneralLayers.LayerNorm import LayerNorm
11
 
12
 
13
  class EncoderLayer(nn.Module):
{Architectures β†’ Modules}/GeneralLayers/LayerNorm.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/LengthRegulator.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/MultiLayeredConv1d.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/MultiSequential.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/PositionalEncoding.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/PositionwiseFeedForward.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/README.md RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/ResidualBlock.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/ResidualStack.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/STFT.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/Swish.py RENAMED
File without changes
{Architectures β†’ Modules}/GeneralLayers/VariancePredictor.py RENAMED
@@ -6,9 +6,9 @@ from abc import ABC
6
 
7
  import torch
8
 
9
- from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
10
- from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
11
- from Architectures.GeneralLayers.LayerNorm import LayerNorm
12
  from Utility.utils import integrate_with_utt_embed
13
 
14
 
 
6
 
7
  import torch
8
 
9
+ from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d
10
+ from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
11
+ from Modules.GeneralLayers.LayerNorm import LayerNorm
12
  from Utility.utils import integrate_with_utt_embed
13
 
14
 
{Architectures β†’ Modules}/GeneralLayers/__init__.py RENAMED
File without changes
{Architectures β†’ Modules}/README.md RENAMED
File without changes
{Architectures β†’ Modules}/ToucanTTS/CodecDiscriminator.py RENAMED
File without changes