maxmax20160403 commited on
Commit
07a0a5e
1 Parent(s): adb4281

Upload 16 files

Browse files
Files changed (16) hide show
  1. app.py +108 -0
  2. attentions.py +417 -0
  3. bert/ProsodyModel.py +75 -0
  4. bert/__init__.py +1 -0
  5. bert/config.json +19 -0
  6. bert/prosody_tool.py +426 -0
  7. bert/vocab.txt +0 -0
  8. commons.py +163 -0
  9. configs/bert_vits.json +50 -0
  10. models.py +533 -0
  11. modules.py +522 -0
  12. text/__init__.py +447 -0
  13. text/symbols.py +71 -0
  14. transforms.py +210 -0
  15. utils.py +319 -0
  16. vits_pinyin.py +88 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import SynthesizerTrn
2
+ from vits_pinyin import VITS_PinYin
3
+ from text import cleaned_text_to_sequence
4
+ from text.symbols import symbols
5
+ import gradio as gr
6
+ import utils
7
+ import torch
8
+ import argparse
9
+ import os
10
+ import re
11
+ import logging
12
+
13
+ logging.getLogger('numba').setLevel(logging.WARNING)
14
+ limitation = os.getenv("SYSTEM") == "spaces"
15
+
16
+
17
+ def create_calback(net_g: SynthesizerTrn, tts_front: VITS_PinYin):
18
+ def tts_calback(text, dur_scale):
19
+ if limitation:
20
+ text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
21
+ max_len = 150
22
+ if text_len > max_len:
23
+ return "Error: Text is too long", None
24
+
25
+ phonemes, char_embeds = tts_front.chinese_to_phonemes(text)
26
+ input_ids = cleaned_text_to_sequence(phonemes)
27
+ with torch.no_grad():
28
+ x_tst = torch.LongTensor(input_ids).unsqueeze(0).to(device)
29
+ x_tst_lengths = torch.LongTensor([len(input_ids)]).to(device)
30
+ x_tst_prosody = torch.FloatTensor(
31
+ char_embeds).unsqueeze(0).to(device)
32
+ audio = net_g.infer(x_tst, x_tst_lengths, x_tst_prosody, noise_scale=0.5,
33
+ length_scale=dur_scale)[0][0, 0].data.cpu().float().numpy()
34
+ del x_tst, x_tst_lengths, x_tst_prosody
35
+ return "Success", (16000, audio)
36
+
37
+ return tts_calback
38
+
39
+
40
+ example = [['天空呈现的透心的蓝,像极了当年。总在这样的时候,透过窗棂,心,在天空里无尽的游弋!柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。', 1],
41
+ ['风的影子翻阅过淡蓝色的信笺,柔和的文字浅浅地漫过我安静的眸,一如几朵悠闲的云儿,忽而氤氲成汽,忽而修饰成花,铅华洗尽后的透彻和靓丽,爽爽朗朗,轻轻盈盈', 1],
42
+ ['时光仿佛有穿越到了从前,在你诗情画意的眼波中,在你舒适浪漫的暇思里,我如风中的思绪徜徉广阔天际,仿佛一片沾染了快乐的羽毛,在云环影绕颤动里浸润着风的呼吸,风的诗韵,那清新的耳语,那婉约的甜蜜,那恬淡的温馨,将一腔情澜染得愈发的缠绵。', 1],]
43
+
44
+
45
+ if __name__ == "__main__":
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--share", action="store_true",
48
+ default=False, help="share gradio app")
49
+ args = parser.parse_args()
50
+
51
+ device = torch.device("cpu")
52
+
53
+ # pinyin
54
+ tts_front = VITS_PinYin("./bert", device)
55
+
56
+ # config
57
+ hps = utils.get_hparams_from_file("./configs/bert_vits.json")
58
+
59
+ # model
60
+ net_g = SynthesizerTrn(
61
+ len(symbols),
62
+ hps.data.filter_length // 2 + 1,
63
+ hps.train.segment_size // hps.data.hop_length,
64
+ **hps.model)
65
+
66
+ model_path = "vits_bert_model.pth"
67
+ utils.load_model(model_path, net_g)
68
+ net_g.eval()
69
+ net_g.to(device)
70
+
71
+ tts_calback = create_calback(net_g, tts_front)
72
+
73
+ app = gr.Blocks()
74
+ with app:
75
+ gr.Markdown("# Best TTS based on BERT and VITS with some Natural Speech Features Of Microsoft\n\n"
76
+ "code : github.com/PlayVoice/vits_chinese\n\n"
77
+ "1, Hidden prosody embedding from BERT,get natural pauses in grammar\n\n"
78
+ "2, Infer loss from NaturalSpeech,get less sound error\n\n"
79
+ "3, Framework of VITS,get high audio quality\n\n"
80
+ "<video id='video' controls='' preload='yes'>\n\n"
81
+ "<source id='mp4' src='https://user-images.githubusercontent.com/16432329/220678182-4775dec8-9229-4578-870f-2eebc3a5d660.mp4' type='video/mp4'>\n\n"
82
+ "</videos>\n\n"
83
+ )
84
+
85
+ with gr.Tabs():
86
+ with gr.TabItem("TTS"):
87
+ with gr.Row():
88
+ with gr.Column():
89
+ textbox = gr.TextArea(label="Text",
90
+ placeholder="Type your sentence here (Maximum 150 words)",
91
+ value="中文语音合成", elem_id=f"tts-input")
92
+ duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
93
+ label='速度 Speed')
94
+ with gr.Column():
95
+ text_output = gr.Textbox(label="Message")
96
+ audio_output = gr.Audio(
97
+ label="Output Audio", elem_id="tts-audio")
98
+ btn = gr.Button("Generate!")
99
+ btn.click(tts_calback,
100
+ inputs=[textbox, duration_slider],
101
+ outputs=[text_output, audio_output])
102
+ gr.Examples(
103
+ examples=example,
104
+ inputs=[textbox, duration_slider],
105
+ outputs=[text_output, audio_output],
106
+ fn=tts_calback
107
+ )
108
+ app.queue(concurrency_count=3).launch(show_api=False, share=args.share)
attentions.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ import commons
9
+ import modules
10
+ from modules import LayerNorm
11
+
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_channels,
17
+ filter_channels,
18
+ n_heads,
19
+ n_layers,
20
+ kernel_size=1,
21
+ p_dropout=0.0,
22
+ window_size=4,
23
+ **kwargs
24
+ ):
25
+ super().__init__()
26
+ self.hidden_channels = hidden_channels
27
+ self.filter_channels = filter_channels
28
+ self.n_heads = n_heads
29
+ self.n_layers = n_layers
30
+ self.kernel_size = kernel_size
31
+ self.p_dropout = p_dropout
32
+ self.window_size = window_size
33
+
34
+ self.drop = nn.Dropout(p_dropout)
35
+ self.attn_layers = nn.ModuleList()
36
+ self.norm_layers_1 = nn.ModuleList()
37
+ self.ffn_layers = nn.ModuleList()
38
+ self.norm_layers_2 = nn.ModuleList()
39
+ for i in range(self.n_layers):
40
+ self.attn_layers.append(
41
+ MultiHeadAttention(
42
+ hidden_channels,
43
+ hidden_channels,
44
+ n_heads,
45
+ p_dropout=p_dropout,
46
+ window_size=window_size,
47
+ )
48
+ )
49
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
50
+ self.ffn_layers.append(
51
+ FFN(
52
+ hidden_channels,
53
+ hidden_channels,
54
+ filter_channels,
55
+ kernel_size,
56
+ p_dropout=p_dropout,
57
+ )
58
+ )
59
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
60
+
61
+ def forward(self, x, x_mask):
62
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
+ x = x * x_mask
64
+ for i in range(self.n_layers):
65
+ y = self.attn_layers[i](x, x, attn_mask)
66
+ y = self.drop(y)
67
+ x = self.norm_layers_1[i](x + y)
68
+
69
+ y = self.ffn_layers[i](x, x_mask)
70
+ y = self.drop(y)
71
+ x = self.norm_layers_2[i](x + y)
72
+ x = x * x_mask
73
+ return x
74
+
75
+
76
+ class Decoder(nn.Module):
77
+ def __init__(
78
+ self,
79
+ hidden_channels,
80
+ filter_channels,
81
+ n_heads,
82
+ n_layers,
83
+ kernel_size=1,
84
+ p_dropout=0.0,
85
+ proximal_bias=False,
86
+ proximal_init=True,
87
+ **kwargs
88
+ ):
89
+ super().__init__()
90
+ self.hidden_channels = hidden_channels
91
+ self.filter_channels = filter_channels
92
+ self.n_heads = n_heads
93
+ self.n_layers = n_layers
94
+ self.kernel_size = kernel_size
95
+ self.p_dropout = p_dropout
96
+ self.proximal_bias = proximal_bias
97
+ self.proximal_init = proximal_init
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.self_attn_layers = nn.ModuleList()
101
+ self.norm_layers_0 = nn.ModuleList()
102
+ self.encdec_attn_layers = nn.ModuleList()
103
+ self.norm_layers_1 = nn.ModuleList()
104
+ self.ffn_layers = nn.ModuleList()
105
+ self.norm_layers_2 = nn.ModuleList()
106
+ for i in range(self.n_layers):
107
+ self.self_attn_layers.append(
108
+ MultiHeadAttention(
109
+ hidden_channels,
110
+ hidden_channels,
111
+ n_heads,
112
+ p_dropout=p_dropout,
113
+ proximal_bias=proximal_bias,
114
+ proximal_init=proximal_init,
115
+ )
116
+ )
117
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
118
+ self.encdec_attn_layers.append(
119
+ MultiHeadAttention(
120
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
+ )
122
+ )
123
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
124
+ self.ffn_layers.append(
125
+ FFN(
126
+ hidden_channels,
127
+ hidden_channels,
128
+ filter_channels,
129
+ kernel_size,
130
+ p_dropout=p_dropout,
131
+ causal=True,
132
+ )
133
+ )
134
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
135
+
136
+ def forward(self, x, x_mask, h, h_mask):
137
+ """
138
+ x: decoder input
139
+ h: encoder output
140
+ """
141
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
+ device=x.device, dtype=x.dtype
143
+ )
144
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
+ x = x * x_mask
146
+ for i in range(self.n_layers):
147
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
148
+ y = self.drop(y)
149
+ x = self.norm_layers_0[i](x + y)
150
+
151
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
+ y = self.drop(y)
153
+ x = self.norm_layers_1[i](x + y)
154
+
155
+ y = self.ffn_layers[i](x, x_mask)
156
+ y = self.drop(y)
157
+ x = self.norm_layers_2[i](x + y)
158
+ x = x * x_mask
159
+ return x
160
+
161
+
162
+ class MultiHeadAttention(nn.Module):
163
+ def __init__(
164
+ self,
165
+ channels,
166
+ out_channels,
167
+ n_heads,
168
+ p_dropout=0.0,
169
+ window_size=None,
170
+ heads_share=True,
171
+ block_length=None,
172
+ proximal_bias=False,
173
+ proximal_init=False,
174
+ ):
175
+ super().__init__()
176
+ assert channels % n_heads == 0
177
+
178
+ self.channels = channels
179
+ self.out_channels = out_channels
180
+ self.n_heads = n_heads
181
+ self.p_dropout = p_dropout
182
+ self.window_size = window_size
183
+ self.heads_share = heads_share
184
+ self.block_length = block_length
185
+ self.proximal_bias = proximal_bias
186
+ self.proximal_init = proximal_init
187
+ self.attn = None
188
+
189
+ self.k_channels = channels // n_heads
190
+ self.conv_q = nn.Conv1d(channels, channels, 1)
191
+ self.conv_k = nn.Conv1d(channels, channels, 1)
192
+ self.conv_v = nn.Conv1d(channels, channels, 1)
193
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
+ self.drop = nn.Dropout(p_dropout)
195
+
196
+ if window_size is not None:
197
+ n_heads_rel = 1 if heads_share else n_heads
198
+ rel_stddev = self.k_channels**-0.5
199
+ self.emb_rel_k = nn.Parameter(
200
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
+ * rel_stddev
202
+ )
203
+ self.emb_rel_v = nn.Parameter(
204
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
+ * rel_stddev
206
+ )
207
+
208
+ nn.init.xavier_uniform_(self.conv_q.weight)
209
+ nn.init.xavier_uniform_(self.conv_k.weight)
210
+ nn.init.xavier_uniform_(self.conv_v.weight)
211
+ if proximal_init:
212
+ with torch.no_grad():
213
+ self.conv_k.weight.copy_(self.conv_q.weight)
214
+ self.conv_k.bias.copy_(self.conv_q.bias)
215
+
216
+ def forward(self, x, c, attn_mask=None):
217
+ q = self.conv_q(x)
218
+ k = self.conv_k(c)
219
+ v = self.conv_v(c)
220
+
221
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
+
223
+ x = self.conv_o(x)
224
+ return x
225
+
226
+ def attention(self, query, key, value, mask=None):
227
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
228
+ b, d, t_s, t_t = (*key.size(), query.size(2))
229
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
+
233
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
+ if self.window_size is not None:
235
+ assert (
236
+ t_s == t_t
237
+ ), "Relative attention is only available for self-attention."
238
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
+ rel_logits = self._matmul_with_relative_keys(
240
+ query / math.sqrt(self.k_channels), key_relative_embeddings
241
+ )
242
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
243
+ scores = scores + scores_local
244
+ if self.proximal_bias:
245
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
246
+ scores = scores + self._attention_bias_proximal(t_s).to(
247
+ device=scores.device, dtype=scores.dtype
248
+ )
249
+ if mask is not None:
250
+ scores = scores.masked_fill(mask == 0, -1e4)
251
+ if self.block_length is not None:
252
+ assert (
253
+ t_s == t_t
254
+ ), "Local attention is only available for self-attention."
255
+ block_mask = (
256
+ torch.ones_like(scores)
257
+ .triu(-self.block_length)
258
+ .tril(self.block_length)
259
+ )
260
+ scores = scores.masked_fill(block_mask == 0, -1e4)
261
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
+ p_attn = self.drop(p_attn)
263
+ output = torch.matmul(p_attn, value)
264
+ if self.window_size is not None:
265
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
266
+ value_relative_embeddings = self._get_relative_embeddings(
267
+ self.emb_rel_v, t_s
268
+ )
269
+ output = output + self._matmul_with_relative_values(
270
+ relative_weights, value_relative_embeddings
271
+ )
272
+ output = (
273
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
274
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
+ return output, p_attn
276
+
277
+ def _matmul_with_relative_values(self, x, y):
278
+ """
279
+ x: [b, h, l, m]
280
+ y: [h or 1, m, d]
281
+ ret: [b, h, l, d]
282
+ """
283
+ ret = torch.matmul(x, y.unsqueeze(0))
284
+ return ret
285
+
286
+ def _matmul_with_relative_keys(self, x, y):
287
+ """
288
+ x: [b, h, l, d]
289
+ y: [h or 1, m, d]
290
+ ret: [b, h, l, m]
291
+ """
292
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
+ return ret
294
+
295
+ def _get_relative_embeddings(self, relative_embeddings, length):
296
+ max_relative_position = 2 * self.window_size + 1
297
+ # Pad first before slice to avoid using cond ops.
298
+ pad_length = max(length - (self.window_size + 1), 0)
299
+ slice_start_position = max((self.window_size + 1) - length, 0)
300
+ slice_end_position = slice_start_position + 2 * length - 1
301
+ if pad_length > 0:
302
+ padded_relative_embeddings = F.pad(
303
+ relative_embeddings,
304
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
+ )
306
+ else:
307
+ padded_relative_embeddings = relative_embeddings
308
+ used_relative_embeddings = padded_relative_embeddings[
309
+ :, slice_start_position:slice_end_position
310
+ ]
311
+ return used_relative_embeddings
312
+
313
+ def _relative_position_to_absolute_position(self, x):
314
+ """
315
+ x: [b, h, l, 2*l-1]
316
+ ret: [b, h, l, l]
317
+ """
318
+ batch, heads, length, _ = x.size()
319
+ # Concat columns of pad to shift from relative to absolute indexing.
320
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
+
322
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
+ x_flat = x.view([batch, heads, length * 2 * length])
324
+ x_flat = F.pad(
325
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
+ )
327
+
328
+ # Reshape and slice out the padded elements.
329
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
+ :, :, :length, length - 1 :
331
+ ]
332
+ return x_final
333
+
334
+ def _absolute_position_to_relative_position(self, x):
335
+ """
336
+ x: [b, h, l, l]
337
+ ret: [b, h, l, 2*l-1]
338
+ """
339
+ batch, heads, length, _ = x.size()
340
+ # padd along column
341
+ x = F.pad(
342
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
+ )
344
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
+ # add 0's in the beginning that will skew the elements after reshape
346
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
+ return x_final
349
+
350
+ def _attention_bias_proximal(self, length):
351
+ """Bias for self-attention to encourage attention to close positions.
352
+ Args:
353
+ length: an integer scalar.
354
+ Returns:
355
+ a Tensor with shape [1, 1, length, length]
356
+ """
357
+ r = torch.arange(length, dtype=torch.float32)
358
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
+
361
+
362
+ class FFN(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels,
366
+ out_channels,
367
+ filter_channels,
368
+ kernel_size,
369
+ p_dropout=0.0,
370
+ activation=None,
371
+ causal=False,
372
+ ):
373
+ super().__init__()
374
+ self.in_channels = in_channels
375
+ self.out_channels = out_channels
376
+ self.filter_channels = filter_channels
377
+ self.kernel_size = kernel_size
378
+ self.p_dropout = p_dropout
379
+ self.activation = activation
380
+ self.causal = causal
381
+
382
+ if causal:
383
+ self.padding = self._causal_padding
384
+ else:
385
+ self.padding = self._same_padding
386
+
387
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
+ self.drop = nn.Dropout(p_dropout)
390
+
391
+ def forward(self, x, x_mask):
392
+ x = self.conv_1(self.padding(x * x_mask))
393
+ if self.activation == "gelu":
394
+ x = x * torch.sigmoid(1.702 * x)
395
+ else:
396
+ x = torch.relu(x)
397
+ x = self.drop(x)
398
+ x = self.conv_2(self.padding(x * x_mask))
399
+ return x * x_mask
400
+
401
+ def _causal_padding(self, x):
402
+ if self.kernel_size == 1:
403
+ return x
404
+ pad_l = self.kernel_size - 1
405
+ pad_r = 0
406
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
+ x = F.pad(x, commons.convert_pad_shape(padding))
408
+ return x
409
+
410
+ def _same_padding(self, x):
411
+ if self.kernel_size == 1:
412
+ return x
413
+ pad_l = (self.kernel_size - 1) // 2
414
+ pad_r = self.kernel_size // 2
415
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
+ x = F.pad(x, commons.convert_pad_shape(padding))
417
+ return x
bert/ProsodyModel.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import BertModel, BertConfig, BertTokenizer
7
+
8
+
9
+ class CharEmbedding(nn.Module):
10
+ def __init__(self, model_dir):
11
+ super().__init__()
12
+ self.tokenizer = BertTokenizer.from_pretrained(model_dir)
13
+ self.bert_config = BertConfig.from_pretrained(model_dir)
14
+ self.hidden_size = self.bert_config.hidden_size
15
+ self.bert = BertModel(self.bert_config)
16
+ self.proj = nn.Linear(self.hidden_size, 256)
17
+ self.linear = nn.Linear(256, 3)
18
+
19
+ def text2Token(self, text):
20
+ token = self.tokenizer.tokenize(text)
21
+ txtid = self.tokenizer.convert_tokens_to_ids(token)
22
+ return txtid
23
+
24
+ def forward(self, inputs_ids, inputs_masks, tokens_type_ids):
25
+ out_seq = self.bert(input_ids=inputs_ids,
26
+ attention_mask=inputs_masks,
27
+ token_type_ids=tokens_type_ids)[0]
28
+ out_seq = self.proj(out_seq)
29
+ return out_seq
30
+
31
+
32
+ class TTSProsody(object):
33
+ def __init__(self, path, device):
34
+ self.device = device
35
+ self.char_model = CharEmbedding(path)
36
+ self.char_model.load_state_dict(
37
+ torch.load(
38
+ os.path.join(path, 'prosody_model.pt'),
39
+ map_location="cpu"
40
+ ),
41
+ strict=False
42
+ )
43
+ self.char_model.eval()
44
+ self.char_model.to(self.device)
45
+
46
+ def get_char_embeds(self, text):
47
+ input_ids = self.char_model.text2Token(text)
48
+ input_masks = [1] * len(input_ids)
49
+ type_ids = [0] * len(input_ids)
50
+ input_ids = torch.LongTensor([input_ids]).to(self.device)
51
+ input_masks = torch.LongTensor([input_masks]).to(self.device)
52
+ type_ids = torch.LongTensor([type_ids]).to(self.device)
53
+
54
+ with torch.no_grad():
55
+ char_embeds = self.char_model(
56
+ input_ids, input_masks, type_ids).squeeze(0).cpu()
57
+ return char_embeds
58
+
59
+ def expand_for_phone(self, char_embeds, length): # length of phones for char
60
+ assert char_embeds.size(0) == len(length)
61
+ expand_vecs = list()
62
+ for vec, leng in zip(char_embeds, length):
63
+ vec = vec.expand(leng, -1)
64
+ expand_vecs.append(vec)
65
+ expand_embeds = torch.cat(expand_vecs, 0)
66
+ assert expand_embeds.size(0) == sum(length)
67
+ return expand_embeds.numpy()
68
+
69
+
70
+ if __name__ == "__main__":
71
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+ prosody = TTSProsody('./bert/', device)
73
+ while True:
74
+ text = input("请输入文本:")
75
+ prosody.get_char_embeds(text)
bert/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ProsodyModel import TTSProsody
bert/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_probs_dropout_prob": 0.1,
3
+ "directionality": "bidi",
4
+ "hidden_act": "gelu",
5
+ "hidden_dropout_prob": 0.1,
6
+ "hidden_size": 768,
7
+ "initializer_range": 0.02,
8
+ "intermediate_size": 3072,
9
+ "max_position_embeddings": 512,
10
+ "num_attention_heads": 12,
11
+ "num_hidden_layers": 12,
12
+ "pooler_fc_size": 768,
13
+ "pooler_num_attention_heads": 12,
14
+ "pooler_num_fc_layers": 3,
15
+ "pooler_size_per_head": 128,
16
+ "pooler_type": "first_token_transform",
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 21128
19
+ }
bert/prosody_tool.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def is_chinese(uchar):
2
+ if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
3
+ return True
4
+ else:
5
+ return False
6
+
7
+
8
+ pinyin_dict = {
9
+ "a": ("^", "a"),
10
+ "ai": ("^", "ai"),
11
+ "an": ("^", "an"),
12
+ "ang": ("^", "ang"),
13
+ "ao": ("^", "ao"),
14
+ "ba": ("b", "a"),
15
+ "bai": ("b", "ai"),
16
+ "ban": ("b", "an"),
17
+ "bang": ("b", "ang"),
18
+ "bao": ("b", "ao"),
19
+ "be": ("b", "e"),
20
+ "bei": ("b", "ei"),
21
+ "ben": ("b", "en"),
22
+ "beng": ("b", "eng"),
23
+ "bi": ("b", "i"),
24
+ "bian": ("b", "ian"),
25
+ "biao": ("b", "iao"),
26
+ "bie": ("b", "ie"),
27
+ "bin": ("b", "in"),
28
+ "bing": ("b", "ing"),
29
+ "bo": ("b", "o"),
30
+ "bu": ("b", "u"),
31
+ "ca": ("c", "a"),
32
+ "cai": ("c", "ai"),
33
+ "can": ("c", "an"),
34
+ "cang": ("c", "ang"),
35
+ "cao": ("c", "ao"),
36
+ "ce": ("c", "e"),
37
+ "cen": ("c", "en"),
38
+ "ceng": ("c", "eng"),
39
+ "cha": ("ch", "a"),
40
+ "chai": ("ch", "ai"),
41
+ "chan": ("ch", "an"),
42
+ "chang": ("ch", "ang"),
43
+ "chao": ("ch", "ao"),
44
+ "che": ("ch", "e"),
45
+ "chen": ("ch", "en"),
46
+ "cheng": ("ch", "eng"),
47
+ "chi": ("ch", "iii"),
48
+ "chong": ("ch", "ong"),
49
+ "chou": ("ch", "ou"),
50
+ "chu": ("ch", "u"),
51
+ "chua": ("ch", "ua"),
52
+ "chuai": ("ch", "uai"),
53
+ "chuan": ("ch", "uan"),
54
+ "chuang": ("ch", "uang"),
55
+ "chui": ("ch", "uei"),
56
+ "chun": ("ch", "uen"),
57
+ "chuo": ("ch", "uo"),
58
+ "ci": ("c", "ii"),
59
+ "cong": ("c", "ong"),
60
+ "cou": ("c", "ou"),
61
+ "cu": ("c", "u"),
62
+ "cuan": ("c", "uan"),
63
+ "cui": ("c", "uei"),
64
+ "cun": ("c", "uen"),
65
+ "cuo": ("c", "uo"),
66
+ "da": ("d", "a"),
67
+ "dai": ("d", "ai"),
68
+ "dan": ("d", "an"),
69
+ "dang": ("d", "ang"),
70
+ "dao": ("d", "ao"),
71
+ "de": ("d", "e"),
72
+ "dei": ("d", "ei"),
73
+ "den": ("d", "en"),
74
+ "deng": ("d", "eng"),
75
+ "di": ("d", "i"),
76
+ "dia": ("d", "ia"),
77
+ "dian": ("d", "ian"),
78
+ "diao": ("d", "iao"),
79
+ "die": ("d", "ie"),
80
+ "ding": ("d", "ing"),
81
+ "diu": ("d", "iou"),
82
+ "dong": ("d", "ong"),
83
+ "dou": ("d", "ou"),
84
+ "du": ("d", "u"),
85
+ "duan": ("d", "uan"),
86
+ "dui": ("d", "uei"),
87
+ "dun": ("d", "uen"),
88
+ "duo": ("d", "uo"),
89
+ "e": ("^", "e"),
90
+ "ei": ("^", "ei"),
91
+ "en": ("^", "en"),
92
+ "ng": ("^", "en"),
93
+ "eng": ("^", "eng"),
94
+ "er": ("^", "er"),
95
+ "fa": ("f", "a"),
96
+ "fan": ("f", "an"),
97
+ "fang": ("f", "ang"),
98
+ "fei": ("f", "ei"),
99
+ "fen": ("f", "en"),
100
+ "feng": ("f", "eng"),
101
+ "fo": ("f", "o"),
102
+ "fou": ("f", "ou"),
103
+ "fu": ("f", "u"),
104
+ "ga": ("g", "a"),
105
+ "gai": ("g", "ai"),
106
+ "gan": ("g", "an"),
107
+ "gang": ("g", "ang"),
108
+ "gao": ("g", "ao"),
109
+ "ge": ("g", "e"),
110
+ "gei": ("g", "ei"),
111
+ "gen": ("g", "en"),
112
+ "geng": ("g", "eng"),
113
+ "gong": ("g", "ong"),
114
+ "gou": ("g", "ou"),
115
+ "gu": ("g", "u"),
116
+ "gua": ("g", "ua"),
117
+ "guai": ("g", "uai"),
118
+ "guan": ("g", "uan"),
119
+ "guang": ("g", "uang"),
120
+ "gui": ("g", "uei"),
121
+ "gun": ("g", "uen"),
122
+ "guo": ("g", "uo"),
123
+ "ha": ("h", "a"),
124
+ "hai": ("h", "ai"),
125
+ "han": ("h", "an"),
126
+ "hang": ("h", "ang"),
127
+ "hao": ("h", "ao"),
128
+ "he": ("h", "e"),
129
+ "hei": ("h", "ei"),
130
+ "hen": ("h", "en"),
131
+ "heng": ("h", "eng"),
132
+ "hong": ("h", "ong"),
133
+ "hou": ("h", "ou"),
134
+ "hu": ("h", "u"),
135
+ "hua": ("h", "ua"),
136
+ "huai": ("h", "uai"),
137
+ "huan": ("h", "uan"),
138
+ "huang": ("h", "uang"),
139
+ "hui": ("h", "uei"),
140
+ "hun": ("h", "uen"),
141
+ "huo": ("h", "uo"),
142
+ "ji": ("j", "i"),
143
+ "jia": ("j", "ia"),
144
+ "jian": ("j", "ian"),
145
+ "jiang": ("j", "iang"),
146
+ "jiao": ("j", "iao"),
147
+ "jie": ("j", "ie"),
148
+ "jin": ("j", "in"),
149
+ "jing": ("j", "ing"),
150
+ "jiong": ("j", "iong"),
151
+ "jiu": ("j", "iou"),
152
+ "ju": ("j", "v"),
153
+ "juan": ("j", "van"),
154
+ "jue": ("j", "ve"),
155
+ "jun": ("j", "vn"),
156
+ "ka": ("k", "a"),
157
+ "kai": ("k", "ai"),
158
+ "kan": ("k", "an"),
159
+ "kang": ("k", "ang"),
160
+ "kao": ("k", "ao"),
161
+ "ke": ("k", "e"),
162
+ "kei": ("k", "ei"),
163
+ "ken": ("k", "en"),
164
+ "keng": ("k", "eng"),
165
+ "kong": ("k", "ong"),
166
+ "kou": ("k", "ou"),
167
+ "ku": ("k", "u"),
168
+ "kua": ("k", "ua"),
169
+ "kuai": ("k", "uai"),
170
+ "kuan": ("k", "uan"),
171
+ "kuang": ("k", "uang"),
172
+ "kui": ("k", "uei"),
173
+ "kun": ("k", "uen"),
174
+ "kuo": ("k", "uo"),
175
+ "la": ("l", "a"),
176
+ "lai": ("l", "ai"),
177
+ "lan": ("l", "an"),
178
+ "lang": ("l", "ang"),
179
+ "lao": ("l", "ao"),
180
+ "le": ("l", "e"),
181
+ "lei": ("l", "ei"),
182
+ "leng": ("l", "eng"),
183
+ "li": ("l", "i"),
184
+ "lia": ("l", "ia"),
185
+ "lian": ("l", "ian"),
186
+ "liang": ("l", "iang"),
187
+ "liao": ("l", "iao"),
188
+ "lie": ("l", "ie"),
189
+ "lin": ("l", "in"),
190
+ "ling": ("l", "ing"),
191
+ "liu": ("l", "iou"),
192
+ "lo": ("l", "o"),
193
+ "long": ("l", "ong"),
194
+ "lou": ("l", "ou"),
195
+ "lu": ("l", "u"),
196
+ "lv": ("l", "v"),
197
+ "luan": ("l", "uan"),
198
+ "lve": ("l", "ve"),
199
+ "lue": ("l", "ve"),
200
+ "lun": ("l", "uen"),
201
+ "luo": ("l", "uo"),
202
+ "ma": ("m", "a"),
203
+ "mai": ("m", "ai"),
204
+ "man": ("m", "an"),
205
+ "mang": ("m", "ang"),
206
+ "mao": ("m", "ao"),
207
+ "me": ("m", "e"),
208
+ "mei": ("m", "ei"),
209
+ "men": ("m", "en"),
210
+ "meng": ("m", "eng"),
211
+ "mi": ("m", "i"),
212
+ "mian": ("m", "ian"),
213
+ "miao": ("m", "iao"),
214
+ "mie": ("m", "ie"),
215
+ "min": ("m", "in"),
216
+ "ming": ("m", "ing"),
217
+ "miu": ("m", "iou"),
218
+ "mo": ("m", "o"),
219
+ "mou": ("m", "ou"),
220
+ "mu": ("m", "u"),
221
+ "na": ("n", "a"),
222
+ "nai": ("n", "ai"),
223
+ "nan": ("n", "an"),
224
+ "nang": ("n", "ang"),
225
+ "nao": ("n", "ao"),
226
+ "ne": ("n", "e"),
227
+ "nei": ("n", "ei"),
228
+ "nen": ("n", "en"),
229
+ "neng": ("n", "eng"),
230
+ "ni": ("n", "i"),
231
+ "nia": ("n", "ia"),
232
+ "nian": ("n", "ian"),
233
+ "niang": ("n", "iang"),
234
+ "niao": ("n", "iao"),
235
+ "nie": ("n", "ie"),
236
+ "nin": ("n", "in"),
237
+ "ning": ("n", "ing"),
238
+ "niu": ("n", "iou"),
239
+ "nong": ("n", "ong"),
240
+ "nou": ("n", "ou"),
241
+ "nu": ("n", "u"),
242
+ "nv": ("n", "v"),
243
+ "nuan": ("n", "uan"),
244
+ "nve": ("n", "ve"),
245
+ "nue": ("n", "ve"),
246
+ "nuo": ("n", "uo"),
247
+ "o": ("^", "o"),
248
+ "ou": ("^", "ou"),
249
+ "pa": ("p", "a"),
250
+ "pai": ("p", "ai"),
251
+ "pan": ("p", "an"),
252
+ "pang": ("p", "ang"),
253
+ "pao": ("p", "ao"),
254
+ "pe": ("p", "e"),
255
+ "pei": ("p", "ei"),
256
+ "pen": ("p", "en"),
257
+ "peng": ("p", "eng"),
258
+ "pi": ("p", "i"),
259
+ "pian": ("p", "ian"),
260
+ "piao": ("p", "iao"),
261
+ "pie": ("p", "ie"),
262
+ "pin": ("p", "in"),
263
+ "ping": ("p", "ing"),
264
+ "po": ("p", "o"),
265
+ "pou": ("p", "ou"),
266
+ "pu": ("p", "u"),
267
+ "qi": ("q", "i"),
268
+ "qia": ("q", "ia"),
269
+ "qian": ("q", "ian"),
270
+ "qiang": ("q", "iang"),
271
+ "qiao": ("q", "iao"),
272
+ "qie": ("q", "ie"),
273
+ "qin": ("q", "in"),
274
+ "qing": ("q", "ing"),
275
+ "qiong": ("q", "iong"),
276
+ "qiu": ("q", "iou"),
277
+ "qu": ("q", "v"),
278
+ "quan": ("q", "van"),
279
+ "que": ("q", "ve"),
280
+ "qun": ("q", "vn"),
281
+ "ran": ("r", "an"),
282
+ "rang": ("r", "ang"),
283
+ "rao": ("r", "ao"),
284
+ "re": ("r", "e"),
285
+ "ren": ("r", "en"),
286
+ "reng": ("r", "eng"),
287
+ "ri": ("r", "iii"),
288
+ "rong": ("r", "ong"),
289
+ "rou": ("r", "ou"),
290
+ "ru": ("r", "u"),
291
+ "rua": ("r", "ua"),
292
+ "ruan": ("r", "uan"),
293
+ "rui": ("r", "uei"),
294
+ "run": ("r", "uen"),
295
+ "ruo": ("r", "uo"),
296
+ "sa": ("s", "a"),
297
+ "sai": ("s", "ai"),
298
+ "san": ("s", "an"),
299
+ "sang": ("s", "ang"),
300
+ "sao": ("s", "ao"),
301
+ "se": ("s", "e"),
302
+ "sen": ("s", "en"),
303
+ "seng": ("s", "eng"),
304
+ "sha": ("sh", "a"),
305
+ "shai": ("sh", "ai"),
306
+ "shan": ("sh", "an"),
307
+ "shang": ("sh", "ang"),
308
+ "shao": ("sh", "ao"),
309
+ "she": ("sh", "e"),
310
+ "shei": ("sh", "ei"),
311
+ "shen": ("sh", "en"),
312
+ "sheng": ("sh", "eng"),
313
+ "shi": ("sh", "iii"),
314
+ "shou": ("sh", "ou"),
315
+ "shu": ("sh", "u"),
316
+ "shua": ("sh", "ua"),
317
+ "shuai": ("sh", "uai"),
318
+ "shuan": ("sh", "uan"),
319
+ "shuang": ("sh", "uang"),
320
+ "shui": ("sh", "uei"),
321
+ "shun": ("sh", "uen"),
322
+ "shuo": ("sh", "uo"),
323
+ "si": ("s", "ii"),
324
+ "song": ("s", "ong"),
325
+ "sou": ("s", "ou"),
326
+ "su": ("s", "u"),
327
+ "suan": ("s", "uan"),
328
+ "sui": ("s", "uei"),
329
+ "sun": ("s", "uen"),
330
+ "suo": ("s", "uo"),
331
+ "ta": ("t", "a"),
332
+ "tai": ("t", "ai"),
333
+ "tan": ("t", "an"),
334
+ "tang": ("t", "ang"),
335
+ "tao": ("t", "ao"),
336
+ "te": ("t", "e"),
337
+ "tei": ("t", "ei"),
338
+ "teng": ("t", "eng"),
339
+ "ti": ("t", "i"),
340
+ "tian": ("t", "ian"),
341
+ "tiao": ("t", "iao"),
342
+ "tie": ("t", "ie"),
343
+ "ting": ("t", "ing"),
344
+ "tong": ("t", "ong"),
345
+ "tou": ("t", "ou"),
346
+ "tu": ("t", "u"),
347
+ "tuan": ("t", "uan"),
348
+ "tui": ("t", "uei"),
349
+ "tun": ("t", "uen"),
350
+ "tuo": ("t", "uo"),
351
+ "wa": ("^", "ua"),
352
+ "wai": ("^", "uai"),
353
+ "wan": ("^", "uan"),
354
+ "wang": ("^", "uang"),
355
+ "wei": ("^", "uei"),
356
+ "wen": ("^", "uen"),
357
+ "weng": ("^", "ueng"),
358
+ "wo": ("^", "uo"),
359
+ "wu": ("^", "u"),
360
+ "xi": ("x", "i"),
361
+ "xia": ("x", "ia"),
362
+ "xian": ("x", "ian"),
363
+ "xiang": ("x", "iang"),
364
+ "xiao": ("x", "iao"),
365
+ "xie": ("x", "ie"),
366
+ "xin": ("x", "in"),
367
+ "xing": ("x", "ing"),
368
+ "xiong": ("x", "iong"),
369
+ "xiu": ("x", "iou"),
370
+ "xu": ("x", "v"),
371
+ "xuan": ("x", "van"),
372
+ "xue": ("x", "ve"),
373
+ "xun": ("x", "vn"),
374
+ "ya": ("^", "ia"),
375
+ "yan": ("^", "ian"),
376
+ "yang": ("^", "iang"),
377
+ "yao": ("^", "iao"),
378
+ "ye": ("^", "ie"),
379
+ "yi": ("^", "i"),
380
+ "yin": ("^", "in"),
381
+ "ying": ("^", "ing"),
382
+ "yo": ("^", "iou"),
383
+ "yong": ("^", "iong"),
384
+ "you": ("^", "iou"),
385
+ "yu": ("^", "v"),
386
+ "yuan": ("^", "van"),
387
+ "yue": ("^", "ve"),
388
+ "yun": ("^", "vn"),
389
+ "za": ("z", "a"),
390
+ "zai": ("z", "ai"),
391
+ "zan": ("z", "an"),
392
+ "zang": ("z", "ang"),
393
+ "zao": ("z", "ao"),
394
+ "ze": ("z", "e"),
395
+ "zei": ("z", "ei"),
396
+ "zen": ("z", "en"),
397
+ "zeng": ("z", "eng"),
398
+ "zha": ("zh", "a"),
399
+ "zhai": ("zh", "ai"),
400
+ "zhan": ("zh", "an"),
401
+ "zhang": ("zh", "ang"),
402
+ "zhao": ("zh", "ao"),
403
+ "zhe": ("zh", "e"),
404
+ "zhei": ("zh", "ei"),
405
+ "zhen": ("zh", "en"),
406
+ "zheng": ("zh", "eng"),
407
+ "zhi": ("zh", "iii"),
408
+ "zhong": ("zh", "ong"),
409
+ "zhou": ("zh", "ou"),
410
+ "zhu": ("zh", "u"),
411
+ "zhua": ("zh", "ua"),
412
+ "zhuai": ("zh", "uai"),
413
+ "zhuan": ("zh", "uan"),
414
+ "zhuang": ("zh", "uang"),
415
+ "zhui": ("zh", "uei"),
416
+ "zhun": ("zh", "uen"),
417
+ "zhuo": ("zh", "uo"),
418
+ "zi": ("z", "ii"),
419
+ "zong": ("z", "ong"),
420
+ "zou": ("z", "ou"),
421
+ "zu": ("z", "u"),
422
+ "zuan": ("z", "uan"),
423
+ "zui": ("z", "uei"),
424
+ "zun": ("z", "uen"),
425
+ "zuo": ("z", "uo"),
426
+ }
bert/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
commons.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += (
34
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
35
+ )
36
+ return kl
37
+
38
+
39
+ def rand_gumbel(shape):
40
+ """Sample from the Gumbel distribution, protect from overflows."""
41
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
42
+ return -torch.log(-torch.log(uniform_samples))
43
+
44
+
45
+ def rand_gumbel_like(x):
46
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
47
+ return g
48
+
49
+
50
+ def slice_segments(x, ids_str, segment_size=4):
51
+ ret = torch.zeros_like(x[:, :, :segment_size])
52
+ for i in range(x.size(0)):
53
+ idx_str = ids_str[i]
54
+ idx_end = idx_str + segment_size
55
+ ret[i] = x[i, :, idx_str:idx_end]
56
+ return ret
57
+
58
+
59
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
60
+ b, d, t = x.size()
61
+ if x_lengths is None:
62
+ x_lengths = t
63
+ ids_str_max = x_lengths - segment_size + 1
64
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
65
+ ret = slice_segments(x, ids_str, segment_size)
66
+ return ret, ids_str
67
+
68
+
69
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
70
+ position = torch.arange(length, dtype=torch.float)
71
+ num_timescales = channels // 2
72
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
73
+ num_timescales - 1
74
+ )
75
+ inv_timescales = min_timescale * torch.exp(
76
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
77
+ )
78
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
79
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
80
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
81
+ signal = signal.view(1, channels, length)
82
+ return signal
83
+
84
+
85
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
86
+ b, channels, length = x.size()
87
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
88
+ return x + signal.to(dtype=x.dtype, device=x.device)
89
+
90
+
91
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
92
+ b, channels, length = x.size()
93
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
94
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
95
+
96
+
97
+ def subsequent_mask(length):
98
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
99
+ return mask
100
+
101
+
102
+ @torch.jit.script
103
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
104
+ n_channels_int = n_channels[0]
105
+ in_act = input_a + input_b
106
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
107
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
108
+ acts = t_act * s_act
109
+ return acts
110
+
111
+
112
+ def convert_pad_shape(pad_shape):
113
+ l = pad_shape[::-1]
114
+ pad_shape = [item for sublist in l for item in sublist]
115
+ return pad_shape
116
+
117
+
118
+ def shift_1d(x):
119
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
120
+ return x
121
+
122
+
123
+ def sequence_mask(length, max_length=None):
124
+ if max_length is None:
125
+ max_length = length.max()
126
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
127
+ return x.unsqueeze(0) < length.unsqueeze(1)
128
+
129
+
130
+ def generate_path(duration, mask):
131
+ """
132
+ duration: [b, 1, t_x]
133
+ mask: [b, 1, t_y, t_x]
134
+ """
135
+ device = duration.device
136
+
137
+ b, _, t_y, t_x = mask.shape
138
+ cum_duration = torch.cumsum(duration, -1)
139
+
140
+ cum_duration_flat = cum_duration.view(b * t_x)
141
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
142
+ path = path.view(b, t_x, t_y)
143
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
144
+ path = path.unsqueeze(1).transpose(2, 3) * mask
145
+ return path
146
+
147
+
148
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
149
+ if isinstance(parameters, torch.Tensor):
150
+ parameters = [parameters]
151
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
152
+ norm_type = float(norm_type)
153
+ if clip_value is not None:
154
+ clip_value = float(clip_value)
155
+
156
+ total_norm = 0
157
+ for p in parameters:
158
+ param_norm = p.grad.data.norm(norm_type)
159
+ total_norm += param_norm.item() ** norm_type
160
+ if clip_value is not None:
161
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
162
+ total_norm = total_norm ** (1.0 / norm_type)
163
+ return total_norm
configs/bert_vits.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 10000,
5
+ "seed": 1234,
6
+ "epochs": 20000,
7
+ "learning_rate": 1e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 8,
11
+ "fp16_run": false,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 12800,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0
18
+ },
19
+ "data": {
20
+ "training_files":"filelists/train.txt",
21
+ "validation_files":"filelists/valid.txt",
22
+ "max_wav_value": 32768.0,
23
+ "sampling_rate": 16000,
24
+ "filter_length": 1024,
25
+ "hop_length": 256,
26
+ "win_length": 1024,
27
+ "n_mel_channels": 80,
28
+ "mel_fmin": 0.0,
29
+ "mel_fmax": null,
30
+ "add_blank": false,
31
+ "n_speakers": 0
32
+ },
33
+ "model": {
34
+ "inter_channels": 192,
35
+ "hidden_channels": 192,
36
+ "filter_channels": 768,
37
+ "n_heads": 2,
38
+ "n_layers": 6,
39
+ "kernel_size": 3,
40
+ "p_dropout": 0.1,
41
+ "resblock": "1",
42
+ "resblock_kernel_sizes": [3,7,11],
43
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44
+ "upsample_rates": [8,8,2,2],
45
+ "upsample_initial_channel": 512,
46
+ "upsample_kernel_sizes": [16,16,4,4],
47
+ "n_layers_q": 3,
48
+ "use_spectral_norm": false
49
+ }
50
+ }
models.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ import modules
9
+ import attentions
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ class DurationPredictor(nn.Module):
17
+ def __init__(
18
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
19
+ ):
20
+ super().__init__()
21
+
22
+ self.in_channels = in_channels
23
+ self.filter_channels = filter_channels
24
+ self.kernel_size = kernel_size
25
+ self.p_dropout = p_dropout
26
+ self.gin_channels = gin_channels
27
+
28
+ self.drop = nn.Dropout(p_dropout)
29
+ self.conv_1 = nn.Conv1d(
30
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
31
+ )
32
+ self.norm_1 = modules.LayerNorm(filter_channels)
33
+ self.conv_2 = nn.Conv1d(
34
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_2 = modules.LayerNorm(filter_channels)
37
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
38
+
39
+ if gin_channels != 0:
40
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
41
+
42
+ def forward(self, x, x_mask, g=None):
43
+ x = torch.detach(x)
44
+ if g is not None:
45
+ g = torch.detach(g)
46
+ x = x + self.cond(g)
47
+ x = self.conv_1(x * x_mask)
48
+ x = torch.relu(x)
49
+ x = self.norm_1(x)
50
+ x = self.drop(x)
51
+ x = self.conv_2(x * x_mask)
52
+ x = torch.relu(x)
53
+ x = self.norm_2(x)
54
+ x = self.drop(x)
55
+ x = self.proj(x * x_mask)
56
+ return x * x_mask
57
+
58
+
59
+ class TextEncoder(nn.Module):
60
+ def __init__(
61
+ self,
62
+ n_vocab,
63
+ out_channels,
64
+ hidden_channels,
65
+ filter_channels,
66
+ n_heads,
67
+ n_layers,
68
+ kernel_size,
69
+ p_dropout,
70
+ ):
71
+ super().__init__()
72
+ self.n_vocab = n_vocab
73
+ self.out_channels = out_channels
74
+ self.hidden_channels = hidden_channels
75
+ self.filter_channels = filter_channels
76
+ self.n_heads = n_heads
77
+ self.n_layers = n_layers
78
+ self.kernel_size = kernel_size
79
+ self.p_dropout = p_dropout
80
+
81
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
82
+ self.emb_bert = nn.Linear(256, hidden_channels)
83
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
84
+
85
+ self.encoder = attentions.Encoder(
86
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
87
+ )
88
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
89
+
90
+ def forward(self, x, x_lengths, bert):
91
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
92
+ b = self.emb_bert(bert)
93
+ x = x + b
94
+ x = torch.transpose(x, 1, -1) # [b, h, t]
95
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
96
+ x.dtype
97
+ )
98
+
99
+ x = self.encoder(x * x_mask, x_mask)
100
+ stats = self.proj(x) * x_mask
101
+
102
+ m, logs = torch.split(stats, self.out_channels, dim=1)
103
+ return x, m, logs, x_mask
104
+
105
+
106
+ class ResidualCouplingBlock(nn.Module):
107
+ def __init__(
108
+ self,
109
+ channels,
110
+ hidden_channels,
111
+ kernel_size,
112
+ dilation_rate,
113
+ n_layers,
114
+ n_flows=4,
115
+ gin_channels=0,
116
+ ):
117
+ super().__init__()
118
+ self.channels = channels
119
+ self.hidden_channels = hidden_channels
120
+ self.kernel_size = kernel_size
121
+ self.dilation_rate = dilation_rate
122
+ self.n_layers = n_layers
123
+ self.n_flows = n_flows
124
+ self.gin_channels = gin_channels
125
+
126
+ self.flows = nn.ModuleList()
127
+ for i in range(n_flows):
128
+ self.flows.append(
129
+ modules.ResidualCouplingLayer(
130
+ channels,
131
+ hidden_channels,
132
+ kernel_size,
133
+ dilation_rate,
134
+ n_layers,
135
+ gin_channels=gin_channels,
136
+ mean_only=True,
137
+ )
138
+ )
139
+ self.flows.append(modules.Flip())
140
+
141
+ def forward(self, x, x_mask, g=None, reverse=False):
142
+ if not reverse:
143
+ for flow in self.flows:
144
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
145
+ else:
146
+ for flow in reversed(self.flows):
147
+ x = flow(x, x_mask, g=g, reverse=reverse)
148
+ return x
149
+
150
+ def remove_weight_norm(self):
151
+ for i in range(self.n_flows):
152
+ self.flows[i * 2].remove_weight_norm()
153
+
154
+
155
+ class PosteriorEncoder(nn.Module):
156
+ def __init__(
157
+ self,
158
+ in_channels,
159
+ out_channels,
160
+ hidden_channels,
161
+ kernel_size,
162
+ dilation_rate,
163
+ n_layers,
164
+ gin_channels=0,
165
+ ):
166
+ super().__init__()
167
+ self.in_channels = in_channels
168
+ self.out_channels = out_channels
169
+ self.hidden_channels = hidden_channels
170
+ self.kernel_size = kernel_size
171
+ self.dilation_rate = dilation_rate
172
+ self.n_layers = n_layers
173
+ self.gin_channels = gin_channels
174
+
175
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
176
+ self.enc = modules.WN(
177
+ hidden_channels,
178
+ kernel_size,
179
+ dilation_rate,
180
+ n_layers,
181
+ gin_channels=gin_channels,
182
+ )
183
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
184
+
185
+ def forward(self, x, x_lengths, g=None):
186
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
187
+ x.dtype
188
+ )
189
+ x = self.pre(x) * x_mask
190
+ x = self.enc(x, x_mask, g=g)
191
+ stats = self.proj(x) * x_mask
192
+ m, logs = torch.split(stats, self.out_channels, dim=1)
193
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
194
+ return z, m, logs, x_mask
195
+
196
+ def remove_weight_norm(self):
197
+ self.enc.remove_weight_norm()
198
+
199
+
200
+ class Generator(torch.nn.Module):
201
+ def __init__(
202
+ self,
203
+ initial_channel,
204
+ resblock,
205
+ resblock_kernel_sizes,
206
+ resblock_dilation_sizes,
207
+ upsample_rates,
208
+ upsample_initial_channel,
209
+ upsample_kernel_sizes,
210
+ gin_channels=0,
211
+ ):
212
+ super(Generator, self).__init__()
213
+ self.num_kernels = len(resblock_kernel_sizes)
214
+ self.num_upsamples = len(upsample_rates)
215
+ self.conv_pre = Conv1d(
216
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
217
+ )
218
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
219
+
220
+ self.ups = nn.ModuleList()
221
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
222
+ self.ups.append(
223
+ weight_norm(
224
+ ConvTranspose1d(
225
+ upsample_initial_channel // (2**i),
226
+ upsample_initial_channel // (2 ** (i + 1)),
227
+ k,
228
+ u,
229
+ padding=(k - u) // 2,
230
+ )
231
+ )
232
+ )
233
+
234
+ self.resblocks = nn.ModuleList()
235
+ for i in range(len(self.ups)):
236
+ ch = upsample_initial_channel // (2 ** (i + 1))
237
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
238
+ self.resblocks.append(resblock(ch, k, d))
239
+
240
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
241
+ self.ups.apply(init_weights)
242
+
243
+ if gin_channels != 0:
244
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
245
+
246
+ def forward(self, x, g=None):
247
+ x = self.conv_pre(x)
248
+ if g is not None:
249
+ x = x + self.cond(g)
250
+
251
+ for i in range(self.num_upsamples):
252
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
253
+ x = self.ups[i](x)
254
+ xs = None
255
+ for j in range(self.num_kernels):
256
+ if xs is None:
257
+ xs = self.resblocks[i * self.num_kernels + j](x)
258
+ else:
259
+ xs += self.resblocks[i * self.num_kernels + j](x)
260
+ x = xs / self.num_kernels
261
+ x = F.leaky_relu(x)
262
+ x = self.conv_post(x)
263
+ x = torch.tanh(x)
264
+
265
+ return x
266
+
267
+ def remove_weight_norm(self):
268
+ for l in self.ups:
269
+ remove_weight_norm(l)
270
+ for l in self.resblocks:
271
+ l.remove_weight_norm()
272
+ remove_weight_norm(self.conv_pre)
273
+ remove_weight_norm(self.conv_post)
274
+
275
+
276
+ class DiscriminatorP(torch.nn.Module):
277
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
278
+ super(DiscriminatorP, self).__init__()
279
+ self.period = period
280
+ self.use_spectral_norm = use_spectral_norm
281
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
282
+ self.convs = nn.ModuleList(
283
+ [
284
+ norm_f(
285
+ Conv2d(
286
+ 1,
287
+ 32,
288
+ (kernel_size, 1),
289
+ (stride, 1),
290
+ padding=(get_padding(kernel_size, 1), 0),
291
+ )
292
+ ),
293
+ norm_f(
294
+ Conv2d(
295
+ 32,
296
+ 128,
297
+ (kernel_size, 1),
298
+ (stride, 1),
299
+ padding=(get_padding(kernel_size, 1), 0),
300
+ )
301
+ ),
302
+ norm_f(
303
+ Conv2d(
304
+ 128,
305
+ 512,
306
+ (kernel_size, 1),
307
+ (stride, 1),
308
+ padding=(get_padding(kernel_size, 1), 0),
309
+ )
310
+ ),
311
+ norm_f(
312
+ Conv2d(
313
+ 512,
314
+ 1024,
315
+ (kernel_size, 1),
316
+ (stride, 1),
317
+ padding=(get_padding(kernel_size, 1), 0),
318
+ )
319
+ ),
320
+ norm_f(
321
+ Conv2d(
322
+ 1024,
323
+ 1024,
324
+ (kernel_size, 1),
325
+ 1,
326
+ padding=(get_padding(kernel_size, 1), 0),
327
+ )
328
+ ),
329
+ ]
330
+ )
331
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
332
+
333
+ def forward(self, x):
334
+ fmap = []
335
+
336
+ # 1d to 2d
337
+ b, c, t = x.shape
338
+ if t % self.period != 0: # pad first
339
+ n_pad = self.period - (t % self.period)
340
+ x = F.pad(x, (0, n_pad), "reflect")
341
+ t = t + n_pad
342
+ x = x.view(b, c, t // self.period, self.period)
343
+
344
+ for l in self.convs:
345
+ x = l(x)
346
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
347
+ fmap.append(x)
348
+ x = self.conv_post(x)
349
+ fmap.append(x)
350
+ x = torch.flatten(x, 1, -1)
351
+
352
+ return x, fmap
353
+
354
+
355
+ class DiscriminatorS(torch.nn.Module):
356
+ def __init__(self, use_spectral_norm=False):
357
+ super(DiscriminatorS, self).__init__()
358
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
359
+ self.convs = nn.ModuleList(
360
+ [
361
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
362
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
363
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
364
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
365
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
366
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
367
+ ]
368
+ )
369
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
370
+
371
+ def forward(self, x):
372
+ fmap = []
373
+
374
+ for l in self.convs:
375
+ x = l(x)
376
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
377
+ fmap.append(x)
378
+ x = self.conv_post(x)
379
+ fmap.append(x)
380
+ x = torch.flatten(x, 1, -1)
381
+
382
+ return x, fmap
383
+
384
+
385
+ class MultiPeriodDiscriminator(torch.nn.Module):
386
+ def __init__(self, use_spectral_norm=False):
387
+ super(MultiPeriodDiscriminator, self).__init__()
388
+ periods = [2, 3, 5, 7, 11]
389
+
390
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
391
+ discs = discs + [
392
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
393
+ ]
394
+ self.discriminators = nn.ModuleList(discs)
395
+
396
+ def forward(self, y, y_hat):
397
+ y_d_rs = []
398
+ y_d_gs = []
399
+ fmap_rs = []
400
+ fmap_gs = []
401
+ for i, d in enumerate(self.discriminators):
402
+ y_d_r, fmap_r = d(y)
403
+ y_d_g, fmap_g = d(y_hat)
404
+ y_d_rs.append(y_d_r)
405
+ y_d_gs.append(y_d_g)
406
+ fmap_rs.append(fmap_r)
407
+ fmap_gs.append(fmap_g)
408
+
409
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
410
+
411
+
412
+ class SynthesizerTrn(nn.Module):
413
+ """
414
+ Synthesizer for Training
415
+ """
416
+
417
+ def __init__(
418
+ self,
419
+ n_vocab,
420
+ spec_channels,
421
+ segment_size,
422
+ inter_channels,
423
+ hidden_channels,
424
+ filter_channels,
425
+ n_heads,
426
+ n_layers,
427
+ kernel_size,
428
+ p_dropout,
429
+ resblock,
430
+ resblock_kernel_sizes,
431
+ resblock_dilation_sizes,
432
+ upsample_rates,
433
+ upsample_initial_channel,
434
+ upsample_kernel_sizes,
435
+ n_speakers=0,
436
+ gin_channels=0,
437
+ use_sdp=False,
438
+ **kwargs
439
+ ):
440
+
441
+ super().__init__()
442
+ self.n_vocab = n_vocab
443
+ self.spec_channels = spec_channels
444
+ self.inter_channels = inter_channels
445
+ self.hidden_channels = hidden_channels
446
+ self.filter_channels = filter_channels
447
+ self.n_heads = n_heads
448
+ self.n_layers = n_layers
449
+ self.kernel_size = kernel_size
450
+ self.p_dropout = p_dropout
451
+ self.resblock = resblock
452
+ self.resblock_kernel_sizes = resblock_kernel_sizes
453
+ self.resblock_dilation_sizes = resblock_dilation_sizes
454
+ self.upsample_rates = upsample_rates
455
+ self.upsample_initial_channel = upsample_initial_channel
456
+ self.upsample_kernel_sizes = upsample_kernel_sizes
457
+ self.segment_size = segment_size
458
+ self.n_speakers = n_speakers
459
+ self.gin_channels = gin_channels
460
+
461
+ self.enc_p = TextEncoder(
462
+ n_vocab,
463
+ inter_channels,
464
+ hidden_channels,
465
+ filter_channels,
466
+ n_heads,
467
+ n_layers,
468
+ kernel_size,
469
+ p_dropout,
470
+ )
471
+ self.dec = Generator(
472
+ inter_channels,
473
+ resblock,
474
+ resblock_kernel_sizes,
475
+ resblock_dilation_sizes,
476
+ upsample_rates,
477
+ upsample_initial_channel,
478
+ upsample_kernel_sizes,
479
+ gin_channels=gin_channels,
480
+ )
481
+ self.enc_q = PosteriorEncoder(
482
+ spec_channels,
483
+ inter_channels,
484
+ hidden_channels,
485
+ 5,
486
+ 1,
487
+ 16,
488
+ gin_channels=gin_channels,
489
+ )
490
+ self.flow = ResidualCouplingBlock(
491
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
492
+ )
493
+ self.dp = DurationPredictor(
494
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
495
+ )
496
+ if n_speakers > 1:
497
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
498
+
499
+ def remove_weight_norm(self):
500
+ print("Removing weight norm...")
501
+ self.dec.remove_weight_norm()
502
+ self.flow.remove_weight_norm()
503
+ self.enc_q.remove_weight_norm()
504
+
505
+
506
+ def infer(self, x, x_lengths, bert, sid=None, noise_scale=1, length_scale=1, max_len=None):
507
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, bert)
508
+ if self.n_speakers > 0:
509
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
510
+ else:
511
+ g = None
512
+
513
+ logw = self.dp(x, x_mask, g=g)
514
+ w = torch.exp(logw) * x_mask * length_scale
515
+ w_ceil = torch.ceil(w)
516
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
517
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
518
+ x_mask.dtype
519
+ )
520
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
521
+ attn = commons.generate_path(w_ceil, attn_mask)
522
+
523
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
524
+ 1, 2
525
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
526
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
527
+ 1, 2
528
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
529
+
530
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
531
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
532
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
533
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
modules.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+ from transforms import piecewise_rational_quadratic_transform
15
+
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = p_dropout
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g=None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ ):
146
+ super(WN, self).__init__()
147
+ assert kernel_size % 2 == 1
148
+ self.hidden_channels = hidden_channels
149
+ self.kernel_size = (kernel_size,)
150
+ self.dilation_rate = dilation_rate
151
+ self.n_layers = n_layers
152
+ self.gin_channels = gin_channels
153
+ self.p_dropout = p_dropout
154
+
155
+ self.in_layers = torch.nn.ModuleList()
156
+ self.res_skip_layers = torch.nn.ModuleList()
157
+ self.drop = nn.Dropout(p_dropout)
158
+
159
+ if gin_channels != 0:
160
+ cond_layer = torch.nn.Conv1d(
161
+ gin_channels, 2 * hidden_channels * n_layers, 1
162
+ )
163
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
+
165
+ for i in range(n_layers):
166
+ dilation = dilation_rate**i
167
+ padding = int((kernel_size * dilation - dilation) / 2)
168
+ in_layer = torch.nn.Conv1d(
169
+ hidden_channels,
170
+ 2 * hidden_channels,
171
+ kernel_size,
172
+ dilation=dilation,
173
+ padding=padding,
174
+ )
175
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
+ self.in_layers.append(in_layer)
177
+
178
+ # last one is not necessary
179
+ if i < n_layers - 1:
180
+ res_skip_channels = 2 * hidden_channels
181
+ else:
182
+ res_skip_channels = hidden_channels
183
+
184
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
+ self.res_skip_layers.append(res_skip_layer)
187
+
188
+ def forward(self, x, x_mask, g=None, **kwargs):
189
+ output = torch.zeros_like(x)
190
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
+
192
+ if g is not None:
193
+ g = self.cond_layer(g)
194
+
195
+ for i in range(self.n_layers):
196
+ x_in = self.in_layers[i](x)
197
+ if g is not None:
198
+ cond_offset = i * 2 * self.hidden_channels
199
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
+ else:
201
+ g_l = torch.zeros_like(x_in)
202
+
203
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
+ acts = self.drop(acts)
205
+
206
+ res_skip_acts = self.res_skip_layers[i](acts)
207
+ if i < self.n_layers - 1:
208
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
+ x = (x + res_acts) * x_mask
210
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
211
+ else:
212
+ output = output + res_skip_acts
213
+ return output * x_mask
214
+
215
+ def remove_weight_norm(self):
216
+ if self.gin_channels != 0:
217
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
218
+ for l in self.in_layers:
219
+ torch.nn.utils.remove_weight_norm(l)
220
+ for l in self.res_skip_layers:
221
+ torch.nn.utils.remove_weight_norm(l)
222
+
223
+
224
+ class ResBlock1(torch.nn.Module):
225
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
+ super(ResBlock1, self).__init__()
227
+ self.convs1 = nn.ModuleList(
228
+ [
229
+ weight_norm(
230
+ Conv1d(
231
+ channels,
232
+ channels,
233
+ kernel_size,
234
+ 1,
235
+ dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]),
237
+ )
238
+ ),
239
+ weight_norm(
240
+ Conv1d(
241
+ channels,
242
+ channels,
243
+ kernel_size,
244
+ 1,
245
+ dilation=dilation[1],
246
+ padding=get_padding(kernel_size, dilation[1]),
247
+ )
248
+ ),
249
+ weight_norm(
250
+ Conv1d(
251
+ channels,
252
+ channels,
253
+ kernel_size,
254
+ 1,
255
+ dilation=dilation[2],
256
+ padding=get_padding(kernel_size, dilation[2]),
257
+ )
258
+ ),
259
+ ]
260
+ )
261
+ self.convs1.apply(init_weights)
262
+
263
+ self.convs2 = nn.ModuleList(
264
+ [
265
+ weight_norm(
266
+ Conv1d(
267
+ channels,
268
+ channels,
269
+ kernel_size,
270
+ 1,
271
+ dilation=1,
272
+ padding=get_padding(kernel_size, 1),
273
+ )
274
+ ),
275
+ weight_norm(
276
+ Conv1d(
277
+ channels,
278
+ channels,
279
+ kernel_size,
280
+ 1,
281
+ dilation=1,
282
+ padding=get_padding(kernel_size, 1),
283
+ )
284
+ ),
285
+ weight_norm(
286
+ Conv1d(
287
+ channels,
288
+ channels,
289
+ kernel_size,
290
+ 1,
291
+ dilation=1,
292
+ padding=get_padding(kernel_size, 1),
293
+ )
294
+ ),
295
+ ]
296
+ )
297
+ self.convs2.apply(init_weights)
298
+
299
+ def forward(self, x, x_mask=None):
300
+ for c1, c2 in zip(self.convs1, self.convs2):
301
+ xt = F.leaky_relu(x, LRELU_SLOPE)
302
+ if x_mask is not None:
303
+ xt = xt * x_mask
304
+ xt = c1(xt)
305
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
306
+ if x_mask is not None:
307
+ xt = xt * x_mask
308
+ xt = c2(xt)
309
+ x = xt + x
310
+ if x_mask is not None:
311
+ x = x * x_mask
312
+ return x
313
+
314
+ def remove_weight_norm(self):
315
+ for l in self.convs1:
316
+ remove_weight_norm(l)
317
+ for l in self.convs2:
318
+ remove_weight_norm(l)
319
+
320
+
321
+ class ResBlock2(torch.nn.Module):
322
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
+ super(ResBlock2, self).__init__()
324
+ self.convs = nn.ModuleList(
325
+ [
326
+ weight_norm(
327
+ Conv1d(
328
+ channels,
329
+ channels,
330
+ kernel_size,
331
+ 1,
332
+ dilation=dilation[0],
333
+ padding=get_padding(kernel_size, dilation[0]),
334
+ )
335
+ ),
336
+ weight_norm(
337
+ Conv1d(
338
+ channels,
339
+ channels,
340
+ kernel_size,
341
+ 1,
342
+ dilation=dilation[1],
343
+ padding=get_padding(kernel_size, dilation[1]),
344
+ )
345
+ ),
346
+ ]
347
+ )
348
+ self.convs.apply(init_weights)
349
+
350
+ def forward(self, x, x_mask=None):
351
+ for c in self.convs:
352
+ xt = F.leaky_relu(x, LRELU_SLOPE)
353
+ if x_mask is not None:
354
+ xt = xt * x_mask
355
+ xt = c(xt)
356
+ x = xt + x
357
+ if x_mask is not None:
358
+ x = x * x_mask
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ for l in self.convs:
363
+ remove_weight_norm(l)
364
+
365
+
366
+ class Log(nn.Module):
367
+ def forward(self, x, x_mask, reverse=False, **kwargs):
368
+ if not reverse:
369
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
+ logdet = torch.sum(-y, [1, 2])
371
+ return y, logdet
372
+ else:
373
+ x = torch.exp(x) * x_mask
374
+ return x
375
+
376
+
377
+ class Flip(nn.Module):
378
+ def forward(self, x, *args, reverse=False, **kwargs):
379
+ x = torch.flip(x, [1])
380
+ if not reverse:
381
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
+ return x, logdet
383
+ else:
384
+ return x
385
+
386
+
387
+ class ElementwiseAffine(nn.Module):
388
+ def __init__(self, channels):
389
+ super().__init__()
390
+ self.channels = channels
391
+ self.m = nn.Parameter(torch.zeros(channels, 1))
392
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
393
+
394
+ def forward(self, x, x_mask, reverse=False, **kwargs):
395
+ if not reverse:
396
+ y = self.m + torch.exp(self.logs) * x
397
+ y = y * x_mask
398
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
399
+ return y, logdet
400
+ else:
401
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
+ return x
403
+
404
+
405
+ class ResidualCouplingLayer(nn.Module):
406
+ def __init__(
407
+ self,
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ p_dropout=0,
414
+ gin_channels=0,
415
+ mean_only=False,
416
+ ):
417
+ assert channels % 2 == 0, "channels should be divisible by 2"
418
+ super().__init__()
419
+ self.channels = channels
420
+ self.hidden_channels = hidden_channels
421
+ self.kernel_size = kernel_size
422
+ self.dilation_rate = dilation_rate
423
+ self.n_layers = n_layers
424
+ self.half_channels = channels // 2
425
+ self.mean_only = mean_only
426
+
427
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
+ self.enc = WN(
429
+ hidden_channels,
430
+ kernel_size,
431
+ dilation_rate,
432
+ n_layers,
433
+ p_dropout=p_dropout,
434
+ gin_channels=gin_channels,
435
+ )
436
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
+ self.post.weight.data.zero_()
438
+ self.post.bias.data.zero_()
439
+
440
+ def forward(self, x, x_mask, g=None, reverse=False):
441
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
+ h = self.pre(x0) * x_mask
443
+ h = self.enc(h, x_mask, g=g)
444
+ stats = self.post(h) * x_mask
445
+ if not self.mean_only:
446
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
+ else:
448
+ m = stats
449
+ logs = torch.zeros_like(m)
450
+
451
+ if not reverse:
452
+ x1 = m + x1 * torch.exp(logs) * x_mask
453
+ x = torch.cat([x0, x1], 1)
454
+ logdet = torch.sum(logs, [1, 2])
455
+ return x, logdet
456
+ else:
457
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
+ x = torch.cat([x0, x1], 1)
459
+ return x
460
+
461
+ def remove_weight_norm(self):
462
+ self.enc.remove_weight_norm()
463
+
464
+
465
+ class ConvFlow(nn.Module):
466
+ def __init__(
467
+ self,
468
+ in_channels,
469
+ filter_channels,
470
+ kernel_size,
471
+ n_layers,
472
+ num_bins=10,
473
+ tail_bound=5.0,
474
+ ):
475
+ super().__init__()
476
+ self.in_channels = in_channels
477
+ self.filter_channels = filter_channels
478
+ self.kernel_size = kernel_size
479
+ self.n_layers = n_layers
480
+ self.num_bins = num_bins
481
+ self.tail_bound = tail_bound
482
+ self.half_channels = in_channels // 2
483
+
484
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
+ self.proj = nn.Conv1d(
487
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
+ )
489
+ self.proj.weight.data.zero_()
490
+ self.proj.bias.data.zero_()
491
+
492
+ def forward(self, x, x_mask, g=None, reverse=False):
493
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
+ h = self.pre(x0)
495
+ h = self.convs(h, x_mask, g=g)
496
+ h = self.proj(h) * x_mask
497
+
498
+ b, c, t = x0.shape
499
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
+
501
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
+ self.filter_channels
504
+ )
505
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
+
507
+ x1, logabsdet = piecewise_rational_quadratic_transform(
508
+ x1,
509
+ unnormalized_widths,
510
+ unnormalized_heights,
511
+ unnormalized_derivatives,
512
+ inverse=reverse,
513
+ tails="linear",
514
+ tail_bound=self.tail_bound,
515
+ )
516
+
517
+ x = torch.cat([x0, x1], 1) * x_mask
518
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
+ if not reverse:
520
+ return x, logdet
521
+ else:
522
+ return x
text/__init__.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text.symbols import symbols
2
+
3
+
4
+ # Mappings from symbol to numeric ID and vice versa:
5
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
6
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
7
+
8
+
9
+ def cleaned_text_to_sequence(cleaned_text):
10
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
11
+ Args:
12
+ text: string to convert to a sequence
13
+ Returns:
14
+ List of integers corresponding to the symbols in the text
15
+ """
16
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split()]
17
+ return sequence
18
+
19
+
20
+ def sequence_to_text(sequence):
21
+ """Converts a sequence of IDs back to a string"""
22
+ result = ""
23
+ for symbol_id in sequence:
24
+ s = _id_to_symbol[symbol_id]
25
+ result += s
26
+ return result
27
+
28
+
29
+ pinyin_dict = {
30
+ "a": ("^", "a"),
31
+ "ai": ("^", "ai"),
32
+ "an": ("^", "an"),
33
+ "ang": ("^", "ang"),
34
+ "ao": ("^", "ao"),
35
+ "ba": ("b", "a"),
36
+ "bai": ("b", "ai"),
37
+ "ban": ("b", "an"),
38
+ "bang": ("b", "ang"),
39
+ "bao": ("b", "ao"),
40
+ "be": ("b", "e"),
41
+ "bei": ("b", "ei"),
42
+ "ben": ("b", "en"),
43
+ "beng": ("b", "eng"),
44
+ "bi": ("b", "i"),
45
+ "bian": ("b", "ian"),
46
+ "biao": ("b", "iao"),
47
+ "bie": ("b", "ie"),
48
+ "bin": ("b", "in"),
49
+ "bing": ("b", "ing"),
50
+ "bo": ("b", "o"),
51
+ "bu": ("b", "u"),
52
+ "ca": ("c", "a"),
53
+ "cai": ("c", "ai"),
54
+ "can": ("c", "an"),
55
+ "cang": ("c", "ang"),
56
+ "cao": ("c", "ao"),
57
+ "ce": ("c", "e"),
58
+ "cen": ("c", "en"),
59
+ "ceng": ("c", "eng"),
60
+ "cha": ("ch", "a"),
61
+ "chai": ("ch", "ai"),
62
+ "chan": ("ch", "an"),
63
+ "chang": ("ch", "ang"),
64
+ "chao": ("ch", "ao"),
65
+ "che": ("ch", "e"),
66
+ "chen": ("ch", "en"),
67
+ "cheng": ("ch", "eng"),
68
+ "chi": ("ch", "iii"),
69
+ "chong": ("ch", "ong"),
70
+ "chou": ("ch", "ou"),
71
+ "chu": ("ch", "u"),
72
+ "chua": ("ch", "ua"),
73
+ "chuai": ("ch", "uai"),
74
+ "chuan": ("ch", "uan"),
75
+ "chuang": ("ch", "uang"),
76
+ "chui": ("ch", "uei"),
77
+ "chun": ("ch", "uen"),
78
+ "chuo": ("ch", "uo"),
79
+ "ci": ("c", "ii"),
80
+ "cong": ("c", "ong"),
81
+ "cou": ("c", "ou"),
82
+ "cu": ("c", "u"),
83
+ "cuan": ("c", "uan"),
84
+ "cui": ("c", "uei"),
85
+ "cun": ("c", "uen"),
86
+ "cuo": ("c", "uo"),
87
+ "da": ("d", "a"),
88
+ "dai": ("d", "ai"),
89
+ "dan": ("d", "an"),
90
+ "dang": ("d", "ang"),
91
+ "dao": ("d", "ao"),
92
+ "de": ("d", "e"),
93
+ "dei": ("d", "ei"),
94
+ "den": ("d", "en"),
95
+ "deng": ("d", "eng"),
96
+ "di": ("d", "i"),
97
+ "dia": ("d", "ia"),
98
+ "dian": ("d", "ian"),
99
+ "diao": ("d", "iao"),
100
+ "die": ("d", "ie"),
101
+ "ding": ("d", "ing"),
102
+ "diu": ("d", "iou"),
103
+ "dong": ("d", "ong"),
104
+ "dou": ("d", "ou"),
105
+ "du": ("d", "u"),
106
+ "duan": ("d", "uan"),
107
+ "dui": ("d", "uei"),
108
+ "dun": ("d", "uen"),
109
+ "duo": ("d", "uo"),
110
+ "e": ("^", "e"),
111
+ "ei": ("^", "ei"),
112
+ "en": ("^", "en"),
113
+ "ng": ("^", "en"),
114
+ "eng": ("^", "eng"),
115
+ "er": ("^", "er"),
116
+ "fa": ("f", "a"),
117
+ "fan": ("f", "an"),
118
+ "fang": ("f", "ang"),
119
+ "fei": ("f", "ei"),
120
+ "fen": ("f", "en"),
121
+ "feng": ("f", "eng"),
122
+ "fo": ("f", "o"),
123
+ "fou": ("f", "ou"),
124
+ "fu": ("f", "u"),
125
+ "ga": ("g", "a"),
126
+ "gai": ("g", "ai"),
127
+ "gan": ("g", "an"),
128
+ "gang": ("g", "ang"),
129
+ "gao": ("g", "ao"),
130
+ "ge": ("g", "e"),
131
+ "gei": ("g", "ei"),
132
+ "gen": ("g", "en"),
133
+ "geng": ("g", "eng"),
134
+ "gong": ("g", "ong"),
135
+ "gou": ("g", "ou"),
136
+ "gu": ("g", "u"),
137
+ "gua": ("g", "ua"),
138
+ "guai": ("g", "uai"),
139
+ "guan": ("g", "uan"),
140
+ "guang": ("g", "uang"),
141
+ "gui": ("g", "uei"),
142
+ "gun": ("g", "uen"),
143
+ "guo": ("g", "uo"),
144
+ "ha": ("h", "a"),
145
+ "hai": ("h", "ai"),
146
+ "han": ("h", "an"),
147
+ "hang": ("h", "ang"),
148
+ "hao": ("h", "ao"),
149
+ "he": ("h", "e"),
150
+ "hei": ("h", "ei"),
151
+ "hen": ("h", "en"),
152
+ "heng": ("h", "eng"),
153
+ "hong": ("h", "ong"),
154
+ "hou": ("h", "ou"),
155
+ "hu": ("h", "u"),
156
+ "hua": ("h", "ua"),
157
+ "huai": ("h", "uai"),
158
+ "huan": ("h", "uan"),
159
+ "huang": ("h", "uang"),
160
+ "hui": ("h", "uei"),
161
+ "hun": ("h", "uen"),
162
+ "huo": ("h", "uo"),
163
+ "ji": ("j", "i"),
164
+ "jia": ("j", "ia"),
165
+ "jian": ("j", "ian"),
166
+ "jiang": ("j", "iang"),
167
+ "jiao": ("j", "iao"),
168
+ "jie": ("j", "ie"),
169
+ "jin": ("j", "in"),
170
+ "jing": ("j", "ing"),
171
+ "jiong": ("j", "iong"),
172
+ "jiu": ("j", "iou"),
173
+ "ju": ("j", "v"),
174
+ "juan": ("j", "van"),
175
+ "jue": ("j", "ve"),
176
+ "jun": ("j", "vn"),
177
+ "ka": ("k", "a"),
178
+ "kai": ("k", "ai"),
179
+ "kan": ("k", "an"),
180
+ "kang": ("k", "ang"),
181
+ "kao": ("k", "ao"),
182
+ "ke": ("k", "e"),
183
+ "kei": ("k", "ei"),
184
+ "ken": ("k", "en"),
185
+ "keng": ("k", "eng"),
186
+ "kong": ("k", "ong"),
187
+ "kou": ("k", "ou"),
188
+ "ku": ("k", "u"),
189
+ "kua": ("k", "ua"),
190
+ "kuai": ("k", "uai"),
191
+ "kuan": ("k", "uan"),
192
+ "kuang": ("k", "uang"),
193
+ "kui": ("k", "uei"),
194
+ "kun": ("k", "uen"),
195
+ "kuo": ("k", "uo"),
196
+ "la": ("l", "a"),
197
+ "lai": ("l", "ai"),
198
+ "lan": ("l", "an"),
199
+ "lang": ("l", "ang"),
200
+ "lao": ("l", "ao"),
201
+ "le": ("l", "e"),
202
+ "lei": ("l", "ei"),
203
+ "leng": ("l", "eng"),
204
+ "li": ("l", "i"),
205
+ "lia": ("l", "ia"),
206
+ "lian": ("l", "ian"),
207
+ "liang": ("l", "iang"),
208
+ "liao": ("l", "iao"),
209
+ "lie": ("l", "ie"),
210
+ "lin": ("l", "in"),
211
+ "ling": ("l", "ing"),
212
+ "liu": ("l", "iou"),
213
+ "lo": ("l", "o"),
214
+ "long": ("l", "ong"),
215
+ "lou": ("l", "ou"),
216
+ "lu": ("l", "u"),
217
+ "lv": ("l", "v"),
218
+ "luan": ("l", "uan"),
219
+ "lve": ("l", "ve"),
220
+ "lue": ("l", "ve"),
221
+ "lun": ("l", "uen"),
222
+ "luo": ("l", "uo"),
223
+ "ma": ("m", "a"),
224
+ "mai": ("m", "ai"),
225
+ "man": ("m", "an"),
226
+ "mang": ("m", "ang"),
227
+ "mao": ("m", "ao"),
228
+ "me": ("m", "e"),
229
+ "mei": ("m", "ei"),
230
+ "men": ("m", "en"),
231
+ "meng": ("m", "eng"),
232
+ "mi": ("m", "i"),
233
+ "mian": ("m", "ian"),
234
+ "miao": ("m", "iao"),
235
+ "mie": ("m", "ie"),
236
+ "min": ("m", "in"),
237
+ "ming": ("m", "ing"),
238
+ "miu": ("m", "iou"),
239
+ "mo": ("m", "o"),
240
+ "mou": ("m", "ou"),
241
+ "mu": ("m", "u"),
242
+ "na": ("n", "a"),
243
+ "nai": ("n", "ai"),
244
+ "nan": ("n", "an"),
245
+ "nang": ("n", "ang"),
246
+ "nao": ("n", "ao"),
247
+ "ne": ("n", "e"),
248
+ "nei": ("n", "ei"),
249
+ "nen": ("n", "en"),
250
+ "neng": ("n", "eng"),
251
+ "ni": ("n", "i"),
252
+ "nia": ("n", "ia"),
253
+ "nian": ("n", "ian"),
254
+ "niang": ("n", "iang"),
255
+ "niao": ("n", "iao"),
256
+ "nie": ("n", "ie"),
257
+ "nin": ("n", "in"),
258
+ "ning": ("n", "ing"),
259
+ "niu": ("n", "iou"),
260
+ "nong": ("n", "ong"),
261
+ "nou": ("n", "ou"),
262
+ "nu": ("n", "u"),
263
+ "nv": ("n", "v"),
264
+ "nuan": ("n", "uan"),
265
+ "nve": ("n", "ve"),
266
+ "nue": ("n", "ve"),
267
+ "nuo": ("n", "uo"),
268
+ "o": ("^", "o"),
269
+ "ou": ("^", "ou"),
270
+ "pa": ("p", "a"),
271
+ "pai": ("p", "ai"),
272
+ "pan": ("p", "an"),
273
+ "pang": ("p", "ang"),
274
+ "pao": ("p", "ao"),
275
+ "pe": ("p", "e"),
276
+ "pei": ("p", "ei"),
277
+ "pen": ("p", "en"),
278
+ "peng": ("p", "eng"),
279
+ "pi": ("p", "i"),
280
+ "pian": ("p", "ian"),
281
+ "piao": ("p", "iao"),
282
+ "pie": ("p", "ie"),
283
+ "pin": ("p", "in"),
284
+ "ping": ("p", "ing"),
285
+ "po": ("p", "o"),
286
+ "pou": ("p", "ou"),
287
+ "pu": ("p", "u"),
288
+ "qi": ("q", "i"),
289
+ "qia": ("q", "ia"),
290
+ "qian": ("q", "ian"),
291
+ "qiang": ("q", "iang"),
292
+ "qiao": ("q", "iao"),
293
+ "qie": ("q", "ie"),
294
+ "qin": ("q", "in"),
295
+ "qing": ("q", "ing"),
296
+ "qiong": ("q", "iong"),
297
+ "qiu": ("q", "iou"),
298
+ "qu": ("q", "v"),
299
+ "quan": ("q", "van"),
300
+ "que": ("q", "ve"),
301
+ "qun": ("q", "vn"),
302
+ "ran": ("r", "an"),
303
+ "rang": ("r", "ang"),
304
+ "rao": ("r", "ao"),
305
+ "re": ("r", "e"),
306
+ "ren": ("r", "en"),
307
+ "reng": ("r", "eng"),
308
+ "ri": ("r", "iii"),
309
+ "rong": ("r", "ong"),
310
+ "rou": ("r", "ou"),
311
+ "ru": ("r", "u"),
312
+ "rua": ("r", "ua"),
313
+ "ruan": ("r", "uan"),
314
+ "rui": ("r", "uei"),
315
+ "run": ("r", "uen"),
316
+ "ruo": ("r", "uo"),
317
+ "sa": ("s", "a"),
318
+ "sai": ("s", "ai"),
319
+ "san": ("s", "an"),
320
+ "sang": ("s", "ang"),
321
+ "sao": ("s", "ao"),
322
+ "se": ("s", "e"),
323
+ "sen": ("s", "en"),
324
+ "seng": ("s", "eng"),
325
+ "sha": ("sh", "a"),
326
+ "shai": ("sh", "ai"),
327
+ "shan": ("sh", "an"),
328
+ "shang": ("sh", "ang"),
329
+ "shao": ("sh", "ao"),
330
+ "she": ("sh", "e"),
331
+ "shei": ("sh", "ei"),
332
+ "shen": ("sh", "en"),
333
+ "sheng": ("sh", "eng"),
334
+ "shi": ("sh", "iii"),
335
+ "shou": ("sh", "ou"),
336
+ "shu": ("sh", "u"),
337
+ "shua": ("sh", "ua"),
338
+ "shuai": ("sh", "uai"),
339
+ "shuan": ("sh", "uan"),
340
+ "shuang": ("sh", "uang"),
341
+ "shui": ("sh", "uei"),
342
+ "shun": ("sh", "uen"),
343
+ "shuo": ("sh", "uo"),
344
+ "si": ("s", "ii"),
345
+ "song": ("s", "ong"),
346
+ "sou": ("s", "ou"),
347
+ "su": ("s", "u"),
348
+ "suan": ("s", "uan"),
349
+ "sui": ("s", "uei"),
350
+ "sun": ("s", "uen"),
351
+ "suo": ("s", "uo"),
352
+ "ta": ("t", "a"),
353
+ "tai": ("t", "ai"),
354
+ "tan": ("t", "an"),
355
+ "tang": ("t", "ang"),
356
+ "tao": ("t", "ao"),
357
+ "te": ("t", "e"),
358
+ "tei": ("t", "ei"),
359
+ "teng": ("t", "eng"),
360
+ "ti": ("t", "i"),
361
+ "tian": ("t", "ian"),
362
+ "tiao": ("t", "iao"),
363
+ "tie": ("t", "ie"),
364
+ "ting": ("t", "ing"),
365
+ "tong": ("t", "ong"),
366
+ "tou": ("t", "ou"),
367
+ "tu": ("t", "u"),
368
+ "tuan": ("t", "uan"),
369
+ "tui": ("t", "uei"),
370
+ "tun": ("t", "uen"),
371
+ "tuo": ("t", "uo"),
372
+ "wa": ("^", "ua"),
373
+ "wai": ("^", "uai"),
374
+ "wan": ("^", "uan"),
375
+ "wang": ("^", "uang"),
376
+ "wei": ("^", "uei"),
377
+ "wen": ("^", "uen"),
378
+ "weng": ("^", "ueng"),
379
+ "wo": ("^", "uo"),
380
+ "wu": ("^", "u"),
381
+ "xi": ("x", "i"),
382
+ "xia": ("x", "ia"),
383
+ "xian": ("x", "ian"),
384
+ "xiang": ("x", "iang"),
385
+ "xiao": ("x", "iao"),
386
+ "xie": ("x", "ie"),
387
+ "xin": ("x", "in"),
388
+ "xing": ("x", "ing"),
389
+ "xiong": ("x", "iong"),
390
+ "xiu": ("x", "iou"),
391
+ "xu": ("x", "v"),
392
+ "xuan": ("x", "van"),
393
+ "xue": ("x", "ve"),
394
+ "xun": ("x", "vn"),
395
+ "ya": ("^", "ia"),
396
+ "yan": ("^", "ian"),
397
+ "yang": ("^", "iang"),
398
+ "yao": ("^", "iao"),
399
+ "ye": ("^", "ie"),
400
+ "yi": ("^", "i"),
401
+ "yin": ("^", "in"),
402
+ "ying": ("^", "ing"),
403
+ "yo": ("^", "iou"),
404
+ "yong": ("^", "iong"),
405
+ "you": ("^", "iou"),
406
+ "yu": ("^", "v"),
407
+ "yuan": ("^", "van"),
408
+ "yue": ("^", "ve"),
409
+ "yun": ("^", "vn"),
410
+ "za": ("z", "a"),
411
+ "zai": ("z", "ai"),
412
+ "zan": ("z", "an"),
413
+ "zang": ("z", "ang"),
414
+ "zao": ("z", "ao"),
415
+ "ze": ("z", "e"),
416
+ "zei": ("z", "ei"),
417
+ "zen": ("z", "en"),
418
+ "zeng": ("z", "eng"),
419
+ "zha": ("zh", "a"),
420
+ "zhai": ("zh", "ai"),
421
+ "zhan": ("zh", "an"),
422
+ "zhang": ("zh", "ang"),
423
+ "zhao": ("zh", "ao"),
424
+ "zhe": ("zh", "e"),
425
+ "zhei": ("zh", "ei"),
426
+ "zhen": ("zh", "en"),
427
+ "zheng": ("zh", "eng"),
428
+ "zhi": ("zh", "iii"),
429
+ "zhong": ("zh", "ong"),
430
+ "zhou": ("zh", "ou"),
431
+ "zhu": ("zh", "u"),
432
+ "zhua": ("zh", "ua"),
433
+ "zhuai": ("zh", "uai"),
434
+ "zhuan": ("zh", "uan"),
435
+ "zhuang": ("zh", "uang"),
436
+ "zhui": ("zh", "uei"),
437
+ "zhun": ("zh", "uen"),
438
+ "zhuo": ("zh", "uo"),
439
+ "zi": ("z", "ii"),
440
+ "zong": ("z", "ong"),
441
+ "zou": ("z", "ou"),
442
+ "zu": ("z", "u"),
443
+ "zuan": ("z", "uan"),
444
+ "zui": ("z", "uei"),
445
+ "zun": ("z", "uen"),
446
+ "zuo": ("z", "uo"),
447
+ }
text/symbols.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
2
+
3
+ _initials = [
4
+ "^",
5
+ "b",
6
+ "c",
7
+ "ch",
8
+ "d",
9
+ "f",
10
+ "g",
11
+ "h",
12
+ "j",
13
+ "k",
14
+ "l",
15
+ "m",
16
+ "n",
17
+ "p",
18
+ "q",
19
+ "r",
20
+ "s",
21
+ "sh",
22
+ "t",
23
+ "x",
24
+ "z",
25
+ "zh",
26
+ ]
27
+
28
+ _tones = ["1", "2", "3", "4", "5"]
29
+
30
+ _finals = [
31
+ "a",
32
+ "ai",
33
+ "an",
34
+ "ang",
35
+ "ao",
36
+ "e",
37
+ "ei",
38
+ "en",
39
+ "eng",
40
+ "er",
41
+ "i",
42
+ "ia",
43
+ "ian",
44
+ "iang",
45
+ "iao",
46
+ "ie",
47
+ "ii",
48
+ "iii",
49
+ "in",
50
+ "ing",
51
+ "iong",
52
+ "iou",
53
+ "o",
54
+ "ong",
55
+ "ou",
56
+ "u",
57
+ "ua",
58
+ "uai",
59
+ "uan",
60
+ "uang",
61
+ "uei",
62
+ "uen",
63
+ "ueng",
64
+ "uo",
65
+ "v",
66
+ "van",
67
+ "ve",
68
+ "vn",
69
+ ]
70
+
71
+ symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
transforms.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+
25
+ if tails is None:
26
+ spline_fn = rational_quadratic_spline
27
+ spline_kwargs = {}
28
+ else:
29
+ spline_fn = unconstrained_rational_quadratic_spline
30
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
31
+
32
+ outputs, logabsdet = spline_fn(
33
+ inputs=inputs,
34
+ unnormalized_widths=unnormalized_widths,
35
+ unnormalized_heights=unnormalized_heights,
36
+ unnormalized_derivatives=unnormalized_derivatives,
37
+ inverse=inverse,
38
+ min_bin_width=min_bin_width,
39
+ min_bin_height=min_bin_height,
40
+ min_derivative=min_derivative,
41
+ **spline_kwargs
42
+ )
43
+ return outputs, logabsdet
44
+
45
+
46
+ def searchsorted(bin_locations, inputs, eps=1e-6):
47
+ bin_locations[..., -1] += eps
48
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
49
+
50
+
51
+ def unconstrained_rational_quadratic_spline(
52
+ inputs,
53
+ unnormalized_widths,
54
+ unnormalized_heights,
55
+ unnormalized_derivatives,
56
+ inverse=False,
57
+ tails="linear",
58
+ tail_bound=1.0,
59
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
60
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
61
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
62
+ ):
63
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
64
+ outside_interval_mask = ~inside_interval_mask
65
+
66
+ outputs = torch.zeros_like(inputs)
67
+ logabsdet = torch.zeros_like(inputs)
68
+
69
+ if tails == "linear":
70
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
71
+ constant = np.log(np.exp(1 - min_derivative) - 1)
72
+ unnormalized_derivatives[..., 0] = constant
73
+ unnormalized_derivatives[..., -1] = constant
74
+
75
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
76
+ logabsdet[outside_interval_mask] = 0
77
+ else:
78
+ raise RuntimeError("{} tails are not implemented.".format(tails))
79
+
80
+ (
81
+ outputs[inside_interval_mask],
82
+ logabsdet[inside_interval_mask],
83
+ ) = rational_quadratic_spline(
84
+ inputs=inputs[inside_interval_mask],
85
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
86
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
87
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
88
+ inverse=inverse,
89
+ left=-tail_bound,
90
+ right=tail_bound,
91
+ bottom=-tail_bound,
92
+ top=tail_bound,
93
+ min_bin_width=min_bin_width,
94
+ min_bin_height=min_bin_height,
95
+ min_derivative=min_derivative,
96
+ )
97
+
98
+ return outputs, logabsdet
99
+
100
+
101
+ def rational_quadratic_spline(
102
+ inputs,
103
+ unnormalized_widths,
104
+ unnormalized_heights,
105
+ unnormalized_derivatives,
106
+ inverse=False,
107
+ left=0.0,
108
+ right=1.0,
109
+ bottom=0.0,
110
+ top=1.0,
111
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
112
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
113
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
114
+ ):
115
+ if torch.min(inputs) < left or torch.max(inputs) > right:
116
+ raise ValueError("Input to a transform is not within its domain")
117
+
118
+ num_bins = unnormalized_widths.shape[-1]
119
+
120
+ if min_bin_width * num_bins > 1.0:
121
+ raise ValueError("Minimal bin width too large for the number of bins")
122
+ if min_bin_height * num_bins > 1.0:
123
+ raise ValueError("Minimal bin height too large for the number of bins")
124
+
125
+ widths = F.softmax(unnormalized_widths, dim=-1)
126
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
127
+ cumwidths = torch.cumsum(widths, dim=-1)
128
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
129
+ cumwidths = (right - left) * cumwidths + left
130
+ cumwidths[..., 0] = left
131
+ cumwidths[..., -1] = right
132
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
133
+
134
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
135
+
136
+ heights = F.softmax(unnormalized_heights, dim=-1)
137
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
138
+ cumheights = torch.cumsum(heights, dim=-1)
139
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
140
+ cumheights = (top - bottom) * cumheights + bottom
141
+ cumheights[..., 0] = bottom
142
+ cumheights[..., -1] = top
143
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
144
+
145
+ if inverse:
146
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
147
+ else:
148
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
149
+
150
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
151
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
152
+
153
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
154
+ delta = heights / widths
155
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
156
+
157
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
158
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
159
+
160
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
161
+
162
+ if inverse:
163
+ a = (inputs - input_cumheights) * (
164
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
165
+ ) + input_heights * (input_delta - input_derivatives)
166
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
167
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
168
+ )
169
+ c = -input_delta * (inputs - input_cumheights)
170
+
171
+ discriminant = b.pow(2) - 4 * a * c
172
+ assert (discriminant >= 0).all()
173
+
174
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
175
+ outputs = root * input_bin_widths + input_cumwidths
176
+
177
+ theta_one_minus_theta = root * (1 - root)
178
+ denominator = input_delta + (
179
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
180
+ * theta_one_minus_theta
181
+ )
182
+ derivative_numerator = input_delta.pow(2) * (
183
+ input_derivatives_plus_one * root.pow(2)
184
+ + 2 * input_delta * theta_one_minus_theta
185
+ + input_derivatives * (1 - root).pow(2)
186
+ )
187
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
188
+
189
+ return outputs, -logabsdet
190
+ else:
191
+ theta = (inputs - input_cumwidths) / input_bin_widths
192
+ theta_one_minus_theta = theta * (1 - theta)
193
+
194
+ numerator = input_heights * (
195
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
196
+ )
197
+ denominator = input_delta + (
198
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
199
+ * theta_one_minus_theta
200
+ )
201
+ outputs = input_cumheights + numerator / denominator
202
+
203
+ derivative_numerator = input_delta.pow(2) * (
204
+ input_derivatives_plus_one * theta.pow(2)
205
+ + 2 * input_delta * theta_one_minus_theta
206
+ + input_derivatives * (1 - theta).pow(2)
207
+ )
208
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
209
+
210
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
+ iteration = checkpoint_dict["iteration"]
22
+ learning_rate = checkpoint_dict["learning_rate"]
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
25
+ saved_state_dict = checkpoint_dict["model"]
26
+ if hasattr(model, "module"):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict = {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, "module"):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info(
42
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
43
+ )
44
+ return model, optimizer, learning_rate, iteration
45
+
46
+
47
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
48
+ logger.info(
49
+ "Saving model and optimizer state at iteration {} to {}".format(
50
+ iteration, checkpoint_path
51
+ )
52
+ )
53
+ if hasattr(model, "module"):
54
+ state_dict = model.module.state_dict()
55
+ else:
56
+ state_dict = model.state_dict()
57
+ torch.save(
58
+ {
59
+ "model": state_dict,
60
+ "iteration": iteration,
61
+ "optimizer": optimizer.state_dict(),
62
+ "learning_rate": learning_rate,
63
+ },
64
+ checkpoint_path,
65
+ )
66
+
67
+
68
+ def load_model(checkpoint_path, model):
69
+ assert os.path.isfile(checkpoint_path)
70
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
71
+ saved_state_dict = checkpoint_dict["model"]
72
+ if hasattr(model, "module"):
73
+ state_dict = model.module.state_dict()
74
+ else:
75
+ state_dict = model.state_dict()
76
+ new_state_dict = {}
77
+ for k, v in state_dict.items():
78
+ try:
79
+ new_state_dict[k] = saved_state_dict[k]
80
+ except:
81
+ logger.info("%s is not in the checkpoint" % k)
82
+ new_state_dict[k] = v
83
+ if hasattr(model, "module"):
84
+ model.module.load_state_dict(new_state_dict)
85
+ else:
86
+ model.load_state_dict(new_state_dict)
87
+ return model
88
+
89
+
90
+ def save_model(model, checkpoint_path):
91
+ if hasattr(model, 'module'):
92
+ state_dict = model.module.state_dict()
93
+ else:
94
+ state_dict = model.state_dict()
95
+ torch.save({'model': state_dict}, checkpoint_path)
96
+
97
+
98
+ def summarize(
99
+ writer,
100
+ global_step,
101
+ scalars={},
102
+ histograms={},
103
+ images={},
104
+ audios={},
105
+ audio_sampling_rate=22050,
106
+ ):
107
+ for k, v in scalars.items():
108
+ writer.add_scalar(k, v, global_step)
109
+ for k, v in histograms.items():
110
+ writer.add_histogram(k, v, global_step)
111
+ for k, v in images.items():
112
+ writer.add_image(k, v, global_step, dataformats="HWC")
113
+ for k, v in audios.items():
114
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
115
+
116
+
117
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
118
+ f_list = glob.glob(os.path.join(dir_path, regex))
119
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
120
+ x = f_list[-1]
121
+ print(x)
122
+ return x
123
+
124
+
125
+ def plot_spectrogram_to_numpy(spectrogram):
126
+ global MATPLOTLIB_FLAG
127
+ if not MATPLOTLIB_FLAG:
128
+ import matplotlib
129
+
130
+ matplotlib.use("Agg")
131
+ MATPLOTLIB_FLAG = True
132
+ mpl_logger = logging.getLogger("matplotlib")
133
+ mpl_logger.setLevel(logging.WARNING)
134
+ import matplotlib.pylab as plt
135
+ import numpy as np
136
+
137
+ fig, ax = plt.subplots(figsize=(10, 2))
138
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
139
+ plt.colorbar(im, ax=ax)
140
+ plt.xlabel("Frames")
141
+ plt.ylabel("Channels")
142
+ plt.tight_layout()
143
+
144
+ fig.canvas.draw()
145
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
146
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
147
+ plt.close()
148
+ return data
149
+
150
+
151
+ def plot_alignment_to_numpy(alignment, info=None):
152
+ global MATPLOTLIB_FLAG
153
+ if not MATPLOTLIB_FLAG:
154
+ import matplotlib
155
+
156
+ matplotlib.use("Agg")
157
+ MATPLOTLIB_FLAG = True
158
+ mpl_logger = logging.getLogger("matplotlib")
159
+ mpl_logger.setLevel(logging.WARNING)
160
+ import matplotlib.pylab as plt
161
+ import numpy as np
162
+
163
+ fig, ax = plt.subplots(figsize=(6, 4))
164
+ im = ax.imshow(
165
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
166
+ )
167
+ fig.colorbar(im, ax=ax)
168
+ xlabel = "Decoder timestep"
169
+ if info is not None:
170
+ xlabel += "\n\n" + info
171
+ plt.xlabel(xlabel)
172
+ plt.ylabel("Encoder timestep")
173
+ plt.tight_layout()
174
+
175
+ fig.canvas.draw()
176
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
177
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
178
+ plt.close()
179
+ return data
180
+
181
+
182
+ def load_wav_to_torch(full_path):
183
+ sampling_rate, data = read(full_path)
184
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
185
+
186
+
187
+ def load_filepaths_and_text(filename, split="|"):
188
+ with open(filename, encoding="utf-8") as f:
189
+ filepaths_and_text = []
190
+ for line in f:
191
+ path_text = line.strip().split(split)
192
+ filepaths_and_text.append(path_text)
193
+ return filepaths_and_text
194
+
195
+
196
+ def get_hparams(init=True):
197
+ parser = argparse.ArgumentParser()
198
+ parser.add_argument(
199
+ "-c",
200
+ "--config",
201
+ type=str,
202
+ default="./configs/bert_vits.json",
203
+ help="JSON file for configuration",
204
+ )
205
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
206
+
207
+ args = parser.parse_args()
208
+ model_dir = os.path.join("./logs", args.model)
209
+
210
+ if not os.path.exists(model_dir):
211
+ os.makedirs(model_dir)
212
+
213
+ config_path = args.config
214
+ config_save_path = os.path.join(model_dir, "config.json")
215
+ if init:
216
+ with open(config_path, "r") as f:
217
+ data = f.read()
218
+ with open(config_save_path, "w") as f:
219
+ f.write(data)
220
+ else:
221
+ with open(config_save_path, "r") as f:
222
+ data = f.read()
223
+ config = json.loads(data)
224
+
225
+ hparams = HParams(**config)
226
+ hparams.model_dir = model_dir
227
+ return hparams
228
+
229
+
230
+ def get_hparams_from_dir(model_dir):
231
+ config_save_path = os.path.join(model_dir, "config.json")
232
+ with open(config_save_path, "r") as f:
233
+ data = f.read()
234
+ config = json.loads(data)
235
+
236
+ hparams = HParams(**config)
237
+ hparams.model_dir = model_dir
238
+ return hparams
239
+
240
+
241
+ def get_hparams_from_file(config_path):
242
+ with open(config_path, "r") as f:
243
+ data = f.read()
244
+ config = json.loads(data)
245
+
246
+ hparams = HParams(**config)
247
+ return hparams
248
+
249
+
250
+ def check_git_hash(model_dir):
251
+ source_dir = os.path.dirname(os.path.realpath(__file__))
252
+ if not os.path.exists(os.path.join(source_dir, ".git")):
253
+ logger.warn(
254
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
255
+ source_dir
256
+ )
257
+ )
258
+ return
259
+
260
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
261
+
262
+ path = os.path.join(model_dir, "githash")
263
+ if os.path.exists(path):
264
+ saved_hash = open(path).read()
265
+ if saved_hash != cur_hash:
266
+ logger.warn(
267
+ "git hash values are different. {}(saved) != {}(current)".format(
268
+ saved_hash[:8], cur_hash[:8]
269
+ )
270
+ )
271
+ else:
272
+ open(path, "w").write(cur_hash)
273
+
274
+
275
+ def get_logger(model_dir, filename="train.log"):
276
+ global logger
277
+ logger = logging.getLogger(os.path.basename(model_dir))
278
+ logger.setLevel(logging.DEBUG)
279
+
280
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
281
+ if not os.path.exists(model_dir):
282
+ os.makedirs(model_dir)
283
+ h = logging.FileHandler(os.path.join(model_dir, filename))
284
+ h.setLevel(logging.DEBUG)
285
+ h.setFormatter(formatter)
286
+ logger.addHandler(h)
287
+ return logger
288
+
289
+
290
+ class HParams:
291
+ def __init__(self, **kwargs):
292
+ for k, v in kwargs.items():
293
+ if type(v) == dict:
294
+ v = HParams(**v)
295
+ self[k] = v
296
+
297
+ def keys(self):
298
+ return self.__dict__.keys()
299
+
300
+ def items(self):
301
+ return self.__dict__.items()
302
+
303
+ def values(self):
304
+ return self.__dict__.values()
305
+
306
+ def __len__(self):
307
+ return len(self.__dict__)
308
+
309
+ def __getitem__(self, key):
310
+ return getattr(self, key)
311
+
312
+ def __setitem__(self, key, value):
313
+ return setattr(self, key, value)
314
+
315
+ def __contains__(self, key):
316
+ return key in self.__dict__
317
+
318
+ def __repr__(self):
319
+ return self.__dict__.__repr__()
vits_pinyin.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from pypinyin import Style
4
+ from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
5
+ from pypinyin.converter import DefaultConverter
6
+ from pypinyin.core import Pinyin
7
+
8
+ from text import pinyin_dict
9
+ from bert import TTSProsody
10
+
11
+
12
+ class MyConverter(NeutralToneWith5Mixin, DefaultConverter):
13
+ pass
14
+
15
+
16
+ def is_chinese(uchar):
17
+ if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
18
+ return True
19
+ else:
20
+ return False
21
+
22
+
23
+ def clean_chinese(text: str):
24
+ text = text.strip()
25
+ text_clean = []
26
+ for char in text:
27
+ if (is_chinese(char)):
28
+ text_clean.append(char)
29
+ else:
30
+ if len(text_clean) > 1 and is_chinese(text_clean[-1]):
31
+ text_clean.append(',')
32
+ text_clean = ''.join(text_clean).strip(',')
33
+ return text_clean
34
+
35
+
36
+ class VITS_PinYin:
37
+ def __init__(self, bert_path, device):
38
+ self.pinyin_parser = Pinyin(MyConverter())
39
+ self.prosody = TTSProsody(bert_path, device)
40
+
41
+ def get_phoneme4pinyin(self, pinyins):
42
+ result = []
43
+ count_phone = []
44
+ for pinyin in pinyins:
45
+ if pinyin[:-1] in pinyin_dict:
46
+ tone = pinyin[-1]
47
+ a = pinyin[:-1]
48
+ a1, a2 = pinyin_dict[a]
49
+ result += [a1, a2 + tone]
50
+ count_phone.append(2)
51
+ return result, count_phone
52
+
53
+ def chinese_to_phonemes(self, text):
54
+ text = clean_chinese(text)
55
+ phonemes = ["sil"]
56
+ chars = ['[PAD]']
57
+ count_phone = []
58
+ count_phone.append(1)
59
+ for subtext in text.split(","):
60
+ if (len(subtext) == 0):
61
+ continue
62
+ pinyins = self.correct_pinyin_tone3(subtext)
63
+ sub_p, sub_c = self.get_phoneme4pinyin(pinyins)
64
+ phonemes.extend(sub_p)
65
+ phonemes.append("sp")
66
+ count_phone.extend(sub_c)
67
+ count_phone.append(1)
68
+ chars.append(subtext)
69
+ chars.append(',')
70
+ phonemes.append("sil")
71
+ count_phone.append(1)
72
+ chars.append('[PAD]')
73
+ chars = "".join(chars)
74
+ char_embeds = self.prosody.get_char_embeds(chars)
75
+ char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone)
76
+ return " ".join(phonemes), char_embeds
77
+
78
+ def correct_pinyin_tone3(self, text):
79
+ pinyin_list = [p[0] for p in self.pinyin_parser.pinyin(
80
+ text, style=Style.TONE3, strict=False, neutral_tone_with_five=True)]
81
+ if len(pinyin_list) >= 2:
82
+ for i in range(1, len(pinyin_list)):
83
+ try:
84
+ if re.findall(r'\d', pinyin_list[i-1])[0] == '3' and re.findall(r'\d', pinyin_list[i])[0] == '3':
85
+ pinyin_list[i-1] = pinyin_list[i-1].replace('3', '2')
86
+ except IndexError:
87
+ pass
88
+ return pinyin_list