ElesisSiegherts commited on
Commit
164dba8
1 Parent(s): 9607672

Upload 128 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. oldVersion/V101/__init__.py +75 -0
  2. oldVersion/V101/__pycache__/__init__.cpython-310.pyc +0 -0
  3. oldVersion/V101/__pycache__/__init__.cpython-38.pyc +0 -0
  4. oldVersion/V101/__pycache__/models.cpython-310.pyc +0 -0
  5. oldVersion/V101/__pycache__/models.cpython-38.pyc +0 -0
  6. oldVersion/V101/models.py +977 -0
  7. oldVersion/V101/text/__init__.py +28 -0
  8. oldVersion/V101/text/__pycache__/__init__.cpython-310.pyc +0 -0
  9. oldVersion/V101/text/__pycache__/__init__.cpython-38.pyc +0 -0
  10. oldVersion/V101/text/__pycache__/chinese.cpython-310.pyc +0 -0
  11. oldVersion/V101/text/__pycache__/chinese.cpython-38.pyc +0 -0
  12. oldVersion/V101/text/__pycache__/cleaner.cpython-310.pyc +0 -0
  13. oldVersion/V101/text/__pycache__/cleaner.cpython-38.pyc +0 -0
  14. oldVersion/V101/text/__pycache__/symbols.cpython-310.pyc +0 -0
  15. oldVersion/V101/text/__pycache__/symbols.cpython-38.pyc +0 -0
  16. oldVersion/V101/text/__pycache__/tone_sandhi.cpython-310.pyc +0 -0
  17. oldVersion/V101/text/__pycache__/tone_sandhi.cpython-38.pyc +0 -0
  18. oldVersion/V101/text/chinese.py +199 -0
  19. oldVersion/V101/text/chinese_bert.py +100 -0
  20. oldVersion/V101/text/cleaner.py +28 -0
  21. oldVersion/V101/text/english.py +214 -0
  22. oldVersion/V101/text/english_bert_mock.py +5 -0
  23. oldVersion/V101/text/japanese.py +112 -0
  24. oldVersion/V101/text/opencpop-strict.txt +429 -0
  25. oldVersion/V101/text/symbols.py +183 -0
  26. oldVersion/V101/text/tone_sandhi.py +769 -0
  27. oldVersion/V110/__init__.py +90 -0
  28. oldVersion/V110/__pycache__/__init__.cpython-310.pyc +0 -0
  29. oldVersion/V110/__pycache__/__init__.cpython-38.pyc +0 -0
  30. oldVersion/V110/__pycache__/models.cpython-310.pyc +0 -0
  31. oldVersion/V110/__pycache__/models.cpython-38.pyc +0 -0
  32. oldVersion/V110/models.py +986 -0
  33. oldVersion/V110/text/__init__.py +29 -0
  34. oldVersion/V110/text/__pycache__/__init__.cpython-310.pyc +0 -0
  35. oldVersion/V110/text/__pycache__/__init__.cpython-38.pyc +0 -0
  36. oldVersion/V110/text/__pycache__/chinese.cpython-310.pyc +0 -0
  37. oldVersion/V110/text/__pycache__/chinese.cpython-38.pyc +0 -0
  38. oldVersion/V110/text/__pycache__/cleaner.cpython-310.pyc +0 -0
  39. oldVersion/V110/text/__pycache__/cleaner.cpython-38.pyc +0 -0
  40. oldVersion/V110/text/__pycache__/japanese.cpython-310.pyc +0 -0
  41. oldVersion/V110/text/__pycache__/japanese.cpython-38.pyc +0 -0
  42. oldVersion/V110/text/__pycache__/symbols.cpython-310.pyc +0 -0
  43. oldVersion/V110/text/__pycache__/symbols.cpython-38.pyc +0 -0
  44. oldVersion/V110/text/__pycache__/tone_sandhi.cpython-310.pyc +0 -0
  45. oldVersion/V110/text/__pycache__/tone_sandhi.cpython-38.pyc +0 -0
  46. oldVersion/V110/text/chinese.py +198 -0
  47. oldVersion/V110/text/chinese_bert.py +97 -0
  48. oldVersion/V110/text/cleaner.py +28 -0
  49. oldVersion/V110/text/english.py +214 -0
  50. oldVersion/V110/text/english_bert_mock.py +5 -0
