Spaces:
Runtime error
Runtime error
NorHsangPha
commited on
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -35
- .gitignore +35 -0
- Architectures/Aligner/Aligner.py +164 -0
- Architectures/Aligner/CodecAlignerDataset.py +274 -0
- Architectures/Aligner/README.md +1 -0
- Architectures/Aligner/Reconstructor.py +40 -0
- Architectures/Aligner/__init__.py +0 -0
- Architectures/Aligner/autoaligner_train_loop.py +188 -0
- Architectures/ControllabilityGAN/GAN.py +82 -0
- Architectures/ControllabilityGAN/__init__.py +0 -0
- Architectures/ControllabilityGAN/dataset/__init__.py +0 -0
- Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py +94 -0
- Architectures/ControllabilityGAN/wgan/__init__.py +0 -0
- Architectures/ControllabilityGAN/wgan/init_weights.py +21 -0
- Architectures/ControllabilityGAN/wgan/init_wgan.py +34 -0
- Architectures/ControllabilityGAN/wgan/resnet_1.py +181 -0
- Architectures/ControllabilityGAN/wgan/resnet_init.py +15 -0
- Architectures/ControllabilityGAN/wgan/wgan_qc.py +272 -0
- Architectures/EmbeddingModel/GST.py +235 -0
- Architectures/EmbeddingModel/README.md +1 -0
- Architectures/EmbeddingModel/StyleEmbedding.py +73 -0
- Architectures/EmbeddingModel/StyleTTSEncoder.py +156 -0
- Architectures/EmbeddingModel/__init__.py +0 -0
- Architectures/GeneralLayers/Attention.py +324 -0
- Architectures/GeneralLayers/ConditionalLayerNorm.py +118 -0
- Architectures/GeneralLayers/Conformer.py +158 -0
- Architectures/GeneralLayers/Convolution.py +55 -0
- Architectures/GeneralLayers/DurationPredictor.py +171 -0
- Architectures/GeneralLayers/EncoderLayer.py +144 -0
- Architectures/GeneralLayers/LayerNorm.py +36 -0
- Architectures/GeneralLayers/LengthRegulator.py +61 -0
- Architectures/GeneralLayers/MultiLayeredConv1d.py +87 -0
- Architectures/GeneralLayers/MultiSequential.py +33 -0
- Architectures/GeneralLayers/PositionalEncoding.py +166 -0
- Architectures/GeneralLayers/PositionwiseFeedForward.py +26 -0
- Architectures/GeneralLayers/README.md +2 -0
- Architectures/GeneralLayers/ResidualBlock.py +98 -0
- Architectures/GeneralLayers/ResidualStack.py +51 -0
- Architectures/GeneralLayers/STFT.py +123 -0
- Architectures/GeneralLayers/Swish.py +18 -0
- Architectures/GeneralLayers/VariancePredictor.py +98 -0
- Architectures/GeneralLayers/__init__.py +0 -0
- Architectures/README.md +2 -0
- Architectures/ToucanTTS/CodecDiscriminator.py +94 -0
- Architectures/ToucanTTS/CodecRefinementTransformer.py +199 -0
- Architectures/ToucanTTS/DurationCalculator.py +30 -0
- Architectures/ToucanTTS/EnergyCalculator.py +94 -0
- Architectures/ToucanTTS/Glow.py +402 -0
- Architectures/ToucanTTS/InferenceToucanTTS.py +375 -0
- Architectures/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py +74 -0
.gitattributes
CHANGED
@@ -1,35 +1,36 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
InferenceInterfaces/src/fonts/Shan.ttf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
tensorboard_logs/
|
3 |
+
Corpora/
|
4 |
+
Models/
|
5 |
+
audios/
|
6 |
+
Preprocessing/glottolog/
|
7 |
+
Preprocessing/multilinguality/datasets/
|
8 |
+
apex/
|
9 |
+
pretrained_models/
|
10 |
+
.tmp/
|
11 |
+
.vscode/
|
12 |
+
split/
|
13 |
+
singing/
|
14 |
+
toucan_conda_venv/
|
15 |
+
venv/
|
16 |
+
vis/
|
17 |
+
Preprocessing/multilinguality/distance_datasets
|
18 |
+
|
19 |
+
|
20 |
+
*_graph
|
21 |
+
gradio*
|
22 |
+
*playground*
|
23 |
+
run_phonemizer.py
|
24 |
+
|
25 |
+
*.pt
|
26 |
+
*.out
|
27 |
+
*.wav
|
28 |
+
*.flac
|
29 |
+
*.json
|
30 |
+
*.pyc
|
31 |
+
*.png
|
32 |
+
*.pdf
|
33 |
+
*.pkl
|
34 |
+
|
35 |
+
labs.ipynb
|
Architectures/Aligner/Aligner.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken and adapted from https://github.com/as-ideas/DeepForcedAligner
|
3 |
+
"""
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.multiprocessing
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import CTCLoss
|
10 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
11 |
+
from torch.nn.utils.rnn import pad_packed_sequence
|
12 |
+
|
13 |
+
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
14 |
+
|
15 |
+
|
16 |
+
class BatchNormConv(nn.Module):
|
17 |
+
|
18 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
|
19 |
+
super().__init__()
|
20 |
+
self.conv = nn.Conv1d(
|
21 |
+
in_channels, out_channels, kernel_size,
|
22 |
+
stride=1, padding=kernel_size // 2, bias=False)
|
23 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
24 |
+
self.relu = nn.ReLU()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = x.transpose(1, 2)
|
28 |
+
x = self.conv(x)
|
29 |
+
x = self.relu(x)
|
30 |
+
x = self.bnorm(x)
|
31 |
+
x = x.transpose(1, 2)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class Aligner(torch.nn.Module):
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
n_features=128,
|
39 |
+
num_symbols=145,
|
40 |
+
lstm_dim=512,
|
41 |
+
conv_dim=512):
|
42 |
+
super().__init__()
|
43 |
+
self.convs = nn.ModuleList([
|
44 |
+
BatchNormConv(n_features, conv_dim, 3),
|
45 |
+
nn.Dropout(p=0.5),
|
46 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
47 |
+
nn.Dropout(p=0.5),
|
48 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
49 |
+
nn.Dropout(p=0.5),
|
50 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
51 |
+
nn.Dropout(p=0.5),
|
52 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
53 |
+
nn.Dropout(p=0.5),
|
54 |
+
])
|
55 |
+
self.rnn = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
|
56 |
+
self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
|
57 |
+
self.tf = ArticulatoryCombinedTextFrontend(language="eng")
|
58 |
+
self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
|
59 |
+
self.vector_to_id = dict()
|
60 |
+
|
61 |
+
def forward(self, x, lens=None):
|
62 |
+
for conv in self.convs:
|
63 |
+
x = conv(x)
|
64 |
+
|
65 |
+
if lens is not None:
|
66 |
+
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
|
67 |
+
x, _ = self.rnn(x)
|
68 |
+
if lens is not None:
|
69 |
+
x, _ = pad_packed_sequence(x, batch_first=True)
|
70 |
+
|
71 |
+
x = self.proj(x)
|
72 |
+
|
73 |
+
return x
|
74 |
+
|
75 |
+
@torch.inference_mode()
|
76 |
+
def inference(self, features, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False):
|
77 |
+
if not train:
|
78 |
+
tokens_indexed = self.tf.text_vectors_to_id_sequence(text_vector=tokens) # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi
|
79 |
+
tokens = np.asarray(tokens_indexed)
|
80 |
+
else:
|
81 |
+
tokens = tokens.cpu().detach().numpy()
|
82 |
+
|
83 |
+
pred = self(features.unsqueeze(0))
|
84 |
+
if return_ctc:
|
85 |
+
ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]),
|
86 |
+
torch.LongTensor([len(tokens)])).item()
|
87 |
+
pred = pred.squeeze().cpu().detach().numpy()
|
88 |
+
pred_max = pred[:, tokens]
|
89 |
+
|
90 |
+
# run monotonic alignment search
|
91 |
+
|
92 |
+
alignment_matrix = binarize_alignment(pred_max)
|
93 |
+
|
94 |
+
if save_img_for_debug is not None:
|
95 |
+
phones = list()
|
96 |
+
for index in tokens:
|
97 |
+
for phone in self.tf.phone_to_id:
|
98 |
+
if self.tf.phone_to_id[phone] == index:
|
99 |
+
phones.append(phone)
|
100 |
+
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
|
101 |
+
|
102 |
+
ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
|
103 |
+
ax.set_ylabel("Mel-Frames")
|
104 |
+
ax.set_xticks(range(len(pred_max[0])))
|
105 |
+
ax.set_xticklabels(labels=phones)
|
106 |
+
ax.set_title("MAS Path")
|
107 |
+
|
108 |
+
plt.tight_layout()
|
109 |
+
fig.savefig(save_img_for_debug)
|
110 |
+
fig.clf()
|
111 |
+
plt.close()
|
112 |
+
|
113 |
+
if return_ctc:
|
114 |
+
return alignment_matrix, ctc_loss
|
115 |
+
return alignment_matrix
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
def binarize_alignment(alignment_prob):
|
120 |
+
"""
|
121 |
+
# Implementation by:
|
122 |
+
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py
|
123 |
+
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py
|
124 |
+
|
125 |
+
Binarizes alignment with MAS.
|
126 |
+
"""
|
127 |
+
# assumes features x text
|
128 |
+
opt = np.zeros_like(alignment_prob)
|
129 |
+
alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later
|
130 |
+
alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1]
|
131 |
+
attn_map = np.log(alignment_prob)
|
132 |
+
attn_map[0, 1:] = -np.inf
|
133 |
+
log_p = np.zeros_like(attn_map)
|
134 |
+
log_p[0, :] = attn_map[0, :]
|
135 |
+
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
136 |
+
for i in range(1, attn_map.shape[0]):
|
137 |
+
for j in range(attn_map.shape[1]): # for each text dim
|
138 |
+
prev_log = log_p[i - 1, j]
|
139 |
+
prev_j = j
|
140 |
+
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
|
141 |
+
prev_log = log_p[i - 1, j - 1]
|
142 |
+
prev_j = j - 1
|
143 |
+
log_p[i, j] = attn_map[i, j] + prev_log
|
144 |
+
prev_ind[i, j] = prev_j
|
145 |
+
# now backtrack
|
146 |
+
curr_text_idx = attn_map.shape[1] - 1
|
147 |
+
for i in range(attn_map.shape[0] - 1, -1, -1):
|
148 |
+
opt[i, curr_text_idx] = 1
|
149 |
+
curr_text_idx = prev_ind[i, curr_text_idx]
|
150 |
+
opt[0, curr_text_idx] = 1
|
151 |
+
return opt
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
tf = ArticulatoryCombinedTextFrontend(language="eng")
|
156 |
+
from Preprocessing.HiFiCodecAudioPreprocessor import CodecAudioPreprocessor
|
157 |
+
|
158 |
+
cap = CodecAudioPreprocessor(input_sr=-1)
|
159 |
+
dummy_codebook_indexes = torch.randint(low=0, high=1023, size=[9, 20])
|
160 |
+
codebook_frames = cap.indexes_to_codec_frames(dummy_codebook_indexes)
|
161 |
+
alignment = Aligner().inference(codebook_frames.transpose(0, 1), tokens=tf.string_to_tensor("Hello world"))
|
162 |
+
print(alignment.shape)
|
163 |
+
plt.imshow(alignment, origin="lower", cmap="GnBu")
|
164 |
+
plt.show()
|
Architectures/Aligner/CodecAlignerDataset.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import soundfile as sf
|
6 |
+
import torch
|
7 |
+
from speechbrain.pretrained import EncoderClassifier
|
8 |
+
from torch.multiprocessing import Manager
|
9 |
+
from torch.multiprocessing import Process
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torchaudio.transforms import Resample
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
|
15 |
+
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
16 |
+
from Utility.storage_config import MODELS_DIR
|
17 |
+
|
18 |
+
|
19 |
+
class CodecAlignerDataset(Dataset):
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
path_to_transcript_dict,
|
23 |
+
cache_dir,
|
24 |
+
lang,
|
25 |
+
loading_processes,
|
26 |
+
device,
|
27 |
+
min_len_in_seconds=1,
|
28 |
+
max_len_in_seconds=15,
|
29 |
+
rebuild_cache=False,
|
30 |
+
verbose=False,
|
31 |
+
phone_input=False,
|
32 |
+
allow_unknown_symbols=False,
|
33 |
+
gpu_count=1,
|
34 |
+
rank=0):
|
35 |
+
self.gpu_count = gpu_count
|
36 |
+
self.rank = rank
|
37 |
+
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
|
38 |
+
self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict,
|
39 |
+
cache_dir=cache_dir,
|
40 |
+
lang=lang,
|
41 |
+
loading_processes=loading_processes,
|
42 |
+
device=device,
|
43 |
+
min_len_in_seconds=min_len_in_seconds,
|
44 |
+
max_len_in_seconds=max_len_in_seconds,
|
45 |
+
verbose=verbose,
|
46 |
+
phone_input=phone_input,
|
47 |
+
allow_unknown_symbols=allow_unknown_symbols,
|
48 |
+
gpu_count=gpu_count,
|
49 |
+
rank=rank)
|
50 |
+
self.lang = lang
|
51 |
+
self.device = device
|
52 |
+
self.cache_dir = cache_dir
|
53 |
+
self.tf = ArticulatoryCombinedTextFrontend(language=self.lang)
|
54 |
+
cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
|
55 |
+
self.speaker_embeddings = cache[2]
|
56 |
+
self.datapoints = cache[0]
|
57 |
+
if self.gpu_count > 1:
|
58 |
+
# we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
|
59 |
+
while len(self.datapoints) % self.gpu_count != 0:
|
60 |
+
self.datapoints.pop(-1) # a bit unfortunate, but if you're using multiple GPUs, you probably have a ton of datapoints anyway.
|
61 |
+
chunksize = int(len(self.datapoints) / self.gpu_count)
|
62 |
+
self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)]
|
63 |
+
self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)]
|
64 |
+
print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.")
|
65 |
+
|
66 |
+
def _build_dataset_cache(self,
|
67 |
+
path_to_transcript_dict,
|
68 |
+
cache_dir,
|
69 |
+
lang,
|
70 |
+
loading_processes,
|
71 |
+
device,
|
72 |
+
min_len_in_seconds=1,
|
73 |
+
max_len_in_seconds=15,
|
74 |
+
verbose=False,
|
75 |
+
phone_input=False,
|
76 |
+
allow_unknown_symbols=False,
|
77 |
+
gpu_count=1,
|
78 |
+
rank=0
|
79 |
+
):
|
80 |
+
if gpu_count != 1:
|
81 |
+
import sys
|
82 |
+
print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.")
|
83 |
+
sys.exit()
|
84 |
+
os.makedirs(cache_dir, exist_ok=True)
|
85 |
+
if type(path_to_transcript_dict) != dict:
|
86 |
+
path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
|
87 |
+
torch.multiprocessing.set_start_method('spawn', force=True)
|
88 |
+
resource_manager = Manager()
|
89 |
+
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
|
90 |
+
key_list = list(self.path_to_transcript_dict.keys())
|
91 |
+
with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note:
|
92 |
+
files_used_note.write(str(key_list))
|
93 |
+
fisher_yates_shuffle(key_list)
|
94 |
+
# build cache
|
95 |
+
print("... building dataset cache ...")
|
96 |
+
self.result_pool = resource_manager.list()
|
97 |
+
# make processes
|
98 |
+
key_splits = list()
|
99 |
+
process_list = list()
|
100 |
+
for i in range(loading_processes):
|
101 |
+
key_splits.append(
|
102 |
+
key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes])
|
103 |
+
for key_split in key_splits:
|
104 |
+
process_list.append(
|
105 |
+
Process(target=self._cache_builder_process,
|
106 |
+
args=(key_split,
|
107 |
+
lang,
|
108 |
+
min_len_in_seconds,
|
109 |
+
max_len_in_seconds,
|
110 |
+
verbose,
|
111 |
+
device,
|
112 |
+
phone_input,
|
113 |
+
allow_unknown_symbols),
|
114 |
+
daemon=True))
|
115 |
+
process_list[-1].start()
|
116 |
+
for process in process_list:
|
117 |
+
process.join()
|
118 |
+
print("pooling results...")
|
119 |
+
pooled_datapoints = list()
|
120 |
+
for chunk in self.result_pool:
|
121 |
+
for datapoint in chunk:
|
122 |
+
pooled_datapoints.append(datapoint) # unpack into a joint list
|
123 |
+
self.result_pool = pooled_datapoints
|
124 |
+
del pooled_datapoints
|
125 |
+
print("converting text to tensors...")
|
126 |
+
text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] # turn everything back to tensors (had to turn it to np arrays to avoid multiprocessing issues)
|
127 |
+
print("converting speech to tensors...")
|
128 |
+
speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool]
|
129 |
+
print("converting waves to tensors...")
|
130 |
+
norm_waves = [torch.Tensor(x[2]) for x in self.result_pool]
|
131 |
+
print("unpacking file list...")
|
132 |
+
filepaths = [x[3] for x in self.result_pool]
|
133 |
+
del self.result_pool
|
134 |
+
self.datapoints = list(zip(text_tensors, speech_tensors))
|
135 |
+
del text_tensors
|
136 |
+
del speech_tensors
|
137 |
+
print("done!")
|
138 |
+
|
139 |
+
# add speaker embeddings
|
140 |
+
self.speaker_embeddings = list()
|
141 |
+
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
|
142 |
+
run_opts={"device": str(device)},
|
143 |
+
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa"))
|
144 |
+
with torch.inference_mode():
|
145 |
+
for wave in tqdm(norm_waves):
|
146 |
+
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
|
147 |
+
|
148 |
+
# save to cache
|
149 |
+
if len(self.datapoints) == 0:
|
150 |
+
raise RuntimeError # something went wrong and there are no datapoints
|
151 |
+
torch.save((self.datapoints, None, self.speaker_embeddings, filepaths),
|
152 |
+
os.path.join(cache_dir, "aligner_train_cache.pt"))
|
153 |
+
|
154 |
+
def _cache_builder_process(self,
|
155 |
+
path_list,
|
156 |
+
lang,
|
157 |
+
min_len,
|
158 |
+
max_len,
|
159 |
+
verbose,
|
160 |
+
device,
|
161 |
+
phone_input,
|
162 |
+
allow_unknown_symbols):
|
163 |
+
process_internal_dataset_chunk = list()
|
164 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
|
165 |
+
# careful: assumes 16kHz or 8kHz audio
|
166 |
+
silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
167 |
+
model='silero_vad',
|
168 |
+
force_reload=False,
|
169 |
+
onnx=False,
|
170 |
+
verbose=False)
|
171 |
+
(get_speech_timestamps,
|
172 |
+
save_audio,
|
173 |
+
read_audio,
|
174 |
+
VADIterator,
|
175 |
+
collect_chunks) = utils
|
176 |
+
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
|
177 |
+
# this to false globally during model loading rather than using inference mode or no_grad
|
178 |
+
silero_model = silero_model.to(device)
|
179 |
+
silence = torch.zeros([16000 // 4], device=device)
|
180 |
+
tf = ArticulatoryCombinedTextFrontend(language=lang)
|
181 |
+
_, sr = sf.read(path_list[0])
|
182 |
+
assumed_sr = sr
|
183 |
+
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
|
184 |
+
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
|
185 |
+
|
186 |
+
for path in tqdm(path_list):
|
187 |
+
if self.path_to_transcript_dict[path].strip() == "":
|
188 |
+
continue
|
189 |
+
|
190 |
+
try:
|
191 |
+
wave, sr = sf.read(path)
|
192 |
+
except:
|
193 |
+
print(f"Problem with an audio file: {path}")
|
194 |
+
continue
|
195 |
+
|
196 |
+
if len(wave.shape) > 1: # the audio is in stereo, so we need to merge the channels.
|
197 |
+
if len(wave[0]) == 2: # let's figure out whether the axes are switched, which seems to be the case sometimes
|
198 |
+
wave = wave.transpose() # if yes, we switch the axes into the order librosa's to_mono function expects.
|
199 |
+
wave = librosa.to_mono(wave)
|
200 |
+
|
201 |
+
if sr != assumed_sr:
|
202 |
+
assumed_sr = sr
|
203 |
+
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
|
204 |
+
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
|
205 |
+
print(f"{path} has a different sampling rate --> adapting the codec processor")
|
206 |
+
|
207 |
+
try:
|
208 |
+
norm_wave = resample(torch.tensor(wave).float().to(device))
|
209 |
+
except ValueError:
|
210 |
+
continue
|
211 |
+
dur_in_seconds = len(norm_wave) / 16000
|
212 |
+
if not (min_len <= dur_in_seconds <= max_len):
|
213 |
+
if verbose:
|
214 |
+
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
|
215 |
+
continue
|
216 |
+
|
217 |
+
# remove silences from front and back, then add constant 1/4th second silences back to front and back
|
218 |
+
with torch.no_grad():
|
219 |
+
speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
|
220 |
+
try:
|
221 |
+
result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
|
222 |
+
except IndexError:
|
223 |
+
print("Audio might be too short to cut silences from front and back.")
|
224 |
+
continue
|
225 |
+
wave = torch.cat([silence, result, silence])
|
226 |
+
|
227 |
+
# raw audio preprocessing is done
|
228 |
+
transcript = self.path_to_transcript_dict[path]
|
229 |
+
|
230 |
+
try:
|
231 |
+
try:
|
232 |
+
cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy()
|
233 |
+
except KeyError:
|
234 |
+
cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy()
|
235 |
+
if not allow_unknown_symbols:
|
236 |
+
continue # we skip sentences with unknown symbols
|
237 |
+
except ValueError:
|
238 |
+
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
|
239 |
+
continue
|
240 |
+
except KeyError:
|
241 |
+
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
|
242 |
+
continue
|
243 |
+
|
244 |
+
cached_speech = ap.audio_to_codebook_indexes(audio=wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy()
|
245 |
+
process_internal_dataset_chunk.append([cached_text,
|
246 |
+
cached_speech,
|
247 |
+
result.cpu().detach().numpy(),
|
248 |
+
path])
|
249 |
+
self.result_pool.append(process_internal_dataset_chunk)
|
250 |
+
|
251 |
+
def __getitem__(self, index):
|
252 |
+
text_vector = self.datapoints[index][0]
|
253 |
+
tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector)
|
254 |
+
tokens = torch.LongTensor(tokens)
|
255 |
+
token_len = torch.LongTensor([len(tokens)])
|
256 |
+
|
257 |
+
codes = self.datapoints[index][1]
|
258 |
+
if codes.size()[0] != 24: # no clue why this is sometimes the case
|
259 |
+
codes = codes.transpose(0, 1)
|
260 |
+
|
261 |
+
return tokens, \
|
262 |
+
token_len, \
|
263 |
+
codes, \
|
264 |
+
None, \
|
265 |
+
self.speaker_embeddings[index]
|
266 |
+
|
267 |
+
def __len__(self):
|
268 |
+
return len(self.datapoints)
|
269 |
+
|
270 |
+
|
271 |
+
def fisher_yates_shuffle(lst):
|
272 |
+
for i in range(len(lst) - 1, 0, -1):
|
273 |
+
j = random.randint(0, i)
|
274 |
+
lst[i], lst[j] = lst[j], lst[i]
|
Architectures/Aligner/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Everything that is concerned with training and using the aligner model is contained in this directory. It is recommended to use the universal aligner model that we supply in the GitHub releases.
|
Architectures/Aligner/Reconstructor.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.multiprocessing
|
3 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
4 |
+
from torch.nn.utils.rnn import pad_packed_sequence
|
5 |
+
|
6 |
+
from Utility.utils import make_non_pad_mask
|
7 |
+
|
8 |
+
|
9 |
+
class Reconstructor(torch.nn.Module):
|
10 |
+
|
11 |
+
def __init__(self,
|
12 |
+
n_features=128,
|
13 |
+
num_symbols=145,
|
14 |
+
speaker_embedding_dim=192,
|
15 |
+
lstm_dim=256):
|
16 |
+
super().__init__()
|
17 |
+
self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, lstm_dim)
|
18 |
+
self.rnn1 = torch.nn.LSTM(lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
|
19 |
+
self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
|
20 |
+
self.out_proj = torch.nn.Linear(2 * lstm_dim, n_features)
|
21 |
+
self.l1_criterion = torch.nn.L1Loss(reduction="none")
|
22 |
+
self.l2_criterion = torch.nn.MSELoss(reduction="none")
|
23 |
+
|
24 |
+
def forward(self, x, lens, ys):
|
25 |
+
x = self.in_proj(x)
|
26 |
+
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
|
27 |
+
x, _ = self.rnn1(x)
|
28 |
+
x, _ = self.rnn2(x)
|
29 |
+
x, _ = pad_packed_sequence(x, batch_first=True)
|
30 |
+
x = self.out_proj(x)
|
31 |
+
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
|
32 |
+
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
|
33 |
+
out_weights /= ys.size(0) * ys.size(2)
|
34 |
+
l1_loss = self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
|
35 |
+
l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
|
36 |
+
return l1_loss + l2_loss
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
print(sum(p.numel() for p in Reconstructor().parameters() if p.requires_grad))
|
Architectures/Aligner/__init__.py
ADDED
File without changes
|
Architectures/Aligner/autoaligner_train_loop.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.multiprocessing
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
from torch.optim import RAdam
|
8 |
+
from torch.utils.data.dataloader import DataLoader
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from Architectures.Aligner.Aligner import Aligner
|
12 |
+
from Architectures.Aligner.Reconstructor import Reconstructor
|
13 |
+
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
14 |
+
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
|
15 |
+
|
16 |
+
|
17 |
+
def collate_and_pad(batch):
|
18 |
+
# text, text_len, speech, speech_len, embed
|
19 |
+
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
|
20 |
+
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
|
21 |
+
[datapoint[2] for datapoint in batch],
|
22 |
+
None,
|
23 |
+
torch.stack([datapoint[4] for datapoint in batch]).squeeze())
|
24 |
+
|
25 |
+
|
26 |
+
def train_loop(train_dataset,
|
27 |
+
device,
|
28 |
+
save_directory,
|
29 |
+
batch_size,
|
30 |
+
steps,
|
31 |
+
path_to_checkpoint=None,
|
32 |
+
fine_tune=False,
|
33 |
+
resume=False,
|
34 |
+
debug_img_path=None,
|
35 |
+
use_reconstruction=True,
|
36 |
+
gpu_count=1,
|
37 |
+
rank=0,
|
38 |
+
steps_per_checkpoint=None):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
resume: whether to resume from the most recent checkpoint
|
42 |
+
steps: How many steps to train
|
43 |
+
path_to_checkpoint: reloads a checkpoint to continue training from there
|
44 |
+
fine_tune: whether to load everything from a checkpoint, or only the model parameters
|
45 |
+
train_dataset: Pytorch Dataset Object for train data
|
46 |
+
device: Device to put the loaded tensors on
|
47 |
+
save_directory: Where to save the checkpoints
|
48 |
+
batch_size: How many elements should be loaded at once
|
49 |
+
debug_img_path: where to put images of the training progress if desired
|
50 |
+
use_reconstruction: whether to use the auxiliary reconstruction procedure/loss, which can make the alignment sharper
|
51 |
+
"""
|
52 |
+
os.makedirs(save_directory, exist_ok=True)
|
53 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
54 |
+
torch.multiprocessing.set_start_method('spawn', force=True)
|
55 |
+
|
56 |
+
if steps_per_checkpoint is None:
|
57 |
+
steps_per_checkpoint = len(train_dataset) // batch_size
|
58 |
+
ap = CodecAudioPreprocessor(input_sr=-1, device=device) # only used to transform features into continuous matrices
|
59 |
+
spectrogram_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device)
|
60 |
+
|
61 |
+
asr_model = Aligner().to(device)
|
62 |
+
optim_asr = RAdam(asr_model.parameters(), lr=0.0001)
|
63 |
+
|
64 |
+
tiny_tts = Reconstructor().to(device)
|
65 |
+
optim_tts = RAdam(tiny_tts.parameters(), lr=0.0001)
|
66 |
+
|
67 |
+
if gpu_count > 1:
|
68 |
+
asr_model.to(rank)
|
69 |
+
tiny_tts.to(rank)
|
70 |
+
asr_model = torch.nn.parallel.DistributedDataParallel(
|
71 |
+
asr_model,
|
72 |
+
device_ids=[rank],
|
73 |
+
output_device=rank,
|
74 |
+
find_unused_parameters=True,
|
75 |
+
).module
|
76 |
+
tiny_tts = torch.nn.parallel.DistributedDataParallel(
|
77 |
+
tiny_tts,
|
78 |
+
device_ids=[rank],
|
79 |
+
output_device=rank,
|
80 |
+
find_unused_parameters=True,
|
81 |
+
).module
|
82 |
+
torch.distributed.barrier()
|
83 |
+
train_sampler = torch.utils.data.RandomSampler(train_dataset)
|
84 |
+
batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
|
85 |
+
|
86 |
+
train_loader = DataLoader(dataset=train_dataset,
|
87 |
+
num_workers=0, # unfortunately necessary for big data due to mmap errors
|
88 |
+
batch_sampler=batch_sampler_train,
|
89 |
+
prefetch_factor=None,
|
90 |
+
collate_fn=collate_and_pad)
|
91 |
+
|
92 |
+
step_counter = 0
|
93 |
+
loss_sum = list()
|
94 |
+
|
95 |
+
if resume:
|
96 |
+
previous_checkpoint = os.path.join(save_directory, "aligner.pt")
|
97 |
+
path_to_checkpoint = previous_checkpoint
|
98 |
+
fine_tune = False
|
99 |
+
|
100 |
+
if path_to_checkpoint is not None:
|
101 |
+
check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
|
102 |
+
asr_model.load_state_dict(check_dict["asr_model"])
|
103 |
+
tiny_tts.load_state_dict(check_dict["tts_model"])
|
104 |
+
if not fine_tune:
|
105 |
+
optim_asr.load_state_dict(check_dict["optimizer"])
|
106 |
+
optim_tts.load_state_dict(check_dict["tts_optimizer"])
|
107 |
+
step_counter = check_dict["step_counter"]
|
108 |
+
if step_counter > steps:
|
109 |
+
print("Desired steps already reached in loaded checkpoint.")
|
110 |
+
return
|
111 |
+
start_time = time.time()
|
112 |
+
|
113 |
+
while True:
|
114 |
+
asr_model.train()
|
115 |
+
tiny_tts.train()
|
116 |
+
for batch in tqdm(train_loader):
|
117 |
+
tokens = batch[0].to(device)
|
118 |
+
tokens_len = batch[1].to(device)
|
119 |
+
speaker_embeddings = batch[4].to(device)
|
120 |
+
|
121 |
+
mels = list()
|
122 |
+
mel_lengths = list()
|
123 |
+
for datapoint in batch[2]:
|
124 |
+
with torch.inference_mode():
|
125 |
+
# extremely unfortunate that we have to do this over here, but multiprocessing and this don't go together well
|
126 |
+
speech = ap.indexes_to_audio(datapoint.int().to(device))
|
127 |
+
mel = spectrogram_extractor.audio_to_mel_spec_tensor(speech, explicit_sampling_rate=16000).transpose(0, 1).cpu()
|
128 |
+
speech_len = torch.LongTensor([len(mel)])
|
129 |
+
mels.append(mel.clone())
|
130 |
+
mel_lengths.append(speech_len)
|
131 |
+
mel = pad_sequence(mels, batch_first=True).to(device)
|
132 |
+
mel_len = torch.stack(mel_lengths).squeeze(1).to(device)
|
133 |
+
|
134 |
+
pred = asr_model(mel, mel_len)
|
135 |
+
|
136 |
+
ctc_loss = asr_model.ctc_loss(pred.transpose(0, 1).log_softmax(2),
|
137 |
+
tokens,
|
138 |
+
mel_len,
|
139 |
+
tokens_len)
|
140 |
+
|
141 |
+
if use_reconstruction:
|
142 |
+
speaker_embeddings_expanded = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(1).expand(-1, pred.size(1), -1)
|
143 |
+
tts_lambda = min([0.1, step_counter / 10000]) # super simple schedule
|
144 |
+
reconstruction_loss = tiny_tts(x=torch.cat([pred, speaker_embeddings_expanded], dim=-1),
|
145 |
+
# combine ASR prediction with speaker embeddings to allow for reconstruction loss on multiple speakers
|
146 |
+
lens=mel_len,
|
147 |
+
ys=mel) * tts_lambda # reconstruction loss to make the states more distinct
|
148 |
+
loss = ctc_loss + reconstruction_loss
|
149 |
+
else:
|
150 |
+
loss = ctc_loss
|
151 |
+
|
152 |
+
optim_asr.zero_grad()
|
153 |
+
if use_reconstruction:
|
154 |
+
optim_tts.zero_grad()
|
155 |
+
loss.backward()
|
156 |
+
torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
|
157 |
+
if use_reconstruction:
|
158 |
+
torch.nn.utils.clip_grad_norm_(tiny_tts.parameters(), 1.0)
|
159 |
+
optim_asr.step()
|
160 |
+
if use_reconstruction:
|
161 |
+
optim_tts.step()
|
162 |
+
|
163 |
+
loss_sum.append(loss.item())
|
164 |
+
step_counter += 1
|
165 |
+
|
166 |
+
if step_counter % steps_per_checkpoint == 0 and rank == 0:
|
167 |
+
asr_model.eval()
|
168 |
+
torch.save({
|
169 |
+
"asr_model" : asr_model.state_dict(),
|
170 |
+
"optimizer" : optim_asr.state_dict(),
|
171 |
+
"tts_model" : tiny_tts.state_dict(),
|
172 |
+
"tts_optimizer": optim_tts.state_dict(),
|
173 |
+
"step_counter" : step_counter,
|
174 |
+
},
|
175 |
+
os.path.join(save_directory, "aligner.pt"))
|
176 |
+
print("Total Loss: {}".format(round(sum(loss_sum) / len(loss_sum), 3)))
|
177 |
+
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
|
178 |
+
print("Steps: {}".format(step_counter))
|
179 |
+
if debug_img_path is not None:
|
180 |
+
asr_model.inference(features=mel[0][:mel_len[0]],
|
181 |
+
tokens=tokens[0][:tokens_len[0]],
|
182 |
+
save_img_for_debug=debug_img_path + f"/{step_counter}.png",
|
183 |
+
train=True) # for testing
|
184 |
+
asr_model.train()
|
185 |
+
loss_sum = list()
|
186 |
+
|
187 |
+
if step_counter > steps and step_counter % steps_per_checkpoint == 0:
|
188 |
+
return
|
Architectures/ControllabilityGAN/GAN.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan
|
4 |
+
|
5 |
+
|
6 |
+
class GanWrapper:
|
7 |
+
|
8 |
+
def __init__(self, path_wgan, device):
|
9 |
+
self.device = device
|
10 |
+
self.path_wgan = path_wgan
|
11 |
+
|
12 |
+
self.mean = None
|
13 |
+
self.std = None
|
14 |
+
self.wgan = None
|
15 |
+
self.normalize = True
|
16 |
+
|
17 |
+
self.load_model(path_wgan)
|
18 |
+
|
19 |
+
self.U = self.compute_controllability()
|
20 |
+
|
21 |
+
self.z_list = list()
|
22 |
+
for _ in range(1100):
|
23 |
+
self.z_list.append(self.wgan.G.module.sample_latent(1, 32))
|
24 |
+
self.z = self.z_list[0]
|
25 |
+
|
26 |
+
def set_latent(self, seed):
|
27 |
+
self.z = self.z = self.z_list[seed]
|
28 |
+
|
29 |
+
def reset_default_latent(self):
|
30 |
+
self.z = self.wgan.G.module.sample_latent(1, 32)
|
31 |
+
|
32 |
+
def load_model(self, path):
|
33 |
+
gan_checkpoint = torch.load(path, map_location="cpu")
|
34 |
+
|
35 |
+
self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
|
36 |
+
self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict'])
|
37 |
+
self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict'])
|
38 |
+
|
39 |
+
self.mean = gan_checkpoint["dataset_mean"]
|
40 |
+
self.std = gan_checkpoint["dataset_std"]
|
41 |
+
|
42 |
+
def compute_controllability(self, n_samples=50000):
|
43 |
+
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
|
44 |
+
intermediate = intermediate.cpu()
|
45 |
+
z = z.cpu()
|
46 |
+
U = self.controllable_speakers(intermediate, z)
|
47 |
+
return U
|
48 |
+
|
49 |
+
def controllable_speakers(self, intermediate, z):
|
50 |
+
pca = torch.pca_lowrank(intermediate)
|
51 |
+
mu = intermediate.mean()
|
52 |
+
X = torch.matmul((intermediate - mu), pca[2])
|
53 |
+
U = torch.linalg.lstsq(X, z)
|
54 |
+
return U
|
55 |
+
|
56 |
+
def get_original_embed(self):
|
57 |
+
self.wgan.G.eval()
|
58 |
+
embed_original = self.wgan.G.module.forward(self.z.to(self.device))
|
59 |
+
|
60 |
+
if self.normalize:
|
61 |
+
embed_original = inverse_normalize(
|
62 |
+
embed_original.cpu(),
|
63 |
+
self.mean.cpu().unsqueeze(0),
|
64 |
+
self.std.cpu().unsqueeze(0)
|
65 |
+
)
|
66 |
+
return embed_original
|
67 |
+
|
68 |
+
def modify_embed(self, x):
|
69 |
+
self.wgan.G.eval()
|
70 |
+
z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
|
71 |
+
embed_modified = self.wgan.G.module.forward(z_new.unsqueeze(0).to(self.device))
|
72 |
+
if self.normalize:
|
73 |
+
embed_modified = inverse_normalize(
|
74 |
+
embed_modified.cpu(),
|
75 |
+
self.mean.cpu().unsqueeze(0),
|
76 |
+
self.std.cpu().unsqueeze(0)
|
77 |
+
)
|
78 |
+
return embed_modified
|
79 |
+
|
80 |
+
|
81 |
+
def inverse_normalize(tensor, mean, std):
|
82 |
+
return tensor * std + mean
|
Architectures/ControllabilityGAN/__init__.py
ADDED
File without changes
|
Architectures/ControllabilityGAN/dataset/__init__.py
ADDED
File without changes
|
Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class SpeakerEmbeddingsDataset(torch.utils.data.Dataset):
|
8 |
+
|
9 |
+
def __init__(self, feature_path, device, mode='utterance'):
|
10 |
+
super(SpeakerEmbeddingsDataset, self).__init__()
|
11 |
+
|
12 |
+
modes = ['utterance', 'speaker']
|
13 |
+
assert mode in modes, f'mode: {mode} is not supported'
|
14 |
+
if mode == 'utterance':
|
15 |
+
self.mode = 'utt'
|
16 |
+
elif mode == 'speaker':
|
17 |
+
self.mode = 'spk'
|
18 |
+
|
19 |
+
self.device = device
|
20 |
+
|
21 |
+
self.x, self.speakers = self._load_features(feature_path)
|
22 |
+
# unique_speakers = set(self.speakers)
|
23 |
+
# spk2class = dict(zip(unique_speakers, range(len(unique_speakers))))
|
24 |
+
# #self.x = self._reformat_features(self.x)
|
25 |
+
# self.y = torch.tensor([spk2class[spk] for spk in self.speakers]).to(self.device)
|
26 |
+
# self.class2spk = dict(zip(spk2class.values(), spk2class.keys()))
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return len(self.speakers)
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
embedding = self.normalize_embedding(self.x[index])
|
33 |
+
# speaker_id = self.y[index]
|
34 |
+
return embedding, torch.zeros([0])
|
35 |
+
|
36 |
+
def normalize_embedding(self, vector):
|
37 |
+
return torch.sub(vector, self.mean) / self.std
|
38 |
+
|
39 |
+
def get_speaker(self, label):
|
40 |
+
return self.class2spk[label]
|
41 |
+
|
42 |
+
def get_embedding_dim(self):
|
43 |
+
return self.x.shape[-1]
|
44 |
+
|
45 |
+
def get_num_speaker(self):
|
46 |
+
return len(torch.unique((self.y)))
|
47 |
+
|
48 |
+
def set_labels(self, labels):
|
49 |
+
self.y_old = self.y
|
50 |
+
self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device)
|
51 |
+
# if isinstance(labels, int) or isinstance(labels, float):
|
52 |
+
# self.y = torch.full(size=len(self), fill_value=labels)
|
53 |
+
# elif len(labels) == len(self):
|
54 |
+
# self.y = torch.tensor(labels)
|
55 |
+
|
56 |
+
def _load_features(self, feature_path):
|
57 |
+
if os.path.isfile(feature_path):
|
58 |
+
vectors = torch.load(feature_path, map_location=self.device)
|
59 |
+
if isinstance(vectors, list):
|
60 |
+
vectors = torch.stack(vectors)
|
61 |
+
|
62 |
+
self.mean = torch.mean(vectors)
|
63 |
+
self.std = torch.std(vectors)
|
64 |
+
return vectors, torch.zeros(vectors.size(0))
|
65 |
+
else:
|
66 |
+
vectors = torch.load(feature_path, map_location=self.device)
|
67 |
+
|
68 |
+
self.mean = torch.mean(vectors)
|
69 |
+
self.std = torch.std(vectors)
|
70 |
+
|
71 |
+
spk2idx = {}
|
72 |
+
with open(feature_path / f'{self.mode}2idx', 'r') as f:
|
73 |
+
for line in f:
|
74 |
+
split_line = line.strip().split()
|
75 |
+
if len(split_line) == 2:
|
76 |
+
spk2idx[split_line[0].strip()] = int(split_line[1])
|
77 |
+
|
78 |
+
speakers, indices = zip(*spk2idx.items())
|
79 |
+
|
80 |
+
if (feature_path / 'utt2spk').exists(): # spk2idx contains utt_ids not speaker_ids
|
81 |
+
utt2spk = {}
|
82 |
+
with open(feature_path / 'utt2spk', 'r') as f:
|
83 |
+
for line in f:
|
84 |
+
split_line = line.strip().split()
|
85 |
+
if len(split_line) == 2:
|
86 |
+
utt2spk[split_line[0].strip()] = split_line[1].strip()
|
87 |
+
|
88 |
+
speakers = [utt2spk[utt] for utt in speakers]
|
89 |
+
|
90 |
+
return vectors[np.array(indices)], speakers
|
91 |
+
|
92 |
+
def _reformat_features(self, features):
|
93 |
+
if len(features.shape) == 2:
|
94 |
+
return features.reshape(features.shape[0], 1, 1, features.shape[1])
|
Architectures/ControllabilityGAN/wgan/__init__.py
ADDED
File without changes
|
Architectures/ControllabilityGAN/wgan/init_weights.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def weights_init_D(m):
|
5 |
+
classname = m.__class__.__name__
|
6 |
+
if classname.find('Conv') != -1:
|
7 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
8 |
+
# nn.init.constant_(m.bias, 0)
|
9 |
+
elif classname.find('BatchNorm') != -1:
|
10 |
+
nn.init.constant_(m.weight, 1)
|
11 |
+
nn.init.constant_(m.bias, 0)
|
12 |
+
|
13 |
+
|
14 |
+
def weights_init_G(m):
|
15 |
+
classname = m.__class__.__name__
|
16 |
+
if classname.find('Conv') != -1:
|
17 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
18 |
+
# nn.init.constant_(m.bias, 0)
|
19 |
+
elif classname.find('BatchNorm') != -1:
|
20 |
+
nn.init.constant_(m.weight, 1)
|
21 |
+
nn.init.constant_(m.bias, 0)
|
Architectures/ControllabilityGAN/wgan/init_wgan.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from Architectures.ControllabilityGAN.wgan.resnet_init import init_resnet
|
4 |
+
from Architectures.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost
|
5 |
+
|
6 |
+
|
7 |
+
def create_wgan(parameters, device, optimizer='adam'):
|
8 |
+
if parameters['model'] == "resnet":
|
9 |
+
generator, discriminator = init_resnet(parameters)
|
10 |
+
else:
|
11 |
+
raise NotImplementedError
|
12 |
+
|
13 |
+
if optimizer == 'adam':
|
14 |
+
optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
|
15 |
+
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
|
16 |
+
elif optimizer == 'rmsprop':
|
17 |
+
optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
|
18 |
+
optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
|
19 |
+
|
20 |
+
criterion = torch.nn.MSELoss()
|
21 |
+
|
22 |
+
gan = WassersteinGanQuadraticCost(generator,
|
23 |
+
discriminator,
|
24 |
+
optimizer_g,
|
25 |
+
optimizer_d,
|
26 |
+
criterion=criterion,
|
27 |
+
data_dimensions=parameters['data_dim'],
|
28 |
+
epochs=parameters['epochs'],
|
29 |
+
batch_size=parameters['batch_size'],
|
30 |
+
device=device,
|
31 |
+
n_max_iterations=parameters['n_max_iterations'],
|
32 |
+
gamma=parameters['gamma'])
|
33 |
+
|
34 |
+
return gan
|
Architectures/ControllabilityGAN/wgan/resnet_1.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
import torch.utils.data.distributed
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class ResNet_G(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, data_dim, z_dim, size, nfilter=64, nfilter_max=512, bn=True, res_ratio=0.1, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
self.input_dim = z_dim
|
13 |
+
self.output_dim = z_dim
|
14 |
+
self.dropout_rate = 0
|
15 |
+
|
16 |
+
s0 = self.s0 = 4
|
17 |
+
nf = self.nf = nfilter
|
18 |
+
nf_max = self.nf_max = nfilter_max
|
19 |
+
self.bn = bn
|
20 |
+
self.z_dim = z_dim
|
21 |
+
|
22 |
+
# Submodules
|
23 |
+
nlayers = int(np.log2(size / s0))
|
24 |
+
self.nf0 = min(nf_max, nf * 2 ** (nlayers + 1))
|
25 |
+
|
26 |
+
self.fc = nn.Linear(z_dim, self.nf0 * s0 * s0)
|
27 |
+
if self.bn:
|
28 |
+
self.bn1d = nn.BatchNorm1d(self.nf0 * s0 * s0)
|
29 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
30 |
+
|
31 |
+
blocks = []
|
32 |
+
for i in range(nlayers, 0, -1):
|
33 |
+
nf0 = min(nf * 2 ** (i + 1), nf_max)
|
34 |
+
nf1 = min(nf * 2 ** i, nf_max)
|
35 |
+
blocks += [
|
36 |
+
ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
|
37 |
+
nn.Upsample(scale_factor=2)
|
38 |
+
]
|
39 |
+
|
40 |
+
nf0 = min(nf * 2, nf_max)
|
41 |
+
nf1 = min(nf, nf_max)
|
42 |
+
blocks += [
|
43 |
+
ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
|
44 |
+
ResNetBlock(nf1, nf1, bn=self.bn, res_ratio=res_ratio)
|
45 |
+
]
|
46 |
+
|
47 |
+
self.resnet = nn.Sequential(*blocks)
|
48 |
+
self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
|
49 |
+
|
50 |
+
self.fc_out = nn.Linear(3 * size * size, data_dim)
|
51 |
+
|
52 |
+
def forward(self, z, return_intermediate=False):
|
53 |
+
# print(z.shape)
|
54 |
+
batch_size = z.size(0)
|
55 |
+
# z = z.view(batch_size, -1)
|
56 |
+
out = self.fc(z)
|
57 |
+
if self.bn:
|
58 |
+
out = self.bn1d(out)
|
59 |
+
out = self.relu(out)
|
60 |
+
if return_intermediate:
|
61 |
+
l_1 = out.detach().clone()
|
62 |
+
out = out.view(batch_size, self.nf0, self.s0, self.s0)
|
63 |
+
# print(out.shape)
|
64 |
+
|
65 |
+
out = self.resnet(out)
|
66 |
+
|
67 |
+
# print(out.shape)
|
68 |
+
# out = out.view(batch_size, self.nf0*self.s0*self.s0*2)
|
69 |
+
|
70 |
+
out = self.conv_img(out)
|
71 |
+
out = self.relu(out)
|
72 |
+
out.flatten(1)
|
73 |
+
out = self.fc_out(out.flatten(1))
|
74 |
+
|
75 |
+
if return_intermediate:
|
76 |
+
return out, l_1
|
77 |
+
return out
|
78 |
+
|
79 |
+
def sample_latent(self, n_samples, z_size):
|
80 |
+
return torch.randn((n_samples, z_size))
|
81 |
+
|
82 |
+
|
83 |
+
class ResNet_D(nn.Module):
|
84 |
+
|
85 |
+
def __init__(self, data_dim, size, nfilter=64, nfilter_max=512, res_ratio=0.1):
|
86 |
+
super().__init__()
|
87 |
+
s0 = self.s0 = 4
|
88 |
+
nf = self.nf = nfilter
|
89 |
+
nf_max = self.nf_max = nfilter_max
|
90 |
+
self.size = size
|
91 |
+
|
92 |
+
# Submodules
|
93 |
+
nlayers = int(np.log2(size / s0))
|
94 |
+
self.nf0 = min(nf_max, nf * 2 ** nlayers)
|
95 |
+
|
96 |
+
nf0 = min(nf, nf_max)
|
97 |
+
nf1 = min(nf * 2, nf_max)
|
98 |
+
blocks = [
|
99 |
+
ResNetBlock(nf0, nf0, bn=False, res_ratio=res_ratio),
|
100 |
+
ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio)
|
101 |
+
]
|
102 |
+
|
103 |
+
self.fc_input = nn.Linear(data_dim, 3 * size * size)
|
104 |
+
|
105 |
+
for i in range(1, nlayers + 1):
|
106 |
+
nf0 = min(nf * 2 ** i, nf_max)
|
107 |
+
nf1 = min(nf * 2 ** (i + 1), nf_max)
|
108 |
+
blocks += [
|
109 |
+
nn.AvgPool2d(3, stride=2, padding=1),
|
110 |
+
ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio),
|
111 |
+
]
|
112 |
+
|
113 |
+
self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1)
|
114 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
115 |
+
self.resnet = nn.Sequential(*blocks)
|
116 |
+
|
117 |
+
self.fc = nn.Linear(self.nf0 * s0 * s0, 1)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
batch_size = x.size(0)
|
121 |
+
|
122 |
+
out = self.fc_input(x)
|
123 |
+
out = self.relu(out).view(batch_size, 3, self.size, self.size)
|
124 |
+
|
125 |
+
out = self.relu((self.conv_img(out)))
|
126 |
+
out = self.resnet(out)
|
127 |
+
out = out.view(batch_size, self.nf0 * self.s0 * self.s0)
|
128 |
+
out = self.fc(out)
|
129 |
+
|
130 |
+
return out
|
131 |
+
|
132 |
+
|
133 |
+
class ResNetBlock(nn.Module):
|
134 |
+
|
135 |
+
def __init__(self, fin, fout, fhidden=None, bn=True, res_ratio=0.1):
|
136 |
+
super().__init__()
|
137 |
+
# Attributes
|
138 |
+
self.bn = bn
|
139 |
+
self.is_bias = not bn
|
140 |
+
self.learned_shortcut = (fin != fout)
|
141 |
+
self.fin = fin
|
142 |
+
self.fout = fout
|
143 |
+
if fhidden is None:
|
144 |
+
self.fhidden = min(fin, fout)
|
145 |
+
else:
|
146 |
+
self.fhidden = fhidden
|
147 |
+
self.res_ratio = res_ratio
|
148 |
+
|
149 |
+
# Submodules
|
150 |
+
self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1, bias=self.is_bias)
|
151 |
+
if self.bn:
|
152 |
+
self.bn2d_0 = nn.BatchNorm2d(self.fhidden)
|
153 |
+
self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=self.is_bias)
|
154 |
+
if self.bn:
|
155 |
+
self.bn2d_1 = nn.BatchNorm2d(self.fout)
|
156 |
+
if self.learned_shortcut:
|
157 |
+
self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
|
158 |
+
if self.bn:
|
159 |
+
self.bn2d_s = nn.BatchNorm2d(self.fout)
|
160 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
x_s = self._shortcut(x)
|
164 |
+
dx = self.conv_0(x)
|
165 |
+
if self.bn:
|
166 |
+
dx = self.bn2d_0(dx)
|
167 |
+
dx = self.relu(dx)
|
168 |
+
dx = self.conv_1(dx)
|
169 |
+
if self.bn:
|
170 |
+
dx = self.bn2d_1(dx)
|
171 |
+
out = self.relu(x_s + self.res_ratio * dx)
|
172 |
+
return out
|
173 |
+
|
174 |
+
def _shortcut(self, x):
|
175 |
+
if self.learned_shortcut:
|
176 |
+
x_s = self.conv_s(x)
|
177 |
+
if self.bn:
|
178 |
+
x_s = self.bn2d_s(x_s)
|
179 |
+
else:
|
180 |
+
x_s = x
|
181 |
+
return x_s
|
Architectures/ControllabilityGAN/wgan/resnet_init.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_D
|
2 |
+
from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_G
|
3 |
+
from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_D
|
4 |
+
from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_G
|
5 |
+
|
6 |
+
|
7 |
+
def init_resnet(parameters):
|
8 |
+
critic = ResNet_D(parameters['data_dim'][-1], parameters['size'], nfilter=parameters['nfilter'], nfilter_max=parameters['nfilter_max'])
|
9 |
+
generator = ResNet_G(parameters['data_dim'][-1], parameters['z_dim'], parameters['size'], nfilter=parameters['nfilter'],
|
10 |
+
nfilter_max=parameters['nfilter_max'])
|
11 |
+
|
12 |
+
generator.apply(weights_init_G)
|
13 |
+
critic.apply(weights_init_D)
|
14 |
+
|
15 |
+
return generator, critic
|
Architectures/ControllabilityGAN/wgan/wgan_qc.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.optim as optim
|
8 |
+
from cvxopt import matrix
|
9 |
+
from cvxopt import solvers
|
10 |
+
from cvxopt import sparse
|
11 |
+
from cvxopt import spmatrix
|
12 |
+
from torch.autograd import grad as torch_grad
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
class WassersteinGanQuadraticCost:
|
17 |
+
|
18 |
+
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations,
|
19 |
+
data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0):
|
20 |
+
self.G = generator
|
21 |
+
self.G_opt = gen_optimizer
|
22 |
+
self.D = discriminator
|
23 |
+
self.D_opt = dis_optimizer
|
24 |
+
self.losses = {
|
25 |
+
'D' : [],
|
26 |
+
'WD': [],
|
27 |
+
'G' : []
|
28 |
+
}
|
29 |
+
self.num_steps = 0
|
30 |
+
self.gen_steps = 0
|
31 |
+
self.epochs = epochs
|
32 |
+
self.n_max_iterations = n_max_iterations
|
33 |
+
# put in the shape of a dataset sample
|
34 |
+
self.data_dim = data_dimensions[0] * data_dimensions[1] * data_dimensions[2]
|
35 |
+
self.batch_size = batch_size
|
36 |
+
self.device = device
|
37 |
+
self.criterion = criterion
|
38 |
+
self.mone = torch.FloatTensor([-1]).to(device)
|
39 |
+
self.tensorboard_counter = 0
|
40 |
+
|
41 |
+
if K <= 0:
|
42 |
+
self.K = 1 / self.data_dim
|
43 |
+
else:
|
44 |
+
self.K = K
|
45 |
+
self.Kr = np.sqrt(self.K)
|
46 |
+
self.LAMBDA = 2 * self.Kr * gamma * 2
|
47 |
+
|
48 |
+
self.G = nn.DataParallel(self.G.to(self.device))
|
49 |
+
self.D = nn.DataParallel(self.D.to(self.device))
|
50 |
+
|
51 |
+
self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal)
|
52 |
+
self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal)
|
53 |
+
|
54 |
+
self.c, self.A, self.pStart = self._prepare_linear_programming_solver_(self.batch_size)
|
55 |
+
|
56 |
+
def _build_lr_scheduler_(self, optimizer, milestones, lr_anneal, last_epoch=-1):
|
57 |
+
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_anneal, last_epoch=-1)
|
58 |
+
return scheduler
|
59 |
+
|
60 |
+
def _quadratic_wasserstein_distance_(self, real, generated):
|
61 |
+
num_r = real.size(0)
|
62 |
+
num_f = generated.size(0)
|
63 |
+
real_flat = real.view(num_r, -1)
|
64 |
+
fake_flat = generated.view(num_f, -1)
|
65 |
+
|
66 |
+
real3D = real_flat.unsqueeze(1).expand(num_r, num_f, self.data_dim)
|
67 |
+
fake3D = fake_flat.unsqueeze(0).expand(num_r, num_f, self.data_dim)
|
68 |
+
# compute squared L2 distance
|
69 |
+
dif = real3D - fake3D
|
70 |
+
dist = 0.5 * dif.pow(2).sum(2).squeeze()
|
71 |
+
|
72 |
+
return self.K * dist
|
73 |
+
|
74 |
+
def _prepare_linear_programming_solver_(self, batch_size):
|
75 |
+
A = spmatrix(1.0, range(batch_size), [0] * batch_size, (batch_size, batch_size))
|
76 |
+
for i in range(1, batch_size):
|
77 |
+
Ai = spmatrix(1.0, range(batch_size), [i] * batch_size, (batch_size, batch_size))
|
78 |
+
A = sparse([A, Ai])
|
79 |
+
|
80 |
+
D = spmatrix(-1.0, range(batch_size), range(batch_size), (batch_size, batch_size))
|
81 |
+
DM = D
|
82 |
+
for i in range(1, batch_size):
|
83 |
+
DM = sparse([DM, D])
|
84 |
+
|
85 |
+
A = sparse([[A], [DM]])
|
86 |
+
|
87 |
+
cr = matrix([-1.0 / batch_size] * batch_size)
|
88 |
+
cf = matrix([1.0 / batch_size] * batch_size)
|
89 |
+
c = matrix([cr, cf])
|
90 |
+
|
91 |
+
pStart = {}
|
92 |
+
pStart['x'] = matrix([matrix([1.0] * batch_size), matrix([-1.0] * batch_size)])
|
93 |
+
pStart['s'] = matrix([1.0] * (2 * batch_size))
|
94 |
+
|
95 |
+
return c, A, pStart
|
96 |
+
|
97 |
+
def _linear_programming_(self, distance, batch_size):
|
98 |
+
b = matrix(distance.cpu().double().detach().numpy().flatten())
|
99 |
+
sol = solvers.lp(self.c, self.A, b, primalstart=self.pStart, solver='glpk',
|
100 |
+
options={'glpk': {'msg_lev': 'GLP_MSG_OFF'}})
|
101 |
+
offset = 0.5 * (sum(sol['x'])) / batch_size
|
102 |
+
sol['x'] = sol['x'] - offset
|
103 |
+
self.pStart['x'] = sol['x']
|
104 |
+
self.pStart['s'] = sol['s']
|
105 |
+
|
106 |
+
return sol
|
107 |
+
|
108 |
+
def _approx_OT_(self, sol):
|
109 |
+
# Compute the OT mapping for each fake dataset
|
110 |
+
ResMat = np.array(sol['z']).reshape((self.batch_size, self.batch_size))
|
111 |
+
mapping = torch.from_numpy(np.argmax(ResMat, axis=0)).long().to(self.device)
|
112 |
+
|
113 |
+
return mapping
|
114 |
+
|
115 |
+
def _optimal_transport_regularization_(self, output_fake, fake, real_fake_diff):
|
116 |
+
output_fake_grad = torch.ones(output_fake.size()).to(self.device)
|
117 |
+
gradients = torch_grad(outputs=output_fake, inputs=fake,
|
118 |
+
grad_outputs=output_fake_grad,
|
119 |
+
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
120 |
+
n = gradients.size(0)
|
121 |
+
RegLoss = 0.5 * ((gradients.view(n, -1).norm(dim=1) / (2 * self.Kr) - self.Kr / 2 * real_fake_diff.view(n,
|
122 |
+
-1).norm(
|
123 |
+
dim=1)).pow(2)).mean()
|
124 |
+
fake.requires_grad = False
|
125 |
+
|
126 |
+
return RegLoss
|
127 |
+
|
128 |
+
def _critic_deep_regression_(self, images, opt_iterations=1):
|
129 |
+
images = images.to(self.device)
|
130 |
+
|
131 |
+
for p in self.D.parameters(): # reset requires_grad
|
132 |
+
p.requires_grad = True # they are set to False below in netG update
|
133 |
+
|
134 |
+
self.G.train()
|
135 |
+
self.D.train()
|
136 |
+
|
137 |
+
# Get generated fake dataset
|
138 |
+
generated_data = self.sample_generator(self.batch_size)
|
139 |
+
|
140 |
+
# compute wasserstein distance
|
141 |
+
distance = self._quadratic_wasserstein_distance_(images, generated_data)
|
142 |
+
# solve linear programming problem
|
143 |
+
sol = self._linear_programming_(distance, self.batch_size)
|
144 |
+
# approximate optimal transport
|
145 |
+
mapping = self._approx_OT_(sol)
|
146 |
+
real_ordered = images[mapping] # match real and fake
|
147 |
+
real_fake_diff = real_ordered - generated_data
|
148 |
+
|
149 |
+
# construct target
|
150 |
+
target = torch.from_numpy(np.array(sol['x'])).float()
|
151 |
+
target = target.squeeze().to(self.device)
|
152 |
+
|
153 |
+
for i in range(opt_iterations):
|
154 |
+
self.D.zero_grad() # ???
|
155 |
+
self.D_opt.zero_grad()
|
156 |
+
generated_data.requires_grad_()
|
157 |
+
if generated_data.grad is not None:
|
158 |
+
generated_data.grad.data.zero_()
|
159 |
+
output_real = self.D(images)
|
160 |
+
output_fake = self.D(generated_data)
|
161 |
+
output_real, output_fake = output_real.squeeze(), output_fake.squeeze()
|
162 |
+
output_R_mean = output_real.mean(0).view(1)
|
163 |
+
output_F_mean = output_fake.mean(0).view(1)
|
164 |
+
|
165 |
+
L2LossD_real = self.criterion(output_R_mean[0], target[:self.batch_size].mean())
|
166 |
+
L2LossD_fake = self.criterion(output_fake, target[self.batch_size:])
|
167 |
+
L2LossD = 0.5 * L2LossD_real + 0.5 * L2LossD_fake
|
168 |
+
|
169 |
+
reg_loss_D = self._optimal_transport_regularization_(output_fake, generated_data, real_fake_diff)
|
170 |
+
|
171 |
+
total_loss = L2LossD + self.LAMBDA * reg_loss_D
|
172 |
+
|
173 |
+
self.losses['D'].append(float(total_loss.data))
|
174 |
+
|
175 |
+
total_loss.backward()
|
176 |
+
self.D_opt.step()
|
177 |
+
|
178 |
+
# this is supposed to be the wasserstein distance
|
179 |
+
wasserstein_distance = output_R_mean - output_F_mean
|
180 |
+
self.losses['WD'].append(float(wasserstein_distance.data))
|
181 |
+
|
182 |
+
def _generator_train_iteration(self, batch_size):
|
183 |
+
for p in self.D.parameters():
|
184 |
+
p.requires_grad = False # freeze critic
|
185 |
+
|
186 |
+
self.G.zero_grad()
|
187 |
+
self.G_opt.zero_grad()
|
188 |
+
|
189 |
+
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
190 |
+
z = self.G.module.sample_latent(batch_size, self.G.module.z_dim)
|
191 |
+
else:
|
192 |
+
z = self.G.sample_latent(batch_size, self.G.z_dim)
|
193 |
+
z.requires_grad = True
|
194 |
+
|
195 |
+
fake = self.G(z)
|
196 |
+
output_fake = self.D(fake)
|
197 |
+
output_F_mean_after = output_fake.mean(0).view(1)
|
198 |
+
|
199 |
+
self.losses['G'].append(float(output_F_mean_after.data))
|
200 |
+
|
201 |
+
output_F_mean_after.backward(self.mone)
|
202 |
+
self.G_opt.step()
|
203 |
+
|
204 |
+
self.schedulerD.step()
|
205 |
+
self.schedulerG.step()
|
206 |
+
|
207 |
+
def _train_epoch(self, data_loader, writer, experiment):
|
208 |
+
for i, data in enumerate(tqdm(data_loader)):
|
209 |
+
images = data[0]
|
210 |
+
speaker_ids = data[1]
|
211 |
+
self.num_steps += 1
|
212 |
+
# self.tensorboard_counter += 1
|
213 |
+
if self.gen_steps >= self.n_max_iterations:
|
214 |
+
return
|
215 |
+
self._critic_deep_regression_(images)
|
216 |
+
self._generator_train_iteration(images.size(0))
|
217 |
+
|
218 |
+
D_loss_avg = np.average(self.losses['D'])
|
219 |
+
G_loss_avg = np.average(self.losses['G'])
|
220 |
+
wd_avg = np.average(self.losses['WD'])
|
221 |
+
|
222 |
+
def train(self, data_loader, writer, experiment=None):
|
223 |
+
self.G.train()
|
224 |
+
self.D.train()
|
225 |
+
|
226 |
+
for epoch in range(self.epochs):
|
227 |
+
if self.gen_steps >= self.n_max_iterations:
|
228 |
+
return
|
229 |
+
time_start_epoch = time.time()
|
230 |
+
self._train_epoch(data_loader, writer, experiment)
|
231 |
+
|
232 |
+
D_loss_avg = np.average(self.losses['D'])
|
233 |
+
|
234 |
+
time_end_epoch = time.time()
|
235 |
+
|
236 |
+
return self
|
237 |
+
|
238 |
+
def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
|
239 |
+
self.G.eval()
|
240 |
+
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
241 |
+
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim)
|
242 |
+
else:
|
243 |
+
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim)
|
244 |
+
latent_samples = latent_samples.to(self.device)
|
245 |
+
if nograd:
|
246 |
+
with torch.no_grad():
|
247 |
+
generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
|
248 |
+
else:
|
249 |
+
generated_data = self.G(latent_samples)
|
250 |
+
self.G.train()
|
251 |
+
if return_intermediate:
|
252 |
+
return generated_data[0].detach(), generated_data[1], latent_samples
|
253 |
+
return generated_data.detach()
|
254 |
+
|
255 |
+
def sample(self, num_samples):
|
256 |
+
generated_data = self.sample_generator(num_samples)
|
257 |
+
# Remove color channel
|
258 |
+
return generated_data.data.cpu().numpy()[:, 0, :, :]
|
259 |
+
|
260 |
+
def save_model_checkpoint(self, model_path, model_parameters, timestampStr):
|
261 |
+
# dateTimeObj = datetime.now()
|
262 |
+
# timestampStr = dateTimeObj.strftime("%d-%m-%Y-%H-%M-%S")
|
263 |
+
name = '%s_%s' % (timestampStr, 'wgan')
|
264 |
+
model_filename = os.path.join(model_path, name)
|
265 |
+
torch.save({
|
266 |
+
'generator_state_dict' : self.G.state_dict(),
|
267 |
+
'critic_state_dict' : self.D.state_dict(),
|
268 |
+
'gen_optimizer_state_dict' : self.G_opt.state_dict(),
|
269 |
+
'critic_optimizer_state_dict': self.D_opt.state_dict(),
|
270 |
+
'model_parameters' : model_parameters,
|
271 |
+
'iterations' : self.num_steps
|
272 |
+
}, model_filename)
|
Architectures/EmbeddingModel/GST.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Nagoya University (Tomoki Hayashi)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from Architectures.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention
|
7 |
+
|
8 |
+
|
9 |
+
class GSTStyleEncoder(torch.nn.Module):
|
10 |
+
"""Style encoder.
|
11 |
+
This module is style encoder introduced in `Style Tokens: Unsupervised Style
|
12 |
+
Modeling, Control and Transfer in End-to-End Speech Synthesis`.
|
13 |
+
.. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
|
14 |
+
Speech Synthesis`: https://arxiv.org/abs/1803.09017
|
15 |
+
Args:
|
16 |
+
idim (int, optional): Dimension of the input features.
|
17 |
+
gst_tokens (int, optional): The number of GST embeddings.
|
18 |
+
gst_token_dim (int, optional): Dimension of each GST embedding.
|
19 |
+
gst_heads (int, optional): The number of heads in GST multihead attention.
|
20 |
+
conv_layers (int, optional): The number of conv layers in the reference encoder.
|
21 |
+
conv_chans_list: (Sequence[int], optional):
|
22 |
+
List of the number of channels of conv layers in the reference encoder.
|
23 |
+
conv_kernel_size (int, optional):
|
24 |
+
Kernel size of conv layers in the reference encoder.
|
25 |
+
conv_stride (int, optional):
|
26 |
+
Stride size of conv layers in the reference encoder.
|
27 |
+
gst_layers (int, optional): The number of GRU layers in the reference encoder.
|
28 |
+
gst_units (int, optional): The number of GRU units in the reference encoder.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
idim: int = 128,
|
34 |
+
gst_tokens: int = 512, # adaspeech suggests to use many more "basis vectors", but I believe that this is already sufficient
|
35 |
+
gst_token_dim: int = 64,
|
36 |
+
gst_heads: int = 8,
|
37 |
+
conv_layers: int = 8,
|
38 |
+
conv_chans_list=(32, 32, 64, 64, 128, 128, 256, 256),
|
39 |
+
conv_kernel_size: int = 3,
|
40 |
+
conv_stride: int = 2,
|
41 |
+
gst_layers: int = 2,
|
42 |
+
gst_units: int = 256,
|
43 |
+
):
|
44 |
+
"""Initialize global style encoder module."""
|
45 |
+
super(GSTStyleEncoder, self).__init__()
|
46 |
+
|
47 |
+
self.num_tokens = gst_tokens
|
48 |
+
self.ref_enc = ReferenceEncoder(idim=idim,
|
49 |
+
conv_layers=conv_layers,
|
50 |
+
conv_chans_list=conv_chans_list,
|
51 |
+
conv_kernel_size=conv_kernel_size,
|
52 |
+
conv_stride=conv_stride,
|
53 |
+
gst_layers=gst_layers,
|
54 |
+
gst_units=gst_units, )
|
55 |
+
self.stl = StyleTokenLayer(ref_embed_dim=gst_units,
|
56 |
+
gst_tokens=gst_tokens,
|
57 |
+
gst_token_dim=gst_token_dim,
|
58 |
+
gst_heads=gst_heads, )
|
59 |
+
|
60 |
+
def forward(self, speech):
|
61 |
+
"""Calculate forward propagation.
|
62 |
+
Args:
|
63 |
+
speech (Tensor): Batch of padded target features (B, Lmax, odim).
|
64 |
+
Returns:
|
65 |
+
Tensor: Style token embeddings (B, token_dim).
|
66 |
+
"""
|
67 |
+
ref_embs = self.ref_enc(speech)
|
68 |
+
style_embs = self.stl(ref_embs)
|
69 |
+
|
70 |
+
return style_embs
|
71 |
+
|
72 |
+
def calculate_ada4_regularization_loss(self):
|
73 |
+
losses = list()
|
74 |
+
for emb1_index in range(self.num_tokens):
|
75 |
+
for emb2_index in range(emb1_index + 1, self.num_tokens):
|
76 |
+
if emb1_index != emb2_index:
|
77 |
+
losses.append(torch.nn.functional.cosine_similarity(self.stl.gst_embs[emb1_index],
|
78 |
+
self.stl.gst_embs[emb2_index], dim=0))
|
79 |
+
return sum(losses)
|
80 |
+
|
81 |
+
|
82 |
+
class ReferenceEncoder(torch.nn.Module):
|
83 |
+
"""Reference encoder module.
|
84 |
+
This module is reference encoder introduced in `Style Tokens: Unsupervised Style
|
85 |
+
Modeling, Control and Transfer in End-to-End Speech Synthesis`.
|
86 |
+
.. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
|
87 |
+
Speech Synthesis`: https://arxiv.org/abs/1803.09017
|
88 |
+
Args:
|
89 |
+
idim (int, optional): Dimension of the input features.
|
90 |
+
conv_layers (int, optional): The number of conv layers in the reference encoder.
|
91 |
+
conv_chans_list: (Sequence[int], optional):
|
92 |
+
List of the number of channels of conv layers in the reference encoder.
|
93 |
+
conv_kernel_size (int, optional):
|
94 |
+
Kernel size of conv layers in the reference encoder.
|
95 |
+
conv_stride (int, optional):
|
96 |
+
Stride size of conv layers in the reference encoder.
|
97 |
+
gst_layers (int, optional): The number of GRU layers in the reference encoder.
|
98 |
+
gst_units (int, optional): The number of GRU units in the reference encoder.
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
idim=80,
|
104 |
+
conv_layers: int = 6,
|
105 |
+
conv_chans_list=(32, 32, 64, 64, 128, 128),
|
106 |
+
conv_kernel_size: int = 3,
|
107 |
+
conv_stride: int = 2,
|
108 |
+
gst_layers: int = 1,
|
109 |
+
gst_units: int = 128,
|
110 |
+
):
|
111 |
+
"""Initialize reference encoder module."""
|
112 |
+
super(ReferenceEncoder, self).__init__()
|
113 |
+
|
114 |
+
# check hyperparameters are valid
|
115 |
+
assert conv_kernel_size % 2 == 1, "kernel size must be odd."
|
116 |
+
assert (
|
117 |
+
len(conv_chans_list) == conv_layers), "the number of conv layers and length of channels list must be the same."
|
118 |
+
|
119 |
+
convs = []
|
120 |
+
padding = (conv_kernel_size - 1) // 2
|
121 |
+
for i in range(conv_layers):
|
122 |
+
conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1]
|
123 |
+
conv_out_chans = conv_chans_list[i]
|
124 |
+
convs += [torch.nn.Conv2d(conv_in_chans,
|
125 |
+
conv_out_chans,
|
126 |
+
kernel_size=conv_kernel_size,
|
127 |
+
stride=conv_stride,
|
128 |
+
padding=padding,
|
129 |
+
# Do not use bias due to the following batch norm
|
130 |
+
bias=False, ),
|
131 |
+
torch.nn.BatchNorm2d(conv_out_chans),
|
132 |
+
torch.nn.ReLU(inplace=True), ]
|
133 |
+
self.convs = torch.nn.Sequential(*convs)
|
134 |
+
|
135 |
+
self.conv_layers = conv_layers
|
136 |
+
self.kernel_size = conv_kernel_size
|
137 |
+
self.stride = conv_stride
|
138 |
+
self.padding = padding
|
139 |
+
|
140 |
+
# get the number of GRU input units
|
141 |
+
gst_in_units = idim
|
142 |
+
for i in range(conv_layers):
|
143 |
+
gst_in_units = (gst_in_units - conv_kernel_size + 2 * padding) // conv_stride + 1
|
144 |
+
gst_in_units *= conv_out_chans
|
145 |
+
self.gst = torch.nn.GRU(gst_in_units, gst_units, gst_layers, batch_first=True)
|
146 |
+
|
147 |
+
def forward(self, speech):
|
148 |
+
"""Calculate forward propagation.
|
149 |
+
Args:
|
150 |
+
speech (Tensor): Batch of padded target features (B, Lmax, idim).
|
151 |
+
Returns:
|
152 |
+
Tensor: Reference embedding (B, gst_units)
|
153 |
+
"""
|
154 |
+
batch_size = speech.size(0)
|
155 |
+
xs = speech.unsqueeze(1) # (B, 1, Lmax, idim)
|
156 |
+
hs = self.convs(xs).transpose(1, 2) # (B, Lmax', conv_out_chans, idim')
|
157 |
+
time_length = hs.size(1)
|
158 |
+
hs = hs.contiguous().view(batch_size, time_length, -1) # (B, Lmax', gst_units)
|
159 |
+
self.gst.flatten_parameters()
|
160 |
+
# pack_padded_sequence(hs, speech_lens, enforce_sorted=False, batch_first=True)
|
161 |
+
_, ref_embs = self.gst(hs) # (gst_layers, batch_size, gst_units)
|
162 |
+
ref_embs = ref_embs[-1] # (batch_size, gst_units)
|
163 |
+
|
164 |
+
return ref_embs
|
165 |
+
|
166 |
+
|
167 |
+
class StyleTokenLayer(torch.nn.Module):
|
168 |
+
"""Style token layer module.
|
169 |
+
This module is style token layer introduced in `Style Tokens: Unsupervised Style
|
170 |
+
Modeling, Control and Transfer in End-to-End Speech Synthesis`.
|
171 |
+
.. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
|
172 |
+
Speech Synthesis`: https://arxiv.org/abs/1803.09017
|
173 |
+
Args:
|
174 |
+
ref_embed_dim (int, optional): Dimension of the input reference embedding.
|
175 |
+
gst_tokens (int, optional): The number of GST embeddings.
|
176 |
+
gst_token_dim (int, optional): Dimension of each GST embedding.
|
177 |
+
gst_heads (int, optional): The number of heads in GST multihead attention.
|
178 |
+
dropout_rate (float, optional): Dropout rate in multi-head attention.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
ref_embed_dim: int = 128,
|
184 |
+
gst_tokens: int = 10,
|
185 |
+
gst_token_dim: int = 128,
|
186 |
+
gst_heads: int = 4,
|
187 |
+
dropout_rate: float = 0.0,
|
188 |
+
):
|
189 |
+
"""Initialize style token layer module."""
|
190 |
+
super(StyleTokenLayer, self).__init__()
|
191 |
+
|
192 |
+
gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads)
|
193 |
+
self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs))
|
194 |
+
self.mha = MultiHeadedAttention(q_dim=ref_embed_dim,
|
195 |
+
k_dim=gst_token_dim // gst_heads,
|
196 |
+
v_dim=gst_token_dim // gst_heads,
|
197 |
+
n_head=gst_heads,
|
198 |
+
n_feat=gst_token_dim,
|
199 |
+
dropout_rate=dropout_rate, )
|
200 |
+
|
201 |
+
def forward(self, ref_embs):
|
202 |
+
"""Calculate forward propagation.
|
203 |
+
Args:
|
204 |
+
ref_embs (Tensor): Reference embeddings (B, ref_embed_dim).
|
205 |
+
Returns:
|
206 |
+
Tensor: Style token embeddings (B, gst_token_dim).
|
207 |
+
"""
|
208 |
+
batch_size = ref_embs.size(0)
|
209 |
+
# (num_tokens, token_dim) -> (batch_size, num_tokens, token_dim)
|
210 |
+
gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1)
|
211 |
+
# NOTE(kan-bayashi): Shoule we apply Tanh?
|
212 |
+
ref_embs = ref_embs.unsqueeze(1) # (batch_size, 1 ,ref_embed_dim)
|
213 |
+
style_embs = self.mha(ref_embs, gst_embs, gst_embs, None)
|
214 |
+
|
215 |
+
return style_embs.squeeze(1)
|
216 |
+
|
217 |
+
|
218 |
+
class MultiHeadedAttention(BaseMultiHeadedAttention):
|
219 |
+
"""Multi head attention module with different input dimension."""
|
220 |
+
|
221 |
+
def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0):
|
222 |
+
"""Initialize multi head attention module."""
|
223 |
+
# NOTE(kan-bayashi): Do not use super().__init__() here since we want to
|
224 |
+
# overwrite BaseMultiHeadedAttention.__init__() method.
|
225 |
+
torch.nn.Module.__init__(self)
|
226 |
+
assert n_feat % n_head == 0
|
227 |
+
# We assume d_v always equals d_k
|
228 |
+
self.d_k = n_feat // n_head
|
229 |
+
self.h = n_head
|
230 |
+
self.linear_q = torch.nn.Linear(q_dim, n_feat)
|
231 |
+
self.linear_k = torch.nn.Linear(k_dim, n_feat)
|
232 |
+
self.linear_v = torch.nn.Linear(v_dim, n_feat)
|
233 |
+
self.linear_out = torch.nn.Linear(n_feat, n_feat)
|
234 |
+
self.attn = None
|
235 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
Architectures/EmbeddingModel/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Everything that is concerned with the embedding model is contained in this directory. The embedding function does not have its own train loop, because it is always trained jointly with the TTS. Most of the time however, it is used in a frozen state. We recommend using the embedding function that we publish in the GitHub releases.
|
Architectures/EmbeddingModel/StyleEmbedding.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from Architectures.EmbeddingModel.GST import GSTStyleEncoder
|
4 |
+
from Architectures.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder
|
5 |
+
|
6 |
+
|
7 |
+
class StyleEmbedding(torch.nn.Module):
|
8 |
+
"""
|
9 |
+
The style embedding should provide information of the speaker and their speaking style
|
10 |
+
|
11 |
+
The feedback signal for the module will come from the TTS objective, so it doesn't have a dedicated train loop.
|
12 |
+
The train loop does however supply supervision in the form of a barlow twins objective.
|
13 |
+
|
14 |
+
See the git history for some other approaches for style embedding, like the SWIN transformer
|
15 |
+
and a simple LSTM baseline. GST turned out to be the best.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, embedding_dim=16, style_tts_encoder=False):
|
19 |
+
super().__init__()
|
20 |
+
self.embedding_dim = embedding_dim
|
21 |
+
self.use_gst = not style_tts_encoder
|
22 |
+
if style_tts_encoder:
|
23 |
+
self.style_encoder = StyleTTSEncoder(style_dim=embedding_dim)
|
24 |
+
else:
|
25 |
+
self.style_encoder = GSTStyleEncoder(gst_token_dim=embedding_dim)
|
26 |
+
|
27 |
+
def forward(self,
|
28 |
+
batch_of_feature_sequences,
|
29 |
+
batch_of_feature_sequence_lengths):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
batch_of_feature_sequences: b is the batch axis, 128 features per timestep
|
33 |
+
and l time-steps, which may include padding
|
34 |
+
for most elements in the batch (b, l, 128)
|
35 |
+
batch_of_feature_sequence_lengths: indicate for every element in the batch,
|
36 |
+
what the true length is, since they are
|
37 |
+
all padded to the length of the longest
|
38 |
+
element in the batch (b, 1)
|
39 |
+
Returns:
|
40 |
+
batch of n dimensional embeddings (b,n)
|
41 |
+
"""
|
42 |
+
|
43 |
+
minimum_sequence_length = 512
|
44 |
+
specs = list()
|
45 |
+
for index, spec_length in enumerate(batch_of_feature_sequence_lengths):
|
46 |
+
spec = batch_of_feature_sequences[index][:spec_length]
|
47 |
+
# double the length at least once, then check
|
48 |
+
spec = spec.repeat((2, 1))
|
49 |
+
current_spec_length = len(spec)
|
50 |
+
while current_spec_length < minimum_sequence_length:
|
51 |
+
# make it longer
|
52 |
+
spec = spec.repeat((2, 1))
|
53 |
+
current_spec_length = len(spec)
|
54 |
+
specs.append(spec[:minimum_sequence_length])
|
55 |
+
|
56 |
+
spec_batch = torch.stack(specs, dim=0)
|
57 |
+
return self.style_encoder(speech=spec_batch)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
style_emb = StyleEmbedding(style_tts_encoder=False)
|
62 |
+
print(f"GST parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")
|
63 |
+
|
64 |
+
seq_length = 398
|
65 |
+
print(style_emb(torch.randn(5, seq_length, 512),
|
66 |
+
torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)
|
67 |
+
|
68 |
+
style_emb = StyleEmbedding(style_tts_encoder=True)
|
69 |
+
print(f"StyleTTS encoder parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")
|
70 |
+
|
71 |
+
seq_length = 398
|
72 |
+
print(style_emb(torch.randn(5, seq_length, 512),
|
73 |
+
torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)
|
Architectures/EmbeddingModel/StyleTTSEncoder.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MIT Licensed Code
|
3 |
+
|
4 |
+
Copyright (c) 2022 Aaron (Yinghao) Li
|
5 |
+
|
6 |
+
https://github.com/yl4579/StyleTTS/blob/main/models.py
|
7 |
+
"""
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn.utils import spectral_norm
|
15 |
+
|
16 |
+
|
17 |
+
class StyleEncoder(nn.Module):
|
18 |
+
def __init__(self, dim_in=128, style_dim=64, max_conv_dim=384):
|
19 |
+
super().__init__()
|
20 |
+
blocks = []
|
21 |
+
blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
|
22 |
+
|
23 |
+
repeat_num = 4
|
24 |
+
for _ in range(repeat_num):
|
25 |
+
dim_out = min(dim_in * 2, max_conv_dim)
|
26 |
+
blocks += [ResBlk(dim_in, dim_out, downsample='half')]
|
27 |
+
dim_in = dim_out
|
28 |
+
|
29 |
+
blocks += [nn.LeakyReLU(0.2)]
|
30 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
31 |
+
blocks += [nn.AdaptiveAvgPool2d(1)]
|
32 |
+
blocks += [nn.LeakyReLU(0.2)]
|
33 |
+
self.shared = nn.Sequential(*blocks)
|
34 |
+
|
35 |
+
self.unshared = nn.Linear(dim_out, style_dim)
|
36 |
+
|
37 |
+
def forward(self, speech):
|
38 |
+
h = self.shared(speech.unsqueeze(1))
|
39 |
+
h = h.view(h.size(0), -1)
|
40 |
+
s = self.unshared(h)
|
41 |
+
|
42 |
+
return s
|
43 |
+
|
44 |
+
|
45 |
+
class ResBlk(nn.Module):
|
46 |
+
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
47 |
+
normalize=False, downsample='none'):
|
48 |
+
super().__init__()
|
49 |
+
self.actv = actv
|
50 |
+
self.normalize = normalize
|
51 |
+
self.downsample = DownSample(downsample)
|
52 |
+
self.downsample_res = LearnedDownSample(downsample, dim_in)
|
53 |
+
self.learned_sc = dim_in != dim_out
|
54 |
+
self._build_weights(dim_in, dim_out)
|
55 |
+
|
56 |
+
def _build_weights(self, dim_in, dim_out):
|
57 |
+
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
|
58 |
+
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
|
59 |
+
if self.normalize:
|
60 |
+
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
61 |
+
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
62 |
+
if self.learned_sc:
|
63 |
+
self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
|
64 |
+
|
65 |
+
def _shortcut(self, x):
|
66 |
+
if self.learned_sc:
|
67 |
+
x = self.conv1x1(x)
|
68 |
+
if self.downsample:
|
69 |
+
x = self.downsample(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
def _residual(self, x):
|
73 |
+
if self.normalize:
|
74 |
+
x = self.norm1(x)
|
75 |
+
x = self.actv(x)
|
76 |
+
x = self.conv1(x)
|
77 |
+
x = self.downsample_res(x)
|
78 |
+
if self.normalize:
|
79 |
+
x = self.norm2(x)
|
80 |
+
x = self.actv(x)
|
81 |
+
x = self.conv2(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = self._shortcut(x) + self._residual(x)
|
86 |
+
return x / math.sqrt(2) # unit variance
|
87 |
+
|
88 |
+
|
89 |
+
class LearnedDownSample(nn.Module):
|
90 |
+
def __init__(self, layer_type, dim_in):
|
91 |
+
super().__init__()
|
92 |
+
self.layer_type = layer_type
|
93 |
+
|
94 |
+
if self.layer_type == 'none':
|
95 |
+
self.conv = nn.Identity()
|
96 |
+
elif self.layer_type == 'timepreserve':
|
97 |
+
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
|
98 |
+
elif self.layer_type == 'half':
|
99 |
+
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
|
100 |
+
else:
|
101 |
+
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
return self.conv(x)
|
105 |
+
|
106 |
+
|
107 |
+
class LearnedUpSample(nn.Module):
|
108 |
+
def __init__(self, layer_type, dim_in):
|
109 |
+
super().__init__()
|
110 |
+
self.layer_type = layer_type
|
111 |
+
|
112 |
+
if self.layer_type == 'none':
|
113 |
+
self.conv = nn.Identity()
|
114 |
+
elif self.layer_type == 'timepreserve':
|
115 |
+
self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
|
116 |
+
elif self.layer_type == 'half':
|
117 |
+
self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
|
118 |
+
else:
|
119 |
+
raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
return self.conv(x)
|
123 |
+
|
124 |
+
|
125 |
+
class DownSample(nn.Module):
|
126 |
+
def __init__(self, layer_type):
|
127 |
+
super().__init__()
|
128 |
+
self.layer_type = layer_type
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
if self.layer_type == 'none':
|
132 |
+
return x
|
133 |
+
elif self.layer_type == 'timepreserve':
|
134 |
+
return F.avg_pool2d(x, (2, 1))
|
135 |
+
elif self.layer_type == 'half':
|
136 |
+
if x.shape[-1] % 2 != 0:
|
137 |
+
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
138 |
+
return F.avg_pool2d(x, 2)
|
139 |
+
else:
|
140 |
+
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
141 |
+
|
142 |
+
|
143 |
+
class UpSample(nn.Module):
|
144 |
+
def __init__(self, layer_type):
|
145 |
+
super().__init__()
|
146 |
+
self.layer_type = layer_type
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
if self.layer_type == 'none':
|
150 |
+
return x
|
151 |
+
elif self.layer_type == 'timepreserve':
|
152 |
+
return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
|
153 |
+
elif self.layer_type == 'half':
|
154 |
+
return F.interpolate(x, scale_factor=2, mode='nearest')
|
155 |
+
else:
|
156 |
+
raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
Architectures/EmbeddingModel/__init__.py
ADDED
File without changes
|
Architectures/GeneralLayers/Attention.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
"""Multi-Head Attention layer definition."""
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from Utility.utils import make_non_pad_mask
|
14 |
+
|
15 |
+
|
16 |
+
class MultiHeadedAttention(nn.Module):
|
17 |
+
"""
|
18 |
+
Multi-Head Attention layer.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
n_head (int): The number of heads.
|
22 |
+
n_feat (int): The number of features.
|
23 |
+
dropout_rate (float): Dropout rate.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, n_head, n_feat, dropout_rate):
|
27 |
+
"""
|
28 |
+
Construct an MultiHeadedAttention object.
|
29 |
+
"""
|
30 |
+
super(MultiHeadedAttention, self).__init__()
|
31 |
+
assert n_feat % n_head == 0
|
32 |
+
# We assume d_v always equals d_k
|
33 |
+
self.d_k = n_feat // n_head
|
34 |
+
self.h = n_head
|
35 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
36 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
37 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
38 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
39 |
+
self.attn = None
|
40 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
41 |
+
|
42 |
+
def forward_qkv(self, query, key, value):
|
43 |
+
"""
|
44 |
+
Transform query, key and value.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
48 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
49 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
53 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
54 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
55 |
+
"""
|
56 |
+
n_batch = query.size(0)
|
57 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
58 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
59 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
60 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
61 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
62 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
63 |
+
|
64 |
+
return q, k, v
|
65 |
+
|
66 |
+
def forward_attention(self, value, scores, mask):
|
67 |
+
"""
|
68 |
+
Compute attention context vector.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
72 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
73 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
77 |
+
weighted by the attention score (#batch, time1, time2).
|
78 |
+
"""
|
79 |
+
n_batch = value.size(0)
|
80 |
+
if mask is not None:
|
81 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
82 |
+
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
83 |
+
scores = scores.masked_fill(mask, min_value)
|
84 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
85 |
+
else:
|
86 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
87 |
+
|
88 |
+
p_attn = self.dropout(self.attn)
|
89 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
90 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model)
|
91 |
+
|
92 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
93 |
+
|
94 |
+
def forward(self, query, key, value, mask):
|
95 |
+
"""
|
96 |
+
Compute scaled dot product attention.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
100 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
101 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
102 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
103 |
+
(#batch, time1, time2).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
107 |
+
"""
|
108 |
+
q, k, v = self.forward_qkv(query, key, value)
|
109 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
110 |
+
return self.forward_attention(v, scores, mask)
|
111 |
+
|
112 |
+
|
113 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
114 |
+
"""
|
115 |
+
Multi-Head Attention layer with relative position encoding.
|
116 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
117 |
+
Paper: https://arxiv.org/abs/1901.02860
|
118 |
+
Args:
|
119 |
+
n_head (int): The number of heads.
|
120 |
+
n_feat (int): The number of features.
|
121 |
+
dropout_rate (float): Dropout rate.
|
122 |
+
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
126 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
127 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
128 |
+
self.zero_triu = zero_triu
|
129 |
+
# linear transformation for positional encoding
|
130 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
131 |
+
# these two learnable bias are used in matrix c and matrix d
|
132 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
133 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
134 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
135 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
136 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
137 |
+
|
138 |
+
def rel_shift(self, x):
|
139 |
+
"""
|
140 |
+
Compute relative positional encoding.
|
141 |
+
Args:
|
142 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
143 |
+
time1 means the length of query vector.
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: Output tensor.
|
146 |
+
"""
|
147 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
148 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
149 |
+
|
150 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
151 |
+
x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] # only keep the positions from 0 to time2
|
152 |
+
|
153 |
+
if self.zero_triu:
|
154 |
+
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
155 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
156 |
+
|
157 |
+
return x
|
158 |
+
|
159 |
+
def forward(self, query, key, value, pos_emb, mask):
|
160 |
+
"""
|
161 |
+
Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
162 |
+
Args:
|
163 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
164 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
165 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
166 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
167 |
+
(#batch, 2*time1-1, size).
|
168 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
169 |
+
(#batch, time1, time2).
|
170 |
+
Returns:
|
171 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
172 |
+
"""
|
173 |
+
q, k, v = self.forward_qkv(query, key, value)
|
174 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
175 |
+
|
176 |
+
n_batch_pos = pos_emb.size(0)
|
177 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
178 |
+
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
179 |
+
|
180 |
+
# (batch, head, time1, d_k)
|
181 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
182 |
+
# (batch, head, time1, d_k)
|
183 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
184 |
+
|
185 |
+
# compute attention score
|
186 |
+
# first compute matrix a and matrix c
|
187 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
188 |
+
# (batch, head, time1, time2)
|
189 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
190 |
+
|
191 |
+
# compute matrix b and matrix d
|
192 |
+
# (batch, head, time1, 2*time1-1)
|
193 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
194 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
195 |
+
|
196 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
197 |
+
|
198 |
+
return self.forward_attention(v, scores, mask)
|
199 |
+
|
200 |
+
|
201 |
+
class GuidedAttentionLoss(torch.nn.Module):
|
202 |
+
"""
|
203 |
+
Guided attention loss function module.
|
204 |
+
|
205 |
+
This module calculates the guided attention loss described
|
206 |
+
in `Efficiently Trainable Text-to-Speech System Based
|
207 |
+
on Deep Convolutional Networks with Guided Attention`_,
|
208 |
+
which forces the attention to be diagonal.
|
209 |
+
|
210 |
+
.. _`Efficiently Trainable Text-to-Speech System
|
211 |
+
Based on Deep Convolutional Networks with Guided Attention`:
|
212 |
+
https://arxiv.org/abs/1710.08969
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, sigma=0.4, alpha=1.0):
|
216 |
+
"""
|
217 |
+
Initialize guided attention loss module.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
sigma (float, optional): Standard deviation to control
|
221 |
+
how close attention to a diagonal.
|
222 |
+
alpha (float, optional): Scaling coefficient (lambda).
|
223 |
+
reset_always (bool, optional): Whether to always reset masks.
|
224 |
+
"""
|
225 |
+
super(GuidedAttentionLoss, self).__init__()
|
226 |
+
self.sigma = sigma
|
227 |
+
self.alpha = alpha
|
228 |
+
self.guided_attn_masks = None
|
229 |
+
self.masks = None
|
230 |
+
|
231 |
+
def _reset_masks(self):
|
232 |
+
self.guided_attn_masks = None
|
233 |
+
self.masks = None
|
234 |
+
|
235 |
+
def forward(self, att_ws, ilens, olens):
|
236 |
+
"""
|
237 |
+
Calculate forward propagation.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
|
241 |
+
ilens (LongTensor): Batch of input lenghts (B,).
|
242 |
+
olens (LongTensor): Batch of output lenghts (B,).
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
Tensor: Guided attention loss value.
|
246 |
+
"""
|
247 |
+
self._reset_masks()
|
248 |
+
self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device)
|
249 |
+
self.masks = self._make_masks(ilens, olens).to(att_ws.device)
|
250 |
+
losses = self.guided_attn_masks * att_ws
|
251 |
+
loss = torch.mean(losses.masked_select(self.masks))
|
252 |
+
self._reset_masks()
|
253 |
+
return self.alpha * loss
|
254 |
+
|
255 |
+
def _make_guided_attention_masks(self, ilens, olens):
|
256 |
+
n_batches = len(ilens)
|
257 |
+
max_ilen = max(ilens)
|
258 |
+
max_olen = max(olens)
|
259 |
+
guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=ilens.device)
|
260 |
+
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
|
261 |
+
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma)
|
262 |
+
return guided_attn_masks
|
263 |
+
|
264 |
+
@staticmethod
|
265 |
+
def _make_guided_attention_mask(ilen, olen, sigma):
|
266 |
+
"""
|
267 |
+
Make guided attention mask.
|
268 |
+
"""
|
269 |
+
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device).float(), torch.arange(ilen, device=ilen.device).float())
|
270 |
+
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
|
271 |
+
|
272 |
+
@staticmethod
|
273 |
+
def _make_masks(ilens, olens):
|
274 |
+
"""
|
275 |
+
Make masks indicating non-padded part.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
ilens (LongTensor or List): Batch of lengths (B,).
|
279 |
+
olens (LongTensor or List): Batch of lengths (B,).
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Tensor: Mask tensor indicating non-padded part.
|
283 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
284 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
285 |
+
"""
|
286 |
+
in_masks = make_non_pad_mask(ilens, device=ilens.device) # (B, T_in)
|
287 |
+
out_masks = make_non_pad_mask(olens, device=olens.device) # (B, T_out)
|
288 |
+
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
|
289 |
+
|
290 |
+
|
291 |
+
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
|
292 |
+
"""
|
293 |
+
Guided attention loss function module for multi head attention.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
sigma (float, optional): Standard deviation to control
|
297 |
+
how close attention to a diagonal.
|
298 |
+
alpha (float, optional): Scaling coefficient (lambda).
|
299 |
+
reset_always (bool, optional): Whether to always reset masks.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def forward(self, att_ws, ilens, olens):
|
303 |
+
"""
|
304 |
+
Calculate forward propagation.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
att_ws (Tensor):
|
308 |
+
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
|
309 |
+
ilens (LongTensor): Batch of input lenghts (B,).
|
310 |
+
olens (LongTensor): Batch of output lenghts (B,).
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
Tensor: Guided attention loss value.
|
314 |
+
"""
|
315 |
+
if self.guided_attn_masks is None:
|
316 |
+
self.guided_attn_masks = (self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1))
|
317 |
+
if self.masks is None:
|
318 |
+
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
|
319 |
+
losses = self.guided_attn_masks * att_ws
|
320 |
+
loss = torch.mean(losses.masked_select(self.masks))
|
321 |
+
if self.reset_always:
|
322 |
+
self._reset_masks()
|
323 |
+
|
324 |
+
return self.alpha * loss
|
Architectures/GeneralLayers/ConditionalLayerNorm.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code taken from https://github.com/tuanh123789/AdaSpeech/blob/main/model/adaspeech_modules.py
|
3 |
+
By https://github.com/tuanh123789
|
4 |
+
No license specified
|
5 |
+
|
6 |
+
Implemented as outlined in AdaSpeech https://arxiv.org/pdf/2103.00993.pdf
|
7 |
+
Used in this toolkit similar to how it is done in AdaSpeech 4 https://arxiv.org/pdf/2204.00436.pdf
|
8 |
+
|
9 |
+
"""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class ConditionalLayerNorm(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
hidden_dim,
|
19 |
+
speaker_embedding_dim,
|
20 |
+
dim=-1):
|
21 |
+
super(ConditionalLayerNorm, self).__init__()
|
22 |
+
self.dim = dim
|
23 |
+
if isinstance(hidden_dim, int):
|
24 |
+
self.normal_shape = hidden_dim
|
25 |
+
self.speaker_embedding_dim = speaker_embedding_dim
|
26 |
+
self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
|
27 |
+
nn.Tanh(),
|
28 |
+
nn.Linear(self.normal_shape, self.normal_shape))
|
29 |
+
self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
|
30 |
+
nn.Tanh(),
|
31 |
+
nn.Linear(self.normal_shape, self.normal_shape))
|
32 |
+
self.reset_parameters()
|
33 |
+
|
34 |
+
def reset_parameters(self):
|
35 |
+
torch.nn.init.constant_(self.W_scale[0].weight, 0.0)
|
36 |
+
torch.nn.init.constant_(self.W_scale[2].weight, 0.0)
|
37 |
+
torch.nn.init.constant_(self.W_scale[0].bias, 1.0)
|
38 |
+
torch.nn.init.constant_(self.W_scale[2].bias, 1.0)
|
39 |
+
torch.nn.init.constant_(self.W_bias[0].weight, 0.0)
|
40 |
+
torch.nn.init.constant_(self.W_bias[2].weight, 0.0)
|
41 |
+
torch.nn.init.constant_(self.W_bias[0].bias, 0.0)
|
42 |
+
torch.nn.init.constant_(self.W_bias[2].bias, 0.0)
|
43 |
+
|
44 |
+
def forward(self, x, speaker_embedding):
|
45 |
+
|
46 |
+
if self.dim != -1:
|
47 |
+
x = x.transpose(-1, self.dim)
|
48 |
+
|
49 |
+
mean = x.mean(dim=-1, keepdim=True)
|
50 |
+
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
|
51 |
+
scale = self.W_scale(speaker_embedding)
|
52 |
+
bias = self.W_bias(speaker_embedding)
|
53 |
+
|
54 |
+
y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1)
|
55 |
+
|
56 |
+
if self.dim != -1:
|
57 |
+
y = y.transpose(-1, self.dim)
|
58 |
+
|
59 |
+
return y
|
60 |
+
|
61 |
+
|
62 |
+
class SequentialWrappableConditionalLayerNorm(nn.Module):
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
hidden_dim,
|
66 |
+
speaker_embedding_dim):
|
67 |
+
super(SequentialWrappableConditionalLayerNorm, self).__init__()
|
68 |
+
if isinstance(hidden_dim, int):
|
69 |
+
self.normal_shape = hidden_dim
|
70 |
+
self.speaker_embedding_dim = speaker_embedding_dim
|
71 |
+
self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
|
72 |
+
nn.Tanh(),
|
73 |
+
nn.Linear(self.normal_shape, self.normal_shape))
|
74 |
+
self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
|
75 |
+
nn.Tanh(),
|
76 |
+
nn.Linear(self.normal_shape, self.normal_shape))
|
77 |
+
self.reset_parameters()
|
78 |
+
|
79 |
+
def reset_parameters(self):
|
80 |
+
torch.nn.init.constant_(self.W_scale[0].weight, 0.0)
|
81 |
+
torch.nn.init.constant_(self.W_scale[2].weight, 0.0)
|
82 |
+
torch.nn.init.constant_(self.W_scale[0].bias, 1.0)
|
83 |
+
torch.nn.init.constant_(self.W_scale[2].bias, 1.0)
|
84 |
+
torch.nn.init.constant_(self.W_bias[0].weight, 0.0)
|
85 |
+
torch.nn.init.constant_(self.W_bias[2].weight, 0.0)
|
86 |
+
torch.nn.init.constant_(self.W_bias[0].bias, 0.0)
|
87 |
+
torch.nn.init.constant_(self.W_bias[2].bias, 0.0)
|
88 |
+
|
89 |
+
def forward(self, packed_input):
|
90 |
+
x, speaker_embedding = packed_input
|
91 |
+
mean = x.mean(dim=-1, keepdim=True)
|
92 |
+
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
|
93 |
+
scale = self.W_scale(speaker_embedding)
|
94 |
+
bias = self.W_bias(speaker_embedding)
|
95 |
+
|
96 |
+
y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1)
|
97 |
+
|
98 |
+
return y
|
99 |
+
|
100 |
+
|
101 |
+
class AdaIN1d(nn.Module):
|
102 |
+
"""
|
103 |
+
MIT Licensed
|
104 |
+
|
105 |
+
Copyright (c) 2022 Aaron (Yinghao) Li
|
106 |
+
https://github.com/yl4579/StyleTTS/blob/main/models.py
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, style_dim, num_features):
|
110 |
+
super().__init__()
|
111 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
112 |
+
self.fc = nn.Linear(style_dim, num_features * 2)
|
113 |
+
|
114 |
+
def forward(self, x, s):
|
115 |
+
h = self.fc(s)
|
116 |
+
h = h.view(h.size(0), h.size(1), 1)
|
117 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
118 |
+
return (1 + gamma.transpose(1, 2)) * self.norm(x.transpose(1, 2)).transpose(1, 2) + beta.transpose(1, 2)
|
Architectures/GeneralLayers/Conformer.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet, but heavily modified
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from Architectures.GeneralLayers.Attention import RelPositionMultiHeadedAttention
|
8 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
|
9 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
|
10 |
+
from Architectures.GeneralLayers.Convolution import ConvolutionModule
|
11 |
+
from Architectures.GeneralLayers.EncoderLayer import EncoderLayer
|
12 |
+
from Architectures.GeneralLayers.LayerNorm import LayerNorm
|
13 |
+
from Architectures.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d
|
14 |
+
from Architectures.GeneralLayers.MultiSequential import repeat
|
15 |
+
from Architectures.GeneralLayers.PositionalEncoding import RelPositionalEncoding
|
16 |
+
from Architectures.GeneralLayers.Swish import Swish
|
17 |
+
from Utility.utils import integrate_with_utt_embed
|
18 |
+
|
19 |
+
|
20 |
+
class Conformer(torch.nn.Module):
|
21 |
+
"""
|
22 |
+
Conformer encoder module.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
idim (int): Input dimension.
|
26 |
+
attention_dim (int): Dimension of attention.
|
27 |
+
attention_heads (int): The number of heads of multi head attention.
|
28 |
+
linear_units (int): The number of units of position-wise feed forward.
|
29 |
+
num_blocks (int): The number of decoder blocks.
|
30 |
+
dropout_rate (float): Dropout rate.
|
31 |
+
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
32 |
+
attention_dropout_rate (float): Dropout rate in attention.
|
33 |
+
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
34 |
+
normalize_before (bool): Whether to use layer_norm before the first block.
|
35 |
+
concat_after (bool): Whether to concat attention layer's input and output.
|
36 |
+
if True, additional linear will be applied.
|
37 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
38 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
39 |
+
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
40 |
+
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
41 |
+
use_cnn_module (bool): Whether to use convolution module.
|
42 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
43 |
+
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, conformer_type, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1,
|
47 |
+
attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1,
|
48 |
+
macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, lang_embs=None, lang_emb_size=16, use_output_norm=True, embedding_integration="AdaIN"):
|
49 |
+
super(Conformer, self).__init__()
|
50 |
+
|
51 |
+
activation = Swish()
|
52 |
+
self.conv_subsampling_factor = 1
|
53 |
+
self.use_output_norm = use_output_norm
|
54 |
+
|
55 |
+
if isinstance(input_layer, torch.nn.Module):
|
56 |
+
self.embed = input_layer
|
57 |
+
self.art_embed_norm = LayerNorm(attention_dim)
|
58 |
+
self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate)
|
59 |
+
elif input_layer is None:
|
60 |
+
self.embed = None
|
61 |
+
self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
|
62 |
+
else:
|
63 |
+
raise ValueError("unknown input_layer: " + input_layer)
|
64 |
+
|
65 |
+
if self.use_output_norm:
|
66 |
+
self.output_norm = LayerNorm(attention_dim)
|
67 |
+
self.utt_embed = utt_embed
|
68 |
+
self.conformer_type = conformer_type
|
69 |
+
self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
|
70 |
+
if utt_embed is not None:
|
71 |
+
if conformer_type == "encoder": # the encoder gets an additional conditioning signal added to its output
|
72 |
+
if embedding_integration == "AdaIN":
|
73 |
+
self.encoder_embedding_projection = AdaIN1d(style_dim=utt_embed, num_features=attention_dim)
|
74 |
+
elif embedding_integration == "ConditionalLayerNorm":
|
75 |
+
self.encoder_embedding_projection = ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim)
|
76 |
+
else:
|
77 |
+
self.encoder_embedding_projection = torch.nn.Linear(attention_dim + utt_embed, attention_dim)
|
78 |
+
else:
|
79 |
+
if embedding_integration == "AdaIN":
|
80 |
+
self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: AdaIN1d(style_dim=utt_embed, num_features=attention_dim))
|
81 |
+
elif embedding_integration == "ConditionalLayerNorm":
|
82 |
+
self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim))
|
83 |
+
else:
|
84 |
+
self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim))
|
85 |
+
if lang_embs is not None:
|
86 |
+
self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size)
|
87 |
+
self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim)
|
88 |
+
self.language_emb_norm = LayerNorm(attention_dim)
|
89 |
+
# self-attention module definition
|
90 |
+
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
91 |
+
encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
|
92 |
+
|
93 |
+
# feed-forward module definition
|
94 |
+
positionwise_layer = MultiLayeredConv1d
|
95 |
+
positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,)
|
96 |
+
|
97 |
+
# convolution module definition
|
98 |
+
convolution_layer = ConvolutionModule
|
99 |
+
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
100 |
+
|
101 |
+
self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
102 |
+
positionwise_layer(*positionwise_layer_args),
|
103 |
+
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
104 |
+
convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate,
|
105 |
+
normalize_before, concat_after))
|
106 |
+
|
107 |
+
def forward(self,
|
108 |
+
xs,
|
109 |
+
masks,
|
110 |
+
utterance_embedding=None,
|
111 |
+
lang_ids=None):
|
112 |
+
"""
|
113 |
+
Encode input sequence.
|
114 |
+
Args:
|
115 |
+
utterance_embedding: embedding containing lots of conditioning signals
|
116 |
+
lang_ids: ids of the languages per sample in the batch
|
117 |
+
xs (torch.Tensor): Input tensor (#batch, time, idim).
|
118 |
+
masks (torch.Tensor): Mask tensor (#batch, time).
|
119 |
+
Returns:
|
120 |
+
torch.Tensor: Output tensor (#batch, time, attention_dim).
|
121 |
+
torch.Tensor: Mask tensor (#batch, time).
|
122 |
+
"""
|
123 |
+
|
124 |
+
if self.embed is not None:
|
125 |
+
xs = self.embed(xs)
|
126 |
+
xs = self.art_embed_norm(xs)
|
127 |
+
|
128 |
+
if lang_ids is not None:
|
129 |
+
lang_embs = self.language_embedding(lang_ids)
|
130 |
+
projected_lang_embs = self.language_embedding_projection(lang_embs).unsqueeze(-1).transpose(1, 2)
|
131 |
+
projected_lang_embs = self.language_emb_norm(projected_lang_embs)
|
132 |
+
xs = xs + projected_lang_embs # offset phoneme representation by language specific offset
|
133 |
+
|
134 |
+
xs = self.pos_enc(xs)
|
135 |
+
|
136 |
+
for encoder_index, encoder in enumerate(self.encoders):
|
137 |
+
if self.utt_embed:
|
138 |
+
if isinstance(xs, tuple):
|
139 |
+
x, pos_emb = xs[0], xs[1]
|
140 |
+
if self.conformer_type != "encoder":
|
141 |
+
x = integrate_with_utt_embed(hs=x, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration)
|
142 |
+
xs = (x, pos_emb)
|
143 |
+
else:
|
144 |
+
if self.conformer_type != "encoder":
|
145 |
+
xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration)
|
146 |
+
xs, masks = encoder(xs, masks)
|
147 |
+
|
148 |
+
if isinstance(xs, tuple):
|
149 |
+
xs = xs[0]
|
150 |
+
|
151 |
+
if self.use_output_norm and not (self.utt_embed and self.conformer_type == "encoder"):
|
152 |
+
xs = self.output_norm(xs)
|
153 |
+
|
154 |
+
if self.utt_embed and self.conformer_type == "encoder":
|
155 |
+
xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding,
|
156 |
+
projection=self.encoder_embedding_projection, embedding_training=self.use_conditional_layernorm_embedding_integration)
|
157 |
+
|
158 |
+
return xs, masks
|
Architectures/GeneralLayers/Convolution.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
4 |
+
# Adapted by Florian Lux 2021
|
5 |
+
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class ConvolutionModule(nn.Module):
|
11 |
+
"""
|
12 |
+
ConvolutionModule in Conformer model.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
channels (int): The number of channels of conv layers.
|
16 |
+
kernel_size (int): Kernel size of conv layers.
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
21 |
+
super(ConvolutionModule, self).__init__()
|
22 |
+
# kernel_size should be an odd number for 'SAME' padding
|
23 |
+
assert (kernel_size - 1) % 2 == 0
|
24 |
+
|
25 |
+
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
26 |
+
self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
|
27 |
+
self.norm = nn.BatchNorm1d(channels)
|
28 |
+
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
29 |
+
self.activation = activation
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
"""
|
33 |
+
Compute convolution module.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
40 |
+
|
41 |
+
"""
|
42 |
+
# exchange the temporal dimension and the feature dimension
|
43 |
+
x = x.transpose(1, 2)
|
44 |
+
|
45 |
+
# GLU mechanism
|
46 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
47 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
48 |
+
|
49 |
+
# 1D Depthwise Conv
|
50 |
+
x = self.depthwise_conv(x)
|
51 |
+
x = self.activation(self.norm(x))
|
52 |
+
|
53 |
+
x = self.pointwise_conv2(x)
|
54 |
+
|
55 |
+
return x.transpose(1, 2)
|
Architectures/GeneralLayers/DurationPredictor.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
|
9 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
|
10 |
+
from Architectures.GeneralLayers.LayerNorm import LayerNorm
|
11 |
+
from Utility.utils import integrate_with_utt_embed
|
12 |
+
|
13 |
+
|
14 |
+
class DurationPredictor(torch.nn.Module):
|
15 |
+
"""
|
16 |
+
Duration predictor module.
|
17 |
+
|
18 |
+
This is a module of duration predictor described
|
19 |
+
in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
20 |
+
The duration predictor predicts a duration of each frame in log domain
|
21 |
+
from the hidden embeddings of encoder.
|
22 |
+
|
23 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
24 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
25 |
+
|
26 |
+
Note:
|
27 |
+
The calculation domain of outputs is different
|
28 |
+
between in `forward` and in `inference`. In `forward`,
|
29 |
+
the outputs are calculated in log domain but in `inference`,
|
30 |
+
those are calculated in linear domain.
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, idim,
|
35 |
+
n_layers=2,
|
36 |
+
n_chans=384,
|
37 |
+
kernel_size=3,
|
38 |
+
dropout_rate=0.1,
|
39 |
+
offset=1.0,
|
40 |
+
utt_embed_dim=None,
|
41 |
+
embedding_integration="AdaIN"):
|
42 |
+
"""
|
43 |
+
Initialize duration predictor module.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
idim (int): Input dimension.
|
47 |
+
n_layers (int, optional): Number of convolutional layers.
|
48 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
49 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
50 |
+
dropout_rate (float, optional): Dropout rate.
|
51 |
+
offset (float, optional): Offset value to avoid nan in log domain.
|
52 |
+
|
53 |
+
"""
|
54 |
+
super(DurationPredictor, self).__init__()
|
55 |
+
self.offset = offset
|
56 |
+
self.conv = torch.nn.ModuleList()
|
57 |
+
self.dropouts = torch.nn.ModuleList()
|
58 |
+
self.norms = torch.nn.ModuleList()
|
59 |
+
self.embedding_projections = torch.nn.ModuleList()
|
60 |
+
self.utt_embed_dim = utt_embed_dim
|
61 |
+
self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
|
62 |
+
|
63 |
+
for idx in range(n_layers):
|
64 |
+
if utt_embed_dim is not None:
|
65 |
+
if embedding_integration == "AdaIN":
|
66 |
+
self.embedding_projections += [AdaIN1d(style_dim=utt_embed_dim, num_features=idim)]
|
67 |
+
elif embedding_integration == "ConditionalLayerNorm":
|
68 |
+
self.embedding_projections += [ConditionalLayerNorm(speaker_embedding_dim=utt_embed_dim, hidden_dim=idim)]
|
69 |
+
else:
|
70 |
+
self.embedding_projections += [torch.nn.Linear(utt_embed_dim + idim, idim)]
|
71 |
+
else:
|
72 |
+
self.embedding_projections += [lambda x: x]
|
73 |
+
in_chans = idim if idx == 0 else n_chans
|
74 |
+
self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ),
|
75 |
+
torch.nn.ReLU())]
|
76 |
+
self.norms += [LayerNorm(n_chans, dim=1)]
|
77 |
+
self.dropouts += [torch.nn.Dropout(dropout_rate)]
|
78 |
+
|
79 |
+
self.linear = torch.nn.Linear(n_chans, 1)
|
80 |
+
|
81 |
+
def _forward(self, xs, x_masks=None, is_inference=False, utt_embed=None):
|
82 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
83 |
+
|
84 |
+
for f, c, d, p in zip(self.conv, self.norms, self.dropouts, self.embedding_projections):
|
85 |
+
xs = f(xs) # (B, C, Tmax)
|
86 |
+
if self.utt_embed_dim is not None:
|
87 |
+
xs = integrate_with_utt_embed(hs=xs.transpose(1, 2), utt_embeddings=utt_embed, projection=p, embedding_training=self.use_conditional_layernorm_embedding_integration).transpose(1, 2)
|
88 |
+
xs = c(xs)
|
89 |
+
xs = d(xs)
|
90 |
+
|
91 |
+
# NOTE: targets are transformed to log domain in the loss calculation, so this will learn to predict in the log space, which makes the value range easier to handle.
|
92 |
+
xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
|
93 |
+
|
94 |
+
if is_inference:
|
95 |
+
# NOTE: since we learned to predict in the log domain, we have to invert the log during inference.
|
96 |
+
xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
|
97 |
+
else:
|
98 |
+
xs = xs.masked_fill(x_masks, 0.0)
|
99 |
+
|
100 |
+
return xs
|
101 |
+
|
102 |
+
def forward(self, xs, padding_mask=None, utt_embed=None):
|
103 |
+
"""
|
104 |
+
Calculate forward propagation.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
108 |
+
padding_mask (ByteTensor, optional):
|
109 |
+
Batch of masks indicating padded part (B, Tmax).
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Tensor: Batch of predicted durations in log domain (B, Tmax).
|
113 |
+
|
114 |
+
"""
|
115 |
+
return self._forward(xs, padding_mask, False, utt_embed=utt_embed)
|
116 |
+
|
117 |
+
def inference(self, xs, padding_mask=None, utt_embed=None):
|
118 |
+
"""
|
119 |
+
Inference duration.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
123 |
+
padding_mask (ByteTensor, optional):
|
124 |
+
Batch of masks indicating padded part (B, Tmax).
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
|
128 |
+
|
129 |
+
"""
|
130 |
+
return self._forward(xs, padding_mask, True, utt_embed=utt_embed)
|
131 |
+
|
132 |
+
|
133 |
+
class DurationPredictorLoss(torch.nn.Module):
|
134 |
+
"""
|
135 |
+
Loss function module for duration predictor.
|
136 |
+
|
137 |
+
The loss value is Calculated in log domain to make it Gaussian.
|
138 |
+
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self, offset=1.0, reduction="mean"):
|
142 |
+
"""
|
143 |
+
Args:
|
144 |
+
offset (float, optional): Offset value to avoid nan in log domain.
|
145 |
+
reduction (str): Reduction type in loss calculation.
|
146 |
+
|
147 |
+
"""
|
148 |
+
super(DurationPredictorLoss, self).__init__()
|
149 |
+
self.criterion = torch.nn.MSELoss(reduction=reduction)
|
150 |
+
self.offset = offset
|
151 |
+
|
152 |
+
def forward(self, outputs, targets):
|
153 |
+
"""
|
154 |
+
Calculate forward propagation.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
outputs (Tensor): Batch of prediction durations in log domain (B, T)
|
158 |
+
targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
Tensor: Mean squared error loss value.
|
162 |
+
|
163 |
+
Note:
|
164 |
+
`outputs` is in log domain but `targets` is in linear domain.
|
165 |
+
|
166 |
+
"""
|
167 |
+
# NOTE: outputs is in log domain while targets in linear
|
168 |
+
targets = torch.log(targets.float() + self.offset)
|
169 |
+
loss = self.criterion(outputs, targets)
|
170 |
+
|
171 |
+
return loss
|
Architectures/GeneralLayers/EncoderLayer.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
4 |
+
# Adapted by Florian Lux 2021
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from Architectures.GeneralLayers.LayerNorm import LayerNorm
|
11 |
+
|
12 |
+
|
13 |
+
class EncoderLayer(nn.Module):
|
14 |
+
"""
|
15 |
+
Encoder layer module.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
size (int): Input dimension.
|
19 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
20 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
21 |
+
can be used as the argument.
|
22 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
23 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
24 |
+
can be used as the argument.
|
25 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
26 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
27 |
+
can be used as the argument.
|
28 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
29 |
+
`ConvlutionModule` instance can be used as the argument.
|
30 |
+
dropout_rate (float): Dropout rate.
|
31 |
+
normalize_before (bool): Whether to use layer_norm before the first block.
|
32 |
+
concat_after (bool): Whether to concat attention layer's input and output.
|
33 |
+
if True, additional linear will be applied.
|
34 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
35 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ):
|
40 |
+
super(EncoderLayer, self).__init__()
|
41 |
+
self.self_attn = self_attn
|
42 |
+
self.feed_forward = feed_forward
|
43 |
+
self.feed_forward_macaron = feed_forward_macaron
|
44 |
+
self.conv_module = conv_module
|
45 |
+
self.norm_ff = LayerNorm(size) # for the FNN module
|
46 |
+
self.norm_mha = LayerNorm(size) # for the MHA module
|
47 |
+
if feed_forward_macaron is not None:
|
48 |
+
self.norm_ff_macaron = LayerNorm(size)
|
49 |
+
self.ff_scale = 0.5
|
50 |
+
else:
|
51 |
+
self.ff_scale = 1.0
|
52 |
+
if self.conv_module is not None:
|
53 |
+
self.norm_conv = LayerNorm(size) # for the CNN module
|
54 |
+
self.norm_final = LayerNorm(size) # for the final output of the block
|
55 |
+
self.dropout = nn.Dropout(dropout_rate)
|
56 |
+
self.size = size
|
57 |
+
self.normalize_before = normalize_before
|
58 |
+
self.concat_after = concat_after
|
59 |
+
if self.concat_after:
|
60 |
+
self.concat_linear = nn.Linear(size + size, size)
|
61 |
+
|
62 |
+
def forward(self, x_input, mask, cache=None):
|
63 |
+
"""
|
64 |
+
Compute encoded features.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
68 |
+
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
69 |
+
- w/o pos emb: Tensor (#batch, time, size).
|
70 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
71 |
+
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
75 |
+
torch.Tensor: Mask tensor (#batch, time).
|
76 |
+
|
77 |
+
"""
|
78 |
+
if isinstance(x_input, tuple):
|
79 |
+
x, pos_emb = x_input[0], x_input[1]
|
80 |
+
else:
|
81 |
+
x, pos_emb = x_input, None
|
82 |
+
|
83 |
+
# whether to use macaron style
|
84 |
+
if self.feed_forward_macaron is not None:
|
85 |
+
residual = x
|
86 |
+
if self.normalize_before:
|
87 |
+
x = self.norm_ff_macaron(x)
|
88 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
89 |
+
if not self.normalize_before:
|
90 |
+
x = self.norm_ff_macaron(x)
|
91 |
+
|
92 |
+
# multi-headed self-attention module
|
93 |
+
residual = x
|
94 |
+
if self.normalize_before:
|
95 |
+
x = self.norm_mha(x)
|
96 |
+
|
97 |
+
if cache is None:
|
98 |
+
x_q = x
|
99 |
+
else:
|
100 |
+
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
101 |
+
x_q = x[:, -1:, :]
|
102 |
+
residual = residual[:, -1:, :]
|
103 |
+
mask = None if mask is None else mask[:, -1:, :]
|
104 |
+
|
105 |
+
if pos_emb is not None:
|
106 |
+
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
107 |
+
else:
|
108 |
+
x_att = self.self_attn(x_q, x, x, mask)
|
109 |
+
|
110 |
+
if self.concat_after:
|
111 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
112 |
+
x = residual + self.concat_linear(x_concat)
|
113 |
+
else:
|
114 |
+
x = residual + self.dropout(x_att)
|
115 |
+
if not self.normalize_before:
|
116 |
+
x = self.norm_mha(x)
|
117 |
+
|
118 |
+
# convolution module
|
119 |
+
if self.conv_module is not None:
|
120 |
+
residual = x
|
121 |
+
if self.normalize_before:
|
122 |
+
x = self.norm_conv(x)
|
123 |
+
x = residual + self.dropout(self.conv_module(x))
|
124 |
+
if not self.normalize_before:
|
125 |
+
x = self.norm_conv(x)
|
126 |
+
|
127 |
+
# feed forward module
|
128 |
+
residual = x
|
129 |
+
if self.normalize_before:
|
130 |
+
x = self.norm_ff(x)
|
131 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
132 |
+
if not self.normalize_before:
|
133 |
+
x = self.norm_ff(x)
|
134 |
+
|
135 |
+
if self.conv_module is not None:
|
136 |
+
x = self.norm_final(x)
|
137 |
+
|
138 |
+
if cache is not None:
|
139 |
+
x = torch.cat([cache, x], dim=1)
|
140 |
+
|
141 |
+
if pos_emb is not None:
|
142 |
+
return (x, pos_emb), mask
|
143 |
+
|
144 |
+
return x, mask
|
Architectures/GeneralLayers/LayerNorm.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class LayerNorm(torch.nn.LayerNorm):
|
9 |
+
"""
|
10 |
+
Layer normalization module.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
nout (int): Output dim size.
|
14 |
+
dim (int): Dimension to be normalized.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, nout, dim=-1, eps=1e-12):
|
18 |
+
"""
|
19 |
+
Construct an LayerNorm object.
|
20 |
+
"""
|
21 |
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
22 |
+
self.dim = dim
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
"""
|
26 |
+
Apply layer normalization.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
x (torch.Tensor): Input tensor.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
torch.Tensor: Normalized tensor.
|
33 |
+
"""
|
34 |
+
if self.dim == -1:
|
35 |
+
return super(LayerNorm, self).forward(x)
|
36 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
Architectures/GeneralLayers/LengthRegulator.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
from abc import ABC
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from Utility.utils import pad_list
|
10 |
+
|
11 |
+
|
12 |
+
class LengthRegulator(torch.nn.Module, ABC):
|
13 |
+
"""
|
14 |
+
Length regulator module for feed-forward Transformer.
|
15 |
+
|
16 |
+
This is a module of length regulator described in
|
17 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
18 |
+
The length regulator expands char or
|
19 |
+
phoneme-level embedding features to frame-level by repeating each
|
20 |
+
feature based on the corresponding predicted durations.
|
21 |
+
|
22 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
23 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
24 |
+
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, pad_value=0.0):
|
28 |
+
"""
|
29 |
+
Initialize length regulator module.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
pad_value (float, optional): Value used for padding.
|
33 |
+
"""
|
34 |
+
super(LengthRegulator, self).__init__()
|
35 |
+
self.pad_value = pad_value
|
36 |
+
|
37 |
+
def forward(self, xs, ds, alpha=1.0):
|
38 |
+
"""
|
39 |
+
Calculate forward propagation.
|
40 |
+
Args:
|
41 |
+
xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
|
42 |
+
ds (LongTensor): Batch of durations of each frame (B, T).
|
43 |
+
alpha (float, optional): Alpha value to control speed of speech.
|
44 |
+
Returns:
|
45 |
+
Tensor: replicated input tensor based on durations (B, T*, D).
|
46 |
+
"""
|
47 |
+
|
48 |
+
if alpha != 1.0:
|
49 |
+
assert alpha > 0
|
50 |
+
ds = torch.round(ds.float() * alpha).long()
|
51 |
+
|
52 |
+
if ds.sum() == 0:
|
53 |
+
ds[ds.sum(dim=1).eq(0)] = 1
|
54 |
+
|
55 |
+
return pad_list([self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)], self.pad_value)
|
56 |
+
|
57 |
+
def _repeat_one_sequence(self, x, d):
|
58 |
+
"""
|
59 |
+
Repeat each frame according to duration
|
60 |
+
"""
|
61 |
+
return torch.repeat_interleave(x, d, dim=0)
|
Architectures/GeneralLayers/MultiLayeredConv1d.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
"""
|
6 |
+
Layer modules for FFT block in FastSpeech (Feed-forward Transformer).
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class MultiLayeredConv1d(torch.nn.Module):
|
13 |
+
"""
|
14 |
+
Multi-layered conv1d for Transformer block.
|
15 |
+
|
16 |
+
This is a module of multi-layered conv1d designed
|
17 |
+
to replace positionwise feed-forward network
|
18 |
+
in Transformer block, which is introduced in
|
19 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
20 |
+
|
21 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
22 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
26 |
+
"""
|
27 |
+
Initialize MultiLayeredConv1d module.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
in_chans (int): Number of input channels.
|
31 |
+
hidden_chans (int): Number of hidden channels.
|
32 |
+
kernel_size (int): Kernel size of conv1d.
|
33 |
+
dropout_rate (float): Dropout rate.
|
34 |
+
"""
|
35 |
+
super(MultiLayeredConv1d, self).__init__()
|
36 |
+
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
|
37 |
+
self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
|
38 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
"""
|
42 |
+
Calculate forward propagation.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
49 |
+
"""
|
50 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
51 |
+
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
52 |
+
|
53 |
+
|
54 |
+
class Conv1dLinear(torch.nn.Module):
|
55 |
+
"""
|
56 |
+
Conv1D + Linear for Transformer block.
|
57 |
+
|
58 |
+
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
62 |
+
"""
|
63 |
+
Initialize Conv1dLinear module.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
in_chans (int): Number of input channels.
|
67 |
+
hidden_chans (int): Number of hidden channels.
|
68 |
+
kernel_size (int): Kernel size of conv1d.
|
69 |
+
dropout_rate (float): Dropout rate.
|
70 |
+
"""
|
71 |
+
super(Conv1dLinear, self).__init__()
|
72 |
+
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
|
73 |
+
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
74 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""
|
78 |
+
Calculate forward propagation.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
85 |
+
"""
|
86 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
87 |
+
return self.w_2(self.dropout(x))
|
Architectures/GeneralLayers/MultiSequential.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class MultiSequential(torch.nn.Sequential):
|
9 |
+
"""
|
10 |
+
Multi-input multi-output torch.nn.Sequential.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def forward(self, *args):
|
14 |
+
"""
|
15 |
+
Repeat.
|
16 |
+
"""
|
17 |
+
for m in self:
|
18 |
+
args = m(*args)
|
19 |
+
return args
|
20 |
+
|
21 |
+
|
22 |
+
def repeat(N, fn):
|
23 |
+
"""
|
24 |
+
Repeat module N times.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
N (int): Number of repeat time.
|
28 |
+
fn (Callable): Function to generate module.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
MultiSequential: Repeated model instance.
|
32 |
+
"""
|
33 |
+
return MultiSequential(*[fn(n) for n in range(N)])
|
Architectures/GeneralLayers/PositionalEncoding.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class PositionalEncoding(torch.nn.Module):
|
11 |
+
"""
|
12 |
+
Positional encoding.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
d_model (int): Embedding dimension.
|
16 |
+
dropout_rate (float): Dropout rate.
|
17 |
+
max_len (int): Maximum input length.
|
18 |
+
reverse (bool): Whether to reverse the input position.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
22 |
+
"""
|
23 |
+
Construct an PositionalEncoding object.
|
24 |
+
"""
|
25 |
+
super(PositionalEncoding, self).__init__()
|
26 |
+
self.d_model = d_model
|
27 |
+
self.reverse = reverse
|
28 |
+
self.xscale = math.sqrt(self.d_model)
|
29 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
30 |
+
self.pe = None
|
31 |
+
self.extend_pe(torch.tensor(0.0, device=d_model.device).expand(1, max_len))
|
32 |
+
|
33 |
+
def extend_pe(self, x):
|
34 |
+
"""
|
35 |
+
Reset the positional encodings.
|
36 |
+
"""
|
37 |
+
if self.pe is not None:
|
38 |
+
if self.pe.size(1) >= x.size(1):
|
39 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
40 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
41 |
+
return
|
42 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
43 |
+
if self.reverse:
|
44 |
+
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
45 |
+
else:
|
46 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
47 |
+
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model))
|
48 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
49 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
50 |
+
pe = pe.unsqueeze(0)
|
51 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
"""
|
55 |
+
Add positional encoding.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
62 |
+
"""
|
63 |
+
self.extend_pe(x)
|
64 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
65 |
+
return self.dropout(x)
|
66 |
+
|
67 |
+
|
68 |
+
class RelPositionalEncoding(torch.nn.Module):
|
69 |
+
"""
|
70 |
+
Relative positional encoding module (new implementation).
|
71 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
72 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
73 |
+
Args:
|
74 |
+
d_model (int): Embedding dimension.
|
75 |
+
dropout_rate (float): Dropout rate.
|
76 |
+
max_len (int): Maximum input length.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
80 |
+
"""
|
81 |
+
Construct an PositionalEncoding object.
|
82 |
+
"""
|
83 |
+
super(RelPositionalEncoding, self).__init__()
|
84 |
+
self.d_model = d_model
|
85 |
+
self.xscale = math.sqrt(self.d_model)
|
86 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
87 |
+
self.pe = None
|
88 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
89 |
+
|
90 |
+
def extend_pe(self, x):
|
91 |
+
"""Reset the positional encodings."""
|
92 |
+
if self.pe is not None:
|
93 |
+
# self.pe contains both positive and negative parts
|
94 |
+
# the length of self.pe is 2 * input_len - 1
|
95 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
96 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
97 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
98 |
+
return
|
99 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
100 |
+
# position of key vector. We use position relative positions when keys
|
101 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
102 |
+
pe_positive = torch.zeros(x.size(1), self.d_model, device=x.device)
|
103 |
+
pe_negative = torch.zeros(x.size(1), self.d_model, device=x.device)
|
104 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32, device=x.device).unsqueeze(1)
|
105 |
+
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32, device=x.device) * -(math.log(10000.0) / self.d_model))
|
106 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
107 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
108 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
109 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
110 |
+
|
111 |
+
# Reserve the order of positive indices and concat both positive and
|
112 |
+
# negative indices. This is used to support the shifting trick
|
113 |
+
# as in https://arxiv.org/abs/1901.02860
|
114 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
115 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
116 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
117 |
+
self.pe = pe.to(dtype=x.dtype)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
"""
|
121 |
+
Add positional encoding.
|
122 |
+
Args:
|
123 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
124 |
+
Returns:
|
125 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
126 |
+
"""
|
127 |
+
self.extend_pe(x)
|
128 |
+
x = x * self.xscale
|
129 |
+
pos_emb = self.pe[:, self.pe.size(1) // 2 - x.size(1) + 1: self.pe.size(1) // 2 + x.size(1), ]
|
130 |
+
return self.dropout(x), self.dropout(pos_emb)
|
131 |
+
|
132 |
+
|
133 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
134 |
+
"""
|
135 |
+
Scaled positional encoding module.
|
136 |
+
|
137 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
138 |
+
|
139 |
+
Args:
|
140 |
+
d_model (int): Embedding dimension.
|
141 |
+
dropout_rate (float): Dropout rate.
|
142 |
+
max_len (int): Maximum input length.
|
143 |
+
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
147 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
148 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
149 |
+
|
150 |
+
def reset_parameters(self):
|
151 |
+
self.alpha.data = torch.tensor(1.0)
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
"""
|
155 |
+
Add positional encoding.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
162 |
+
|
163 |
+
"""
|
164 |
+
self.extend_pe(x)
|
165 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
166 |
+
return self.dropout(x)
|
Architectures/GeneralLayers/PositionwiseFeedForward.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
idim (int): Input dimenstion.
|
13 |
+
hidden_units (int): The number of hidden units.
|
14 |
+
dropout_rate (float): Dropout rate.
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
19 |
+
super(PositionwiseFeedForward, self).__init__()
|
20 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
21 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
22 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
23 |
+
self.activation = activation
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
Architectures/GeneralLayers/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
This directory contains a collection of layers that are used both during training time and during inference time. Large
|
2 |
+
portions of these layers are either directly taken from ESPnet or adaptations of such.
|
Architectures/GeneralLayers/ResidualBlock.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""
|
4 |
+
References:
|
5 |
+
- https://github.com/jik876/hifi-gan
|
6 |
+
- https://github.com/kan-bayashi/ParallelWaveGAN
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class Conv1d(torch.nn.Conv1d):
|
13 |
+
"""
|
14 |
+
Conv1d module with customized initialization.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, *args, **kwargs):
|
18 |
+
super(Conv1d, self).__init__(*args, **kwargs)
|
19 |
+
|
20 |
+
def reset_parameters(self):
|
21 |
+
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
22 |
+
if self.bias is not None:
|
23 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
24 |
+
|
25 |
+
|
26 |
+
class Conv1d1x1(Conv1d):
|
27 |
+
"""
|
28 |
+
1x1 Conv1d with customized initialization.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, in_channels, out_channels, bias):
|
32 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias)
|
33 |
+
|
34 |
+
|
35 |
+
class HiFiGANResidualBlock(torch.nn.Module):
|
36 |
+
"""Residual block module in HiFiGAN."""
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
kernel_size=3,
|
40 |
+
channels=512,
|
41 |
+
dilations=(1, 3, 5),
|
42 |
+
bias=True,
|
43 |
+
use_additional_convs=True,
|
44 |
+
nonlinear_activation="LeakyReLU",
|
45 |
+
nonlinear_activation_params={"negative_slope": 0.1}, ):
|
46 |
+
"""
|
47 |
+
Initialize HiFiGANResidualBlock module.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
51 |
+
channels (int): Number of channels for convolution layer.
|
52 |
+
dilations (List[int]): List of dilation factors.
|
53 |
+
use_additional_convs (bool): Whether to use additional convolution layers.
|
54 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
55 |
+
nonlinear_activation (str): Activation function module name.
|
56 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
57 |
+
"""
|
58 |
+
super().__init__()
|
59 |
+
self.use_additional_convs = use_additional_convs
|
60 |
+
self.convs1 = torch.nn.ModuleList()
|
61 |
+
if use_additional_convs:
|
62 |
+
self.convs2 = torch.nn.ModuleList()
|
63 |
+
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
64 |
+
for dilation in dilations:
|
65 |
+
self.convs1 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
66 |
+
torch.nn.Conv1d(channels,
|
67 |
+
channels,
|
68 |
+
kernel_size,
|
69 |
+
1,
|
70 |
+
dilation=dilation,
|
71 |
+
bias=bias,
|
72 |
+
padding=(kernel_size - 1) // 2 * dilation, ), )]
|
73 |
+
if use_additional_convs:
|
74 |
+
self.convs2 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
75 |
+
torch.nn.Conv1d(channels,
|
76 |
+
channels,
|
77 |
+
kernel_size,
|
78 |
+
1,
|
79 |
+
dilation=1,
|
80 |
+
bias=bias,
|
81 |
+
padding=(kernel_size - 1) // 2, ), )]
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
"""
|
85 |
+
Calculate forward propagation.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
x (Tensor): Input tensor (B, channels, T).
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Tensor: Output tensor (B, channels, T).
|
92 |
+
"""
|
93 |
+
for idx in range(len(self.convs1)):
|
94 |
+
xt = self.convs1[idx](x)
|
95 |
+
if self.use_additional_convs:
|
96 |
+
xt = self.convs2[idx](xt)
|
97 |
+
x = xt + x
|
98 |
+
return x
|
Architectures/GeneralLayers/ResidualStack.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class ResidualStack(torch.nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, kernel_size=3, channels=32, dilation=1, bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2},
|
12 |
+
pad="ReflectionPad1d", pad_params={}, ):
|
13 |
+
"""
|
14 |
+
Initialize ResidualStack module.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
18 |
+
channels (int): Number of channels of convolution layers.
|
19 |
+
dilation (int): Dilation factor.
|
20 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
21 |
+
nonlinear_activation (str): Activation function module name.
|
22 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
23 |
+
pad (str): Padding function module name before dilated convolution layer.
|
24 |
+
pad_params (dict): Hyperparameters for padding function.
|
25 |
+
|
26 |
+
"""
|
27 |
+
super(ResidualStack, self).__init__()
|
28 |
+
|
29 |
+
# defile residual stack part
|
30 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
31 |
+
self.stack = torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
32 |
+
getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
|
33 |
+
torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
|
34 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
35 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias), )
|
36 |
+
|
37 |
+
# defile extra layer for skip connection
|
38 |
+
self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
|
39 |
+
|
40 |
+
def forward(self, c):
|
41 |
+
"""
|
42 |
+
Calculate forward propagation.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
c (Tensor): Input tensor (B, channels, T).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Tensor: Output tensor (B, chennels, T).
|
49 |
+
|
50 |
+
"""
|
51 |
+
return self.stack(c) + self.skip_layer(c)
|
Architectures/GeneralLayers/STFT.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.functional import stft as torch_stft
|
7 |
+
from torch_complex.tensor import ComplexTensor
|
8 |
+
|
9 |
+
from Utility.utils import make_pad_mask
|
10 |
+
|
11 |
+
|
12 |
+
class STFT(torch.nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, n_fft=512,
|
15 |
+
win_length=None,
|
16 |
+
hop_length=128,
|
17 |
+
window="hann",
|
18 |
+
center=True,
|
19 |
+
normalized=False,
|
20 |
+
onesided=True):
|
21 |
+
super().__init__()
|
22 |
+
self.n_fft = n_fft
|
23 |
+
if win_length is None:
|
24 |
+
self.win_length = n_fft
|
25 |
+
else:
|
26 |
+
self.win_length = win_length
|
27 |
+
self.hop_length = hop_length
|
28 |
+
self.center = center
|
29 |
+
self.normalized = normalized
|
30 |
+
self.onesided = onesided
|
31 |
+
self.window = window
|
32 |
+
|
33 |
+
def extra_repr(self):
|
34 |
+
return (f"n_fft={self.n_fft}, "
|
35 |
+
f"win_length={self.win_length}, "
|
36 |
+
f"hop_length={self.hop_length}, "
|
37 |
+
f"center={self.center}, "
|
38 |
+
f"normalized={self.normalized}, "
|
39 |
+
f"onesided={self.onesided}")
|
40 |
+
|
41 |
+
def forward(self, input_wave, ilens=None):
|
42 |
+
"""
|
43 |
+
STFT forward function.
|
44 |
+
Args:
|
45 |
+
input_wave: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
46 |
+
ilens: (Batch)
|
47 |
+
Returns:
|
48 |
+
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
49 |
+
"""
|
50 |
+
bs = input_wave.size(0)
|
51 |
+
|
52 |
+
if input_wave.dim() == 3:
|
53 |
+
multi_channel = True
|
54 |
+
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
55 |
+
input_wave = input_wave.transpose(1, 2).reshape(-1, input_wave.size(1))
|
56 |
+
else:
|
57 |
+
multi_channel = False
|
58 |
+
|
59 |
+
# output: (Batch, Freq, Frames, 2=real_imag)
|
60 |
+
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
61 |
+
if self.window is not None:
|
62 |
+
window_func = getattr(torch, f"{self.window}_window")
|
63 |
+
window = window_func(self.win_length, dtype=input_wave.dtype, device=input_wave.device)
|
64 |
+
else:
|
65 |
+
window = None
|
66 |
+
|
67 |
+
complex_output = torch_stft(input=input_wave,
|
68 |
+
n_fft=self.n_fft,
|
69 |
+
win_length=self.win_length,
|
70 |
+
hop_length=self.hop_length,
|
71 |
+
center=self.center,
|
72 |
+
window=window,
|
73 |
+
normalized=self.normalized,
|
74 |
+
onesided=self.onesided,
|
75 |
+
return_complex=True)
|
76 |
+
output = torch.view_as_real(complex_output)
|
77 |
+
# output: (Batch, Freq, Frames, 2=real_imag)
|
78 |
+
# -> (Batch, Frames, Freq, 2=real_imag)
|
79 |
+
output = output.transpose(1, 2)
|
80 |
+
if multi_channel:
|
81 |
+
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
82 |
+
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
83 |
+
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2)
|
84 |
+
|
85 |
+
if ilens is not None:
|
86 |
+
if self.center:
|
87 |
+
pad = self.win_length // 2
|
88 |
+
ilens = ilens + 2 * pad
|
89 |
+
|
90 |
+
olens = torch.div((ilens - self.win_length), self.hop_length, rounding_mode='trunc') + 1
|
91 |
+
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
92 |
+
else:
|
93 |
+
olens = None
|
94 |
+
|
95 |
+
return output, olens
|
96 |
+
|
97 |
+
def inverse(self, input, ilens=None):
|
98 |
+
"""
|
99 |
+
Inverse STFT.
|
100 |
+
Args:
|
101 |
+
input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
|
102 |
+
ilens: (batch,)
|
103 |
+
Returns:
|
104 |
+
wavs: (batch, samples)
|
105 |
+
ilens: (batch,)
|
106 |
+
"""
|
107 |
+
istft = torch.functional.istft
|
108 |
+
|
109 |
+
if self.window is not None:
|
110 |
+
window_func = getattr(torch, f"{self.window}_window")
|
111 |
+
window = window_func(self.win_length, dtype=input.dtype, device=input.device)
|
112 |
+
else:
|
113 |
+
window = None
|
114 |
+
|
115 |
+
if isinstance(input, ComplexTensor):
|
116 |
+
input = torch.stack([input.real, input.imag], dim=-1)
|
117 |
+
assert input.shape[-1] == 2
|
118 |
+
input = input.transpose(1, 2)
|
119 |
+
|
120 |
+
wavs = istft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=window, center=self.center,
|
121 |
+
normalized=self.normalized, onesided=self.onesided, length=ilens.max() if ilens is not None else ilens)
|
122 |
+
|
123 |
+
return wavs, ilens
|
Architectures/GeneralLayers/Swish.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
4 |
+
# Adapted by Florian Lux 2021
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Swish(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Construct a Swish activation function for Conformer.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
"""
|
16 |
+
Return Swish activation function.
|
17 |
+
"""
|
18 |
+
return x * torch.sigmoid(x)
|
Architectures/GeneralLayers/VariancePredictor.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2023
|
4 |
+
|
5 |
+
from abc import ABC
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
|
10 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
|
11 |
+
from Architectures.GeneralLayers.LayerNorm import LayerNorm
|
12 |
+
from Utility.utils import integrate_with_utt_embed
|
13 |
+
|
14 |
+
|
15 |
+
class VariancePredictor(torch.nn.Module, ABC):
|
16 |
+
"""
|
17 |
+
Variance predictor module.
|
18 |
+
|
19 |
+
This is a module of variance predictor described in `FastSpeech 2:
|
20 |
+
Fast and High-Quality End-to-End Text to Speech`_.
|
21 |
+
|
22 |
+
.. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`:
|
23 |
+
https://arxiv.org/abs/2006.04558
|
24 |
+
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
idim,
|
29 |
+
n_layers=2,
|
30 |
+
n_chans=384,
|
31 |
+
kernel_size=3,
|
32 |
+
bias=True,
|
33 |
+
dropout_rate=0.5,
|
34 |
+
utt_embed_dim=None,
|
35 |
+
embedding_integration="AdaIN"):
|
36 |
+
"""
|
37 |
+
Initialize duration predictor module.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
idim (int): Input dimension.
|
41 |
+
n_layers (int, optional): Number of convolutional layers.
|
42 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
43 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
44 |
+
dropout_rate (float, optional): Dropout rate.
|
45 |
+
"""
|
46 |
+
super().__init__()
|
47 |
+
self.conv = torch.nn.ModuleList()
|
48 |
+
self.dropouts = torch.nn.ModuleList()
|
49 |
+
self.norms = torch.nn.ModuleList()
|
50 |
+
self.embedding_projections = torch.nn.ModuleList()
|
51 |
+
self.utt_embed_dim = utt_embed_dim
|
52 |
+
self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
|
53 |
+
|
54 |
+
for idx in range(n_layers):
|
55 |
+
if utt_embed_dim is not None:
|
56 |
+
if embedding_integration == "AdaIN":
|
57 |
+
self.embedding_projections += [AdaIN1d(style_dim=utt_embed_dim, num_features=idim)]
|
58 |
+
elif embedding_integration == "ConditionalLayerNorm":
|
59 |
+
self.embedding_projections += [ConditionalLayerNorm(speaker_embedding_dim=utt_embed_dim, hidden_dim=idim)]
|
60 |
+
else:
|
61 |
+
self.embedding_projections += [torch.nn.Linear(utt_embed_dim + idim, idim)]
|
62 |
+
else:
|
63 |
+
self.embedding_projections += [lambda x: x]
|
64 |
+
in_chans = idim if idx == 0 else n_chans
|
65 |
+
self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias, ),
|
66 |
+
torch.nn.ReLU())]
|
67 |
+
self.norms += [LayerNorm(n_chans, dim=1)]
|
68 |
+
self.dropouts += [torch.nn.Dropout(dropout_rate)]
|
69 |
+
|
70 |
+
self.linear = torch.nn.Linear(n_chans, 1)
|
71 |
+
|
72 |
+
def forward(self, xs, padding_mask=None, utt_embed=None):
|
73 |
+
"""
|
74 |
+
Calculate forward propagation.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
78 |
+
padding_mask (ByteTensor, optional):
|
79 |
+
Batch of masks indicating padded part (B, Tmax).
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
Tensor: Batch of predicted sequences (B, Tmax, 1).
|
83 |
+
"""
|
84 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
85 |
+
|
86 |
+
for f, c, d, p in zip(self.conv, self.norms, self.dropouts, self.embedding_projections):
|
87 |
+
xs = f(xs) # (B, C, Tmax)
|
88 |
+
if self.utt_embed_dim is not None:
|
89 |
+
xs = integrate_with_utt_embed(hs=xs.transpose(1, 2), utt_embeddings=utt_embed, projection=p, embedding_training=self.use_conditional_layernorm_embedding_integration).transpose(1, 2)
|
90 |
+
xs = c(xs)
|
91 |
+
xs = d(xs)
|
92 |
+
|
93 |
+
xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1)
|
94 |
+
|
95 |
+
if padding_mask is not None:
|
96 |
+
xs = xs.masked_fill(padding_mask, 0.0)
|
97 |
+
|
98 |
+
return xs
|
Architectures/GeneralLayers/__init__.py
ADDED
File without changes
|
Architectures/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
This directory contains all the models that are used in this toolkit for various tasks. The models' directories contain their
|
2 |
+
feature extractors, their datasets, their architectures, and their train loops.
|
Architectures/ToucanTTS/CodecDiscriminator.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def weights_init_D(m):
|
6 |
+
classname = m.__class__.__name__
|
7 |
+
if classname.find('Conv') != -1:
|
8 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
9 |
+
elif classname.find('BatchNorm') != -1:
|
10 |
+
nn.init.constant_(m.weight, 1)
|
11 |
+
nn.init.constant_(m.bias, 0)
|
12 |
+
|
13 |
+
|
14 |
+
class SpectrogramDiscriminator(torch.nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
self.D = DiscriminatorNet()
|
18 |
+
self.D.apply(weights_init_D)
|
19 |
+
|
20 |
+
def _generator_feedback(self, data_generated, data_real):
|
21 |
+
for p in self.D.parameters():
|
22 |
+
p.requires_grad = False # freeze critic
|
23 |
+
|
24 |
+
score_fake, fmap_fake = self.D(data_generated)
|
25 |
+
_, fmap_real = self.D(data_real)
|
26 |
+
|
27 |
+
feature_matching_loss = 0.0
|
28 |
+
for feat_fake, feat_real in zip(fmap_fake, fmap_real):
|
29 |
+
feature_matching_loss += nn.functional.l1_loss(feat_fake, feat_real.detach())
|
30 |
+
|
31 |
+
discr_loss = nn.functional.mse_loss(input=score_fake, target=torch.ones(score_fake.shape, device=score_fake.device), reduction="mean")
|
32 |
+
|
33 |
+
return feature_matching_loss + discr_loss
|
34 |
+
|
35 |
+
def _discriminator_feature_matching(self, data_generated, data_real):
|
36 |
+
for p in self.D.parameters():
|
37 |
+
p.requires_grad = True # unfreeze critic
|
38 |
+
self.D.train()
|
39 |
+
|
40 |
+
score_fake, _ = self.D(data_generated)
|
41 |
+
score_real, _ = self.D(data_real)
|
42 |
+
|
43 |
+
discr_loss = 0.0
|
44 |
+
discr_loss = discr_loss + nn.functional.mse_loss(input=score_fake, target=torch.zeros(score_fake.shape, device=score_fake.device), reduction="mean")
|
45 |
+
discr_loss = discr_loss + nn.functional.mse_loss(input=score_real, target=torch.ones(score_real.shape, device=score_real.device), reduction="mean")
|
46 |
+
|
47 |
+
return discr_loss
|
48 |
+
|
49 |
+
def calc_discriminator_loss(self, data_generated, data_real):
|
50 |
+
return self._discriminator_feature_matching(data_generated.detach(), data_real)
|
51 |
+
|
52 |
+
def calc_generator_feedback(self, data_generated, data_real):
|
53 |
+
return self._generator_feedback(data_generated, data_real)
|
54 |
+
|
55 |
+
|
56 |
+
class DiscriminatorNet(nn.Module):
|
57 |
+
def __init__(self):
|
58 |
+
super().__init__()
|
59 |
+
self.filters = nn.ModuleList([
|
60 |
+
nn.utils.weight_norm(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
61 |
+
nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
62 |
+
nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
63 |
+
nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
64 |
+
nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
65 |
+
])
|
66 |
+
|
67 |
+
self.out = nn.utils.weight_norm(nn.Conv2d(32, 1, 3, 1, 1))
|
68 |
+
|
69 |
+
self.fc = nn.Linear(900, 1) # this needs to be changed everytime the window length is changes. It would be nice if this could be done dynamically.
|
70 |
+
|
71 |
+
def forward(self, y):
|
72 |
+
feature_maps = list()
|
73 |
+
feature_maps.append(y)
|
74 |
+
for d in self.filters:
|
75 |
+
y = d(y)
|
76 |
+
feature_maps.append(y)
|
77 |
+
y = nn.functional.leaky_relu(y, 0.1)
|
78 |
+
y = self.out(y)
|
79 |
+
feature_maps.append(y)
|
80 |
+
y = torch.flatten(y, 1, -1)
|
81 |
+
y = self.fc(y)
|
82 |
+
|
83 |
+
return y, feature_maps
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
d = SpectrogramDiscriminator()
|
88 |
+
fake = torch.randn([2, 100, 72]) # [Batch, Sequence Length, Spectrogram Buckets]
|
89 |
+
real = torch.randn([2, 100, 72]) # [Batch, Sequence Length, Spectrogram Buckets]
|
90 |
+
|
91 |
+
critic_loss = d.calc_discriminator_loss((fake.unsqueeze(1)), real.unsqueeze(1))
|
92 |
+
generator_loss = d.calc_generator_feedback(fake.unsqueeze(1), real.unsqueeze(1))
|
93 |
+
print(critic_loss)
|
94 |
+
print(generator_loss)
|
Architectures/ToucanTTS/CodecRefinementTransformer.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from Architectures.GeneralLayers.Conformer import Conformer
|
4 |
+
|
5 |
+
|
6 |
+
class CodecRefinementTransformer(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self,
|
9 |
+
attention_dimension=128,
|
10 |
+
num_codebooks=4,
|
11 |
+
codebook_size=1024,
|
12 |
+
backtranslation_dim=8,
|
13 |
+
attention_heads=4,
|
14 |
+
positionwise_conv_kernel_size=1,
|
15 |
+
use_macaron_style_in_conformer=True,
|
16 |
+
use_cnn_in_conformer=False, # for now, we try using just a regular transformer
|
17 |
+
decoder_layers=6,
|
18 |
+
decoder_units=1280,
|
19 |
+
decoder_concat_after=False,
|
20 |
+
conformer_decoder_kernel_size=31,
|
21 |
+
decoder_normalize_before=True,
|
22 |
+
transformer_dec_dropout_rate=0.2,
|
23 |
+
transformer_dec_positional_dropout_rate=0.1,
|
24 |
+
transformer_dec_attn_dropout_rate=0.1,
|
25 |
+
utt_embed_dim=512,
|
26 |
+
use_conditional_layernorm_embedding_integration=False,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.reconstruction_transformer = Conformer(
|
31 |
+
conformer_type="decoder",
|
32 |
+
attention_dim=num_codebooks * backtranslation_dim,
|
33 |
+
attention_heads=attention_heads,
|
34 |
+
linear_units=decoder_units,
|
35 |
+
num_blocks=decoder_layers,
|
36 |
+
input_layer=None,
|
37 |
+
dropout_rate=transformer_dec_dropout_rate,
|
38 |
+
positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
39 |
+
attention_dropout_rate=transformer_dec_attn_dropout_rate,
|
40 |
+
normalize_before=decoder_normalize_before,
|
41 |
+
concat_after=decoder_concat_after,
|
42 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
43 |
+
macaron_style=use_macaron_style_in_conformer,
|
44 |
+
use_cnn_module=use_cnn_in_conformer,
|
45 |
+
cnn_module_kernel=conformer_decoder_kernel_size,
|
46 |
+
use_output_norm=False,
|
47 |
+
utt_embed=utt_embed_dim,
|
48 |
+
use_conditional_layernorm_embedding_integration=use_conditional_layernorm_embedding_integration
|
49 |
+
)
|
50 |
+
|
51 |
+
self.num_codebooks = num_codebooks
|
52 |
+
self.codebook_size = codebook_size
|
53 |
+
self.input_embeddings = torch.nn.ModuleList()
|
54 |
+
self.backtranslation_heads = torch.nn.ModuleList()
|
55 |
+
self.hierarchical_classifier = torch.nn.ModuleList()
|
56 |
+
self.padding_id = codebook_size + 5
|
57 |
+
for head in range(num_codebooks):
|
58 |
+
self.input_embeddings.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id))
|
59 |
+
self.backtranslation_heads.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id))
|
60 |
+
self.hierarchical_classifier.append(torch.nn.Linear(num_codebooks * backtranslation_dim + head * backtranslation_dim, codebook_size))
|
61 |
+
|
62 |
+
self.criterion = MaskedRefinementObjective()
|
63 |
+
for backtranslation_head in self.backtranslation_heads:
|
64 |
+
torch.nn.init.normal_(backtranslation_head.weight, mean=0, std=attention_dimension ** -0.5)
|
65 |
+
for input_embedding in self.input_embeddings:
|
66 |
+
torch.nn.init.normal_(input_embedding.weight, mean=0, std=attention_dimension ** -0.5)
|
67 |
+
|
68 |
+
def forward(self, index_sequence, is_inference, speaker_embedding, padding_mask=None, gold_index_sequence=None):
|
69 |
+
"""
|
70 |
+
index_sequence: [batch, codebook_index, time_steps] a sequence of indexes that come from an argmax of the previous prediction layer.
|
71 |
+
is_inference: boolean flag that indicates whether to return the masked language modelling loss or the refined sequence
|
72 |
+
speaker_embedding: [batch, speaker_embed_dim]
|
73 |
+
padding_mask: [batch, time_steps] a mask that is True for all time steps that are padding and should not be considered and False everywhere else.
|
74 |
+
|
75 |
+
return: loss if is_inference is false, otherwise [batch, codebook_index, time_steps] a sequence of indexes with the same shape and same interpretation, refined through iterative masked language modelling.
|
76 |
+
"""
|
77 |
+
|
78 |
+
if not is_inference:
|
79 |
+
index_sequence_padding_accounted = index_sequence.masked_fill(mask=padding_mask.unsqueeze(1), value=self.padding_id)
|
80 |
+
else:
|
81 |
+
index_sequence_padding_accounted = index_sequence # in the case of inference, there is no padding
|
82 |
+
|
83 |
+
sequence_of_continuous_tokens = self.indexes_per_codebook_to_stacked_embedding_vector(index_sequence_padding_accounted) # return [batch, time_steps, num_codebooks x backtranslation_dim]
|
84 |
+
contextualized_sequence = self.contextualize_sequence(sequence_of_continuous_tokens, speaker_embedding, non_padding_mask=~padding_mask if padding_mask is not None else None)
|
85 |
+
|
86 |
+
predicted_indexes_one_hot = list()
|
87 |
+
backtranslated_indexes = list()
|
88 |
+
for head_index, classifier_head in enumerate(self.hierarchical_classifier):
|
89 |
+
# each codebook considers all previous codebooks.
|
90 |
+
predicted_indexes_one_hot.append(classifier_head(torch.cat([contextualized_sequence] + backtranslated_indexes, dim=2)))
|
91 |
+
predicted_lookup_index = torch.argmax(predicted_indexes_one_hot[-1], dim=-1)
|
92 |
+
backtranslation = self.backtranslation_heads[head_index](predicted_lookup_index)
|
93 |
+
if len(backtranslation.size()) == 1:
|
94 |
+
backtranslation = backtranslation.unsqueeze(0)
|
95 |
+
backtranslated_indexes.append(backtranslation)
|
96 |
+
indexes = torch.cat(predicted_indexes_one_hot, dim=2)
|
97 |
+
# [Batch, Sequence, Hidden]
|
98 |
+
indexes = indexes.view(contextualized_sequence.size(0), contextualized_sequence.size(1), self.num_codebooks, self.codebook_size)
|
99 |
+
# [Batch, Sequence, Codebook, Classes]
|
100 |
+
indexes = indexes.transpose(1, 2)
|
101 |
+
# [Batch, Codebook, Sequence, Classes]
|
102 |
+
indexes = indexes.transpose(2, 3)
|
103 |
+
# [Batch, Codebook, Classes, Sequence]
|
104 |
+
indexes = indexes.transpose(0, 1)
|
105 |
+
# [Codebook, Batch, Classes, Sequence]
|
106 |
+
|
107 |
+
if is_inference:
|
108 |
+
return indexes
|
109 |
+
else:
|
110 |
+
return self.criterion(predicted_one_hot=indexes, gold_one_hot=gold_index_sequence, non_pad_mask=~padding_mask)
|
111 |
+
|
112 |
+
def contextualize_sequence(self, masked_sequence, utterance_embedding, non_padding_mask):
|
113 |
+
decoded_speech, _ = self.reconstruction_transformer(masked_sequence, non_padding_mask.unsqueeze(2) if non_padding_mask is not None else None, utterance_embedding=utterance_embedding)
|
114 |
+
return decoded_speech
|
115 |
+
|
116 |
+
def indexes_per_codebook_to_stacked_embedding_vector(self, index_sequence_per_codebook):
|
117 |
+
continuous_frame_sequences = list()
|
118 |
+
|
119 |
+
for codebook_id, backtranslation_head in enumerate(self.backtranslation_heads):
|
120 |
+
continuous_frame_sequences.append(backtranslation_head(index_sequence_per_codebook.transpose(0, 1)[codebook_id]))
|
121 |
+
stacked_embedding_vector = torch.cat(continuous_frame_sequences, dim=-1)
|
122 |
+
return stacked_embedding_vector
|
123 |
+
|
124 |
+
|
125 |
+
class MaskedRefinementObjective(torch.nn.Module):
|
126 |
+
|
127 |
+
def __init__(self):
|
128 |
+
super().__init__()
|
129 |
+
self.classification_loss = torch.nn.CrossEntropyLoss(reduction="none")
|
130 |
+
self.l1_loss = torch.nn.L1Loss(reduction="none")
|
131 |
+
|
132 |
+
def forward(self, predicted_one_hot, gold_one_hot, non_pad_mask):
|
133 |
+
ce = list()
|
134 |
+
for one_hot_pred, one_hot_target in zip(predicted_one_hot, gold_one_hot.transpose(0, 1).transpose(2, 3)):
|
135 |
+
# we iterate over codebooks
|
136 |
+
ce.append(self.classification_loss(one_hot_pred, one_hot_target))
|
137 |
+
classification_loss = torch.stack(ce).sum(0)
|
138 |
+
# make weighted mask and apply it
|
139 |
+
out_masks = non_pad_mask.unsqueeze(-1).to(gold_one_hot.device)
|
140 |
+
out_masks = torch.nn.functional.pad(out_masks.transpose(1, 2), [0, gold_one_hot.size(2) - out_masks.size(1), 0, 0, 0, 0], value=False).transpose(1, 2)
|
141 |
+
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
|
142 |
+
out_weights /= gold_one_hot.size(0) * gold_one_hot.size(-1)
|
143 |
+
# apply weight
|
144 |
+
classification_loss = classification_loss.mul(out_weights.squeeze()).masked_select(out_masks.squeeze()).sum()
|
145 |
+
|
146 |
+
return classification_loss, classification_loss
|
147 |
+
|
148 |
+
|
149 |
+
def one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook):
|
150 |
+
return torch.argmax(batch_of_indexes_one_hot_per_codebook, dim=-2).transpose(0, 1)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == '__main__':
|
154 |
+
from Architectures.ToucanTTS.ToucanTTS import ToucanTTS
|
155 |
+
from Utility.utils import make_pad_mask
|
156 |
+
|
157 |
+
# prepare dummy inputs
|
158 |
+
num_codebooks = 4
|
159 |
+
dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone]
|
160 |
+
dummy_text_lens = torch.LongTensor([2, 3, 3])
|
161 |
+
gold_speech_batch = torch.randn([3, num_codebooks, 30, 1024]) # [Batch, Sequence Length, Spectrogram Buckets]
|
162 |
+
gold_speech_lens = torch.LongTensor([10, 30, 20])
|
163 |
+
gold_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5], [5, 5, 10]])
|
164 |
+
gold_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]], [[1.1], [1.2], [0.8]]])
|
165 |
+
gold_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]], [[1.1], [1.2], [0.8]]])
|
166 |
+
dummy_utterance_embed = torch.randn([3, 512]) # [Batch, Dimensions of Speaker Embedding]
|
167 |
+
dummy_language_id = torch.LongTensor([5, 3, 2]).unsqueeze(1)
|
168 |
+
|
169 |
+
# run TTS on pseudo inputs
|
170 |
+
batch_of_indexes_one_hot_per_codebook, _, _, _, _, _ = ToucanTTS(num_codebooks=num_codebooks, use_language_model=False)._forward(dummy_text_batch,
|
171 |
+
dummy_text_lens,
|
172 |
+
gold_speech_batch,
|
173 |
+
gold_speech_lens,
|
174 |
+
gold_durations,
|
175 |
+
gold_pitch,
|
176 |
+
gold_energy,
|
177 |
+
utterance_embedding=dummy_utterance_embed,
|
178 |
+
lang_ids=dummy_language_id)
|
179 |
+
|
180 |
+
# reformat outputs to be a token sequence
|
181 |
+
batch_of_indexes = one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook)
|
182 |
+
|
183 |
+
# refine the output of the TTS with the Language Model
|
184 |
+
refiner = CodecRefinementTransformer()
|
185 |
+
|
186 |
+
loss = refiner(index_sequence=one_hot_sequence_to_token_sequence(gold_speech_batch.transpose(3, 2)).transpose(0, 1), padding_mask=make_pad_mask(gold_speech_lens), is_inference=False, speaker_embedding=dummy_utterance_embed, gold_index_sequence=gold_speech_batch)
|
187 |
+
print(loss)
|
188 |
+
|
189 |
+
refined_indexes = refiner(index_sequence=batch_of_indexes[1].unsqueeze(0), is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
|
190 |
+
print(refined_indexes.shape)
|
191 |
+
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
|
192 |
+
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
|
193 |
+
print(refined_indexes.shape)
|
194 |
+
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
|
195 |
+
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
|
196 |
+
print(refined_indexes.shape)
|
197 |
+
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
|
198 |
+
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
|
199 |
+
print(refined_indexes.shape)
|
Architectures/ToucanTTS/DurationCalculator.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Nagoya University (Tomoki Hayashi)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class DurationCalculator(torch.nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, reduction_factor=1.0):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def forward(self, att_ws, vis=None):
|
17 |
+
"""
|
18 |
+
Convert alignment matrix to durations.
|
19 |
+
"""
|
20 |
+
if vis is not None:
|
21 |
+
plt.figure(figsize=(8, 4))
|
22 |
+
plt.imshow(att_ws.cpu().numpy(), interpolation='nearest', aspect='auto', origin="lower")
|
23 |
+
plt.xlabel("Inputs")
|
24 |
+
plt.ylabel("Outputs")
|
25 |
+
plt.tight_layout()
|
26 |
+
plt.savefig(vis)
|
27 |
+
plt.close()
|
28 |
+
# calculate duration from 2d alignment matrix
|
29 |
+
durations = torch.stack([att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])])
|
30 |
+
return durations.view(-1)
|
Architectures/ToucanTTS/EnergyCalculator.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Nagoya University (Tomoki Hayashi)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from Architectures.GeneralLayers.STFT import STFT
|
9 |
+
from Utility.utils import pad_list
|
10 |
+
|
11 |
+
|
12 |
+
class EnergyCalculator(torch.nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, fs=16000, n_fft=1024, win_length=None, hop_length=256, window="hann", center=True,
|
15 |
+
normalized=False, onesided=True, use_token_averaged_energy=True, reduction_factor=1):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.fs = fs
|
19 |
+
self.n_fft = n_fft
|
20 |
+
self.hop_length = hop_length
|
21 |
+
self.win_length = win_length
|
22 |
+
self.window = window
|
23 |
+
self.use_token_averaged_energy = use_token_averaged_energy
|
24 |
+
if use_token_averaged_energy:
|
25 |
+
assert reduction_factor >= 1
|
26 |
+
self.reduction_factor = reduction_factor
|
27 |
+
|
28 |
+
self.stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided)
|
29 |
+
|
30 |
+
def output_size(self):
|
31 |
+
return 1
|
32 |
+
|
33 |
+
def get_parameters(self):
|
34 |
+
return dict(fs=self.fs, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, win_length=self.win_length, center=self.stft.center,
|
35 |
+
normalized=self.stft.normalized, use_token_averaged_energy=self.use_token_averaged_energy, reduction_factor=self.reduction_factor)
|
36 |
+
|
37 |
+
def forward(self, input_waves, input_waves_lengths=None, feats_lengths=None, durations=None,
|
38 |
+
durations_lengths=None, norm_by_average=True, text=None):
|
39 |
+
# If not provided, we assume that the inputs have the same length
|
40 |
+
if input_waves_lengths is None:
|
41 |
+
input_waves_lengths = (input_waves.new_ones(input_waves.shape[0], dtype=torch.long) * input_waves.shape[1])
|
42 |
+
|
43 |
+
# Domain-conversion: e.g. Stft: time -> time-freq
|
44 |
+
input_stft, energy_lengths = self.stft(input_waves, input_waves_lengths)
|
45 |
+
|
46 |
+
assert input_stft.dim() >= 4, input_stft.shape
|
47 |
+
assert input_stft.shape[-1] == 2, input_stft.shape
|
48 |
+
|
49 |
+
# input_stft: (..., F, 2) -> (..., F)
|
50 |
+
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
|
51 |
+
# sum over frequency (B, N, F) -> (B, N)
|
52 |
+
energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10))
|
53 |
+
|
54 |
+
# (Optional): Adjust length to match with the features
|
55 |
+
if feats_lengths is not None:
|
56 |
+
energy = [self._adjust_num_frames(e[:el].view(-1), fl) for e, el, fl in zip(energy, energy_lengths, feats_lengths)]
|
57 |
+
energy_lengths = feats_lengths
|
58 |
+
|
59 |
+
# (Optional): Average by duration to calculate token-wise energy
|
60 |
+
if self.use_token_averaged_energy:
|
61 |
+
energy = [self._average_by_duration(e[:el].view(-1), d, text) for e, el, d in zip(energy, energy_lengths, durations)]
|
62 |
+
energy_lengths = durations_lengths
|
63 |
+
|
64 |
+
# Padding
|
65 |
+
if isinstance(energy, list):
|
66 |
+
energy = pad_list(energy, 0.0)
|
67 |
+
|
68 |
+
if norm_by_average:
|
69 |
+
average = energy[0][energy[0] != 0.0].mean()
|
70 |
+
energy = energy / average
|
71 |
+
|
72 |
+
# Return with the shape (B, T, 1)
|
73 |
+
return energy.unsqueeze(-1), energy_lengths
|
74 |
+
|
75 |
+
def _average_by_duration(self, x, d, text=None):
|
76 |
+
d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
|
77 |
+
x_avg = [x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0) for start, end in zip(d_cumsum[:-1], d_cumsum[1:])]
|
78 |
+
|
79 |
+
# find tokens that are not phoneme and set energy to 0
|
80 |
+
# while this makes sense, it make sit harder to model, so we leave this out
|
81 |
+
# if text is not None:
|
82 |
+
# for i, vector in enumerate(text):
|
83 |
+
# if vector[get_feature_to_index_lookup()["phoneme"]] == 0:
|
84 |
+
# x_avg[i] = torch.tensor(0.0, device=x.device)
|
85 |
+
|
86 |
+
return torch.stack(x_avg)
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def _adjust_num_frames(x, num_frames):
|
90 |
+
if num_frames > len(x):
|
91 |
+
x = F.pad(x, (0, num_frames - len(x)))
|
92 |
+
elif num_frames < len(x):
|
93 |
+
x = x[:num_frames]
|
94 |
+
return x
|
Architectures/ToucanTTS/Glow.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
import torch
|
4 |
+
import torch.distributions as dist
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from Architectures.ToucanTTS import glow_utils
|
9 |
+
from Architectures.ToucanTTS.wavenet import WN
|
10 |
+
|
11 |
+
|
12 |
+
class ActNorm(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, channels, ddi=False, **kwargs):
|
15 |
+
super().__init__()
|
16 |
+
self.channels = channels
|
17 |
+
self.initialized = not ddi
|
18 |
+
|
19 |
+
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
20 |
+
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
21 |
+
|
22 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
23 |
+
if x_mask is None:
|
24 |
+
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
25 |
+
x_len = torch.sum(x_mask, [1, 2])
|
26 |
+
if not self.initialized:
|
27 |
+
self.initialize(x, x_mask)
|
28 |
+
self.initialized = True
|
29 |
+
|
30 |
+
if reverse:
|
31 |
+
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
32 |
+
logdet = torch.sum(-self.logs) * x_len
|
33 |
+
else:
|
34 |
+
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
35 |
+
logdet = torch.sum(self.logs) * x_len # [b]
|
36 |
+
return z, logdet
|
37 |
+
|
38 |
+
def store_inverse(self):
|
39 |
+
pass
|
40 |
+
|
41 |
+
def set_ddi(self, ddi):
|
42 |
+
self.initialized = not ddi
|
43 |
+
|
44 |
+
def initialize(self, x, x_mask):
|
45 |
+
with torch.no_grad():
|
46 |
+
denom = torch.sum(x_mask, [0, 2])
|
47 |
+
m = torch.sum(x * x_mask, [0, 2]) / denom
|
48 |
+
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
49 |
+
v = m_sq - (m ** 2)
|
50 |
+
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
51 |
+
|
52 |
+
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
53 |
+
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
54 |
+
|
55 |
+
self.bias.data.copy_(bias_init)
|
56 |
+
self.logs.data.copy_(logs_init)
|
57 |
+
|
58 |
+
|
59 |
+
class InvConvNear(nn.Module):
|
60 |
+
|
61 |
+
def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
|
62 |
+
super().__init__()
|
63 |
+
assert (n_split % 2 == 0)
|
64 |
+
self.channels = channels
|
65 |
+
self.n_split = n_split
|
66 |
+
self.n_sqz = n_sqz
|
67 |
+
self.no_jacobian = no_jacobian
|
68 |
+
|
69 |
+
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_(), 'complete')[0]
|
70 |
+
if torch.det(w_init) < 0:
|
71 |
+
w_init[:, 0] = -1 * w_init[:, 0]
|
72 |
+
self.lu = lu
|
73 |
+
if lu:
|
74 |
+
# LU decomposition can slightly speed up the inverse
|
75 |
+
np_p, np_l, np_u = scipy.linalg.lu(w_init)
|
76 |
+
np_s = np.diag(np_u)
|
77 |
+
np_sign_s = np.sign(np_s)
|
78 |
+
np_log_s = np.log(np.abs(np_s))
|
79 |
+
np_u = np.triu(np_u, k=1)
|
80 |
+
l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
|
81 |
+
eye = np.eye(*w_init.shape, dtype=float)
|
82 |
+
|
83 |
+
self.register_buffer('p', torch.Tensor(np_p.astype(float)))
|
84 |
+
self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
|
85 |
+
self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
|
86 |
+
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
|
87 |
+
self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
|
88 |
+
self.register_buffer('l_mask', torch.Tensor(l_mask))
|
89 |
+
self.register_buffer('eye', torch.Tensor(eye))
|
90 |
+
else:
|
91 |
+
self.weight = nn.Parameter(w_init)
|
92 |
+
|
93 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
94 |
+
b, c, t = x.size()
|
95 |
+
assert (c % self.n_split == 0)
|
96 |
+
if x_mask is None:
|
97 |
+
x_mask = 1
|
98 |
+
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
|
99 |
+
else:
|
100 |
+
x_len = torch.sum(x_mask, [1, 2])
|
101 |
+
|
102 |
+
x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
|
103 |
+
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
|
104 |
+
|
105 |
+
if self.lu:
|
106 |
+
self.weight, log_s = self._get_weight()
|
107 |
+
logdet = log_s.sum()
|
108 |
+
logdet = logdet * (c / self.n_split) * x_len
|
109 |
+
else:
|
110 |
+
logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
|
111 |
+
|
112 |
+
if reverse:
|
113 |
+
if hasattr(self, "weight_inv"):
|
114 |
+
weight = self.weight_inv
|
115 |
+
else:
|
116 |
+
weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
117 |
+
logdet = -logdet
|
118 |
+
else:
|
119 |
+
weight = self.weight
|
120 |
+
if self.no_jacobian:
|
121 |
+
logdet = 0
|
122 |
+
|
123 |
+
weight = weight.view(self.n_split, self.n_split, 1, 1).to(x.device)
|
124 |
+
z = F.conv2d(x, weight)
|
125 |
+
|
126 |
+
z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
|
127 |
+
z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
|
128 |
+
return z, logdet
|
129 |
+
|
130 |
+
def _get_weight(self):
|
131 |
+
l, log_s, u = self.l, self.log_s, self.u
|
132 |
+
l = l * self.l_mask + self.eye
|
133 |
+
u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
|
134 |
+
weight = torch.matmul(self.p, torch.matmul(l, u))
|
135 |
+
return weight, log_s
|
136 |
+
|
137 |
+
def store_inverse(self):
|
138 |
+
weight, _ = self._get_weight()
|
139 |
+
self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
|
140 |
+
|
141 |
+
|
142 |
+
class InvConv(nn.Module):
|
143 |
+
|
144 |
+
def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
|
145 |
+
super().__init__()
|
146 |
+
w_shape = [channels, channels]
|
147 |
+
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
|
148 |
+
LU_decomposed = lu
|
149 |
+
if not LU_decomposed:
|
150 |
+
# Sample a random orthogonal matrix:
|
151 |
+
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
|
152 |
+
else:
|
153 |
+
np_p, np_l, np_u = scipy.linalg.lu(w_init)
|
154 |
+
np_s = np.diag(np_u)
|
155 |
+
np_sign_s = np.sign(np_s)
|
156 |
+
np_log_s = np.log(np.abs(np_s))
|
157 |
+
np_u = np.triu(np_u, k=1)
|
158 |
+
l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
|
159 |
+
eye = np.eye(*w_shape, dtype=float)
|
160 |
+
|
161 |
+
self.register_buffer('p', torch.Tensor(np_p.astype(float)))
|
162 |
+
self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
|
163 |
+
self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
|
164 |
+
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
|
165 |
+
self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
|
166 |
+
self.l_mask = torch.Tensor(l_mask)
|
167 |
+
self.eye = torch.Tensor(eye)
|
168 |
+
self.w_shape = w_shape
|
169 |
+
self.LU = LU_decomposed
|
170 |
+
self.weight = None
|
171 |
+
|
172 |
+
def get_weight(self, device, reverse):
|
173 |
+
w_shape = self.w_shape
|
174 |
+
self.p = self.p.to(device)
|
175 |
+
self.sign_s = self.sign_s.to(device)
|
176 |
+
self.l_mask = self.l_mask.to(device)
|
177 |
+
self.eye = self.eye.to(device)
|
178 |
+
l = self.l * self.l_mask + self.eye
|
179 |
+
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
|
180 |
+
dlogdet = self.log_s.sum()
|
181 |
+
if not reverse:
|
182 |
+
w = torch.matmul(self.p, torch.matmul(l, u))
|
183 |
+
else:
|
184 |
+
l = torch.inverse(l.double()).float()
|
185 |
+
u = torch.inverse(u.double()).float()
|
186 |
+
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
|
187 |
+
return w.view(w_shape[0], w_shape[1], 1), dlogdet
|
188 |
+
|
189 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
190 |
+
"""
|
191 |
+
log-det = log|abs(|W|)| * pixels
|
192 |
+
"""
|
193 |
+
b, c, t = x.size()
|
194 |
+
if x_mask is None:
|
195 |
+
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
|
196 |
+
else:
|
197 |
+
x_len = torch.sum(x_mask, [1, 2])
|
198 |
+
logdet = 0
|
199 |
+
if not reverse:
|
200 |
+
weight, dlogdet = self.get_weight(x.device, reverse)
|
201 |
+
z = F.conv1d(x, weight)
|
202 |
+
if logdet is not None:
|
203 |
+
logdet = logdet + dlogdet * x_len
|
204 |
+
return z, logdet
|
205 |
+
else:
|
206 |
+
if self.weight is None:
|
207 |
+
weight, dlogdet = self.get_weight(x.device, reverse)
|
208 |
+
else:
|
209 |
+
weight, dlogdet = self.weight, self.dlogdet
|
210 |
+
z = F.conv1d(x, weight)
|
211 |
+
if logdet is not None:
|
212 |
+
logdet = logdet - dlogdet * x_len
|
213 |
+
return z, logdet
|
214 |
+
|
215 |
+
def store_inverse(self):
|
216 |
+
self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
|
217 |
+
|
218 |
+
|
219 |
+
class CouplingBlock(nn.Module):
|
220 |
+
|
221 |
+
def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
|
222 |
+
gin_channels=0, p_dropout=0., sigmoid_scale=False, wn=None, use_weightnorm=True):
|
223 |
+
super().__init__()
|
224 |
+
self.in_channels = in_channels
|
225 |
+
self.hidden_channels = hidden_channels
|
226 |
+
self.kernel_size = kernel_size
|
227 |
+
self.dilation_rate = dilation_rate
|
228 |
+
self.n_layers = n_layers
|
229 |
+
self.gin_channels = gin_channels
|
230 |
+
self.p_dropout = p_dropout
|
231 |
+
self.sigmoid_scale = sigmoid_scale
|
232 |
+
|
233 |
+
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
234 |
+
if use_weightnorm:
|
235 |
+
start = torch.nn.utils.weight_norm(start)
|
236 |
+
self.start = start
|
237 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
238 |
+
# do nothing at first. This helps with training stability
|
239 |
+
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
240 |
+
end.weight.data.zero_()
|
241 |
+
end.bias.data.zero_()
|
242 |
+
self.end = end
|
243 |
+
self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout, use_weightnorm=use_weightnorm)
|
244 |
+
if wn is not None:
|
245 |
+
self.wn.in_layers = wn.in_layers
|
246 |
+
self.wn.res_skip_layers = wn.res_skip_layers
|
247 |
+
|
248 |
+
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
|
249 |
+
if x_mask is None:
|
250 |
+
x_mask = 1
|
251 |
+
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
252 |
+
|
253 |
+
x = self.start(x_0) * x_mask
|
254 |
+
x = self.wn(x, x_mask, g)
|
255 |
+
out = self.end(x)
|
256 |
+
|
257 |
+
z_0 = x_0
|
258 |
+
m = out[:, :self.in_channels // 2, :]
|
259 |
+
logs = out[:, self.in_channels // 2:, :]
|
260 |
+
if self.sigmoid_scale:
|
261 |
+
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
262 |
+
if reverse:
|
263 |
+
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
264 |
+
logdet = torch.sum(-logs * x_mask, [1, 2])
|
265 |
+
else:
|
266 |
+
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
267 |
+
logdet = torch.sum(logs * x_mask, [1, 2])
|
268 |
+
z = torch.cat([z_0, z_1], 1)
|
269 |
+
return z, logdet
|
270 |
+
|
271 |
+
def store_inverse(self):
|
272 |
+
self.wn.remove_weight_norm()
|
273 |
+
|
274 |
+
|
275 |
+
class Glow(nn.Module):
|
276 |
+
|
277 |
+
def __init__(self,
|
278 |
+
in_channels,
|
279 |
+
hidden_channels,
|
280 |
+
kernel_size,
|
281 |
+
dilation_rate,
|
282 |
+
n_blocks,
|
283 |
+
n_layers,
|
284 |
+
condition_integration_projection,
|
285 |
+
p_dropout=0.,
|
286 |
+
n_split=4,
|
287 |
+
n_sqz=2,
|
288 |
+
sigmoid_scale=False,
|
289 |
+
text_condition_channels=0,
|
290 |
+
inv_conv_type='near',
|
291 |
+
share_cond_layers=False,
|
292 |
+
share_wn_layers=0,
|
293 |
+
use_weightnorm=True # If weightnorm is set to false, we can deepcopy the module, which we need to be able to do to perform SWA. Without weightnorm, the module will probably take a little longer to converge.
|
294 |
+
):
|
295 |
+
super().__init__()
|
296 |
+
|
297 |
+
self.in_channels = in_channels
|
298 |
+
self.hidden_channels = hidden_channels
|
299 |
+
self.kernel_size = kernel_size
|
300 |
+
self.dilation_rate = dilation_rate
|
301 |
+
self.n_blocks = n_blocks
|
302 |
+
self.n_layers = n_layers
|
303 |
+
self.p_dropout = p_dropout
|
304 |
+
self.n_split = n_split
|
305 |
+
self.n_sqz = n_sqz
|
306 |
+
self.sigmoid_scale = sigmoid_scale
|
307 |
+
self.text_condition_channels = text_condition_channels
|
308 |
+
self.share_cond_layers = share_cond_layers
|
309 |
+
self.prior_dist = dist.Normal(0, 1)
|
310 |
+
self.g_proj = condition_integration_projection
|
311 |
+
if text_condition_channels != 0 and share_cond_layers:
|
312 |
+
cond_layer = torch.nn.Conv1d(text_condition_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
|
313 |
+
if use_weightnorm:
|
314 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
315 |
+
else:
|
316 |
+
self.cond_layer = cond_layer
|
317 |
+
wn = None
|
318 |
+
self.flows = nn.ModuleList()
|
319 |
+
for b in range(n_blocks):
|
320 |
+
self.flows.append(ActNorm(channels=in_channels * n_sqz))
|
321 |
+
if inv_conv_type == 'near':
|
322 |
+
self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
|
323 |
+
if inv_conv_type == 'invconv':
|
324 |
+
self.flows.append(InvConv(channels=in_channels * n_sqz))
|
325 |
+
if share_wn_layers > 0:
|
326 |
+
if b % share_wn_layers == 0:
|
327 |
+
wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, text_condition_channels * n_sqz, p_dropout, share_cond_layers, use_weightnorm=use_weightnorm)
|
328 |
+
self.flows.append(
|
329 |
+
CouplingBlock(
|
330 |
+
in_channels * n_sqz,
|
331 |
+
hidden_channels,
|
332 |
+
kernel_size=kernel_size,
|
333 |
+
dilation_rate=dilation_rate,
|
334 |
+
n_layers=n_layers,
|
335 |
+
gin_channels=text_condition_channels * n_sqz,
|
336 |
+
p_dropout=p_dropout,
|
337 |
+
sigmoid_scale=sigmoid_scale,
|
338 |
+
wn=wn,
|
339 |
+
use_weightnorm=use_weightnorm
|
340 |
+
))
|
341 |
+
|
342 |
+
def forward(self, tgt_mels, infer, mel_out, encoded_texts, tgt_nonpadding, glow_sampling_temperature=0.2):
|
343 |
+
x_recon = mel_out.transpose(1, 2)
|
344 |
+
g = x_recon
|
345 |
+
B, _, T = g.shape
|
346 |
+
if encoded_texts is not None and self.text_condition_channels != 0:
|
347 |
+
g = torch.cat([g, encoded_texts.transpose(1, 2)], 1)
|
348 |
+
g = self.g_proj(g)
|
349 |
+
prior_dist = self.prior_dist
|
350 |
+
if not infer:
|
351 |
+
y_lengths = tgt_nonpadding.sum(-1)
|
352 |
+
tgt_mels = tgt_mels.transpose(1, 2)
|
353 |
+
z_postflow, ldj = self._forward(tgt_mels, tgt_nonpadding, g=g)
|
354 |
+
ldj = ldj / y_lengths / 80
|
355 |
+
try:
|
356 |
+
postflow_loss = -prior_dist.log_prob(z_postflow).mean() - ldj.mean()
|
357 |
+
except ValueError:
|
358 |
+
print("log probability of postflow could not be calculated for this step")
|
359 |
+
postflow_loss = None
|
360 |
+
return postflow_loss
|
361 |
+
else:
|
362 |
+
nonpadding = torch.ones_like(x_recon[:, :1, :]) if tgt_nonpadding is None else tgt_nonpadding
|
363 |
+
z_post = torch.randn(x_recon.shape).to(g.device) * glow_sampling_temperature
|
364 |
+
x_recon, _ = self._forward(z_post, nonpadding, g, reverse=True)
|
365 |
+
return x_recon.transpose(1, 2)
|
366 |
+
|
367 |
+
def _forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
|
368 |
+
logdet_tot = 0
|
369 |
+
if not reverse:
|
370 |
+
flows = self.flows
|
371 |
+
else:
|
372 |
+
flows = reversed(self.flows)
|
373 |
+
if return_hiddens:
|
374 |
+
hs = []
|
375 |
+
if self.n_sqz > 1:
|
376 |
+
x, x_mask_ = glow_utils.squeeze(x, x_mask, self.n_sqz)
|
377 |
+
if g is not None:
|
378 |
+
g, _ = glow_utils.squeeze(g, x_mask, self.n_sqz)
|
379 |
+
x_mask = x_mask_
|
380 |
+
if self.share_cond_layers and g is not None:
|
381 |
+
g = self.cond_layer(g)
|
382 |
+
for f in flows:
|
383 |
+
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
384 |
+
if return_hiddens:
|
385 |
+
hs.append(x)
|
386 |
+
logdet_tot += logdet
|
387 |
+
if self.n_sqz > 1:
|
388 |
+
x, x_mask = glow_utils.unsqueeze(x, x_mask, self.n_sqz)
|
389 |
+
if return_hiddens:
|
390 |
+
return x, logdet_tot, hs
|
391 |
+
return x, logdet_tot
|
392 |
+
|
393 |
+
def store_inverse(self):
|
394 |
+
def remove_weight_norm(m):
|
395 |
+
try:
|
396 |
+
nn.utils.remove_weight_norm(m)
|
397 |
+
except ValueError: # this module didn't have weight norm
|
398 |
+
return
|
399 |
+
|
400 |
+
self.apply(remove_weight_norm)
|
401 |
+
for f in self.flows:
|
402 |
+
f.store_inverse()
|
Architectures/ToucanTTS/InferenceToucanTTS.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dotwiz
|
2 |
+
import torch
|
3 |
+
from torch.nn import Linear
|
4 |
+
from torch.nn import Sequential
|
5 |
+
from torch.nn import Tanh
|
6 |
+
|
7 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
|
8 |
+
from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
|
9 |
+
from Architectures.GeneralLayers.Conformer import Conformer
|
10 |
+
from Architectures.GeneralLayers.DurationPredictor import DurationPredictor
|
11 |
+
from Architectures.GeneralLayers.LengthRegulator import LengthRegulator
|
12 |
+
from Architectures.GeneralLayers.VariancePredictor import VariancePredictor
|
13 |
+
from Architectures.ToucanTTS.Glow import Glow
|
14 |
+
from Preprocessing.articulatory_features import get_feature_to_index_lookup
|
15 |
+
from Utility.utils import integrate_with_utt_embed
|
16 |
+
from Utility.utils import make_non_pad_mask
|
17 |
+
|
18 |
+
|
19 |
+
class ToucanTTS(torch.nn.Module):
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
weights,
|
23 |
+
config):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.config = config
|
27 |
+
config = dotwiz.DotWiz(config)
|
28 |
+
|
29 |
+
input_feature_dimensions = config.input_feature_dimensions
|
30 |
+
attention_dimension = config.attention_dimension
|
31 |
+
attention_heads = config.attention_heads
|
32 |
+
positionwise_conv_kernel_size = config.positionwise_conv_kernel_size
|
33 |
+
use_scaled_positional_encoding = config.use_scaled_positional_encoding
|
34 |
+
use_macaron_style_in_conformer = config.use_macaron_style_in_conformer
|
35 |
+
use_cnn_in_conformer = config.use_cnn_in_conformer
|
36 |
+
encoder_layers = config.encoder_layers
|
37 |
+
encoder_units = config.encoder_units
|
38 |
+
encoder_normalize_before = config.encoder_normalize_before
|
39 |
+
encoder_concat_after = config.encoder_concat_after
|
40 |
+
conformer_encoder_kernel_size = config.conformer_encoder_kernel_size
|
41 |
+
transformer_enc_dropout_rate = config.transformer_enc_dropout_rate
|
42 |
+
transformer_enc_positional_dropout_rate = config.transformer_enc_positional_dropout_rate
|
43 |
+
transformer_enc_attn_dropout_rate = config.transformer_enc_attn_dropout_rate
|
44 |
+
decoder_layers = config.decoder_layers
|
45 |
+
decoder_units = config.decoder_units
|
46 |
+
decoder_concat_after = config.decoder_concat_after
|
47 |
+
conformer_decoder_kernel_size = config.conformer_decoder_kernel_size
|
48 |
+
decoder_normalize_before = config.decoder_normalize_before
|
49 |
+
transformer_dec_dropout_rate = config.transformer_dec_dropout_rate
|
50 |
+
transformer_dec_positional_dropout_rate = config.transformer_dec_positional_dropout_rate
|
51 |
+
transformer_dec_attn_dropout_rate = config.transformer_dec_attn_dropout_rate
|
52 |
+
duration_predictor_layers = config.duration_predictor_layers
|
53 |
+
duration_predictor_kernel_size = config.duration_predictor_kernel_size
|
54 |
+
duration_predictor_dropout_rate = config.duration_predictor_dropout_rate
|
55 |
+
pitch_predictor_layers = config.pitch_predictor_layers
|
56 |
+
pitch_predictor_kernel_size = config.pitch_predictor_kernel_size
|
57 |
+
pitch_predictor_dropout = config.pitch_predictor_dropout
|
58 |
+
pitch_embed_kernel_size = config.pitch_embed_kernel_size
|
59 |
+
pitch_embed_dropout = config.pitch_embed_dropout
|
60 |
+
energy_predictor_layers = config.energy_predictor_layers
|
61 |
+
energy_predictor_kernel_size = config.energy_predictor_kernel_size
|
62 |
+
energy_predictor_dropout = config.energy_predictor_dropout
|
63 |
+
energy_embed_kernel_size = config.energy_embed_kernel_size
|
64 |
+
energy_embed_dropout = config.energy_embed_dropout
|
65 |
+
utt_embed_dim = config.utt_embed_dim
|
66 |
+
lang_embs = config.lang_embs
|
67 |
+
embedding_integration = config.embedding_integration
|
68 |
+
glow_kernel_size = config.glow_kernel_size
|
69 |
+
glow_blocks = config.glow_blocks
|
70 |
+
glow_layers = config.glow_layers
|
71 |
+
lang_emb_size = config.lang_emb_size
|
72 |
+
integrate_language_embedding_into_encoder_out = config.integrate_language_embedding_into_encoder_out
|
73 |
+
|
74 |
+
self.input_feature_dimensions = input_feature_dimensions
|
75 |
+
self.attention_dimension = attention_dimension
|
76 |
+
self.use_scaled_pos_enc = use_scaled_positional_encoding
|
77 |
+
self.multilingual_model = lang_embs is not None
|
78 |
+
self.multispeaker_model = utt_embed_dim is not None
|
79 |
+
self.integrate_language_embedding_into_encoder_out = integrate_language_embedding_into_encoder_out
|
80 |
+
self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
|
81 |
+
|
82 |
+
articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension))
|
83 |
+
self.encoder = Conformer(conformer_type="encoder",
|
84 |
+
attention_dim=attention_dimension,
|
85 |
+
attention_heads=attention_heads,
|
86 |
+
linear_units=encoder_units,
|
87 |
+
num_blocks=encoder_layers,
|
88 |
+
input_layer=articulatory_feature_embedding,
|
89 |
+
dropout_rate=transformer_enc_dropout_rate,
|
90 |
+
positional_dropout_rate=transformer_enc_positional_dropout_rate,
|
91 |
+
attention_dropout_rate=transformer_enc_attn_dropout_rate,
|
92 |
+
normalize_before=encoder_normalize_before,
|
93 |
+
concat_after=encoder_concat_after,
|
94 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
95 |
+
macaron_style=use_macaron_style_in_conformer,
|
96 |
+
use_cnn_module=True,
|
97 |
+
cnn_module_kernel=conformer_encoder_kernel_size,
|
98 |
+
zero_triu=False,
|
99 |
+
utt_embed=utt_embed_dim,
|
100 |
+
lang_embs=lang_embs,
|
101 |
+
lang_emb_size=lang_emb_size,
|
102 |
+
use_output_norm=True,
|
103 |
+
embedding_integration=embedding_integration)
|
104 |
+
|
105 |
+
if self.integrate_language_embedding_into_encoder_out:
|
106 |
+
if embedding_integration == "AdaIN":
|
107 |
+
self.language_embedding_infusion = AdaIN1d(style_dim=lang_emb_size, num_features=attention_dimension)
|
108 |
+
elif embedding_integration == "ConditionalLayerNorm":
|
109 |
+
self.language_embedding_infusion = ConditionalLayerNorm(speaker_embedding_dim=lang_emb_size, hidden_dim=attention_dimension)
|
110 |
+
else:
|
111 |
+
self.language_embedding_infusion = torch.nn.Linear(attention_dimension + lang_emb_size, attention_dimension)
|
112 |
+
|
113 |
+
self.duration_predictor = DurationPredictor(idim=attention_dimension,
|
114 |
+
n_layers=duration_predictor_layers,
|
115 |
+
n_chans=attention_dimension,
|
116 |
+
kernel_size=duration_predictor_kernel_size,
|
117 |
+
dropout_rate=duration_predictor_dropout_rate,
|
118 |
+
utt_embed_dim=utt_embed_dim,
|
119 |
+
embedding_integration=embedding_integration)
|
120 |
+
|
121 |
+
self.pitch_predictor = VariancePredictor(idim=attention_dimension,
|
122 |
+
n_layers=pitch_predictor_layers,
|
123 |
+
n_chans=attention_dimension,
|
124 |
+
kernel_size=pitch_predictor_kernel_size,
|
125 |
+
dropout_rate=pitch_predictor_dropout,
|
126 |
+
utt_embed_dim=utt_embed_dim,
|
127 |
+
embedding_integration=embedding_integration)
|
128 |
+
|
129 |
+
self.energy_predictor = VariancePredictor(idim=attention_dimension,
|
130 |
+
n_layers=energy_predictor_layers,
|
131 |
+
n_chans=attention_dimension,
|
132 |
+
kernel_size=energy_predictor_kernel_size,
|
133 |
+
dropout_rate=energy_predictor_dropout,
|
134 |
+
utt_embed_dim=utt_embed_dim,
|
135 |
+
embedding_integration=embedding_integration)
|
136 |
+
|
137 |
+
self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1,
|
138 |
+
out_channels=attention_dimension,
|
139 |
+
kernel_size=pitch_embed_kernel_size,
|
140 |
+
padding=(pitch_embed_kernel_size - 1) // 2),
|
141 |
+
torch.nn.Dropout(pitch_embed_dropout))
|
142 |
+
|
143 |
+
self.energy_embed = Sequential(torch.nn.Conv1d(in_channels=1,
|
144 |
+
out_channels=attention_dimension,
|
145 |
+
kernel_size=energy_embed_kernel_size,
|
146 |
+
padding=(energy_embed_kernel_size - 1) // 2),
|
147 |
+
torch.nn.Dropout(energy_embed_dropout))
|
148 |
+
|
149 |
+
self.length_regulator = LengthRegulator()
|
150 |
+
|
151 |
+
self.decoder = Conformer(conformer_type="decoder",
|
152 |
+
attention_dim=attention_dimension,
|
153 |
+
attention_heads=attention_heads,
|
154 |
+
linear_units=decoder_units,
|
155 |
+
num_blocks=decoder_layers,
|
156 |
+
input_layer=None,
|
157 |
+
dropout_rate=transformer_dec_dropout_rate,
|
158 |
+
positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
159 |
+
attention_dropout_rate=transformer_dec_attn_dropout_rate,
|
160 |
+
normalize_before=decoder_normalize_before,
|
161 |
+
concat_after=decoder_concat_after,
|
162 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
163 |
+
macaron_style=use_macaron_style_in_conformer,
|
164 |
+
use_cnn_module=use_cnn_in_conformer,
|
165 |
+
cnn_module_kernel=conformer_decoder_kernel_size,
|
166 |
+
use_output_norm=not embedding_integration in ["AdaIN", "ConditionalLayerNorm"],
|
167 |
+
utt_embed=utt_embed_dim,
|
168 |
+
embedding_integration=embedding_integration)
|
169 |
+
|
170 |
+
self.output_projection = torch.nn.Linear(attention_dimension, 128)
|
171 |
+
|
172 |
+
self.post_flow = Glow(
|
173 |
+
in_channels=128,
|
174 |
+
hidden_channels=attention_dimension, # post_glow_hidden
|
175 |
+
kernel_size=glow_kernel_size, # post_glow_kernel_size
|
176 |
+
dilation_rate=1,
|
177 |
+
n_blocks=glow_blocks, # post_glow_n_blocks (original 12 in paper)
|
178 |
+
n_layers=glow_layers, # post_glow_n_block_layers (original 3 in paper)
|
179 |
+
n_split=4,
|
180 |
+
n_sqz=2,
|
181 |
+
text_condition_channels=attention_dimension,
|
182 |
+
share_cond_layers=False, # post_share_cond_layers
|
183 |
+
share_wn_layers=4,
|
184 |
+
sigmoid_scale=False,
|
185 |
+
condition_integration_projection=torch.nn.Conv1d(128 + attention_dimension, attention_dimension, 5, padding=2)
|
186 |
+
)
|
187 |
+
|
188 |
+
self.load_state_dict(weights)
|
189 |
+
self.eval()
|
190 |
+
|
191 |
+
def _forward(self,
|
192 |
+
text_tensors,
|
193 |
+
text_lengths,
|
194 |
+
gold_durations=None,
|
195 |
+
gold_pitch=None,
|
196 |
+
gold_energy=None,
|
197 |
+
duration_scaling_factor=1.0,
|
198 |
+
utterance_embedding=None,
|
199 |
+
lang_ids=None,
|
200 |
+
pitch_variance_scale=1.0,
|
201 |
+
energy_variance_scale=1.0,
|
202 |
+
pause_duration_scaling_factor=1.0,
|
203 |
+
glow_sampling_temperature=0.2):
|
204 |
+
|
205 |
+
if not self.multilingual_model:
|
206 |
+
lang_ids = None
|
207 |
+
|
208 |
+
if not self.multispeaker_model:
|
209 |
+
utterance_embedding = None
|
210 |
+
else:
|
211 |
+
utterance_embedding = torch.nn.functional.normalize(utterance_embedding)
|
212 |
+
|
213 |
+
# encoding the texts
|
214 |
+
text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2)
|
215 |
+
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
|
216 |
+
|
217 |
+
if self.integrate_language_embedding_into_encoder_out:
|
218 |
+
lang_embs = self.encoder.language_embedding(lang_ids).squeeze(-1).detach()
|
219 |
+
encoded_texts = integrate_with_utt_embed(hs=encoded_texts, utt_embeddings=lang_embs, projection=self.language_embedding_infusion, embedding_training=self.use_conditional_layernorm_embedding_integration)
|
220 |
+
|
221 |
+
# predicting pitch, energy and durations
|
222 |
+
pitch_predictions = self.pitch_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_pitch is None else gold_pitch
|
223 |
+
energy_predictions = self.energy_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_energy is None else gold_energy
|
224 |
+
predicted_durations = self.duration_predictor.inference(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_durations is None else gold_durations
|
225 |
+
|
226 |
+
# modifying the predictions with control parameters
|
227 |
+
for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
|
228 |
+
if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
|
229 |
+
predicted_durations[0][phoneme_index] = 0
|
230 |
+
if phoneme_vector[get_feature_to_index_lookup()["silence"]] == 1 and pause_duration_scaling_factor != 1.0:
|
231 |
+
predicted_durations[0][phoneme_index] = torch.round(predicted_durations[0][phoneme_index].float() * pause_duration_scaling_factor).long()
|
232 |
+
if duration_scaling_factor != 1.0:
|
233 |
+
assert duration_scaling_factor > 0
|
234 |
+
predicted_durations = torch.round(predicted_durations.float() * duration_scaling_factor).long()
|
235 |
+
pitch_predictions = make_near_zero_to_zero(pitch_predictions.squeeze(0)).unsqueeze(0)
|
236 |
+
energy_predictions = make_near_zero_to_zero(energy_predictions.squeeze(0)).unsqueeze(0)
|
237 |
+
pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale)
|
238 |
+
energy_predictions = _scale_variance(energy_predictions, energy_variance_scale)
|
239 |
+
|
240 |
+
# enriching the text with pitch and energy info
|
241 |
+
embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
|
242 |
+
embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
|
243 |
+
enriched_encoded_texts = encoded_texts + embedded_pitch_curve + embedded_energy_curve
|
244 |
+
|
245 |
+
# predicting durations for text and upsampling accordingly
|
246 |
+
upsampled_enriched_encoded_texts = self.length_regulator(enriched_encoded_texts, predicted_durations)
|
247 |
+
|
248 |
+
# decoding spectrogram
|
249 |
+
decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, None, utterance_embedding=utterance_embedding)
|
250 |
+
|
251 |
+
frames = self.output_projection(decoded_speech)
|
252 |
+
|
253 |
+
refined_codec_frames = self.post_flow(tgt_mels=None, infer=True, mel_out=frames, encoded_texts=upsampled_enriched_encoded_texts, tgt_nonpadding=None, glow_sampling_temperature=glow_sampling_temperature)
|
254 |
+
|
255 |
+
return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze()
|
256 |
+
|
257 |
+
@torch.inference_mode()
|
258 |
+
def forward(self,
|
259 |
+
text,
|
260 |
+
durations=None,
|
261 |
+
pitch=None,
|
262 |
+
energy=None,
|
263 |
+
utterance_embedding=None,
|
264 |
+
return_duration_pitch_energy=False,
|
265 |
+
lang_id=None,
|
266 |
+
duration_scaling_factor=1.0,
|
267 |
+
pitch_variance_scale=1.0,
|
268 |
+
energy_variance_scale=1.0,
|
269 |
+
pause_duration_scaling_factor=1.0,
|
270 |
+
glow_sampling_temperature=0.2):
|
271 |
+
"""
|
272 |
+
Generate the sequence of spectrogram frames given the sequence of vectorized phonemes.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
text: input sequence of vectorized phonemes
|
276 |
+
durations: durations to be used (optional, if not provided, they will be predicted)
|
277 |
+
pitch: token-averaged pitch curve to be used (optional, if not provided, it will be predicted)
|
278 |
+
energy: token-averaged energy curve to be used (optional, if not provided, it will be predicted)
|
279 |
+
return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
|
280 |
+
utterance_embedding: embedding of speaker information
|
281 |
+
lang_id: id to be fed into the embedding layer that contains language information
|
282 |
+
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
|
283 |
+
1.0 means no scaling happens, higher values increase durations for the whole
|
284 |
+
utterance, lower values decrease durations for the whole utterance.
|
285 |
+
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
|
286 |
+
1.0 means no scaling happens, higher values increase variance of the pitch curve,
|
287 |
+
lower values decrease variance of the pitch curve.
|
288 |
+
energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
|
289 |
+
1.0 means no scaling happens, higher values increase variance of the energy curve,
|
290 |
+
lower values decrease variance of the energy curve.
|
291 |
+
pause_duration_scaling_factor: reasonable values are 0.6 < scale < 1.4.
|
292 |
+
scales the durations of pauses on top of the regular duration scaling
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
features spectrogram
|
296 |
+
|
297 |
+
"""
|
298 |
+
# setup batch axis
|
299 |
+
text_length = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device)
|
300 |
+
if durations is not None:
|
301 |
+
durations = durations.unsqueeze(0).to(text.device)
|
302 |
+
if pitch is not None:
|
303 |
+
pitch = pitch.unsqueeze(0).to(text.device)
|
304 |
+
if energy is not None:
|
305 |
+
energy = energy.unsqueeze(0).to(text.device)
|
306 |
+
if lang_id is not None:
|
307 |
+
lang_id = lang_id.to(text.device)
|
308 |
+
|
309 |
+
outs, \
|
310 |
+
predicted_durations, \
|
311 |
+
pitch_predictions, \
|
312 |
+
energy_predictions = self._forward(text.unsqueeze(0),
|
313 |
+
text_length,
|
314 |
+
gold_durations=durations,
|
315 |
+
gold_pitch=pitch,
|
316 |
+
gold_energy=energy,
|
317 |
+
utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id,
|
318 |
+
duration_scaling_factor=duration_scaling_factor,
|
319 |
+
pitch_variance_scale=pitch_variance_scale,
|
320 |
+
energy_variance_scale=energy_variance_scale,
|
321 |
+
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
322 |
+
glow_sampling_temperature=glow_sampling_temperature)
|
323 |
+
|
324 |
+
if return_duration_pitch_energy:
|
325 |
+
return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions
|
326 |
+
return outs.squeeze().transpose(0, 1)
|
327 |
+
|
328 |
+
def store_inverse_all(self):
|
329 |
+
def remove_weight_norm(m):
|
330 |
+
try:
|
331 |
+
torch.nn.utils.remove_weight_norm(m)
|
332 |
+
except ValueError: # this module didn't have weight norm
|
333 |
+
return
|
334 |
+
self.post_flow.store_inverse()
|
335 |
+
self.apply(remove_weight_norm)
|
336 |
+
|
337 |
+
|
338 |
+
def _scale_variance(sequence, scale):
|
339 |
+
if scale == 1.0:
|
340 |
+
return sequence
|
341 |
+
average = sequence[0][sequence[0] != 0.0].mean()
|
342 |
+
sequence = sequence - average # center sequence around 0
|
343 |
+
sequence = sequence * scale # scale the variance
|
344 |
+
sequence = sequence + average # move center back to original with changed variance
|
345 |
+
for sequence_index in range(len(sequence[0])):
|
346 |
+
if sequence[0][sequence_index] < 0.0:
|
347 |
+
sequence[0][sequence_index] = 0.0
|
348 |
+
return sequence
|
349 |
+
|
350 |
+
|
351 |
+
def smooth_time_series(matrix, n_neighbors):
|
352 |
+
"""
|
353 |
+
Smooth a 2D matrix along the time axis using a moving average.
|
354 |
+
|
355 |
+
Parameters:
|
356 |
+
- matrix (torch.Tensor): Input matrix (2D tensor) representing the time series.
|
357 |
+
- n_neighbors (int): Number of neighboring rows to include in the moving average.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
- torch.Tensor: Smoothed matrix.
|
361 |
+
"""
|
362 |
+
smoothed_matrix = torch.zeros_like(matrix)
|
363 |
+
for i in range(matrix.size(0)):
|
364 |
+
lower = max(0, i - n_neighbors)
|
365 |
+
upper = min(matrix.size(0), i + n_neighbors + 1)
|
366 |
+
smoothed_matrix[i] = torch.mean(matrix[lower:upper], dim=0)
|
367 |
+
|
368 |
+
return smoothed_matrix
|
369 |
+
|
370 |
+
|
371 |
+
def make_near_zero_to_zero(sequence):
|
372 |
+
for index in range(len(sequence)):
|
373 |
+
if sequence[index] < 0.2:
|
374 |
+
sequence[index] = 0.0
|
375 |
+
return sequence
|
Architectures/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from Preprocessing.multilinguality.create_distance_lookups import CacheCreator
|
7 |
+
from Utility.utils import load_json_from_path
|
8 |
+
|
9 |
+
|
10 |
+
class LanguageEmbeddingSpaceStructureLoss(torch.nn.Module):
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
cc = CacheCreator(cache_root="Preprocessing/multilinguality")
|
15 |
+
if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
|
16 |
+
cc.create_tree_cache(cache_root="Preprocessing/multilinguality")
|
17 |
+
if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
|
18 |
+
cc.create_map_cache(cache_root="Preprocessing/multilinguality")
|
19 |
+
if not os.path.exists("Preprocessing/multilinguality/asp_dict.pkl"):
|
20 |
+
print("download asp file") # TODO downloader script with release
|
21 |
+
|
22 |
+
self.tree_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json')
|
23 |
+
self.map_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_map_dist.json')
|
24 |
+
with open("Preprocessing/multilinguality/asp_dict.pkl", 'rb') as dictfile:
|
25 |
+
self.asp_sim = pickle.load(dictfile)
|
26 |
+
self.lang_list = list(self.asp_sim.keys()) # list of all languages, to get lang_b's index
|
27 |
+
|
28 |
+
self.largest_value_map_dist = 0.0
|
29 |
+
for _, values in self.map_dist.items():
|
30 |
+
for _, value in values.items():
|
31 |
+
self.largest_value_map_dist = max(self.largest_value_map_dist, value)
|
32 |
+
|
33 |
+
self.iso_codes_to_ids = load_json_from_path("Preprocessing/multilinguality/iso_lookup.json")[-1]
|
34 |
+
self.ids_to_iso_codes = {v: k for k, v in self.iso_codes_to_ids.items()}
|
35 |
+
|
36 |
+
def forward(self, language_ids, language_embeddings):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
language_ids (Tensor): IDs of languages in the same order as the embeddings to calculate the distances according to the metrics.
|
40 |
+
language_embeddings (Tensor): Batch of language embeddings, of which the distances will be compared to the distances according to the metrics.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Tensor: Language Embedding Structure Loss Value
|
44 |
+
"""
|
45 |
+
|
46 |
+
losses = list()
|
47 |
+
for language_id_1, language_embedding_1 in zip(language_ids, language_embeddings):
|
48 |
+
for language_id_2, language_embedding_2 in zip(language_ids, language_embeddings):
|
49 |
+
if language_id_1 != language_id_2:
|
50 |
+
embed_dist = torch.nn.functional.l1_loss(language_embedding_1, language_embedding_2)
|
51 |
+
lang_1 = self.ids_to_iso_codes[language_id_1]
|
52 |
+
lang_2 = self.ids_to_iso_codes[language_id_2]
|
53 |
+
|
54 |
+
# Value Range Normalized Tree Dist
|
55 |
+
try:
|
56 |
+
tree_dist = self.tree_dist[lang_1][lang_2]
|
57 |
+
except KeyError:
|
58 |
+
tree_dist = self.tree_dist[lang_2][lang_1]
|
59 |
+
|
60 |
+
# Value Range Normalized Map Dist
|
61 |
+
try:
|
62 |
+
map_dist = self.map_dist[lang_1][lang_2] / self.largest_value_map_dist
|
63 |
+
except KeyError:
|
64 |
+
map_dist = self.map_dist[lang_2][lang_1] / self.largest_value_map_dist
|
65 |
+
|
66 |
+
# Value Range Normalized ASP Dist
|
67 |
+
lang_2_idx = self.lang_list.index(lang_2)
|
68 |
+
asp_dist = 1.0 - self.asp_sim[lang_1][lang_2_idx] # it's a similarity measure that goes from 0 to 1, so we subtract it from 1 to turn it into a distance
|
69 |
+
|
70 |
+
# Average distance should be similar to embedding distance to bring some structure into the embedding-space
|
71 |
+
metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist) + torch.tensor(asp_dist)) / 3
|
72 |
+
losses.append(torch.nn.functional.l1_loss(embed_dist, metric_distance))
|
73 |
+
|
74 |
+
return sum(losses) / len(losses)
|