oldVersion/V101/__init__.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1.0.1 版本兼容
3
+ https://github.com/fishaudio/Bert-VITS2/releases/tag/1.0.1
4
+ """
5
+ import torch
6
+ import commons
7
+ from .text.cleaner import clean_text
8
+ from .text import cleaned_text_to_sequence
9
+ from oldVersion.V111.text import get_bert
10
+
11
+
12
+ def get_text(text, language_str, hps, device):
13
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
14
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
15
+
16
+ if hps.data.add_blank:
17
+ phone = commons.intersperse(phone, 0)
18
+ tone = commons.intersperse(tone, 0)
19
+ language = commons.intersperse(language, 0)
20
+ for i in range(len(word2ph)):
21
+ word2ph[i] = word2ph[i] * 2
22
+ word2ph[0] += 1
23
+ bert = get_bert(norm_text, word2ph, language_str, device)
24
+ del word2ph
25
+
26
+ assert bert.shape[-1] == len(phone)
27
+
28
+ phone = torch.LongTensor(phone)
29
+ tone = torch.LongTensor(tone)
30
+ language = torch.LongTensor(language)
31
+
32
+ return bert, phone, tone, language
33
+
34
+
35
+ def infer(
36
+ text,
37
+ sdp_ratio,
38
+ noise_scale,
39
+ noise_scale_w,
40
+ length_scale,
41
+ sid,
42
+ hps,
43
+ net_g,
44
+ device,
45
+ ):
46
+ bert, phones, tones, lang_ids = get_text(text, "ZH", hps, device)
47
+ with torch.no_grad():
48
+ x_tst = phones.to(device).unsqueeze(0)
49
+ tones = tones.to(device).unsqueeze(0)
50
+ lang_ids = lang_ids.to(device).unsqueeze(0)
51
+ bert = bert.to(device).unsqueeze(0)
52
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
53
+ del phones
54
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
55
+ audio = (
56
+ net_g.infer(
57
+ x_tst,
58
+ x_tst_lengths,
59
+ speakers,
60
+ tones,
61
+ lang_ids,
62
+ bert,
63
+ sdp_ratio=sdp_ratio,
64
+ noise_scale=noise_scale,
65
+ noise_scale_w=noise_scale_w,
66
+ length_scale=length_scale,
67
+ )[0][0, 0]
68
+ .data.cpu()
69
+ .float()
70
+ .numpy()
71
+ )
72
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
73
+ if torch.cuda.is_available():
74
+ torch.cuda.empty_cache()
75
+ return audio
oldVersion/V101/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.69 kB). View file
 
oldVersion/V101/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.69 kB). View file
 
oldVersion/V101/__pycache__/models.cpython-310.pyc ADDED
Binary file (20.5 kB). View file
 
oldVersion/V101/__pycache__/models.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
oldVersion/V101/models.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ from commons import init_weights, get_padding
15
+ from .text import symbols, num_tones, num_languages
16
+
17
+
18
+ class DurationDiscriminator(nn.Module): # vits2
19
+ def __init__(
20
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
21
+ ):
22
+ super().__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.filter_channels = filter_channels
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.gin_channels = gin_channels
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.conv_1 = nn.Conv1d(
32
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
33
+ )
34
+ self.norm_1 = modules.LayerNorm(filter_channels)
35
+ self.conv_2 = nn.Conv1d(
36
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
37
+ )
38
+ self.norm_2 = modules.LayerNorm(filter_channels)
39
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
40
+
41
+ self.pre_out_conv_1 = nn.Conv1d(
42
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
43
+ )
44
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
45
+ self.pre_out_conv_2 = nn.Conv1d(
46
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
47
+ )
48
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
49
+
50
+ if gin_channels != 0:
51
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
52
+
53
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
54
+
55
+ def forward_probability(self, x, x_mask, dur, g=None):
56
+ dur = self.dur_proj(dur)
57
+ x = torch.cat([x, dur], dim=1)
58
+ x = self.pre_out_conv_1(x * x_mask)
59
+ x = torch.relu(x)
60
+ x = self.pre_out_norm_1(x)
61
+ x = self.drop(x)
62
+ x = self.pre_out_conv_2(x * x_mask)
63
+ x = torch.relu(x)
64
+ x = self.pre_out_norm_2(x)
65
+ x = self.drop(x)
66
+ x = x * x_mask
67
+ x = x.transpose(1, 2)
68
+ output_prob = self.output_layer(x)
69
+ return output_prob
70
+
71
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
72
+ x = torch.detach(x)
73
+ if g is not None:
74
+ g = torch.detach(g)
75
+ x = x + self.cond(g)
76
+ x = self.conv_1(x * x_mask)
77
+ x = torch.relu(x)
78
+ x = self.norm_1(x)
79
+ x = self.drop(x)
80
+ x = self.conv_2(x * x_mask)
81
+ x = torch.relu(x)
82
+ x = self.norm_2(x)
83
+ x = self.drop(x)
84
+
85
+ output_probs = []
86
+ for dur in [dur_r, dur_hat]:
87
+ output_prob = self.forward_probability(x, x_mask, dur, g)
88
+ output_probs.append(output_prob)
89
+
90
+ return output_probs
91
+
92
+
93
+ class TransformerCouplingBlock(nn.Module):
94
+ def __init__(
95
+ self,
96
+ channels,
97
+ hidden_channels,
98
+ filter_channels,
99
+ n_heads,
100
+ n_layers,
101
+ kernel_size,
102
+ p_dropout,
103
+ n_flows=4,
104
+ gin_channels=0,
105
+ share_parameter=False,
106
+ ):
107
+ super().__init__()
108
+ self.channels = channels
109
+ self.hidden_channels = hidden_channels
110
+ self.kernel_size = kernel_size
111
+ self.n_layers = n_layers
112
+ self.n_flows = n_flows
113
+ self.gin_channels = gin_channels
114
+
115
+ self.flows = nn.ModuleList()
116
+
117
+ self.wn = (
118
+ attentions.FFT(
119
+ hidden_channels,
120
+ filter_channels,
121
+ n_heads,
122
+ n_layers,
123
+ kernel_size,
124
+ p_dropout,
125
+ isflow=True,
126
+ gin_channels=self.gin_channels,
127
+ )
128
+ if share_parameter
129
+ else None
130
+ )
131
+
132
+ for i in range(n_flows):
133
+ self.flows.append(
134
+ modules.TransformerCouplingLayer(
135
+ channels,
136
+ hidden_channels,
137
+ kernel_size,
138
+ n_layers,
139
+ n_heads,
140
+ p_dropout,
141
+ filter_channels,
142
+ mean_only=True,
143
+ wn_sharing_parameter=self.wn,
144
+ gin_channels=self.gin_channels,
145
+ )
146
+ )
147
+ self.flows.append(modules.Flip())
148
+
149
+ def forward(self, x, x_mask, g=None, reverse=False):
150
+ if not reverse:
151
+ for flow in self.flows:
152
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
153
+ else:
154
+ for flow in reversed(self.flows):
155
+ x = flow(x, x_mask, g=g, reverse=reverse)
156
+ return x
157
+
158
+
159
+ class StochasticDurationPredictor(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ filter_channels,
164
+ kernel_size,
165
+ p_dropout,
166
+ n_flows=4,
167
+ gin_channels=0,
168
+ ):
169
+ super().__init__()
170
+ filter_channels = in_channels # it needs to be removed from future version.
171
+ self.in_channels = in_channels
172
+ self.filter_channels = filter_channels
173
+ self.kernel_size = kernel_size
174
+ self.p_dropout = p_dropout
175
+ self.n_flows = n_flows
176
+ self.gin_channels = gin_channels
177
+
178
+ self.log_flow = modules.Log()
179
+ self.flows = nn.ModuleList()
180
+ self.flows.append(modules.ElementwiseAffine(2))
181
+ for i in range(n_flows):
182
+ self.flows.append(
183
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
184
+ )
185
+ self.flows.append(modules.Flip())
186
+
187
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
188
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
189
+ self.post_convs = modules.DDSConv(
190
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
191
+ )
192
+ self.post_flows = nn.ModuleList()
193
+ self.post_flows.append(modules.ElementwiseAffine(2))
194
+ for i in range(4):
195
+ self.post_flows.append(
196
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
197
+ )
198
+ self.post_flows.append(modules.Flip())
199
+
200
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
201
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
202
+ self.convs = modules.DDSConv(
203
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
204
+ )
205
+ if gin_channels != 0:
206
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
207
+
208
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
209
+ x = torch.detach(x)
210
+ x = self.pre(x)
211
+ if g is not None:
212
+ g = torch.detach(g)
213
+ x = x + self.cond(g)
214
+ x = self.convs(x, x_mask)
215
+ x = self.proj(x) * x_mask
216
+
217
+ if not reverse:
218
+ flows = self.flows
219
+ assert w is not None
220
+
221
+ logdet_tot_q = 0
222
+ h_w = self.post_pre(w)
223
+ h_w = self.post_convs(h_w, x_mask)
224
+ h_w = self.post_proj(h_w) * x_mask
225
+ e_q = (
226
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
227
+ * x_mask
228
+ )
229
+ z_q = e_q
230
+ for flow in self.post_flows:
231
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
232
+ logdet_tot_q += logdet_q
233
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
234
+ u = torch.sigmoid(z_u) * x_mask
235
+ z0 = (w - u) * x_mask
236
+ logdet_tot_q += torch.sum(
237
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
238
+ )
239
+ logq = (
240
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
241
+ - logdet_tot_q
242
+ )
243
+
244
+ logdet_tot = 0
245
+ z0, logdet = self.log_flow(z0, x_mask)
246
+ logdet_tot += logdet
247
+ z = torch.cat([z0, z1], 1)
248
+ for flow in flows:
249
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
250
+ logdet_tot = logdet_tot + logdet
251
+ nll = (
252
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
253
+ - logdet_tot
254
+ )
255
+ return nll + logq # [b]
256
+ else:
257
+ flows = list(reversed(self.flows))
258
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
259
+ z = (
260
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
261
+ * noise_scale
262
+ )
263
+ for flow in flows:
264
+ z = flow(z, x_mask, g=x, reverse=reverse)
265
+ z0, z1 = torch.split(z, [1, 1], 1)
266
+ logw = z0
267
+ return logw
268
+
269
+
270
+ class DurationPredictor(nn.Module):
271
+ def __init__(
272
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
273
+ ):
274
+ super().__init__()
275
+
276
+ self.in_channels = in_channels
277
+ self.filter_channels = filter_channels
278
+ self.kernel_size = kernel_size
279
+ self.p_dropout = p_dropout
280
+ self.gin_channels = gin_channels
281
+
282
+ self.drop = nn.Dropout(p_dropout)
283
+ self.conv_1 = nn.Conv1d(
284
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
285
+ )
286
+ self.norm_1 = modules.LayerNorm(filter_channels)
287
+ self.conv_2 = nn.Conv1d(
288
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
289
+ )
290
+ self.norm_2 = modules.LayerNorm(filter_channels)
291
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
292
+
293
+ if gin_channels != 0:
294
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
295
+
296
+ def forward(self, x, x_mask, g=None):
297
+ x = torch.detach(x)
298
+ if g is not None:
299
+ g = torch.detach(g)
300
+ x = x + self.cond(g)
301
+ x = self.conv_1(x * x_mask)
302
+ x = torch.relu(x)
303
+ x = self.norm_1(x)
304
+ x = self.drop(x)
305
+ x = self.conv_2(x * x_mask)
306
+ x = torch.relu(x)
307
+ x = self.norm_2(x)
308
+ x = self.drop(x)
309
+ x = self.proj(x * x_mask)
310
+ return x * x_mask
311
+
312
+
313
+ class TextEncoder(nn.Module):
314
+ def __init__(
315
+ self,
316
+ n_vocab,
317
+ out_channels,
318
+ hidden_channels,
319
+ filter_channels,
320
+ n_heads,
321
+ n_layers,
322
+ kernel_size,
323
+ p_dropout,
324
+ gin_channels=0,
325
+ ):
326
+ super().__init__()
327
+ self.n_vocab = n_vocab
328
+ self.out_channels = out_channels
329
+ self.hidden_channels = hidden_channels
330
+ self.filter_channels = filter_channels
331
+ self.n_heads = n_heads
332
+ self.n_layers = n_layers
333
+ self.kernel_size = kernel_size
334
+ self.p_dropout = p_dropout
335
+ self.gin_channels = gin_channels
336
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
337
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
338
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
339
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
340
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
341
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
342
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
343
+
344
+ self.encoder = attentions.Encoder(
345
+ hidden_channels,
346
+ filter_channels,
347
+ n_heads,
348
+ n_layers,
349
+ kernel_size,
350
+ p_dropout,
351
+ gin_channels=self.gin_channels,
352
+ )
353
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
354
+
355
+ def forward(self, x, x_lengths, tone, language, bert, g=None):
356
+ x = (
357
+ self.emb(x)
358
+ + self.tone_emb(tone)
359
+ + self.language_emb(language)
360
+ + self.bert_proj(bert).transpose(1, 2)
361
+ ) * math.sqrt(
362
+ self.hidden_channels
363
+ ) # [b, t, h]
364
+ x = torch.transpose(x, 1, -1) # [b, h, t]
365
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
366
+ x.dtype
367
+ )
368
+
369
+ x = self.encoder(x * x_mask, x_mask, g=g)
370
+ stats = self.proj(x) * x_mask
371
+
372
+ m, logs = torch.split(stats, self.out_channels, dim=1)
373
+ return x, m, logs, x_mask
374
+
375
+
376
+ class ResidualCouplingBlock(nn.Module):
377
+ def __init__(
378
+ self,
379
+ channels,
380
+ hidden_channels,
381
+ kernel_size,
382
+ dilation_rate,
383
+ n_layers,
384
+ n_flows=4,
385
+ gin_channels=0,
386
+ ):
387
+ super().__init__()
388
+ self.channels = channels
389
+ self.hidden_channels = hidden_channels
390
+ self.kernel_size = kernel_size
391
+ self.dilation_rate = dilation_rate
392
+ self.n_layers = n_layers
393
+ self.n_flows = n_flows
394
+ self.gin_channels = gin_channels
395
+
396
+ self.flows = nn.ModuleList()
397
+ for i in range(n_flows):
398
+ self.flows.append(
399
+ modules.ResidualCouplingLayer(
400
+ channels,
401
+ hidden_channels,
402
+ kernel_size,
403
+ dilation_rate,
404
+ n_layers,
405
+ gin_channels=gin_channels,
406
+ mean_only=True,
407
+ )
408
+ )
409
+ self.flows.append(modules.Flip())
410
+
411
+ def forward(self, x, x_mask, g=None, reverse=False):
412
+ if not reverse:
413
+ for flow in self.flows:
414
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
415
+ else:
416
+ for flow in reversed(self.flows):
417
+ x = flow(x, x_mask, g=g, reverse=reverse)
418
+ return x
419
+
420
+
421
+ class PosteriorEncoder(nn.Module):
422
+ def __init__(
423
+ self,
424
+ in_channels,
425
+ out_channels,
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ gin_channels=0,
431
+ ):
432
+ super().__init__()
433
+ self.in_channels = in_channels
434
+ self.out_channels = out_channels
435
+ self.hidden_channels = hidden_channels
436
+ self.kernel_size = kernel_size
437
+ self.dilation_rate = dilation_rate
438
+ self.n_layers = n_layers
439
+ self.gin_channels = gin_channels
440
+
441
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
442
+ self.enc = modules.WN(
443
+ hidden_channels,
444
+ kernel_size,
445
+ dilation_rate,
446
+ n_layers,
447
+ gin_channels=gin_channels,
448
+ )
449
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
450
+
451
+ def forward(self, x, x_lengths, g=None):
452
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
453
+ x.dtype
454
+ )
455
+ x = self.pre(x) * x_mask
456
+ x = self.enc(x, x_mask, g=g)
457
+ stats = self.proj(x) * x_mask
458
+ m, logs = torch.split(stats, self.out_channels, dim=1)
459
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
460
+ return z, m, logs, x_mask
461
+
462
+
463
+ class Generator(torch.nn.Module):
464
+ def __init__(
465
+ self,
466
+ initial_channel,
467
+ resblock,
468
+ resblock_kernel_sizes,
469
+ resblock_dilation_sizes,
470
+ upsample_rates,
471
+ upsample_initial_channel,
472
+ upsample_kernel_sizes,
473
+ gin_channels=0,
474
+ ):
475
+ super(Generator, self).__init__()
476
+ self.num_kernels = len(resblock_kernel_sizes)
477
+ self.num_upsamples = len(upsample_rates)
478
+ self.conv_pre = Conv1d(
479
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
480
+ )
481
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
482
+
483
+ self.ups = nn.ModuleList()
484
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
485
+ self.ups.append(
486
+ weight_norm(
487
+ ConvTranspose1d(
488
+ upsample_initial_channel // (2**i),
489
+ upsample_initial_channel // (2 ** (i + 1)),
490
+ k,
491
+ u,
492
+ padding=(k - u) // 2,
493
+ )
494
+ )
495
+ )
496
+
497
+ self.resblocks = nn.ModuleList()
498
+ for i in range(len(self.ups)):
499
+ ch = upsample_initial_channel // (2 ** (i + 1))
500
+ for j, (k, d) in enumerate(
501
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
502
+ ):
503
+ self.resblocks.append(resblock(ch, k, d))
504
+
505
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
506
+ self.ups.apply(init_weights)
507
+
508
+ if gin_channels != 0:
509
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
510
+
511
+ def forward(self, x, g=None):
512
+ x = self.conv_pre(x)
513
+ if g is not None:
514
+ x = x + self.cond(g)
515
+
516
+ for i in range(self.num_upsamples):
517
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
518
+ x = self.ups[i](x)
519
+ xs = None
520
+ for j in range(self.num_kernels):
521
+ if xs is None:
522
+ xs = self.resblocks[i * self.num_kernels + j](x)
523
+ else:
524
+ xs += self.resblocks[i * self.num_kernels + j](x)
525
+ x = xs / self.num_kernels
526
+ x = F.leaky_relu(x)
527
+ x = self.conv_post(x)
528
+ x = torch.tanh(x)
529
+
530
+ return x
531
+
532
+ def remove_weight_norm(self):
533
+ print("Removing weight norm...")
534
+ for l in self.ups:
535
+ remove_weight_norm(l)
536
+ for l in self.resblocks:
537
+ l.remove_weight_norm()
538
+
539
+
540
+ class DiscriminatorP(torch.nn.Module):
541
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
542
+ super(DiscriminatorP, self).__init__()
543
+ self.period = period
544
+ self.use_spectral_norm = use_spectral_norm
545
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
546
+ self.convs = nn.ModuleList(
547
+ [
548
+ norm_f(
549
+ Conv2d(
550
+ 1,
551
+ 32,
552
+ (kernel_size, 1),
553
+ (stride, 1),
554
+ padding=(get_padding(kernel_size, 1), 0),
555
+ )
556
+ ),
557
+ norm_f(
558
+ Conv2d(
559
+ 32,
560
+ 128,
561
+ (kernel_size, 1),
562
+ (stride, 1),
563
+ padding=(get_padding(kernel_size, 1), 0),
564
+ )
565
+ ),
566
+ norm_f(
567
+ Conv2d(
568
+ 128,
569
+ 512,
570
+ (kernel_size, 1),
571
+ (stride, 1),
572
+ padding=(get_padding(kernel_size, 1), 0),
573
+ )
574
+ ),
575
+ norm_f(
576
+ Conv2d(
577
+ 512,
578
+ 1024,
579
+ (kernel_size, 1),
580
+ (stride, 1),
581
+ padding=(get_padding(kernel_size, 1), 0),
582
+ )
583
+ ),
584
+ norm_f(
585
+ Conv2d(
586
+ 1024,
587
+ 1024,
588
+ (kernel_size, 1),
589
+ 1,
590
+ padding=(get_padding(kernel_size, 1), 0),
591
+ )
592
+ ),
593
+ ]
594
+ )
595
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
596
+
597
+ def forward(self, x):
598
+ fmap = []
599
+
600
+ # 1d to 2d
601
+ b, c, t = x.shape
602
+ if t % self.period != 0: # pad first
603
+ n_pad = self.period - (t % self.period)
604
+ x = F.pad(x, (0, n_pad), "reflect")
605
+ t = t + n_pad
606
+ x = x.view(b, c, t // self.period, self.period)
607
+
608
+ for l in self.convs:
609
+ x = l(x)
610
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
611
+ fmap.append(x)
612
+ x = self.conv_post(x)
613
+ fmap.append(x)
614
+ x = torch.flatten(x, 1, -1)
615
+
616
+ return x, fmap
617
+
618
+
619
+ class DiscriminatorS(torch.nn.Module):
620
+ def __init__(self, use_spectral_norm=False):
621
+ super(DiscriminatorS, self).__init__()
622
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
623
+ self.convs = nn.ModuleList(
624
+ [
625
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
626
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
627
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
628
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
629
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
630
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
631
+ ]
632
+ )
633
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
634
+
635
+ def forward(self, x):
636
+ fmap = []
637
+
638
+ for l in self.convs:
639
+ x = l(x)
640
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
641
+ fmap.append(x)
642
+ x = self.conv_post(x)
643
+ fmap.append(x)
644
+ x = torch.flatten(x, 1, -1)
645
+
646
+ return x, fmap
647
+
648
+
649
+ class MultiPeriodDiscriminator(torch.nn.Module):
650
+ def __init__(self, use_spectral_norm=False):
651
+ super(MultiPeriodDiscriminator, self).__init__()
652
+ periods = [2, 3, 5, 7, 11]
653
+
654
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
655
+ discs = discs + [
656
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
657
+ ]
658
+ self.discriminators = nn.ModuleList(discs)
659
+
660
+ def forward(self, y, y_hat):
661
+ y_d_rs = []
662
+ y_d_gs = []
663
+ fmap_rs = []
664
+ fmap_gs = []
665
+ for i, d in enumerate(self.discriminators):
666
+ y_d_r, fmap_r = d(y)
667
+ y_d_g, fmap_g = d(y_hat)
668
+ y_d_rs.append(y_d_r)
669
+ y_d_gs.append(y_d_g)
670
+ fmap_rs.append(fmap_r)
671
+ fmap_gs.append(fmap_g)
672
+
673
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
674
+
675
+
676
+ class ReferenceEncoder(nn.Module):
677
+ """
678
+ inputs --- [N, Ty/r, n_mels*r] mels
679
+ outputs --- [N, ref_enc_gru_size]
680
+ """
681
+
682
+ def __init__(self, spec_channels, gin_channels=0):
683
+ super().__init__()
684
+ self.spec_channels = spec_channels
685
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
686
+ K = len(ref_enc_filters)
687
+ filters = [1] + ref_enc_filters
688
+ convs = [
689
+ weight_norm(
690
+ nn.Conv2d(
691
+ in_channels=filters[i],
692
+ out_channels=filters[i + 1],
693
+ kernel_size=(3, 3),
694
+ stride=(2, 2),
695
+ padding=(1, 1),
696
+ )
697
+ )
698
+ for i in range(K)
699
+ ]
700
+ self.convs = nn.ModuleList(convs)
701
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
702
+
703
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
704
+ self.gru = nn.GRU(
705
+ input_size=ref_enc_filters[-1] * out_channels,
706
+ hidden_size=256 // 2,
707
+ batch_first=True,
708
+ )
709
+ self.proj = nn.Linear(128, gin_channels)
710
+
711
+ def forward(self, inputs, mask=None):
712
+ N = inputs.size(0)
713
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
714
+ for conv in self.convs:
715
+ out = conv(out)
716
+ # out = wn(out)
717
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
718
+
719
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
720
+ T = out.size(1)
721
+ N = out.size(0)
722
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
723
+
724
+ self.gru.flatten_parameters()
725
+ memory, out = self.gru(out) # out --- [1, N, 128]
726
+
727
+ return self.proj(out.squeeze(0))
728
+
729
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
730
+ for i in range(n_convs):
731
+ L = (L - kernel_size + 2 * pad) // stride + 1
732
+ return L
733
+
734
+
735
+ class SynthesizerTrn(nn.Module):
736
+ """
737
+ Synthesizer for Training
738
+ """
739
+
740
+ def __init__(
741
+ self,
742
+ n_vocab,
743
+ spec_channels,
744
+ segment_size,
745
+ inter_channels,
746
+ hidden_channels,
747
+ filter_channels,
748
+ n_heads,
749
+ n_layers,
750
+ kernel_size,
751
+ p_dropout,
752
+ resblock,
753
+ resblock_kernel_sizes,
754
+ resblock_dilation_sizes,
755
+ upsample_rates,
756
+ upsample_initial_channel,
757
+ upsample_kernel_sizes,
758
+ n_speakers=256,
759
+ gin_channels=256,
760
+ use_sdp=True,
761
+ n_flow_layer=4,
762
+ n_layers_trans_flow=3,
763
+ flow_share_parameter=False,
764
+ use_transformer_flow=True,
765
+ **kwargs
766
+ ):
767
+ super().__init__()
768
+ self.n_vocab = n_vocab
769
+ self.spec_channels = spec_channels
770
+ self.inter_channels = inter_channels
771
+ self.hidden_channels = hidden_channels
772
+ self.filter_channels = filter_channels
773
+ self.n_heads = n_heads
774
+ self.n_layers = n_layers
775
+ self.kernel_size = kernel_size
776
+ self.p_dropout = p_dropout
777
+ self.resblock = resblock
778
+ self.resblock_kernel_sizes = resblock_kernel_sizes
779
+ self.resblock_dilation_sizes = resblock_dilation_sizes
780
+ self.upsample_rates = upsample_rates
781
+ self.upsample_initial_channel = upsample_initial_channel
782
+ self.upsample_kernel_sizes = upsample_kernel_sizes
783
+ self.segment_size = segment_size
784
+ self.n_speakers = n_speakers
785
+ self.gin_channels = gin_channels
786
+ self.n_layers_trans_flow = n_layers_trans_flow
787
+ self.use_spk_conditioned_encoder = kwargs.get(
788
+ "use_spk_conditioned_encoder", True
789
+ )
790
+ self.use_sdp = use_sdp
791
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
792
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
793
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
794
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
795
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
796
+ self.enc_gin_channels = gin_channels
797
+ self.enc_p = TextEncoder(
798
+ n_vocab,
799
+ inter_channels,
800
+ hidden_channels,
801
+ filter_channels,
802
+ n_heads,
803
+ n_layers,
804
+ kernel_size,
805
+ p_dropout,
806
+ gin_channels=self.enc_gin_channels,
807
+ )
808
+ self.dec = Generator(
809
+ inter_channels,
810
+ resblock,
811
+ resblock_kernel_sizes,
812
+ resblock_dilation_sizes,
813
+ upsample_rates,
814
+ upsample_initial_channel,
815
+ upsample_kernel_sizes,
816
+ gin_channels=gin_channels,
817
+ )
818
+ self.enc_q = PosteriorEncoder(
819
+ spec_channels,
820
+ inter_channels,
821
+ hidden_channels,
822
+ 5,
823
+ 1,
824
+ 16,
825
+ gin_channels=gin_channels,
826
+ )
827
+ if use_transformer_flow:
828
+ self.flow = TransformerCouplingBlock(
829
+ inter_channels,
830
+ hidden_channels,
831
+ filter_channels,
832
+ n_heads,
833
+ n_layers_trans_flow,
834
+ 5,
835
+ p_dropout,
836
+ n_flow_layer,
837
+ gin_channels=gin_channels,
838
+ share_parameter=flow_share_parameter,
839
+ )
840
+ else:
841
+ self.flow = ResidualCouplingBlock(
842
+ inter_channels,
843
+ hidden_channels,
844
+ 5,
845
+ 1,
846
+ n_flow_layer,
847
+ gin_channels=gin_channels,
848
+ )
849
+ self.sdp = StochasticDurationPredictor(
850
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
851
+ )
852
+ self.dp = DurationPredictor(
853
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
854
+ )
855
+
856
+ if n_speakers > 0:
857
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
858
+ else:
859
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
860
+
861
+ def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert):
862
+ if self.n_speakers >= 0:
863
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
864
+ else:
865
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
866
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert, g=g)
867
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
868
+ z_p = self.flow(z, y_mask, g=g)
869
+
870
+ with torch.no_grad():
871
+ # negative cross-entropy
872
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
873
+ neg_cent1 = torch.sum(
874
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
875
+ ) # [b, 1, t_s]
876
+ neg_cent2 = torch.matmul(
877
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
878
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
879
+ neg_cent3 = torch.matmul(
880
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
881
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
882
+ neg_cent4 = torch.sum(
883
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
884
+ ) # [b, 1, t_s]
885
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
886
+ if self.use_noise_scaled_mas:
887
+ epsilon = (
888
+ torch.std(neg_cent)
889
+ * torch.randn_like(neg_cent)
890
+ * self.current_mas_noise_scale
891
+ )
892
+ neg_cent = neg_cent + epsilon
893
+
894
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
895
+ attn = (
896
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
897
+ .unsqueeze(1)
898
+ .detach()
899
+ )
900
+
901
+ w = attn.sum(2)
902
+
903
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
904
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
905
+
906
+ logw_ = torch.log(w + 1e-6) * x_mask
907
+ logw = self.dp(x, x_mask, g=g)
908
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
909
+ x_mask
910
+ ) # for averaging
911
+
912
+ l_length = l_length_dp + l_length_sdp
913
+
914
+ # expand prior
915
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
916
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
917
+
918
+ z_slice, ids_slice = commons.rand_slice_segments(
919
+ z, y_lengths, self.segment_size
920
+ )
921
+ o = self.dec(z_slice, g=g)
922
+ return (
923
+ o,
924
+ l_length,
925
+ attn,
926
+ ids_slice,
927
+ x_mask,
928
+ y_mask,
929
+ (z, z_p, m_p, logs_p, m_q, logs_q),
930
+ (x, logw, logw_),
931
+ )
932
+
933
+ def infer(
934
+ self,
935
+ x,
936
+ x_lengths,
937
+ sid,
938
+ tone,
939
+ language,
940
+ bert,
941
+ noise_scale=0.667,
942
+ length_scale=1,
943
+ noise_scale_w=0.8,
944
+ max_len=None,
945
+ sdp_ratio=0,
946
+ y=None,
947
+ ):
948
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
949
+ # g = self.gst(y)
950
+ if self.n_speakers > 0:
951
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
952
+ else:
953
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
954
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert, g=g)
955
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
956
+ sdp_ratio
957
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
958
+ w = torch.exp(logw) * x_mask * length_scale
959
+ w_ceil = torch.ceil(w)
960
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
961
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
962
+ x_mask.dtype
963
+ )
964
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
965
+ attn = commons.generate_path(w_ceil, attn_mask)
966
+
967
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
968
+ 1, 2
969
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
970
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
971
+ 1, 2
972
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
973
+
974
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
975
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
976
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
977
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
oldVersion/V101/text/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+
7
+ def cleaned_text_to_sequence(cleaned_text, tones, language):
8
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
+ Args:
10
+ text: string to convert to a sequence
11
+ Returns:
12
+ List of integers corresponding to the symbols in the text
13
+ """
14
+ phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
15
+ tone_start = language_tone_start_map[language]
16
+ tones = [i + tone_start for i in tones]
17
+ lang_id = language_id_map[language]
18
+ lang_ids = [lang_id for i in phones]
19
+ return phones, tones, lang_ids
20
+
21
+
22
+ def get_bert(norm_text, word2ph, language):
23
+ from .chinese_bert import get_bert_feature as zh_bert
24
+ from .english_bert_mock import get_bert_feature as en_bert
25
+
26
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert}
27
+ bert = lang_bert_func_map[language](norm_text, word2ph)
28
+ return bert
oldVersion/V101/text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.52 kB). View file
 
oldVersion/V101/text/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
oldVersion/V101/text/__pycache__/chinese.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
oldVersion/V101/text/__pycache__/chinese.cpython-38.pyc ADDED
Binary file (4.53 kB). View file
 
oldVersion/V101/text/__pycache__/cleaner.cpython-310.pyc ADDED
Binary file (946 Bytes). View file
 
oldVersion/V101/text/__pycache__/cleaner.cpython-38.pyc ADDED
Binary file (936 Bytes). View file
 
oldVersion/V101/text/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (1.48 kB). View file
 
oldVersion/V101/text/__pycache__/symbols.cpython-38.pyc ADDED
Binary file (1.82 kB). View file
 
oldVersion/V101/text/__pycache__/tone_sandhi.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
oldVersion/V101/text/__pycache__/tone_sandhi.cpython-38.pyc ADDED
Binary file (15.6 kB). View file
 
oldVersion/V101/text/chinese.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+
8
+ from .symbols import punctuation
9
+ from .tone_sandhi import ToneSandhi
10
+
11
+ current_file_path = os.path.dirname(__file__)
12
+ pinyin_to_symbol_map = {
13
+ line.split("\t")[0]: line.strip().split("\t")[1]
14
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
15
+ }
16
+
17
+ import jieba.posseg as psg
18
+
19
+
20
+ rep_map = {
21
+ ":": ",",
22
+ ";": ",",
23
+ ",": ",",
24
+ "。": ".",
25
+ "!": "!",
26
+ "?": "?",
27
+ "\n": ".",
28
+ "·": ",",
29
+ "、": ",",
30
+ "...": "…",
31
+ "$": ".",
32
+ "“": "'",
33
+ "”": "'",
34
+ "‘": "'",
35
+ "’": "'",
36
+ "(": "'",
37
+ ")": "'",
38
+ "(": "'",
39
+ ")": "'",
40
+ "《": "'",
41
+ "》": "'",
42
+ "【": "'",
43
+ "】": "'",
44
+ "[": "'",
45
+ "]": "'",
46
+ "—": "-",
47
+ "~": "-",
48
+ "~": "-",
49
+ "「": "'",
50
+ "」": "'",
51
+ }
52
+
53
+ tone_modifier = ToneSandhi()
54
+
55
+
56
+ def replace_punctuation(text):
57
+ text = text.replace("嗯", "恩").replace("呣", "母")
58
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
59
+
60
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
61
+
62
+ replaced_text = re.sub(
63
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
64
+ )
65
+
66
+ return replaced_text
67
+
68
+
69
+ def g2p(text):
70
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
71
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
72
+ phones, tones, word2ph = _g2p(sentences)
73
+ assert sum(word2ph) == len(phones)
74
+ assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
75
+ phones = ["_"] + phones + ["_"]
76
+ tones = [0] + tones + [0]
77
+ word2ph = [1] + word2ph + [1]
78
+ return phones, tones, word2ph
79
+
80
+
81
+ def _get_initials_finals(word):
82
+ initials = []
83
+ finals = []
84
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
85
+ orig_finals = lazy_pinyin(
86
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
87
+ )
88
+ for c, v in zip(orig_initials, orig_finals):
89
+ initials.append(c)
90
+ finals.append(v)
91
+ return initials, finals
92
+
93
+
94
+ def _g2p(segments):
95
+ phones_list = []
96
+ tones_list = []
97
+ word2ph = []
98
+ for seg in segments:
99
+ # Replace all English words in the sentence
100
+ seg = re.sub("[a-zA-Z]+", "", seg)
101
+ seg_cut = psg.lcut(seg)
102
+ initials = []
103
+ finals = []
104
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
105
+ for word, pos in seg_cut:
106
+ if pos == "eng":
107
+ continue
108
+ sub_initials, sub_finals = _get_initials_finals(word)
109
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
110
+ initials.append(sub_initials)
111
+ finals.append(sub_finals)
112
+
113
+ # assert len(sub_initials) == len(sub_finals) == len(word)
114
+ initials = sum(initials, [])
115
+ finals = sum(finals, [])
116
+ #
117
+ for c, v in zip(initials, finals):
118
+ raw_pinyin = c + v
119
+ # NOTE: post process for pypinyin outputs
120
+ # we discriminate i, ii and iii
121
+ if c == v:
122
+ assert c in punctuation
123
+ phone = [c]
124
+ tone = "0"
125
+ word2ph.append(1)
126
+ else:
127
+ v_without_tone = v[:-1]
128
+ tone = v[-1]
129
+
130
+ pinyin = c + v_without_tone
131
+ assert tone in "12345"
132
+
133
+ if c:
134
+ # 多音节
135
+ v_rep_map = {
136
+ "uei": "ui",
137
+ "iou": "iu",
138
+ "uen": "un",
139
+ }
140
+ if v_without_tone in v_rep_map.keys():
141
+ pinyin = c + v_rep_map[v_without_tone]
142
+ else:
143
+ # 单音节
144
+ pinyin_rep_map = {
145
+ "ing": "ying",
146
+ "i": "yi",
147
+ "in": "yin",
148
+ "u": "wu",
149
+ }
150
+ if pinyin in pinyin_rep_map.keys():
151
+ pinyin = pinyin_rep_map[pinyin]
152
+ else:
153
+ single_rep_map = {
154
+ "v": "yu",
155
+ "e": "e",
156
+ "i": "y",
157
+ "u": "w",
158
+ }
159
+ if pinyin[0] in single_rep_map.keys():
160
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
161
+
162
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
163
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
164
+ word2ph.append(len(phone))
165
+
166
+ phones_list += phone
167
+ tones_list += [int(tone)] * len(phone)
168
+ return phones_list, tones_list, word2ph
169
+
170
+
171
+ def text_normalize(text):
172
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
173
+ for number in numbers:
174
+ text = text.replace(number, cn2an.an2cn(number), 1)
175
+ text = replace_punctuation(text)
176
+ return text
177
+
178
+
179
+ def get_bert_feature(text, word2ph):
180
+ from text import chinese_bert
181
+
182
+ return chinese_bert.get_bert_feature(text, word2ph)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from text.chinese_bert import get_bert_feature
187
+
188
+ text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
189
+ text = text_normalize(text)
190
+ print(text)
191
+ phones, tones, word2ph = g2p(text)
192
+ bert = get_bert_feature(text, word2ph)
193
+
194
+ print(phones, tones, word2ph, bert.shape)
195
+
196
+
197
+ # # 示例用法
198
+ # text = "这是一个示例文本:,你好!这是一个测试...."
199
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
oldVersion/V101/text/chinese_bert.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+
5
+ device = torch.device(
6
+ "cuda"
7
+ if torch.cuda.is_available()
8
+ else (
9
+ "mps"
10
+ if sys.platform == "darwin" and torch.backends.mps.is_available()
11
+ else "cpu"
12
+ )
13
+ )
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("./bert/chinese-roberta-wwm-ext-large")
16
+ model = AutoModelForMaskedLM.from_pretrained("./bert/chinese-roberta-wwm-ext-large").to(
17
+ device
18
+ )
19
+
20
+
21
+ def get_bert_feature(text, word2ph):
22
+ with torch.no_grad():
23
+ inputs = tokenizer(text, return_tensors="pt")
24
+ for i in inputs:
25
+ inputs[i] = inputs[i].to(device)
26
+ res = model(**inputs, output_hidden_states=True)
27
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
28
+
29
+ assert len(word2ph) == len(text) + 2
30
+ word2phone = word2ph
31
+ phone_level_feature = []
32
+ for i in range(len(word2phone)):
33
+ repeat_feature = res[i].repeat(word2phone[i], 1)
34
+ phone_level_feature.append(repeat_feature)
35
+
36
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
37
+
38
+ return phone_level_feature.T
39
+
40
+
41
+ if __name__ == "__main__":
42
+ # feature = get_bert_feature('你好,我是说的道理。')
43
+ import torch
44
+
45
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
46
+ word2phone = [
47
+ 1,
48
+ 2,
49
+ 1,
50
+ 2,
51
+ 2,
52
+ 1,
53
+ 2,
54
+ 2,
55
+ 1,
56
+ 2,
57
+ 2,
58
+ 1,
59
+ 2,
60
+ 2,
61
+ 2,
62
+ 2,
63
+ 2,
64
+ 1,
65
+ 1,
66
+ 2,
67
+ 2,
68
+ 1,
69
+ 2,
70
+ 2,
71
+ 2,
72
+ 2,
73
+ 1,
74
+ 2,
75
+ 2,
76
+ 2,
77
+ 2,
78
+ 2,
79
+ 1,
80
+ 2,
81
+ 2,
82
+ 2,
83
+ 2,
84
+ 1,
85
+ ]
86
+
87
+ # 计算总帧数
88
+ total_frames = sum(word2phone)
89
+ print(word_level_feature.shape)
90
+ print(word2phone)
91
+ phone_level_feature = []
92
+ for i in range(len(word2phone)):
93
+ print(word_level_feature[i].shape)
94
+
95
+ # 对每个词重复word2phone[i]次
96
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
97
+ phone_level_feature.append(repeat_feature)
98
+
99
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
100
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
oldVersion/V101/text/cleaner.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import chinese, cleaned_text_to_sequence
2
+
3
+
4
+ language_module_map = {"ZH": chinese}
5
+
6
+
7
+ def clean_text(text, language):
8
+ language_module = language_module_map[language]
9
+ norm_text = language_module.text_normalize(text)
10
+ phones, tones, word2ph = language_module.g2p(norm_text)
11
+ return norm_text, phones, tones, word2ph
12
+
13
+
14
+ def clean_text_bert(text, language):
15
+ language_module = language_module_map[language]
16
+ norm_text = language_module.text_normalize(text)
17
+ phones, tones, word2ph = language_module.g2p(norm_text)
18
+ bert = language_module.get_bert_feature(norm_text, word2ph)
19
+ return phones, tones, bert
20
+
21
+
22
+ def text_to_sequence(text, language):
23
+ norm_text, phones, tones, word2ph = clean_text(text, language)
24
+ return cleaned_text_to_sequence(phones, tones, language)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ pass
oldVersion/V101/text/english.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+
6
+ from text import symbols
7
+
8
+ current_file_path = os.path.dirname(__file__)
9
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
10
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
11
+ _g2p = G2p()
12
+
13
+ arpa = {
14
+ "AH0",
15
+ "S",
16
+ "AH1",
17
+ "EY2",
18
+ "AE2",
19
+ "EH0",
20
+ "OW2",
21
+ "UH0",
22
+ "NG",
23
+ "B",
24
+ "G",
25
+ "AY0",
26
+ "M",
27
+ "AA0",
28
+ "F",
29
+ "AO0",
30
+ "ER2",
31
+ "UH1",
32
+ "IY1",
33
+ "AH2",
34
+ "DH",
35
+ "IY0",
36
+ "EY1",
37
+ "IH0",
38
+ "K",
39
+ "N",
40
+ "W",
41
+ "IY2",
42
+ "T",
43
+ "AA1",
44
+ "ER1",
45
+ "EH2",
46
+ "OY0",
47
+ "UH2",
48
+ "UW1",
49
+ "Z",
50
+ "AW2",
51
+ "AW1",
52
+ "V",
53
+ "UW2",
54
+ "AA2",
55
+ "ER",
56
+ "AW0",
57
+ "UW0",
58
+ "R",
59
+ "OW1",
60
+ "EH1",
61
+ "ZH",
62
+ "AE0",
63
+ "IH2",
64
+ "IH",
65
+ "Y",
66
+ "JH",
67
+ "P",
68
+ "AY1",
69
+ "EY0",
70
+ "OY2",
71
+ "TH",
72
+ "HH",
73
+ "D",
74
+ "ER0",
75
+ "CH",
76
+ "AO1",
77
+ "AE1",
78
+ "AO2",
79
+ "OY1",
80
+ "AY2",
81
+ "IH1",
82
+ "OW0",
83
+ "L",
84
+ "SH",
85
+ }
86
+
87
+
88
+ def post_replace_ph(ph):
89
+ rep_map = {
90
+ ":": ",",
91
+ ";": ",",
92
+ ",": ",",
93
+ "。": ".",
94
+ "!": "!",
95
+ "?": "?",
96
+ "\n": ".",
97
+ "·": ",",
98
+ "、": ",",
99
+ "...": "…",
100
+ "v": "V",
101
+ }
102
+ if ph in rep_map.keys():
103
+ ph = rep_map[ph]
104
+ if ph in symbols:
105
+ return ph
106
+ if ph not in symbols:
107
+ ph = "UNK"
108
+ return ph
109
+
110
+
111
+ def read_dict():
112
+ g2p_dict = {}
113
+ start_line = 49
114
+ with open(CMU_DICT_PATH) as f:
115
+ line = f.readline()
116
+ line_index = 1
117
+ while line:
118
+ if line_index >= start_line:
119
+ line = line.strip()
120
+ word_split = line.split(" ")
121
+ word = word_split[0]
122
+
123
+ syllable_split = word_split[1].split(" - ")
124
+ g2p_dict[word] = []
125
+ for syllable in syllable_split:
126
+ phone_split = syllable.split(" ")
127
+ g2p_dict[word].append(phone_split)
128
+
129
+ line_index = line_index + 1
130
+ line = f.readline()
131
+
132
+ return g2p_dict
133
+
134
+
135
+ def cache_dict(g2p_dict, file_path):
136
+ with open(file_path, "wb") as pickle_file:
137
+ pickle.dump(g2p_dict, pickle_file)
138
+
139
+
140
+ def get_dict():
141
+ if os.path.exists(CACHE_PATH):
142
+ with open(CACHE_PATH, "rb") as pickle_file:
143
+ g2p_dict = pickle.load(pickle_file)
144
+ else:
145
+ g2p_dict = read_dict()
146
+ cache_dict(g2p_dict, CACHE_PATH)
147
+
148
+ return g2p_dict
149
+
150
+
151
+ eng_dict = get_dict()
152
+
153
+
154
+ def refine_ph(phn):
155
+ tone = 0
156
+ if re.search(r"\d$", phn):
157
+ tone = int(phn[-1]) + 1
158
+ phn = phn[:-1]
159
+ return phn.lower(), tone
160
+
161
+
162
+ def refine_syllables(syllables):
163
+ tones = []
164
+ phonemes = []
165
+ for phn_list in syllables:
166
+ for i in range(len(phn_list)):
167
+ phn = phn_list[i]
168
+ phn, tone = refine_ph(phn)
169
+ phonemes.append(phn)
170
+ tones.append(tone)
171
+ return phonemes, tones
172
+
173
+
174
+ def text_normalize(text):
175
+ # todo: eng text normalize
176
+ return text
177
+
178
+
179
+ def g2p(text):
180
+ phones = []
181
+ tones = []
182
+ words = re.split(r"([,;.\-\?\!\s+])", text)
183
+ for w in words:
184
+ if w.upper() in eng_dict:
185
+ phns, tns = refine_syllables(eng_dict[w.upper()])
186
+ phones += phns
187
+ tones += tns
188
+ else:
189
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
190
+ for ph in phone_list:
191
+ if ph in arpa:
192
+ ph, tn = refine_ph(ph)
193
+ phones.append(ph)
194
+ tones.append(tn)
195
+ else:
196
+ phones.append(ph)
197
+ tones.append(0)
198
+ # todo: implement word2ph
199
+ word2ph = [1 for i in phones]
200
+
201
+ phones = [post_replace_ph(i) for i in phones]
202
+ return phones, tones, word2ph
203
+
204
+
205
+ if __name__ == "__main__":
206
+ # print(get_dict())
207
+ # print(eng_word_to_phoneme("hello"))
208
+ print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))
209
+ # all_phones = set()
210
+ # for k, syllables in eng_dict.items():
211
+ # for group in syllables:
212
+ # for ph in group:
213
+ # all_phones.add(ph)
214
+ # print(all_phones)
oldVersion/V101/text/english_bert_mock.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_bert_feature(norm_text, word2ph):
5
+ return torch.zeros(1024, sum(word2ph))
oldVersion/V101/text/japanese.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py
2
+ import re
3
+ import sys
4
+
5
+ import pyopenjtalk
6
+
7
+ from . import symbols
8
+
9
+ # Regular expression matching Japanese without punctuation marks:
10
+ _japanese_characters = re.compile(
11
+ r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
12
+ )
13
+
14
+ # Regular expression matching non-Japanese characters or punctuation marks:
15
+ _japanese_marks = re.compile(
16
+ r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
17
+ )
18
+
19
+ # List of (symbol, Japanese) pairs for marks:
20
+ _symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]]
21
+
22
+
23
+ # List of (consonant, sokuon) pairs:
24
+ _real_sokuon = [
25
+ (re.compile("%s" % x[0]), x[1])
26
+ for x in [
27
+ (r"Q([↑↓]*[kg])", r"k#\1"),
28
+ (r"Q([↑↓]*[tdjʧ])", r"t#\1"),
29
+ (r"Q([↑↓]*[sʃ])", r"s\1"),
30
+ (r"Q([↑↓]*[pb])", r"p#\1"),
31
+ ]
32
+ ]
33
+
34
+ # List of (consonant, hatsuon) pairs:
35
+ _real_hatsuon = [
36
+ (re.compile("%s" % x[0]), x[1])
37
+ for x in [
38
+ (r"N([↑↓]*[pbm])", r"m\1"),
39
+ (r"N([↑↓]*[ʧʥj])", r"n^\1"),
40
+ (r"N([↑↓]*[tdn])", r"n\1"),
41
+ (r"N([↑↓]*[kg])", r"ŋ\1"),
42
+ ]
43
+ ]
44
+
45
+
46
+ def post_replace_ph(ph):
47
+ rep_map = {
48
+ ":": ",",
49
+ ";": ",",
50
+ ",": ",",
51
+ "。": ".",
52
+ "!": "!",
53
+ "?": "?",
54
+ "\n": ".",
55
+ "·": ",",
56
+ "、": ",",
57
+ "...": "…",
58
+ "v": "V",
59
+ }
60
+ if ph in rep_map.keys():
61
+ ph = rep_map[ph]
62
+ if ph in symbols:
63
+ return ph
64
+ if ph not in symbols:
65
+ ph = "UNK"
66
+ return ph
67
+
68
+
69
+ def symbols_to_japanese(text):
70
+ for regex, replacement in _symbols_to_japanese:
71
+ text = re.sub(regex, replacement, text)
72
+ return text
73
+
74
+
75
+ def preprocess_jap(text):
76
+ """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
77
+ text = symbols_to_japanese(text)
78
+ sentences = re.split(_japanese_marks, text)
79
+ marks = re.findall(_japanese_marks, text)
80
+ text = []
81
+ for i, sentence in enumerate(sentences):
82
+ if re.match(_japanese_characters, sentence):
83
+ p = pyopenjtalk.g2p(sentence)
84
+ text += p.split(" ")
85
+
86
+ if i < len(marks):
87
+ text += [marks[i].replace(" ", "")]
88
+ return text
89
+
90
+
91
+ def text_normalize(text):
92
+ # todo: jap text normalize
93
+ return text
94
+
95
+
96
+ def g2p(norm_text):
97
+ phones = preprocess_jap(norm_text)
98
+ phones = [post_replace_ph(i) for i in phones]
99
+ # todo: implement tones and word2ph
100
+ tones = [0 for i in phones]
101
+ word2ph = [1 for i in phones]
102
+ return phones, tones, word2ph
103
+
104
+
105
+ if __name__ == "__main__":
106
+ for line in open("../../../Downloads/transcript_utf8.txt").readlines():
107
+ text = line.split(":")[1]
108
+ phones, tones, word2ph = g2p(text)
109
+ for p in phones:
110
+ if p == "z":
111
+ print(text, phones)
112
+ sys.exit(0)
oldVersion/V101/text/opencpop-strict.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a AA a
2
+ ai AA ai
3
+ an AA an
4
+ ang AA ang
5
+ ao AA ao
6
+ ba b a
7
+ bai b ai
8
+ ban b an
9
+ bang b ang
10
+ bao b ao
11
+ bei b ei
12
+ ben b en
13
+ beng b eng
14
+ bi b i
15
+ bian b ian
16
+ biao b iao
17
+ bie b ie
18
+ bin b in
19
+ bing b ing
20
+ bo b o
21
+ bu b u
22
+ ca c a
23
+ cai c ai
24
+ can c an
25
+ cang c ang
26
+ cao c ao
27
+ ce c e
28
+ cei c ei
29
+ cen c en
30
+ ceng c eng
31
+ cha ch a
32
+ chai ch ai
33
+ chan ch an
34
+ chang ch ang
35
+ chao ch ao
36
+ che ch e
37
+ chen ch en
38
+ cheng ch eng
39
+ chi ch ir
40
+ chong ch ong
41
+ chou ch ou
42
+ chu ch u
43
+ chua ch ua
44
+ chuai ch uai
45
+ chuan ch uan
46
+ chuang ch uang
47
+ chui ch ui
48
+ chun ch un
49
+ chuo ch uo
50
+ ci c i0
51
+ cong c ong
52
+ cou c ou
53
+ cu c u
54
+ cuan c uan
55
+ cui c ui
56
+ cun c un
57
+ cuo c uo
58
+ da d a
59
+ dai d ai
60
+ dan d an
61
+ dang d ang
62
+ dao d ao
63
+ de d e
64
+ dei d ei
65
+ den d en
66
+ deng d eng
67
+ di d i
68
+ dia d ia
69
+ dian d ian
70
+ diao d iao
71
+ die d ie
72
+ ding d ing
73
+ diu d iu
74
+ dong d ong
75
+ dou d ou
76
+ du d u
77
+ duan d uan
78
+ dui d ui
79
+ dun d un
80
+ duo d uo
81
+ e EE e
82
+ ei EE ei
83
+ en EE en
84
+ eng EE eng
85
+ er EE er
86
+ fa f a
87
+ fan f an
88
+ fang f ang
89
+ fei f ei
90
+ fen f en
91
+ feng f eng
92
+ fo f o
93
+ fou f ou
94
+ fu f u
95
+ ga g a
96
+ gai g ai
97
+ gan g an
98
+ gang g ang
99
+ gao g ao
100
+ ge g e
101
+ gei g ei
102
+ gen g en
103
+ geng g eng
104
+ gong g ong
105
+ gou g ou
106
+ gu g u
107
+ gua g ua
108
+ guai g uai
109
+ guan g uan
110
+ guang g uang
111
+ gui g ui
112
+ gun g un
113
+ guo g uo
114
+ ha h a
115
+ hai h ai
116
+ han h an
117
+ hang h ang
118
+ hao h ao
119
+ he h e
120
+ hei h ei
121
+ hen h en
122
+ heng h eng
123
+ hong h ong
124
+ hou h ou
125
+ hu h u
126
+ hua h ua
127
+ huai h uai
128
+ huan h uan
129
+ huang h uang
130
+ hui h ui
131
+ hun h un
132
+ huo h uo
133
+ ji j i
134
+ jia j ia
135
+ jian j ian
136
+ jiang j iang
137
+ jiao j iao
138
+ jie j ie
139
+ jin j in
140
+ jing j ing
141
+ jiong j iong
142
+ jiu j iu
143
+ ju j v
144
+ jv j v
145
+ juan j van
146
+ jvan j van
147
+ jue j ve
148
+ jve j ve
149
+ jun j vn
150
+ jvn j vn
151
+ ka k a
152
+ kai k ai
153
+ kan k an
154
+ kang k ang
155
+ kao k ao
156
+ ke k e
157
+ kei k ei
158
+ ken k en
159
+ keng k eng
160
+ kong k ong
161
+ kou k ou
162
+ ku k u
163
+ kua k ua
164
+ kuai k uai
165
+ kuan k uan
166
+ kuang k uang
167
+ kui k ui
168
+ kun k un
169
+ kuo k uo
170
+ la l a
171
+ lai l ai
172
+ lan l an
173
+ lang l ang
174
+ lao l ao
175
+ le l e
176
+ lei l ei
177
+ leng l eng
178
+ li l i
179
+ lia l ia
180
+ lian l ian
181
+ liang l iang
182
+ liao l iao
183
+ lie l ie
184
+ lin l in
185
+ ling l ing
186
+ liu l iu
187
+ lo l o
188
+ long l ong
189
+ lou l ou
190
+ lu l u
191
+ luan l uan
192
+ lun l un
193
+ luo l uo
194
+ lv l v
195
+ lve l ve
196
+ ma m a
197
+ mai m ai
198
+ man m an
199
+ mang m ang
200
+ mao m ao
201
+ me m e
202
+ mei m ei
203
+ men m en
204
+ meng m eng
205
+ mi m i
206
+ mian m ian
207
+ miao m iao
208
+ mie m ie
209
+ min m in
210
+ ming m ing
211
+ miu m iu
212
+ mo m o
213
+ mou m ou
214
+ mu m u
215
+ na n a
216
+ nai n ai
217
+ nan n an
218
+ nang n ang
219
+ nao n ao
220
+ ne n e
221
+ nei n ei
222
+ nen n en
223
+ neng n eng
224
+ ni n i
225
+ nian n ian
226
+ niang n iang
227
+ niao n iao
228
+ nie n ie
229
+ nin n in
230
+ ning n ing
231
+ niu n iu
232
+ nong n ong
233
+ nou n ou
234
+ nu n u
235
+ nuan n uan
236
+ nun n un
237
+ nuo n uo
238
+ nv n v
239
+ nve n ve
240
+ o OO o
241
+ ou OO ou
242
+ pa p a
243
+ pai p ai
244
+ pan p an
245
+ pang p ang
246
+ pao p ao
247
+ pei p ei
248
+ pen p en
249
+ peng p eng
250
+ pi p i
251
+ pian p ian
252
+ piao p iao
253
+ pie p ie
254
+ pin p in
255
+ ping p ing
256
+ po p o
257
+ pou p ou
258
+ pu p u
259
+ qi q i
260
+ qia q ia
261
+ qian q ian
262
+ qiang q iang
263
+ qiao q iao
264
+ qie q ie
265
+ qin q in
266
+ qing q ing
267
+ qiong q iong
268
+ qiu q iu
269
+ qu q v
270
+ qv q v
271
+ quan q van
272
+ qvan q van
273
+ que q ve
274
+ qve q ve
275
+ qun q vn
276
+ qvn q vn
277
+ ran r an
278
+ rang r ang
279
+ rao r ao
280
+ re r e
281
+ ren r en
282
+ reng r eng
283
+ ri r ir
284
+ rong r ong
285
+ rou r ou
286
+ ru r u
287
+ rua r ua
288
+ ruan r uan
289
+ rui r ui
290
+ run r un
291
+ ruo r uo
292
+ sa s a
293
+ sai s ai
294
+ san s an
295
+ sang s ang
296
+ sao s ao
297
+ se s e
298
+ sen s en
299
+ seng s eng
300
+ sha sh a
301
+ shai sh ai
302
+ shan sh an
303
+ shang sh ang
304
+ shao sh ao
305
+ she sh e
306
+ shei sh ei
307
+ shen sh en
308
+ sheng sh eng
309
+ shi sh ir
310
+ shou sh ou
311
+ shu sh u
312
+ shua sh ua
313
+ shuai sh uai
314
+ shuan sh uan
315
+ shuang sh uang
316
+ shui sh ui
317
+ shun sh un
318
+ shuo sh uo
319
+ si s i0
320
+ song s ong
321
+ sou s ou
322
+ su s u
323
+ suan s uan
324
+ sui s ui
325
+ sun s un
326
+ suo s uo
327
+ ta t a
328
+ tai t ai
329
+ tan t an
330
+ tang t ang
331
+ tao t ao
332
+ te t e
333
+ tei t ei
334
+ teng t eng
335
+ ti t i
336
+ tian t ian
337
+ tiao t iao
338
+ tie t ie
339
+ ting t ing
340
+ tong t ong
341
+ tou t ou
342
+ tu t u
343
+ tuan t uan
344
+ tui t ui
345
+ tun t un
346
+ tuo t uo
347
+ wa w a
348
+ wai w ai
349
+ wan w an
350
+ wang w ang
351
+ wei w ei
352
+ wen w en
353
+ weng w eng
354
+ wo w o
355
+ wu w u
356
+ xi x i
357
+ xia x ia
358
+ xian x ian
359
+ xiang x iang
360
+ xiao x iao
361
+ xie x ie
362
+ xin x in
363
+ xing x ing
364
+ xiong x iong
365
+ xiu x iu
366
+ xu x v
367
+ xv x v
368
+ xuan x van
369
+ xvan x van
370
+ xue x ve
371
+ xve x ve
372
+ xun x vn
373
+ xvn x vn
374
+ ya y a
375
+ yan y En
376
+ yang y ang
377
+ yao y ao
378
+ ye y E
379
+ yi y i
380
+ yin y in
381
+ ying y ing
382
+ yo y o
383
+ yong y ong
384
+ you y ou
385
+ yu y v
386
+ yv y v
387
+ yuan y van
388
+ yvan y van
389
+ yue y ve
390
+ yve y ve
391
+ yun y vn
392
+ yvn y vn
393
+ za z a
394
+ zai z ai
395
+ zan z an
396
+ zang z ang
397
+ zao z ao
398
+ ze z e
399
+ zei z ei
400
+ zen z en
401
+ zeng z eng
402
+ zha zh a
403
+ zhai zh ai
404
+ zhan zh an
405
+ zhang zh ang
406
+ zhao zh ao
407
+ zhe zh e
408
+ zhei zh ei
409
+ zhen zh en
410
+ zheng zh eng
411
+ zhi zh ir
412
+ zhong zh ong
413
+ zhou zh ou
414
+ zhu zh u
415
+ zhua zh ua
416
+ zhuai zh uai
417
+ zhuan zh uan
418
+ zhuang zh uang
419
+ zhui zh ui
420
+ zhun zh un
421
+ zhuo zh uo
422
+ zi z i0
423
+ zong z ong
424
+ zou z ou
425
+ zu z u
426
+ zuan z uan
427
+ zui z ui
428
+ zun z un
429
+ zuo z uo
oldVersion/V101/text/symbols.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
+
5
+ # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
+ num_zh_tones = 6
74
+
75
+ # japanese
76
+ ja_symbols = [
77
+ "I",
78
+ "N",
79
+ "U",
80
+ "a",
81
+ "b",
82
+ "by",
83
+ "ch",
84
+ "cl",
85
+ "d",
86
+ "dy",
87
+ "e",
88
+ "f",
89
+ "g",
90
+ "gy",
91
+ "h",
92
+ "hy",
93
+ "i",
94
+ "j",
95
+ "k",
96
+ "ky",
97
+ "m",
98
+ "my",
99
+ "n",
100
+ "ny",
101
+ "o",
102
+ "p",
103
+ "py",
104
+ "r",
105
+ "ry",
106
+ "s",
107
+ "sh",
108
+ "t",
109
+ "ts",
110
+ "u",
111
+ "V",
112
+ "w",
113
+ "y",
114
+ "z",
115
+ ]
116
+ num_ja_tones = 1
117
+
118
+ # English
119
+ en_symbols = [
120
+ "aa",
121
+ "ae",
122
+ "ah",
123
+ "ao",
124
+ "aw",
125
+ "ay",
126
+ "b",
127
+ "ch",
128
+ "d",
129
+ "dh",
130
+ "eh",
131
+ "er",
132
+ "ey",
133
+ "f",
134
+ "g",
135
+ "hh",
136
+ "ih",
137
+ "iy",
138
+ "jh",
139
+ "k",
140
+ "l",
141
+ "m",
142
+ "n",
143
+ "ng",
144
+ "ow",
145
+ "oy",
146
+ "p",
147
+ "r",
148
+ "s",
149
+ "sh",
150
+ "t",
151
+ "th",
152
+ "uh",
153
+ "uw",
154
+ "V",
155
+ "w",
156
+ "y",
157
+ "z",
158
+ "zh",
159
+ ]
160
+ num_en_tones = 4
161
+
162
+ # combine all symbols
163
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
164
+ symbols = [pad] + normal_symbols + pu_symbols
165
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
166
+
167
+ # combine all tones
168
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
169
+
170
+ # language maps
171
+ language_id_map = {"ZH": 0, "JA": 1, "EN": 2}
172
+ num_languages = len(language_id_map.keys())
173
+
174
+ language_tone_start_map = {
175
+ "ZH": 0,
176
+ "JA": num_zh_tones,
177
+ "EN": num_zh_tones + num_ja_tones,
178
+ }
179
+
180
+ if __name__ == "__main__":
181
+ a = set(zh_symbols)
182
+ b = set(en_symbols)
183
+ print(sorted(a & b))
oldVersion/V101/text/tone_sandhi.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ from typing import Tuple
16
+
17
+ import jieba
18
+ from pypinyin import lazy_pinyin
19
+ from pypinyin import Style
20
+
21
+
22
+ class ToneSandhi:
23
+ def __init__(self):
24
+ self.must_neural_tone_words = {
25
+ "麻烦",
26
+ "麻利",
27
+ "鸳鸯",
28
+ "高粱",
29
+ "骨头",
30
+ "骆驼",
31
+ "马虎",
32
+ "首饰",
33
+ "馒头",
34
+ "馄饨",
35
+ "风筝",
36
+ "难为",
37
+ "队伍",
38
+ "阔气",
39
+ "闺女",
40
+ "门道",
41
+ "锄头",
42
+ "铺盖",
43
+ "铃铛",
44
+ "铁匠",
45
+ "钥匙",
46
+ "里脊",
47
+ "里头",
48
+ "部分",
49
+ "那么",
50
+ "道士",
51
+ "造化",
52
+ "迷糊",
53
+ "连累",
54
+ "这么",
55
+ "这个",
56
+ "运气",
57
+ "过去",
58
+ "软和",
59
+ "转悠",
60
+ "踏实",
61
+ "跳蚤",
62
+ "跟头",
63
+ "趔趄",
64
+ "财主",
65
+ "豆腐",
66
+ "讲究",
67
+ "记性",
68
+ "记号",
69
+ "认识",
70
+ "规矩",
71
+ "见识",
72
+ "裁缝",
73
+ "补丁",
74
+ "衣裳",
75
+ "衣服",
76
+ "衙门",
77
+ "街坊",
78
+ "行李",
79
+ "行当",
80
+ "蛤蟆",
81
+ "蘑菇",
82
+ "薄荷",
83
+ "葫芦",
84
+ "葡萄",
85
+ "萝卜",
86
+ "荸荠",
87
+ "苗条",
88
+ "苗头",
89
+ "苍蝇",
90
+ "芝麻",
91
+ "舒服",
92
+ "舒坦",
93
+ "舌头",
94
+ "自在",
95
+ "膏药",
96
+ "脾气",
97
+ "脑袋",
98
+ "脊梁",
99
+ "能耐",
100
+ "胳膊",
101
+ "胭脂",
102
+ "胡萝",
103
+ "胡琴",
104
+ "胡同",
105
+ "聪明",
106
+ "耽误",
107
+ "耽搁",
108
+ "耷拉",
109
+ "耳朵",
110
+ "老爷",
111
+ "老实",
112
+ "老婆",
113
+ "老头",
114
+ "老太",
115
+ "翻腾",
116
+ "罗嗦",
117
+ "罐头",
118
+ "编辑",
119
+ "结实",
120
+ "红火",
121
+ "累赘",
122
+ "糨糊",
123
+ "糊涂",
124
+ "精神",
125
+ "粮食",
126
+ "簸箕",
127
+ "篱笆",
128
+ "算计",
129
+ "算盘",
130
+ "答应",
131
+ "笤帚",
132
+ "笑语",
133
+ "笑话",
134
+ "窟窿",
135
+ "窝囊",
136
+ "窗户",
137
+ "稳当",
138
+ "稀罕",
139
+ "称呼",
140
+ "秧歌",
141
+ "秀气",
142
+ "秀才",
143
+ "福气",
144
+ "祖宗",
145
+ "砚台",
146
+ "码头",
147
+ "石榴",
148
+ "石头",
149
+ "石匠",
150
+ "知识",
151
+ "眼睛",
152
+ "眯缝",
153
+ "眨巴",
154
+ "眉毛",
155
+ "相声",
156
+ "盘算",
157
+ "白净",
158
+ "痢疾",
159
+ "痛快",
160
+ "疟疾",
161
+ "疙瘩",
162
+ "疏忽",
163
+ "畜生",
164
+ "生意",
165
+ "甘蔗",
166
+ "琵琶",
167
+ "琢磨",
168
+ "琉璃",
169
+ "玻璃",
170
+ "玫瑰",
171
+ "玄乎",
172
+ "狐狸",
173
+ "状元",
174
+ "特务",
175
+ "牲口",
176
+ "牙碜",
177
+ "牌楼",
178
+ "爽快",
179
+ "爱人",
180
+ "热闹",
181
+ "烧饼",
182
+ "烟筒",
183
+ "烂糊",
184
+ "点心",
185
+ "炊帚",
186
+ "灯笼",
187
+ "火候",
188
+ "漂亮",
189
+ "滑溜",
190
+ "溜达",
191
+ "温和",
192
+ "清楚",
193
+ "消息",
194
+ "浪头",
195
+ "活泼",
196
+ "比方",
197
+ "正经",
198
+ "欺负",
199
+ "模糊",
200
+ "槟榔",
201
+ "棺材",
202
+ "棒槌",
203
+ "棉花",
204
+ "核桃",
205
+ "栅栏",
206
+ "柴火",
207
+ "架势",
208
+ "枕头",
209
+ "枇杷",
210
+ "机灵",
211
+ "本事",
212
+ "木头",
213
+ "木匠",
214
+ "朋友",
215
+ "月饼",
216
+ "月亮",
217
+ "暖和",
218
+ "明白",
219
+ "时候",
220
+ "新鲜",
221
+ "故事",
222
+ "收拾",
223
+ "收成",
224
+ "提防",
225
+ "挖苦",
226
+ "挑剔",
227
+ "指甲",
228
+ "指头",
229
+ "拾掇",
230
+ "拳头",
231
+ "拨弄",
232
+ "招牌",
233
+ "招呼",
234
+ "抬举",
235
+ "护士",
236
+ "折腾",
237
+ "扫帚",
238
+ "打量",
239
+ "打算",
240
+ "打点",
241
+ "打扮",
242
+ "打听",
243
+ "打发",
244
+ "扎实",
245
+ "扁担",
246
+ "戒指",
247
+ "懒得",
248
+ "意识",
249
+ "意思",
250
+ "情形",
251
+ "悟性",
252
+ "怪物",
253
+ "思量",
254
+ "怎么",
255
+ "念头",
256
+ "念叨",
257
+ "快活",
258
+ "忙活",
259
+ "志气",
260
+ "心思",
261
+ "得罪",
262
+ "张罗",
263
+ "弟兄",
264
+ "开通",
265
+ "应酬",
266
+ "庄稼",
267
+ "干事",
268
+ "帮手",
269
+ "帐篷",
270
+ "希罕",
271
+ "师父",
272
+ "师傅",
273
+ "巴结",
274
+ "巴掌",
275
+ "差事",
276
+ "工夫",
277
+ "岁数",
278
+ "屁股",
279
+ "尾巴",
280
+ "少爷",
281
+ "小气",
282
+ "小伙",
283
+ "将就",
284
+ "对头",
285
+ "对付",
286
+ "寡妇",
287
+ "家伙",
288
+ "客气",
289
+ "实在",
290
+ "官司",
291
+ "学问",
292
+ "学生",
293
+ "字号",
294
+ "嫁妆",
295
+ "媳妇",
296
+ "媒人",
297
+ "婆家",
298
+ "娘家",
299
+ "委屈",
300
+ "姑娘",
301
+ "姐夫",
302
+ "妯娌",
303
+ "妥当",
304
+ "妖精",
305
+ "奴才",
306
+ "女婿",
307
+ "头发",
308
+ "太阳",
309
+ "大爷",
310
+ "大方",
311
+ "大意",
312
+ "大夫",
313
+ "多少",
314
+ "多么",
315
+ "外甥",
316
+ "壮实",
317
+ "地道",
318
+ "地方",
319
+ "在乎",
320
+ "困难",
321
+ "嘴巴",
322
+ "嘱咐",
323
+ "嘟囔",
324
+ "嘀咕",
325
+ "喜欢",
326
+ "喇嘛",
327
+ "喇叭",
328
+ "商量",
329
+ "唾沫",
330
+ "哑巴",
331
+ "哈欠",
332
+ "哆嗦",
333
+ "咳嗽",
334
+ "和尚",
335
+ "告诉",
336
+ "告示",
337
+ "含糊",
338
+ "吓唬",
339
+ "后头",
340
+ "名字",
341
+ "名堂",
342
+ "合同",
343
+ "吆喝",
344
+ "叫唤",
345
+ "口袋",
346
+ "厚道",
347
+ "厉害",
348
+ "千斤",
349
+ "包袱",
350
+ "包涵",
351
+ "匀称",
352
+ "勤快",
353
+ "动静",
354
+ "动弹",
355
+ "功夫",
356
+ "力气",
357
+ "前头",
358
+ "刺猬",
359
+ "刺激",
360
+ "别扭",
361
+ "利落",
362
+ "利索",
363
+ "利害",
364
+ "分析",
365
+ "出息",
366
+ "凑合",
367
+ "凉快",
368
+ "冷战",
369
+ "冤枉",
370
+ "冒失",
371
+ "养活",
372
+ "关系",
373
+ "先生",
374
+ "兄弟",
375
+ "便宜",
376
+ "使唤",
377
+ "佩服",
378
+ "作坊",
379
+ "体面",
380
+ "位置",
381
+ "似的",
382
+ "伙计",
383
+ "休息",
384
+ "什么",
385
+ "人家",
386
+ "亲戚",
387
+ "亲家",
388
+ "交情",
389
+ "云彩",
390
+ "事情",
391
+ "买卖",
392
+ "主意",
393
+ "丫头",
394
+ "丧气",
395
+ "两口",
396
+ "东西",
397
+ "东家",
398
+ "世故",
399
+ "不由",
400
+ "不在",
401
+ "下水",
402
+ "下巴",
403
+ "上头",
404
+ "上司",
405
+ "丈夫",
406
+ "丈人",
407
+ "一辈",
408
+ "那个",
409
+ "菩萨",
410
+ "父亲",
411
+ "母亲",
412
+ "咕噜",
413
+ "邋遢",
414
+ "费用",
415
+ "冤家",
416
+ "甜头",
417
+ "介绍",
418
+ "荒唐",
419
+ "大人",
420
+ "泥鳅",
421
+ "幸福",
422
+ "熟悉",
423
+ "计划",
424
+ "扑腾",
425
+ "蜡烛",
426
+ "姥爷",
427
+ "照顾",
428
+ "喉咙",
429
+ "吉他",
430
+ "弄堂",
431
+ "蚂蚱",
432
+ "凤凰",
433
+ "拖沓",
434
+ "寒碜",
435
+ "糟蹋",
436
+ "倒腾",
437
+ "报复",
438
+ "逻辑",
439
+ "盘缠",
440
+ "喽啰",
441
+ "牢骚",
442
+ "咖喱",
443
+ "扫把",
444
+ "惦记",
445
+ }
446
+ self.must_not_neural_tone_words = {
447
+ "男子",
448
+ "女子",
449
+ "分子",
450
+ "原子",
451
+ "量子",
452
+ "莲子",
453
+ "石子",
454
+ "瓜子",
455
+ "电子",
456
+ "人人",
457
+ "虎虎",
458
+ }
459
+ self.punc = ":,;。?!“”‘’':,;.?!"
460
+
461
+ # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
462
+ # e.g.
463
+ # word: "家里"
464
+ # pos: "s"
465
+ # finals: ['ia1', 'i3']
466
+ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
467
+ # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
468
+ for j, item in enumerate(word):
469
+ if (
470
+ j - 1 >= 0
471
+ and item == word[j - 1]
472
+ and pos[0] in {"n", "v", "a"}
473
+ and word not in self.must_not_neural_tone_words
474
+ ):
475
+ finals[j] = finals[j][:-1] + "5"
476
+ ge_idx = word.find("个")
477
+ if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
478
+ finals[-1] = finals[-1][:-1] + "5"
479
+ elif len(word) >= 1 and word[-1] in "的地得":
480
+ finals[-1] = finals[-1][:-1] + "5"
481
+ # e.g. 走了, 看着, 去过
482
+ # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
483
+ # finals[-1] = finals[-1][:-1] + "5"
484
+ elif (
485
+ len(word) > 1
486
+ and word[-1] in "们子"
487
+ and pos in {"r", "n"}
488
+ and word not in self.must_not_neural_tone_words
489
+ ):
490
+ finals[-1] = finals[-1][:-1] + "5"
491
+ # e.g. 桌上, 地下, 家里
492
+ elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
493
+ finals[-1] = finals[-1][:-1] + "5"
494
+ # e.g. 上来, 下去
495
+ elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
496
+ finals[-1] = finals[-1][:-1] + "5"
497
+ # 个做量词
498
+ elif (
499
+ ge_idx >= 1
500
+ and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
501
+ ) or word == "个":
502
+ finals[ge_idx] = finals[ge_idx][:-1] + "5"
503
+ else:
504
+ if (
505
+ word in self.must_neural_tone_words
506
+ or word[-2:] in self.must_neural_tone_words
507
+ ):
508
+ finals[-1] = finals[-1][:-1] + "5"
509
+
510
+ word_list = self._split_word(word)
511
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
512
+ for i, word in enumerate(word_list):
513
+ # conventional neural in Chinese
514
+ if (
515
+ word in self.must_neural_tone_words
516
+ or word[-2:] in self.must_neural_tone_words
517
+ ):
518
+ finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
519
+ finals = sum(finals_list, [])
520
+ return finals
521
+
522
+ def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
523
+ # e.g. 看不懂
524
+ if len(word) == 3 and word[1] == "不":
525
+ finals[1] = finals[1][:-1] + "5"
526
+ else:
527
+ for i, char in enumerate(word):
528
+ # "不" before tone4 should be bu2, e.g. 不怕
529
+ if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
530
+ finals[i] = finals[i][:-1] + "2"
531
+ return finals
532
+
533
+ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
534
+ # "一" in number sequences, e.g. 一零零, 二一零
535
+ if word.find("一") != -1 and all(
536
+ [item.isnumeric() for item in word if item != "一"]
537
+ ):
538
+ return finals
539
+ # "一" between reduplication words shold be yi5, e.g. 看一看
540
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
541
+ finals[1] = finals[1][:-1] + "5"
542
+ # when "一" is ordinal word, it should be yi1
543
+ elif word.startswith("第一"):
544
+ finals[1] = finals[1][:-1] + "1"
545
+ else:
546
+ for i, char in enumerate(word):
547
+ if char == "一" and i + 1 < len(word):
548
+ # "一" before tone4 should be yi2, e.g. 一段
549
+ if finals[i + 1][-1] == "4":
550
+ finals[i] = finals[i][:-1] + "2"
551
+ # "一" before non-tone4 should be yi4, e.g. 一天
552
+ else:
553
+ # "一" 后面如果是标点,还读一声
554
+ if word[i + 1] not in self.punc:
555
+ finals[i] = finals[i][:-1] + "4"
556
+ return finals
557
+
558
+ def _split_word(self, word: str) -> List[str]:
559
+ word_list = jieba.cut_for_search(word)
560
+ word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
561
+ first_subword = word_list[0]
562
+ first_begin_idx = word.find(first_subword)
563
+ if first_begin_idx == 0:
564
+ second_subword = word[len(first_subword) :]
565
+ new_word_list = [first_subword, second_subword]
566
+ else:
567
+ second_subword = word[: -len(first_subword)]
568
+ new_word_list = [second_subword, first_subword]
569
+ return new_word_list
570
+
571
+ def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
572
+ if len(word) == 2 and self._all_tone_three(finals):
573
+ finals[0] = finals[0][:-1] + "2"
574
+ elif len(word) == 3:
575
+ word_list = self._split_word(word)
576
+ if self._all_tone_three(finals):
577
+ # disyllabic + monosyllabic, e.g. 蒙古/包
578
+ if len(word_list[0]) == 2:
579
+ finals[0] = finals[0][:-1] + "2"
580
+ finals[1] = finals[1][:-1] + "2"
581
+ # monosyllabic + disyllabic, e.g. 纸/老虎
582
+ elif len(word_list[0]) == 1:
583
+ finals[1] = finals[1][:-1] + "2"
584
+ else:
585
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
586
+ if len(finals_list) == 2:
587
+ for i, sub in enumerate(finals_list):
588
+ # e.g. 所有/人
589
+ if self._all_tone_three(sub) and len(sub) == 2:
590
+ finals_list[i][0] = finals_list[i][0][:-1] + "2"
591
+ # e.g. 好/喜欢
592
+ elif (
593
+ i == 1
594
+ and not self._all_tone_three(sub)
595
+ and finals_list[i][0][-1] == "3"
596
+ and finals_list[0][-1][-1] == "3"
597
+ ):
598
+ finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
599
+ finals = sum(finals_list, [])
600
+ # split idiom into two words who's length is 2
601
+ elif len(word) == 4:
602
+ finals_list = [finals[:2], finals[2:]]
603
+ finals = []
604
+ for sub in finals_list:
605
+ if self._all_tone_three(sub):
606
+ sub[0] = sub[0][:-1] + "2"
607
+ finals += sub
608
+
609
+ return finals
610
+
611
+ def _all_tone_three(self, finals: List[str]) -> bool:
612
+ return all(x[-1] == "3" for x in finals)
613
+
614
+ # merge "不" and the word behind it
615
+ # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
616
+ def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
617
+ new_seg = []
618
+ last_word = ""
619
+ for word, pos in seg:
620
+ if last_word == "不":
621
+ word = last_word + word
622
+ if word != "不":
623
+ new_seg.append((word, pos))
624
+ last_word = word[:]
625
+ if last_word == "不":
626
+ new_seg.append((last_word, "d"))
627
+ last_word = ""
628
+ return new_seg
629
+
630
+ # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
631
+ # function 2: merge single "一" and the word behind it
632
+ # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
633
+ # e.g.
634
+ # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
635
+ # output seg: [['听一听', 'v']]
636
+ def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
637
+ new_seg = []
638
+ # function 1
639
+ for i, (word, pos) in enumerate(seg):
640
+ if (
641
+ i - 1 >= 0
642
+ and word == "一"
643
+ and i + 1 < len(seg)
644
+ and seg[i - 1][0] == seg[i + 1][0]
645
+ and seg[i - 1][1] == "v"
646
+ ):
647
+ new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
648
+ else:
649
+ if (
650
+ i - 2 >= 0
651
+ and seg[i - 1][0] == "一"
652
+ and seg[i - 2][0] == word
653
+ and pos == "v"
654
+ ):
655
+ continue
656
+ else:
657
+ new_seg.append([word, pos])
658
+ seg = new_seg
659
+ new_seg = []
660
+ # function 2
661
+ for i, (word, pos) in enumerate(seg):
662
+ if new_seg and new_seg[-1][0] == "一":
663
+ new_seg[-1][0] = new_seg[-1][0] + word
664
+ else:
665
+ new_seg.append([word, pos])
666
+ return new_seg
667
+
668
+ # the first and the second words are all_tone_three
669
+ def _merge_continuous_three_tones(
670
+ self, seg: List[Tuple[str, str]]
671
+ ) -> List[Tuple[str, str]]:
672
+ new_seg = []
673
+ sub_finals_list = [
674
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
675
+ for (word, pos) in seg
676
+ ]
677
+ assert len(sub_finals_list) == len(seg)
678
+ merge_last = [False] * len(seg)
679
+ for i, (word, pos) in enumerate(seg):
680
+ if (
681
+ i - 1 >= 0
682
+ and self._all_tone_three(sub_finals_list[i - 1])
683
+ and self._all_tone_three(sub_finals_list[i])
684
+ and not merge_last[i - 1]
685
+ ):
686
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
687
+ if (
688
+ not self._is_reduplication(seg[i - 1][0])
689
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
690
+ ):
691
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
692
+ merge_last[i] = True
693
+ else:
694
+ new_seg.append([word, pos])
695
+ else:
696
+ new_seg.append([word, pos])
697
+
698
+ return new_seg
699
+
700
+ def _is_reduplication(self, word: str) -> bool:
701
+ return len(word) == 2 and word[0] == word[1]
702
+
703
+ # the last char of first word and the first char of second word is tone_three
704
+ def _merge_continuous_three_tones_2(
705
+ self, seg: List[Tuple[str, str]]
706
+ ) -> List[Tuple[str, str]]:
707
+ new_seg = []
708
+ sub_finals_list = [
709
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
710
+ for (word, pos) in seg
711
+ ]
712
+ assert len(sub_finals_list) == len(seg)
713
+ merge_last = [False] * len(seg)
714
+ for i, (word, pos) in enumerate(seg):
715
+ if (
716
+ i - 1 >= 0
717
+ and sub_finals_list[i - 1][-1][-1] == "3"
718
+ and sub_finals_list[i][0][-1] == "3"
719
+ and not merge_last[i - 1]
720
+ ):
721
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
722
+ if (
723
+ not self._is_reduplication(seg[i - 1][0])
724
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
725
+ ):
726
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
727
+ merge_last[i] = True
728
+ else:
729
+ new_seg.append([word, pos])
730
+ else:
731
+ new_seg.append([word, pos])
732
+ return new_seg
733
+
734
+ def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
735
+ new_seg = []
736
+ for i, (word, pos) in enumerate(seg):
737
+ if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
738
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
739
+ else:
740
+ new_seg.append([word, pos])
741
+ return new_seg
742
+
743
+ def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
744
+ new_seg = []
745
+ for i, (word, pos) in enumerate(seg):
746
+ if new_seg and word == new_seg[-1][0]:
747
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
748
+ else:
749
+ new_seg.append([word, pos])
750
+ return new_seg
751
+
752
+ def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
753
+ seg = self._merge_bu(seg)
754
+ try:
755
+ seg = self._merge_yi(seg)
756
+ except:
757
+ print("_merge_yi failed")
758
+ seg = self._merge_reduplication(seg)
759
+ seg = self._merge_continuous_three_tones(seg)
760
+ seg = self._merge_continuous_three_tones_2(seg)
761
+ seg = self._merge_er(seg)
762
+ return seg
763
+
764
+ def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
765
+ finals = self._bu_sandhi(word, finals)
766
+ finals = self._yi_sandhi(word, finals)
767
+ finals = self._neural_sandhi(word, pos, finals)
768
+ finals = self._three_sandhi(word, finals)
769
+ return finals
oldVersion/V110/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1.1 版本兼容
3
+ https://github.com/fishaudio/Bert-VITS2/releases/tag/1.1
4
+ """
5
+ import torch
6
+ import commons
7
+ from .text.cleaner import clean_text
8
+ from .text import cleaned_text_to_sequence
9
+ from oldVersion.V111.text import get_bert
10
+
11
+
12
+ def get_text(text, language_str, hps, device):
13
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
14
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
15
+
16
+ if hps.data.add_blank:
17
+ phone = commons.intersperse(phone, 0)
18
+ tone = commons.intersperse(tone, 0)
19
+ language = commons.intersperse(language, 0)
20
+ for i in range(len(word2ph)):
21
+ word2ph[i] = word2ph[i] * 2
22
+ word2ph[0] += 1
23
+ bert = get_bert(norm_text, word2ph, language_str, device)
24
+ del word2ph
25
+ assert bert.shape[-1] == len(phone), phone
26
+
27
+ if language_str == "ZH":
28
+ bert = bert
29
+ ja_bert = torch.zeros(768, len(phone))
30
+ elif language_str == "JP":
31
+ ja_bert = bert
32
+ bert = torch.zeros(1024, len(phone))
33
+ else:
34
+ bert = torch.zeros(1024, len(phone))
35
+ ja_bert = torch.zeros(768, len(phone))
36
+
37
+ assert bert.shape[-1] == len(
38
+ phone
39
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
40
+
41
+ phone = torch.LongTensor(phone)
42
+ tone = torch.LongTensor(tone)
43
+ language = torch.LongTensor(language)
44
+ return bert, ja_bert, phone, tone, language
45
+
46
+
47
+ def infer(
48
+ text,
49
+ sdp_ratio,
50
+ noise_scale,
51
+ noise_scale_w,
52
+ length_scale,
53
+ sid,
54
+ language,
55
+ hps,
56
+ net_g,
57
+ device,
58
+ ):
59
+ bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps, device)
60
+ with torch.no_grad():
61
+ x_tst = phones.to(device).unsqueeze(0)
62
+ tones = tones.to(device).unsqueeze(0)
63
+ lang_ids = lang_ids.to(device).unsqueeze(0)
64
+ bert = bert.to(device).unsqueeze(0)
65
+ ja_bert = ja_bert.to(device).unsqueeze(0)
66
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
67
+ del phones
68
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
69
+ audio = (
70
+ net_g.infer(
71
+ x_tst,
72
+ x_tst_lengths,
73
+ speakers,
74
+ tones,
75
+ lang_ids,
76
+ bert,
77
+ ja_bert,
78
+ sdp_ratio=sdp_ratio,
79
+ noise_scale=noise_scale,
80
+ noise_scale_w=noise_scale_w,
81
+ length_scale=length_scale,
82
+ )[0][0, 0]
83
+ .data.cpu()
84
+ .float()
85
+ .numpy()
86
+ )
87
+ del x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, ja_bert
88
+ if torch.cuda.is_available():
89
+ torch.cuda.empty_cache()
90
+ return audio
oldVersion/V110/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.95 kB). View file
 
oldVersion/V110/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
oldVersion/V110/__pycache__/models.cpython-310.pyc ADDED
Binary file (20.7 kB). View file
 
oldVersion/V110/__pycache__/models.cpython-38.pyc ADDED
Binary file (20.9 kB). View file
 
oldVersion/V110/models.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ from commons import init_weights, get_padding
15
+ from .text import symbols, num_tones, num_languages
16
+
17
+
18
+ class DurationDiscriminator(nn.Module): # vits2
19
+ def __init__(
20
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
21
+ ):
22
+ super().__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.filter_channels = filter_channels
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.gin_channels = gin_channels
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.conv_1 = nn.Conv1d(
32
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
33
+ )
34
+ self.norm_1 = modules.LayerNorm(filter_channels)
35
+ self.conv_2 = nn.Conv1d(
36
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
37
+ )
38
+ self.norm_2 = modules.LayerNorm(filter_channels)
39
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
40
+
41
+ self.pre_out_conv_1 = nn.Conv1d(
42
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
43
+ )
44
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
45
+ self.pre_out_conv_2 = nn.Conv1d(
46
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
47
+ )
48
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
49
+
50
+ if gin_channels != 0:
51
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
52
+
53
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
54
+
55
+ def forward_probability(self, x, x_mask, dur, g=None):
56
+ dur = self.dur_proj(dur)
57
+ x = torch.cat([x, dur], dim=1)
58
+ x = self.pre_out_conv_1(x * x_mask)
59
+ x = torch.relu(x)
60
+ x = self.pre_out_norm_1(x)
61
+ x = self.drop(x)
62
+ x = self.pre_out_conv_2(x * x_mask)
63
+ x = torch.relu(x)
64
+ x = self.pre_out_norm_2(x)
65
+ x = self.drop(x)
66
+ x = x * x_mask
67
+ x = x.transpose(1, 2)
68
+ output_prob = self.output_layer(x)
69
+ return output_prob
70
+
71
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
72
+ x = torch.detach(x)
73
+ if g is not None:
74
+ g = torch.detach(g)
75
+ x = x + self.cond(g)
76
+ x = self.conv_1(x * x_mask)
77
+ x = torch.relu(x)
78
+ x = self.norm_1(x)
79
+ x = self.drop(x)
80
+ x = self.conv_2(x * x_mask)
81
+ x = torch.relu(x)
82
+ x = self.norm_2(x)
83
+ x = self.drop(x)
84
+
85
+ output_probs = []
86
+ for dur in [dur_r, dur_hat]:
87
+ output_prob = self.forward_probability(x, x_mask, dur, g)
88
+ output_probs.append(output_prob)
89
+
90
+ return output_probs
91
+
92
+
93
+ class TransformerCouplingBlock(nn.Module):
94
+ def __init__(
95
+ self,
96
+ channels,
97
+ hidden_channels,
98
+ filter_channels,
99
+ n_heads,
100
+ n_layers,
101
+ kernel_size,
102
+ p_dropout,
103
+ n_flows=4,
104
+ gin_channels=0,
105
+ share_parameter=False,
106
+ ):
107
+ super().__init__()
108
+ self.channels = channels
109
+ self.hidden_channels = hidden_channels
110
+ self.kernel_size = kernel_size
111
+ self.n_layers = n_layers
112
+ self.n_flows = n_flows
113
+ self.gin_channels = gin_channels
114
+
115
+ self.flows = nn.ModuleList()
116
+
117
+ self.wn = (
118
+ attentions.FFT(
119
+ hidden_channels,
120
+ filter_channels,
121
+ n_heads,
122
+ n_layers,
123
+ kernel_size,
124
+ p_dropout,
125
+ isflow=True,
126
+ gin_channels=self.gin_channels,
127
+ )
128
+ if share_parameter
129
+ else None
130
+ )
131
+
132
+ for i in range(n_flows):
133
+ self.flows.append(
134
+ modules.TransformerCouplingLayer(
135
+ channels,
136
+ hidden_channels,
137
+ kernel_size,
138
+ n_layers,
139
+ n_heads,
140
+ p_dropout,
141
+ filter_channels,
142
+ mean_only=True,
143
+ wn_sharing_parameter=self.wn,
144
+ gin_channels=self.gin_channels,
145
+ )
146
+ )
147
+ self.flows.append(modules.Flip())
148
+
149
+ def forward(self, x, x_mask, g=None, reverse=False):
150
+ if not reverse:
151
+ for flow in self.flows:
152
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
153
+ else:
154
+ for flow in reversed(self.flows):
155
+ x = flow(x, x_mask, g=g, reverse=reverse)
156
+ return x
157
+
158
+
159
+ class StochasticDurationPredictor(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ filter_channels,
164
+ kernel_size,
165
+ p_dropout,
166
+ n_flows=4,
167
+ gin_channels=0,
168
+ ):
169
+ super().__init__()
170
+ filter_channels = in_channels # it needs to be removed from future version.
171
+ self.in_channels = in_channels
172
+ self.filter_channels = filter_channels
173
+ self.kernel_size = kernel_size
174
+ self.p_dropout = p_dropout
175
+ self.n_flows = n_flows
176
+ self.gin_channels = gin_channels
177
+
178
+ self.log_flow = modules.Log()
179
+ self.flows = nn.ModuleList()
180
+ self.flows.append(modules.ElementwiseAffine(2))
181
+ for i in range(n_flows):
182
+ self.flows.append(
183
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
184
+ )
185
+ self.flows.append(modules.Flip())
186
+
187
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
188
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
189
+ self.post_convs = modules.DDSConv(
190
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
191
+ )
192
+ self.post_flows = nn.ModuleList()
193
+ self.post_flows.append(modules.ElementwiseAffine(2))
194
+ for i in range(4):
195
+ self.post_flows.append(
196
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
197
+ )
198
+ self.post_flows.append(modules.Flip())
199
+
200
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
201
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
202
+ self.convs = modules.DDSConv(
203
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
204
+ )
205
+ if gin_channels != 0:
206
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
207
+
208
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
209
+ x = torch.detach(x)
210
+ x = self.pre(x)
211
+ if g is not None:
212
+ g = torch.detach(g)
213
+ x = x + self.cond(g)
214
+ x = self.convs(x, x_mask)
215
+ x = self.proj(x) * x_mask
216
+
217
+ if not reverse:
218
+ flows = self.flows
219
+ assert w is not None
220
+
221
+ logdet_tot_q = 0
222
+ h_w = self.post_pre(w)
223
+ h_w = self.post_convs(h_w, x_mask)
224
+ h_w = self.post_proj(h_w) * x_mask
225
+ e_q = (
226
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
227
+ * x_mask
228
+ )
229
+ z_q = e_q
230
+ for flow in self.post_flows:
231
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
232
+ logdet_tot_q += logdet_q
233
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
234
+ u = torch.sigmoid(z_u) * x_mask
235
+ z0 = (w - u) * x_mask
236
+ logdet_tot_q += torch.sum(
237
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
238
+ )
239
+ logq = (
240
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
241
+ - logdet_tot_q
242
+ )
243
+
244
+ logdet_tot = 0
245
+ z0, logdet = self.log_flow(z0, x_mask)
246
+ logdet_tot += logdet
247
+ z = torch.cat([z0, z1], 1)
248
+ for flow in flows:
249
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
250
+ logdet_tot = logdet_tot + logdet
251
+ nll = (
252
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
253
+ - logdet_tot
254
+ )
255
+ return nll + logq # [b]
256
+ else:
257
+ flows = list(reversed(self.flows))
258
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
259
+ z = (
260
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
261
+ * noise_scale
262
+ )
263
+ for flow in flows:
264
+ z = flow(z, x_mask, g=x, reverse=reverse)
265
+ z0, z1 = torch.split(z, [1, 1], 1)
266
+ logw = z0
267
+ return logw
268
+
269
+
270
+ class DurationPredictor(nn.Module):
271
+ def __init__(
272
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
273
+ ):
274
+ super().__init__()
275
+
276
+ self.in_channels = in_channels
277
+ self.filter_channels = filter_channels
278
+ self.kernel_size = kernel_size
279
+ self.p_dropout = p_dropout
280
+ self.gin_channels = gin_channels
281
+
282
+ self.drop = nn.Dropout(p_dropout)
283
+ self.conv_1 = nn.Conv1d(
284
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
285
+ )
286
+ self.norm_1 = modules.LayerNorm(filter_channels)
287
+ self.conv_2 = nn.Conv1d(
288
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
289
+ )
290
+ self.norm_2 = modules.LayerNorm(filter_channels)
291
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
292
+
293
+ if gin_channels != 0:
294
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
295
+
296
+ def forward(self, x, x_mask, g=None):
297
+ x = torch.detach(x)
298
+ if g is not None:
299
+ g = torch.detach(g)
300
+ x = x + self.cond(g)
301
+ x = self.conv_1(x * x_mask)
302
+ x = torch.relu(x)
303
+ x = self.norm_1(x)
304
+ x = self.drop(x)
305
+ x = self.conv_2(x * x_mask)
306
+ x = torch.relu(x)
307
+ x = self.norm_2(x)
308
+ x = self.drop(x)
309
+ x = self.proj(x * x_mask)
310
+ return x * x_mask
311
+
312
+
313
+ class TextEncoder(nn.Module):
314
+ def __init__(
315
+ self,
316
+ n_vocab,
317
+ out_channels,
318
+ hidden_channels,
319
+ filter_channels,
320
+ n_heads,
321
+ n_layers,
322
+ kernel_size,
323
+ p_dropout,
324
+ gin_channels=0,
325
+ ):
326
+ super().__init__()
327
+ self.n_vocab = n_vocab
328
+ self.out_channels = out_channels
329
+ self.hidden_channels = hidden_channels
330
+ self.filter_channels = filter_channels
331
+ self.n_heads = n_heads
332
+ self.n_layers = n_layers
333
+ self.kernel_size = kernel_size
334
+ self.p_dropout = p_dropout
335
+ self.gin_channels = gin_channels
336
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
337
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
338
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
339
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
340
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
341
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
342
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
343
+ self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
344
+
345
+ self.encoder = attentions.Encoder(
346
+ hidden_channels,
347
+ filter_channels,
348
+ n_heads,
349
+ n_layers,
350
+ kernel_size,
351
+ p_dropout,
352
+ gin_channels=self.gin_channels,
353
+ )
354
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
355
+
356
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
357
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
358
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
359
+ x = (
360
+ self.emb(x)
361
+ + self.tone_emb(tone)
362
+ + self.language_emb(language)
363
+ + bert_emb
364
+ + ja_bert_emb
365
+ ) * math.sqrt(
366
+ self.hidden_channels
367
+ ) # [b, t, h]
368
+ x = torch.transpose(x, 1, -1) # [b, h, t]
369
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
370
+ x.dtype
371
+ )
372
+
373
+ x = self.encoder(x * x_mask, x_mask, g=g)
374
+ stats = self.proj(x) * x_mask
375
+
376
+ m, logs = torch.split(stats, self.out_channels, dim=1)
377
+ return x, m, logs, x_mask
378
+
379
+
380
+ class ResidualCouplingBlock(nn.Module):
381
+ def __init__(
382
+ self,
383
+ channels,
384
+ hidden_channels,
385
+ kernel_size,
386
+ dilation_rate,
387
+ n_layers,
388
+ n_flows=4,
389
+ gin_channels=0,
390
+ ):
391
+ super().__init__()
392
+ self.channels = channels
393
+ self.hidden_channels = hidden_channels
394
+ self.kernel_size = kernel_size
395
+ self.dilation_rate = dilation_rate
396
+ self.n_layers = n_layers
397
+ self.n_flows = n_flows
398
+ self.gin_channels = gin_channels
399
+
400
+ self.flows = nn.ModuleList()
401
+ for i in range(n_flows):
402
+ self.flows.append(
403
+ modules.ResidualCouplingLayer(
404
+ channels,
405
+ hidden_channels,
406
+ kernel_size,
407
+ dilation_rate,
408
+ n_layers,
409
+ gin_channels=gin_channels,
410
+ mean_only=True,
411
+ )
412
+ )
413
+ self.flows.append(modules.Flip())
414
+
415
+ def forward(self, x, x_mask, g=None, reverse=False):
416
+ if not reverse:
417
+ for flow in self.flows:
418
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
419
+ else:
420
+ for flow in reversed(self.flows):
421
+ x = flow(x, x_mask, g=g, reverse=reverse)
422
+ return x
423
+
424
+
425
+ class PosteriorEncoder(nn.Module):
426
+ def __init__(
427
+ self,
428
+ in_channels,
429
+ out_channels,
430
+ hidden_channels,
431
+ kernel_size,
432
+ dilation_rate,
433
+ n_layers,
434
+ gin_channels=0,
435
+ ):
436
+ super().__init__()
437
+ self.in_channels = in_channels
438
+ self.out_channels = out_channels
439
+ self.hidden_channels = hidden_channels
440
+ self.kernel_size = kernel_size
441
+ self.dilation_rate = dilation_rate
442
+ self.n_layers = n_layers
443
+ self.gin_channels = gin_channels
444
+
445
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
446
+ self.enc = modules.WN(
447
+ hidden_channels,
448
+ kernel_size,
449
+ dilation_rate,
450
+ n_layers,
451
+ gin_channels=gin_channels,
452
+ )
453
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
454
+
455
+ def forward(self, x, x_lengths, g=None):
456
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
457
+ x.dtype
458
+ )
459
+ x = self.pre(x) * x_mask
460
+ x = self.enc(x, x_mask, g=g)
461
+ stats = self.proj(x) * x_mask
462
+ m, logs = torch.split(stats, self.out_channels, dim=1)
463
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
464
+ return z, m, logs, x_mask
465
+
466
+
467
+ class Generator(torch.nn.Module):
468
+ def __init__(
469
+ self,
470
+ initial_channel,
471
+ resblock,
472
+ resblock_kernel_sizes,
473
+ resblock_dilation_sizes,
474
+ upsample_rates,
475
+ upsample_initial_channel,
476
+ upsample_kernel_sizes,
477
+ gin_channels=0,
478
+ ):
479
+ super(Generator, self).__init__()
480
+ self.num_kernels = len(resblock_kernel_sizes)
481
+ self.num_upsamples = len(upsample_rates)
482
+ self.conv_pre = Conv1d(
483
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
484
+ )
485
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
486
+
487
+ self.ups = nn.ModuleList()
488
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
489
+ self.ups.append(
490
+ weight_norm(
491
+ ConvTranspose1d(
492
+ upsample_initial_channel // (2**i),
493
+ upsample_initial_channel // (2 ** (i + 1)),
494
+ k,
495
+ u,
496
+ padding=(k - u) // 2,
497
+ )
498
+ )
499
+ )
500
+
501
+ self.resblocks = nn.ModuleList()
502
+ for i in range(len(self.ups)):
503
+ ch = upsample_initial_channel // (2 ** (i + 1))
504
+ for j, (k, d) in enumerate(
505
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
506
+ ):
507
+ self.resblocks.append(resblock(ch, k, d))
508
+
509
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
510
+ self.ups.apply(init_weights)
511
+
512
+ if gin_channels != 0:
513
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
514
+
515
+ def forward(self, x, g=None):
516
+ x = self.conv_pre(x)
517
+ if g is not None:
518
+ x = x + self.cond(g)
519
+
520
+ for i in range(self.num_upsamples):
521
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
522
+ x = self.ups[i](x)
523
+ xs = None
524
+ for j in range(self.num_kernels):
525
+ if xs is None:
526
+ xs = self.resblocks[i * self.num_kernels + j](x)
527
+ else:
528
+ xs += self.resblocks[i * self.num_kernels + j](x)
529
+ x = xs / self.num_kernels
530
+ x = F.leaky_relu(x)
531
+ x = self.conv_post(x)
532
+ x = torch.tanh(x)
533
+
534
+ return x
535
+
536
+ def remove_weight_norm(self):
537
+ print("Removing weight norm...")
538
+ for layer in self.ups:
539
+ remove_weight_norm(layer)
540
+ for layer in self.resblocks:
541
+ layer.remove_weight_norm()
542
+
543
+
544
+ class DiscriminatorP(torch.nn.Module):
545
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
546
+ super(DiscriminatorP, self).__init__()
547
+ self.period = period
548
+ self.use_spectral_norm = use_spectral_norm
549
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
550
+ self.convs = nn.ModuleList(
551
+ [
552
+ norm_f(
553
+ Conv2d(
554
+ 1,
555
+ 32,
556
+ (kernel_size, 1),
557
+ (stride, 1),
558
+ padding=(get_padding(kernel_size, 1), 0),
559
+ )
560
+ ),
561
+ norm_f(
562
+ Conv2d(
563
+ 32,
564
+ 128,
565
+ (kernel_size, 1),
566
+ (stride, 1),
567
+ padding=(get_padding(kernel_size, 1), 0),
568
+ )
569
+ ),
570
+ norm_f(
571
+ Conv2d(
572
+ 128,
573
+ 512,
574
+ (kernel_size, 1),
575
+ (stride, 1),
576
+ padding=(get_padding(kernel_size, 1), 0),
577
+ )
578
+ ),
579
+ norm_f(
580
+ Conv2d(
581
+ 512,
582
+ 1024,
583
+ (kernel_size, 1),
584
+ (stride, 1),
585
+ padding=(get_padding(kernel_size, 1), 0),
586
+ )
587
+ ),
588
+ norm_f(
589
+ Conv2d(
590
+ 1024,
591
+ 1024,
592
+ (kernel_size, 1),
593
+ 1,
594
+ padding=(get_padding(kernel_size, 1), 0),
595
+ )
596
+ ),
597
+ ]
598
+ )
599
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
600
+
601
+ def forward(self, x):
602
+ fmap = []
603
+
604
+ # 1d to 2d
605
+ b, c, t = x.shape
606
+ if t % self.period != 0: # pad first
607
+ n_pad = self.period - (t % self.period)
608
+ x = F.pad(x, (0, n_pad), "reflect")
609
+ t = t + n_pad
610
+ x = x.view(b, c, t // self.period, self.period)
611
+
612
+ for layer in self.convs:
613
+ x = layer(x)
614
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
615
+ fmap.append(x)
616
+ x = self.conv_post(x)
617
+ fmap.append(x)
618
+ x = torch.flatten(x, 1, -1)
619
+
620
+ return x, fmap
621
+
622
+
623
+ class DiscriminatorS(torch.nn.Module):
624
+ def __init__(self, use_spectral_norm=False):
625
+ super(DiscriminatorS, self).__init__()
626
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
627
+ self.convs = nn.ModuleList(
628
+ [
629
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
630
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
631
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
632
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
633
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
634
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
635
+ ]
636
+ )
637
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
638
+
639
+ def forward(self, x):
640
+ fmap = []
641
+
642
+ for layer in self.convs:
643
+ x = layer(x)
644
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
645
+ fmap.append(x)
646
+ x = self.conv_post(x)
647
+ fmap.append(x)
648
+ x = torch.flatten(x, 1, -1)
649
+
650
+ return x, fmap
651
+
652
+
653
+ class MultiPeriodDiscriminator(torch.nn.Module):
654
+ def __init__(self, use_spectral_norm=False):
655
+ super(MultiPeriodDiscriminator, self).__init__()
656
+ periods = [2, 3, 5, 7, 11]
657
+
658
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
659
+ discs = discs + [
660
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
661
+ ]
662
+ self.discriminators = nn.ModuleList(discs)
663
+
664
+ def forward(self, y, y_hat):
665
+ y_d_rs = []
666
+ y_d_gs = []
667
+ fmap_rs = []
668
+ fmap_gs = []
669
+ for i, d in enumerate(self.discriminators):
670
+ y_d_r, fmap_r = d(y)
671
+ y_d_g, fmap_g = d(y_hat)
672
+ y_d_rs.append(y_d_r)
673
+ y_d_gs.append(y_d_g)
674
+ fmap_rs.append(fmap_r)
675
+ fmap_gs.append(fmap_g)
676
+
677
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
678
+
679
+
680
+ class ReferenceEncoder(nn.Module):
681
+ """
682
+ inputs --- [N, Ty/r, n_mels*r] mels
683
+ outputs --- [N, ref_enc_gru_size]
684
+ """
685
+
686
+ def __init__(self, spec_channels, gin_channels=0):
687
+ super().__init__()
688
+ self.spec_channels = spec_channels
689
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
690
+ K = len(ref_enc_filters)
691
+ filters = [1] + ref_enc_filters
692
+ convs = [
693
+ weight_norm(
694
+ nn.Conv2d(
695
+ in_channels=filters[i],
696
+ out_channels=filters[i + 1],
697
+ kernel_size=(3, 3),
698
+ stride=(2, 2),
699
+ padding=(1, 1),
700
+ )
701
+ )
702
+ for i in range(K)
703
+ ]
704
+ self.convs = nn.ModuleList(convs)
705
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
706
+
707
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
708
+ self.gru = nn.GRU(
709
+ input_size=ref_enc_filters[-1] * out_channels,
710
+ hidden_size=256 // 2,
711
+ batch_first=True,
712
+ )
713
+ self.proj = nn.Linear(128, gin_channels)
714
+
715
+ def forward(self, inputs, mask=None):
716
+ N = inputs.size(0)
717
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
718
+ for conv in self.convs:
719
+ out = conv(out)
720
+ # out = wn(out)
721
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
722
+
723
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
724
+ T = out.size(1)
725
+ N = out.size(0)
726
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
727
+
728
+ self.gru.flatten_parameters()
729
+ memory, out = self.gru(out) # out --- [1, N, 128]
730
+
731
+ return self.proj(out.squeeze(0))
732
+
733
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
734
+ for i in range(n_convs):
735
+ L = (L - kernel_size + 2 * pad) // stride + 1
736
+ return L
737
+
738
+
739
+ class SynthesizerTrn(nn.Module):
740
+ """
741
+ Synthesizer for Training
742
+ """
743
+
744
+ def __init__(
745
+ self,
746
+ n_vocab,
747
+ spec_channels,
748
+ segment_size,
749
+ inter_channels,
750
+ hidden_channels,
751
+ filter_channels,
752
+ n_heads,
753
+ n_layers,
754
+ kernel_size,
755
+ p_dropout,
756
+ resblock,
757
+ resblock_kernel_sizes,
758
+ resblock_dilation_sizes,
759
+ upsample_rates,
760
+ upsample_initial_channel,
761
+ upsample_kernel_sizes,
762
+ n_speakers=256,
763
+ gin_channels=256,
764
+ use_sdp=True,
765
+ n_flow_layer=4,
766
+ n_layers_trans_flow=6,
767
+ flow_share_parameter=False,
768
+ use_transformer_flow=True,
769
+ **kwargs
770
+ ):
771
+ super().__init__()
772
+ self.n_vocab = n_vocab
773
+ self.spec_channels = spec_channels
774
+ self.inter_channels = inter_channels
775
+ self.hidden_channels = hidden_channels
776
+ self.filter_channels = filter_channels
777
+ self.n_heads = n_heads
778
+ self.n_layers = n_layers
779
+ self.kernel_size = kernel_size
780
+ self.p_dropout = p_dropout
781
+ self.resblock = resblock
782
+ self.resblock_kernel_sizes = resblock_kernel_sizes
783
+ self.resblock_dilation_sizes = resblock_dilation_sizes
784
+ self.upsample_rates = upsample_rates
785
+ self.upsample_initial_channel = upsample_initial_channel
786
+ self.upsample_kernel_sizes = upsample_kernel_sizes
787
+ self.segment_size = segment_size
788
+ self.n_speakers = n_speakers
789
+ self.gin_channels = gin_channels
790
+ self.n_layers_trans_flow = n_layers_trans_flow
791
+ self.use_spk_conditioned_encoder = kwargs.get(
792
+ "use_spk_conditioned_encoder", True
793
+ )
794
+ self.use_sdp = use_sdp
795
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
796
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
797
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
798
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
799
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
800
+ self.enc_gin_channels = gin_channels
801
+ self.enc_p = TextEncoder(
802
+ n_vocab,
803
+ inter_channels,
804
+ hidden_channels,
805
+ filter_channels,
806
+ n_heads,
807
+ n_layers,
808
+ kernel_size,
809
+ p_dropout,
810
+ gin_channels=self.enc_gin_channels,
811
+ )
812
+ self.dec = Generator(
813
+ inter_channels,
814
+ resblock,
815
+ resblock_kernel_sizes,
816
+ resblock_dilation_sizes,
817
+ upsample_rates,
818
+ upsample_initial_channel,
819
+ upsample_kernel_sizes,
820
+ gin_channels=gin_channels,
821
+ )
822
+ self.enc_q = PosteriorEncoder(
823
+ spec_channels,
824
+ inter_channels,
825
+ hidden_channels,
826
+ 5,
827
+ 1,
828
+ 16,
829
+ gin_channels=gin_channels,
830
+ )
831
+ if use_transformer_flow:
832
+ self.flow = TransformerCouplingBlock(
833
+ inter_channels,
834
+ hidden_channels,
835
+ filter_channels,
836
+ n_heads,
837
+ n_layers_trans_flow,
838
+ 5,
839
+ p_dropout,
840
+ n_flow_layer,
841
+ gin_channels=gin_channels,
842
+ share_parameter=flow_share_parameter,
843
+ )
844
+ else:
845
+ self.flow = ResidualCouplingBlock(
846
+ inter_channels,
847
+ hidden_channels,
848
+ 5,
849
+ 1,
850
+ n_flow_layer,
851
+ gin_channels=gin_channels,
852
+ )
853
+ self.sdp = StochasticDurationPredictor(
854
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
855
+ )
856
+ self.dp = DurationPredictor(
857
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
858
+ )
859
+
860
+ if n_speakers > 0:
861
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
862
+ else:
863
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
864
+
865
+ def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
866
+ if self.n_speakers > 0:
867
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
868
+ else:
869
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
870
+ x, m_p, logs_p, x_mask = self.enc_p(
871
+ x, x_lengths, tone, language, bert, ja_bert, g=g
872
+ )
873
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
874
+ z_p = self.flow(z, y_mask, g=g)
875
+
876
+ with torch.no_grad():
877
+ # negative cross-entropy
878
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
879
+ neg_cent1 = torch.sum(
880
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
881
+ ) # [b, 1, t_s]
882
+ neg_cent2 = torch.matmul(
883
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
884
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
885
+ neg_cent3 = torch.matmul(
886
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
887
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
888
+ neg_cent4 = torch.sum(
889
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
890
+ ) # [b, 1, t_s]
891
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
892
+ if self.use_noise_scaled_mas:
893
+ epsilon = (
894
+ torch.std(neg_cent)
895
+ * torch.randn_like(neg_cent)
896
+ * self.current_mas_noise_scale
897
+ )
898
+ neg_cent = neg_cent + epsilon
899
+
900
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
901
+ attn = (
902
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
903
+ .unsqueeze(1)
904
+ .detach()
905
+ )
906
+
907
+ w = attn.sum(2)
908
+
909
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
910
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
911
+
912
+ logw_ = torch.log(w + 1e-6) * x_mask
913
+ logw = self.dp(x, x_mask, g=g)
914
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
915
+ x_mask
916
+ ) # for averaging
917
+
918
+ l_length = l_length_dp + l_length_sdp
919
+
920
+ # expand prior
921
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
922
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
923
+
924
+ z_slice, ids_slice = commons.rand_slice_segments(
925
+ z, y_lengths, self.segment_size
926
+ )
927
+ o = self.dec(z_slice, g=g)
928
+ return (
929
+ o,
930
+ l_length,
931
+ attn,
932
+ ids_slice,
933
+ x_mask,
934
+ y_mask,
935
+ (z, z_p, m_p, logs_p, m_q, logs_q),
936
+ (x, logw, logw_),
937
+ )
938
+
939
+ def infer(
940
+ self,
941
+ x,
942
+ x_lengths,
943
+ sid,
944
+ tone,
945
+ language,
946
+ bert,
947
+ ja_bert,
948
+ noise_scale=0.667,
949
+ length_scale=1,
950
+ noise_scale_w=0.8,
951
+ max_len=None,
952
+ sdp_ratio=0,
953
+ y=None,
954
+ ):
955
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
956
+ # g = self.gst(y)
957
+ if self.n_speakers > 0:
958
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
959
+ else:
960
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
961
+ x, m_p, logs_p, x_mask = self.enc_p(
962
+ x, x_lengths, tone, language, bert, ja_bert, g=g
963
+ )
964
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
965
+ sdp_ratio
966
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
967
+ w = torch.exp(logw) * x_mask * length_scale
968
+ w_ceil = torch.ceil(w)
969
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
970
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
971
+ x_mask.dtype
972
+ )
973
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
974
+ attn = commons.generate_path(w_ceil, attn_mask)
975
+
976
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
977
+ 1, 2
978
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
979
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
980
+ 1, 2
981
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
982
+
983
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
984
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
985
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
986
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
oldVersion/V110/text/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+
7
+ def cleaned_text_to_sequence(cleaned_text, tones, language):
8
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
+ Args:
10
+ text: string to convert to a sequence
11
+ Returns:
12
+ List of integers corresponding to the symbols in the text
13
+ """
14
+ phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
15
+ tone_start = language_tone_start_map[language]
16
+ tones = [i + tone_start for i in tones]
17
+ lang_id = language_id_map[language]
18
+ lang_ids = [lang_id for i in phones]
19
+ return phones, tones, lang_ids
20
+
21
+
22
+ def get_bert(norm_text, word2ph, language, device):
23
+ from .chinese_bert import get_bert_feature as zh_bert
24
+ from .english_bert_mock import get_bert_feature as en_bert
25
+ from .japanese_bert import get_bert_feature as jp_bert
26
+
27
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
28
+ bert = lang_bert_func_map[language](norm_text, word2ph, device)
29
+ return bert
oldVersion/V110/text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
oldVersion/V110/text/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
oldVersion/V110/text/__pycache__/chinese.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
oldVersion/V110/text/__pycache__/chinese.cpython-38.pyc ADDED
Binary file (4.53 kB). View file
 
oldVersion/V110/text/__pycache__/cleaner.cpython-310.pyc ADDED
Binary file (973 Bytes). View file
 
oldVersion/V110/text/__pycache__/cleaner.cpython-38.pyc ADDED
Binary file (963 Bytes). View file
 
oldVersion/V110/text/__pycache__/japanese.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
oldVersion/V110/text/__pycache__/japanese.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
oldVersion/V110/text/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
oldVersion/V110/text/__pycache__/symbols.cpython-38.pyc ADDED
Binary file (1.85 kB). View file
 
oldVersion/V110/text/__pycache__/tone_sandhi.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
oldVersion/V110/text/__pycache__/tone_sandhi.cpython-38.pyc ADDED
Binary file (15.6 kB). View file
 
oldVersion/V110/text/chinese.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ from .symbols import punctuation
8
+ from .tone_sandhi import ToneSandhi
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ pinyin_to_symbol_map = {
12
+ line.split("\t")[0]: line.strip().split("\t")[1]
13
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
+ }
15
+
16
+ import jieba.posseg as psg
17
+
18
+
19
+ rep_map = {
20
+ ":": ",",
21
+ ";": ",",
22
+ ",": ",",
23
+ "。": ".",
24
+ "!": "!",
25
+ "?": "?",
26
+ "\n": ".",
27
+ "·": ",",
28
+ "、": ",",
29
+ "...": "…",
30
+ "$": ".",
31
+ "“": "'",
32
+ "”": "'",
33
+ "‘": "'",
34
+ "’": "'",
35
+ "(": "'",
36
+ ")": "'",
37
+ "(": "'",
38
+ ")": "'",
39
+ "《": "'",
40
+ "》": "'",
41
+ "【": "'",
42
+ "】": "'",
43
+ "[": "'",
44
+ "]": "'",
45
+ "—": "-",
46
+ "~": "-",
47
+ "~": "-",
48
+ "「": "'",
49
+ "」": "'",
50
+ }
51
+
52
+ tone_modifier = ToneSandhi()
53
+
54
+
55
+ def replace_punctuation(text):
56
+ text = text.replace("嗯", "恩").replace("呣", "母")
57
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
58
+
59
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
60
+
61
+ replaced_text = re.sub(
62
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
63
+ )
64
+
65
+ return replaced_text
66
+
67
+
68
+ def g2p(text):
69
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
70
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
71
+ phones, tones, word2ph = _g2p(sentences)
72
+ assert sum(word2ph) == len(phones)
73
+ assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
74
+ phones = ["_"] + phones + ["_"]
75
+ tones = [0] + tones + [0]
76
+ word2ph = [1] + word2ph + [1]
77
+ return phones, tones, word2ph
78
+
79
+
80
+ def _get_initials_finals(word):
81
+ initials = []
82
+ finals = []
83
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
84
+ orig_finals = lazy_pinyin(
85
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
86
+ )
87
+ for c, v in zip(orig_initials, orig_finals):
88
+ initials.append(c)
89
+ finals.append(v)
90
+ return initials, finals
91
+
92
+
93
+ def _g2p(segments):
94
+ phones_list = []
95
+ tones_list = []
96
+ word2ph = []
97
+ for seg in segments:
98
+ # Replace all English words in the sentence
99
+ seg = re.sub("[a-zA-Z]+", "", seg)
100
+ seg_cut = psg.lcut(seg)
101
+ initials = []
102
+ finals = []
103
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
104
+ for word, pos in seg_cut:
105
+ if pos == "eng":
106
+ continue
107
+ sub_initials, sub_finals = _get_initials_finals(word)
108
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
109
+ initials.append(sub_initials)
110
+ finals.append(sub_finals)
111
+
112
+ # assert len(sub_initials) == len(sub_finals) == len(word)
113
+ initials = sum(initials, [])
114
+ finals = sum(finals, [])
115
+ #
116
+ for c, v in zip(initials, finals):
117
+ raw_pinyin = c + v
118
+ # NOTE: post process for pypinyin outputs
119
+ # we discriminate i, ii and iii
120
+ if c == v:
121
+ assert c in punctuation
122
+ phone = [c]
123
+ tone = "0"
124
+ word2ph.append(1)
125
+ else:
126
+ v_without_tone = v[:-1]
127
+ tone = v[-1]
128
+
129
+ pinyin = c + v_without_tone
130
+ assert tone in "12345"
131
+
132
+ if c:
133
+ # 多音节
134
+ v_rep_map = {
135
+ "uei": "ui",
136
+ "iou": "iu",
137
+ "uen": "un",
138
+ }
139
+ if v_without_tone in v_rep_map.keys():
140
+ pinyin = c + v_rep_map[v_without_tone]
141
+ else:
142
+ # 单音节
143
+ pinyin_rep_map = {
144
+ "ing": "ying",
145
+ "i": "yi",
146
+ "in": "yin",
147
+ "u": "wu",
148
+ }
149
+ if pinyin in pinyin_rep_map.keys():
150
+ pinyin = pinyin_rep_map[pinyin]
151
+ else:
152
+ single_rep_map = {
153
+ "v": "yu",
154
+ "e": "e",
155
+ "i": "y",
156
+ "u": "w",
157
+ }
158
+ if pinyin[0] in single_rep_map.keys():
159
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
160
+
161
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
162
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
163
+ word2ph.append(len(phone))
164
+
165
+ phones_list += phone
166
+ tones_list += [int(tone)] * len(phone)
167
+ return phones_list, tones_list, word2ph
168
+
169
+
170
+ def text_normalize(text):
171
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
172
+ for number in numbers:
173
+ text = text.replace(number, cn2an.an2cn(number), 1)
174
+ text = replace_punctuation(text)
175
+ return text
176
+
177
+
178
+ def get_bert_feature(text, word2ph):
179
+ from text import chinese_bert
180
+
181
+ return chinese_bert.get_bert_feature(text, word2ph)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ from text.chinese_bert import get_bert_feature
186
+
187
+ text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
188
+ text = text_normalize(text)
189
+ print(text)
190
+ phones, tones, word2ph = g2p(text)
191
+ bert = get_bert_feature(text, word2ph)
192
+
193
+ print(phones, tones, word2ph, bert.shape)
194
+
195
+
196
+ # # 示例用法
197
+ # text = "这是一个示例文本:,你好!这是一个测试...."
198
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
oldVersion/V110/text/chinese_bert.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("./bert/chinese-roberta-wwm-ext-large")
6
+
7
+
8
+ def get_bert_feature(text, word2ph, device=None):
9
+ if (
10
+ sys.platform == "darwin"
11
+ and torch.backends.mps.is_available()
12
+ and device == "cpu"
13
+ ):
14
+ device = "mps"
15
+ if not device:
16
+ device = "cuda"
17
+ model = AutoModelForMaskedLM.from_pretrained(
18
+ "./bert/chinese-roberta-wwm-ext-large"
19
+ ).to(device)
20
+ with torch.no_grad():
21
+ inputs = tokenizer(text, return_tensors="pt")
22
+ for i in inputs:
23
+ inputs[i] = inputs[i].to(device)
24
+ res = model(**inputs, output_hidden_states=True)
25
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
26
+
27
+ assert len(word2ph) == len(text) + 2
28
+ word2phone = word2ph
29
+ phone_level_feature = []
30
+ for i in range(len(word2phone)):
31
+ repeat_feature = res[i].repeat(word2phone[i], 1)
32
+ phone_level_feature.append(repeat_feature)
33
+
34
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
35
+
36
+ return phone_level_feature.T
37
+
38
+
39
+ if __name__ == "__main__":
40
+ import torch
41
+
42
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
43
+ word2phone = [
44
+ 1,
45
+ 2,
46
+ 1,
47
+ 2,
48
+ 2,
49
+ 1,
50
+ 2,
51
+ 2,
52
+ 1,
53
+ 2,
54
+ 2,
55
+ 1,
56
+ 2,
57
+ 2,
58
+ 2,
59
+ 2,
60
+ 2,
61
+ 1,
62
+ 1,
63
+ 2,
64
+ 2,
65
+ 1,
66
+ 2,
67
+ 2,
68
+ 2,
69
+ 2,
70
+ 1,
71
+ 2,
72
+ 2,
73
+ 2,
74
+ 2,
75
+ 2,
76
+ 1,
77
+ 2,
78
+ 2,
79
+ 2,
80
+ 2,
81
+ 1,
82
+ ]
83
+
84
+ # 计算总帧数
85
+ total_frames = sum(word2phone)
86
+ print(word_level_feature.shape)
87
+ print(word2phone)
88
+ phone_level_feature = []
89
+ for i in range(len(word2phone)):
90
+ print(word_level_feature[i].shape)
91
+
92
+ # 对每个词重复word2phone[i]次
93
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
94
+ phone_level_feature.append(repeat_feature)
95
+
96
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
97
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
oldVersion/V110/text/cleaner.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import chinese, japanese, cleaned_text_to_sequence
2
+
3
+
4
+ language_module_map = {"ZH": chinese, "JP": japanese}
5
+
6
+
7
+ def clean_text(text, language):
8
+ language_module = language_module_map[language]
9
+ norm_text = language_module.text_normalize(text)
10
+ phones, tones, word2ph = language_module.g2p(norm_text)
11
+ return norm_text, phones, tones, word2ph
12
+
13
+
14
+ def clean_text_bert(text, language):
15
+ language_module = language_module_map[language]
16
+ norm_text = language_module.text_normalize(text)
17
+ phones, tones, word2ph = language_module.g2p(norm_text)
18
+ bert = language_module.get_bert_feature(norm_text, word2ph)
19
+ return phones, tones, bert
20
+
21
+
22
+ def text_to_sequence(text, language):
23
+ norm_text, phones, tones, word2ph = clean_text(text, language)
24
+ return cleaned_text_to_sequence(phones, tones, language)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ pass
oldVersion/V110/text/english.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+
6
+ from . import symbols
7
+
8
+ current_file_path = os.path.dirname(__file__)
9
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
10
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
11
+ _g2p = G2p()
12
+
13
+ arpa = {
14
+ "AH0",
15
+ "S",
16
+ "AH1",
17
+ "EY2",
18
+ "AE2",
19
+ "EH0",
20
+ "OW2",
21
+ "UH0",
22
+ "NG",
23
+ "B",
24
+ "G",
25
+ "AY0",
26
+ "M",
27
+ "AA0",
28
+ "F",
29
+ "AO0",
30
+ "ER2",
31
+ "UH1",
32
+ "IY1",
33
+ "AH2",
34
+ "DH",
35
+ "IY0",
36
+ "EY1",
37
+ "IH0",
38
+ "K",
39
+ "N",
40
+ "W",
41
+ "IY2",
42
+ "T",
43
+ "AA1",
44
+ "ER1",
45
+ "EH2",
46
+ "OY0",
47
+ "UH2",
48
+ "UW1",
49
+ "Z",
50
+ "AW2",
51
+ "AW1",
52
+ "V",
53
+ "UW2",
54
+ "AA2",
55
+ "ER",
56
+ "AW0",
57
+ "UW0",
58
+ "R",
59
+ "OW1",
60
+ "EH1",
61
+ "ZH",
62
+ "AE0",
63
+ "IH2",
64
+ "IH",
65
+ "Y",
66
+ "JH",
67
+ "P",
68
+ "AY1",
69
+ "EY0",
70
+ "OY2",
71
+ "TH",
72
+ "HH",
73
+ "D",
74
+ "ER0",
75
+ "CH",
76
+ "AO1",
77
+ "AE1",
78
+ "AO2",
79
+ "OY1",
80
+ "AY2",
81
+ "IH1",
82
+ "OW0",
83
+ "L",
84
+ "SH",
85
+ }
86
+
87
+
88
+ def post_replace_ph(ph):
89
+ rep_map = {
90
+ ":": ",",
91
+ ";": ",",
92
+ ",": ",",
93
+ "。": ".",
94
+ "!": "!",
95
+ "?": "?",
96
+ "\n": ".",
97
+ "·": ",",
98
+ "、": ",",
99
+ "...": "…",
100
+ "v": "V",
101
+ }
102
+ if ph in rep_map.keys():
103
+ ph = rep_map[ph]
104
+ if ph in symbols:
105
+ return ph
106
+ if ph not in symbols:
107
+ ph = "UNK"
108
+ return ph
109
+
110
+
111
+ def read_dict():
112
+ g2p_dict = {}
113
+ start_line = 49
114
+ with open(CMU_DICT_PATH) as f:
115
+ line = f.readline()
116
+ line_index = 1
117
+ while line:
118
+ if line_index >= start_line:
119
+ line = line.strip()
120
+ word_split = line.split(" ")
121
+ word = word_split[0]
122
+
123
+ syllable_split = word_split[1].split(" - ")
124
+ g2p_dict[word] = []
125
+ for syllable in syllable_split:
126
+ phone_split = syllable.split(" ")
127
+ g2p_dict[word].append(phone_split)
128
+
129
+ line_index = line_index + 1
130
+ line = f.readline()
131
+
132
+ return g2p_dict
133
+
134
+
135
+ def cache_dict(g2p_dict, file_path):
136
+ with open(file_path, "wb") as pickle_file:
137
+ pickle.dump(g2p_dict, pickle_file)
138
+
139
+
140
+ def get_dict():
141
+ if os.path.exists(CACHE_PATH):
142
+ with open(CACHE_PATH, "rb") as pickle_file:
143
+ g2p_dict = pickle.load(pickle_file)
144
+ else:
145
+ g2p_dict = read_dict()
146
+ cache_dict(g2p_dict, CACHE_PATH)
147
+
148
+ return g2p_dict
149
+
150
+
151
+ eng_dict = get_dict()
152
+
153
+
154
+ def refine_ph(phn):
155
+ tone = 0
156
+ if re.search(r"\d$", phn):
157
+ tone = int(phn[-1]) + 1
158
+ phn = phn[:-1]
159
+ return phn.lower(), tone
160
+
161
+
162
+ def refine_syllables(syllables):
163
+ tones = []
164
+ phonemes = []
165
+ for phn_list in syllables:
166
+ for i in range(len(phn_list)):
167
+ phn = phn_list[i]
168
+ phn, tone = refine_ph(phn)
169
+ phonemes.append(phn)
170
+ tones.append(tone)
171
+ return phonemes, tones
172
+
173
+
174
+ def text_normalize(text):
175
+ # todo: eng text normalize
176
+ return text
177
+
178
+
179
+ def g2p(text):
180
+ phones = []
181
+ tones = []
182
+ words = re.split(r"([,;.\-\?\!\s+])", text)
183
+ for w in words:
184
+ if w.upper() in eng_dict:
185
+ phns, tns = refine_syllables(eng_dict[w.upper()])
186
+ phones += phns
187
+ tones += tns
188
+ else:
189
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
190
+ for ph in phone_list:
191
+ if ph in arpa:
192
+ ph, tn = refine_ph(ph)
193
+ phones.append(ph)
194
+ tones.append(tn)
195
+ else:
196
+ phones.append(ph)
197
+ tones.append(0)
198
+ # todo: implement word2ph
199
+ word2ph = [1 for i in phones]
200
+
201
+ phones = [post_replace_ph(i) for i in phones]
202
+ return phones, tones, word2ph
203
+
204
+
205
+ if __name__ == "__main__":
206
+ # print(get_dict())
207
+ # print(eng_word_to_phoneme("hello"))
208
+ print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))
209
+ # all_phones = set()
210
+ # for k, syllables in eng_dict.items():
211
+ # for group in syllables:
212
+ # for ph in group:
213
+ # all_phones.add(ph)
214
+ # print(all_phones)
oldVersion/V110/text/english_bert_mock.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_bert_feature(norm_text, word2ph):
5
+ return torch.zeros(1024, sum(word2ph))