XzJosh commited on
Commit
bf0dde6
1 Parent(s): 047257a

Upload 72 files

Browse files
.gitattributes CHANGED
@@ -44,3 +44,4 @@ audio/Taffy/t2~1_234.wav filter=lfs diff=lfs merge=lfs -text
44
  audio/Taffy/t2~1_260.wav filter=lfs diff=lfs merge=lfs -text
45
  audio/Taffy/Taffy_242.wav filter=lfs diff=lfs merge=lfs -text
46
  audio/Taffy/Taffy_250.wav filter=lfs diff=lfs merge=lfs -text
 
 
44
  audio/Taffy/t2~1_260.wav filter=lfs diff=lfs merge=lfs -text
45
  audio/Taffy/Taffy_242.wav filter=lfs diff=lfs merge=lfs -text
46
  audio/Taffy/Taffy_250.wav filter=lfs diff=lfs merge=lfs -text
47
+ text/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
AR/data/bucket_sampler.py CHANGED
@@ -41,12 +41,13 @@ class DistributedBucketSampler(Sampler[T_co]):
41
  if num_replicas is None:
42
  if not dist.is_available():
43
  raise RuntimeError("Requires distributed package to be available")
44
- num_replicas = dist.get_world_size()
45
  if rank is None:
46
  if not dist.is_available():
47
  raise RuntimeError("Requires distributed package to be available")
48
- rank = dist.get_rank()
49
- torch.cuda.set_device(rank)
 
50
  if rank >= num_replicas or rank < 0:
51
  raise ValueError(
52
  "Invalid rank {}, rank should be in the interval"
 
41
  if num_replicas is None:
42
  if not dist.is_available():
43
  raise RuntimeError("Requires distributed package to be available")
44
+ num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
45
  if rank is None:
46
  if not dist.is_available():
47
  raise RuntimeError("Requires distributed package to be available")
48
+ rank = dist.get_rank() if torch.cuda.is_available() else 0
49
+ if torch.cuda.is_available():
50
+ torch.cuda.set_device(rank)
51
  if rank >= num_replicas or rank < 0:
52
  raise ValueError(
53
  "Invalid rank {}, rank should be in the interval"
AR/models/t2s_lightning_module_onnx.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2
+ import os, sys
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+ from typing import Dict
7
+
8
+ import torch
9
+ from pytorch_lightning import LightningModule
10
+ from AR.models.t2s_model_onnx import Text2SemanticDecoder
11
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
12
+ from AR.modules.optim import ScaledAdam
13
+
14
+
15
+ class Text2SemanticLightningModule(LightningModule):
16
+ def __init__(self, config, output_dir, is_train=True):
17
+ super().__init__()
18
+ self.config = config
19
+ self.top_k = 3
20
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21
+ pretrained_s1 = config.get("pretrained_s1")
22
+ if pretrained_s1 and is_train:
23
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
24
+ print(
25
+ self.load_state_dict(
26
+ torch.load(pretrained_s1, map_location="cpu")["weight"]
27
+ )
28
+ )
29
+ if is_train:
30
+ self.automatic_optimization = False
31
+ self.save_hyperparameters()
32
+ self.eval_dir = output_dir / "eval"
33
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ def training_step(self, batch: Dict, batch_idx: int):
36
+ opt = self.optimizers()
37
+ scheduler = self.lr_schedulers()
38
+ loss, acc = self.model.forward(
39
+ batch["phoneme_ids"],
40
+ batch["phoneme_ids_len"],
41
+ batch["semantic_ids"],
42
+ batch["semantic_ids_len"],
43
+ batch["bert_feature"],
44
+ )
45
+ self.manual_backward(loss)
46
+ if batch_idx > 0 and batch_idx % 4 == 0:
47
+ opt.step()
48
+ opt.zero_grad()
49
+ scheduler.step()
50
+
51
+ self.log(
52
+ "total_loss",
53
+ loss,
54
+ on_step=True,
55
+ on_epoch=True,
56
+ prog_bar=True,
57
+ sync_dist=True,
58
+ )
59
+ self.log(
60
+ "lr",
61
+ scheduler.get_last_lr()[0],
62
+ on_epoch=True,
63
+ prog_bar=True,
64
+ sync_dist=True,
65
+ )
66
+ self.log(
67
+ f"top_{self.top_k}_acc",
68
+ acc,
69
+ on_step=True,
70
+ on_epoch=True,
71
+ prog_bar=True,
72
+ sync_dist=True,
73
+ )
74
+
75
+ def validation_step(self, batch: Dict, batch_idx: int):
76
+ return
77
+
78
+ def configure_optimizers(self):
79
+ model_parameters = self.model.parameters()
80
+ parameters_names = []
81
+ parameters_names.append(
82
+ [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
83
+ )
84
+ lm_opt = ScaledAdam(
85
+ model_parameters,
86
+ lr=0.01,
87
+ betas=(0.9, 0.95),
88
+ clipping_scale=2.0,
89
+ parameters_names=parameters_names,
90
+ show_dominant_parameters=False,
91
+ clipping_update_period=1000,
92
+ )
93
+
94
+ return {
95
+ "optimizer": lm_opt,
96
+ "lr_scheduler": {
97
+ "scheduler": WarmupCosineLRSchedule(
98
+ lm_opt,
99
+ init_lr=self.config["optimizer"]["lr_init"],
100
+ peak_lr=self.config["optimizer"]["lr"],
101
+ end_lr=self.config["optimizer"]["lr_end"],
102
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
103
+ total_steps=self.config["optimizer"]["decay_steps"],
104
+ )
105
+ },
106
+ }
AR/models/t2s_model.py CHANGED
@@ -302,6 +302,8 @@ class Text2SemanticDecoder(nn.Module):
302
  xy_dec[:, -1]
303
  ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
304
  # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
 
 
305
  samples = sample(
306
  logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
307
  )[0].unsqueeze(0)
 
302
  xy_dec[:, -1]
303
  ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
304
  # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
305
+ if(idx==0):###第一次跑不能EOS否则没有了
306
+ logits = logits[:, :-1] ###刨除1024终止符号的概率
307
  samples = sample(
308
  logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
309
  )[0].unsqueeze(0)
AR/models/t2s_model_onnx.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.modules.embedding_onnx import SinePositionalEmbedding
6
+ from AR.modules.embedding_onnx import TokenEmbedding
7
+ from AR.modules.transformer_onnx import LayerNorm
8
+ from AR.modules.transformer_onnx import TransformerEncoder
9
+ from AR.modules.transformer_onnx import TransformerEncoderLayer
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from torchmetrics.classification import MulticlassAccuracy
13
+
14
+ default_config = {
15
+ "embedding_dim": 512,
16
+ "hidden_dim": 512,
17
+ "num_head": 8,
18
+ "num_layers": 12,
19
+ "num_codebook": 8,
20
+ "p_dropout": 0.0,
21
+ "vocab_size": 1024 + 1,
22
+ "phoneme_vocab_size": 512,
23
+ "EOS": 1024,
24
+ }
25
+
26
+ inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
27
+
28
+ def logits_to_probs(
29
+ logits,
30
+ previous_tokens = None,
31
+ temperature: float = 1.0,
32
+ top_k = None,
33
+ top_p = None,
34
+ repetition_penalty: float = 1.0,
35
+ ):
36
+ previous_tokens = previous_tokens.squeeze()
37
+ if previous_tokens is not None and repetition_penalty != 1.0:
38
+ previous_tokens = previous_tokens.long()
39
+ score = torch.gather(logits, dim=0, index=previous_tokens)
40
+ score = torch.where(
41
+ score < 0, score * repetition_penalty, score / repetition_penalty
42
+ )
43
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
44
+
45
+ if top_p is not None and top_p < 1.0:
46
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
47
+ cum_probs = torch.cumsum(
48
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
49
+ )
50
+ sorted_indices_to_remove = cum_probs > top_p
51
+ sorted_indices_to_remove[0] = False # keep at least one option
52
+ indices_to_remove = sorted_indices_to_remove.scatter(
53
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
54
+ )
55
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
56
+
57
+ logits = logits / max(temperature, 1e-5)
58
+
59
+ if top_k is not None:
60
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
61
+ pivot = v.select(-1, -1).unsqueeze(-1)
62
+ logits = torch.where(logits < pivot, inf_tensor_value, logits)
63
+
64
+ probs = torch.nn.functional.softmax(logits, dim=-1)
65
+ return probs
66
+
67
+
68
+ def multinomial_sample_one_no_sync(
69
+ probs_sort
70
+ ): # Does multinomial sampling without a cuda synchronization
71
+ q = torch.randn_like(probs_sort)
72
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
73
+
74
+
75
+ def sample(
76
+ logits,
77
+ previous_tokens,
78
+ **sampling_kwargs,
79
+ ):
80
+ probs = logits_to_probs(
81
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
82
+ )
83
+ idx_next = multinomial_sample_one_no_sync(probs)
84
+ return idx_next, probs
85
+
86
+
87
+ class OnnxEncoder(nn.Module):
88
+ def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
89
+ super().__init__()
90
+ self.ar_text_embedding = ar_text_embedding
91
+ self.bert_proj = bert_proj
92
+ self.ar_text_position = ar_text_position
93
+
94
+ def forward(self, x, bert_feature):
95
+ x = self.ar_text_embedding(x)
96
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
97
+ return self.ar_text_position(x)
98
+
99
+
100
+ class T2SFirstStageDecoder(nn.Module):
101
+ def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
102
+ top_k, early_stop_num, num_layers):
103
+ super().__init__()
104
+ self.ar_audio_embedding = ar_audio_embedding
105
+ self.ar_audio_position = ar_audio_position
106
+ self.h = h
107
+ self.ar_predict_layer = ar_predict_layer
108
+ self.loss_fct = loss_fct
109
+ self.ar_accuracy_metric = ar_accuracy_metric
110
+ self.top_k = top_k
111
+ self.early_stop_num = early_stop_num
112
+ self.num_layers = num_layers
113
+
114
+ def forward(self, x, prompt):
115
+ y = prompt
116
+ x_example = x[:,:,0] * 0.0
117
+ #N, 1, 512
118
+ cache = {
119
+ "all_stage": self.num_layers,
120
+ "k": None,
121
+ "v": None,
122
+ "y_emb": None,
123
+ "first_infer": 1,
124
+ "stage": 0,
125
+ }
126
+
127
+ y_emb = self.ar_audio_embedding(y)
128
+
129
+ cache["y_emb"] = y_emb
130
+ y_pos = self.ar_audio_position(y_emb)
131
+
132
+ xy_pos = torch.concat([x, y_pos], dim=1)
133
+
134
+ y_example = y_pos[:,:,0] * 0.0
135
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
136
+ y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
137
+ y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
138
+ torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
139
+ )
140
+ y_attn_mask = y_attn_mask > 0
141
+
142
+ x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
143
+ y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
144
+ x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
145
+ y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
146
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
147
+ cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
148
+ .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
149
+ cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
150
+ .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
151
+
152
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
153
+ logits = self.ar_predict_layer(xy_dec[:, -1])
154
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
155
+
156
+ y = torch.concat([y, samples], dim=1)
157
+
158
+ return y, cache["k"], cache["v"], cache["y_emb"], x_example
159
+
160
+
161
+ class T2SStageDecoder(nn.Module):
162
+ def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
163
+ top_k, early_stop_num, num_layers):
164
+ super().__init__()
165
+ self.ar_audio_embedding = ar_audio_embedding
166
+ self.ar_audio_position = ar_audio_position
167
+ self.h = h
168
+ self.ar_predict_layer = ar_predict_layer
169
+ self.loss_fct = loss_fct
170
+ self.ar_accuracy_metric = ar_accuracy_metric
171
+ self.top_k = top_k
172
+ self.early_stop_num = early_stop_num
173
+ self.num_layers = num_layers
174
+
175
+ def forward(self, y, k, v, y_emb, x_example):
176
+ cache = {
177
+ "all_stage": self.num_layers,
178
+ "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
179
+ "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
180
+ "y_emb": y_emb,
181
+ "first_infer": 0,
182
+ "stage": 0,
183
+ }
184
+
185
+ y_emb = torch.cat(
186
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
187
+ )
188
+ cache["y_emb"] = y_emb
189
+ y_pos = self.ar_audio_position(y_emb)
190
+
191
+ xy_pos = y_pos[:, -1:]
192
+
193
+ y_example = y_pos[:,:,0] * 0.0
194
+
195
+ xy_attn_mask = torch.cat([x_example, y_example], dim=1)
196
+ xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
197
+
198
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
199
+ logits = self.ar_predict_layer(xy_dec[:, -1])
200
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
201
+
202
+ y = torch.concat([y, samples], dim=1)
203
+
204
+ return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
205
+
206
+
207
+ class Text2SemanticDecoder(nn.Module):
208
+ def __init__(self, config, norm_first=False, top_k=3):
209
+ super(Text2SemanticDecoder, self).__init__()
210
+ self.model_dim = config["model"]["hidden_dim"]
211
+ self.embedding_dim = config["model"]["embedding_dim"]
212
+ self.num_head = config["model"]["head"]
213
+ self.num_layers = config["model"]["n_layer"]
214
+ self.norm_first = norm_first
215
+ self.vocab_size = config["model"]["vocab_size"]
216
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
217
+ self.p_dropout = float(config["model"]["dropout"])
218
+ self.EOS = config["model"]["EOS"]
219
+ self.norm_first = norm_first
220
+ assert self.EOS == self.vocab_size - 1
221
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
222
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
223
+ self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
224
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
225
+ self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
226
+ self.h = TransformerEncoder(
227
+ TransformerEncoderLayer(
228
+ d_model=self.model_dim,
229
+ nhead=self.num_head,
230
+ dim_feedforward=self.model_dim * 4,
231
+ dropout=0.1,
232
+ batch_first=True,
233
+ norm_first=norm_first,
234
+ ),
235
+ num_layers=self.num_layers,
236
+ norm=LayerNorm(self.model_dim) if norm_first else None,
237
+ )
238
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
239
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
240
+ self.ar_accuracy_metric = MulticlassAccuracy(
241
+ self.vocab_size,
242
+ top_k=top_k,
243
+ average="micro",
244
+ multidim_average="global",
245
+ ignore_index=self.EOS,
246
+ )
247
+ self.top_k = torch.LongTensor([1])
248
+ self.early_stop_num = torch.LongTensor([-1])
249
+
250
+ def init_onnx(self):
251
+ self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
252
+ self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
253
+ self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
254
+ self.num_layers)
255
+ self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
256
+ self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
257
+ self.num_layers)
258
+
259
+ def forward(self, x, prompts, bert_feature):
260
+ early_stop_num = self.early_stop_num
261
+ prefix_len = prompts.shape[1]
262
+
263
+ x = self.onnx_encoder(x, bert_feature)
264
+ y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
265
+
266
+ stop = False
267
+ for idx in range(1, 1500):
268
+ enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
269
+ y, k, v, y_emb, stage, logits, samples = enco
270
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
271
+ stop = True
272
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
273
+ stop = True
274
+ if stop:
275
+ break
276
+ y[0, -1] = 0
277
+ return y, idx
278
+
279
+ def infer(self, x, prompts, bert_feature):
280
+ top_k = self.top_k
281
+ early_stop_num = self.early_stop_num
282
+
283
+ x = self.onnx_encoder(x, bert_feature)
284
+
285
+ y = prompts
286
+ prefix_len = y.shape[1]
287
+ x_len = x.shape[1]
288
+ x_example = x[:,:,0] * 0.0
289
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
290
+ x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
291
+
292
+ stop = False
293
+ cache = {
294
+ "all_stage": self.num_layers,
295
+ "k": [None] * self.num_layers,
296
+ "v": [None] * self.num_layers,
297
+ "y_emb": None,
298
+ "first_infer": 1,
299
+ "stage": 0,
300
+ }
301
+ for idx in range(1500):
302
+ if cache["first_infer"] == 1:
303
+ y_emb = self.ar_audio_embedding(y)
304
+ else:
305
+ y_emb = torch.cat(
306
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
307
+ )
308
+ cache["y_emb"] = y_emb
309
+ y_pos = self.ar_audio_position(y_emb)
310
+ if cache["first_infer"] == 1:
311
+ xy_pos = torch.concat([x, y_pos], dim=1)
312
+ else:
313
+ xy_pos = y_pos[:, -1:]
314
+ y_len = y_pos.shape[1]
315
+ if cache["first_infer"] == 1:
316
+ x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
317
+ y_attn_mask = F.pad(
318
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
319
+ (x_len, 0), value=False
320
+ )
321
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
322
+ else:
323
+ xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
324
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
325
+ logits = self.ar_predict_layer(xy_dec[:, -1])
326
+ samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
327
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
328
+ stop = True
329
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
330
+ stop = True
331
+ if stop:
332
+ if prompts.shape[1] == y.shape[1]:
333
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
334
+ break
335
+ y = torch.concat([y, samples], dim=1)
336
+ cache["first_infer"] = 0
337
+ return y, idx
AR/modules/activation_onnx.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear
7
+ from torch.nn import Module
8
+ from torch.nn.init import constant_
9
+ from torch.nn.init import xavier_normal_
10
+ from torch.nn.init import xavier_uniform_
11
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
12
+ from torch.nn.parameter import Parameter
13
+
14
+ from torch.nn import functional as F
15
+ from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
16
+
17
+
18
+ class MultiheadAttention(Module):
19
+ __constants__ = ["batch_first"]
20
+ bias_k: Optional[torch.Tensor]
21
+ bias_v: Optional[torch.Tensor]
22
+
23
+ def __init__(
24
+ self,
25
+ embed_dim,
26
+ num_heads,
27
+ dropout=0.0,
28
+ bias=True,
29
+ add_bias_kv=False,
30
+ add_zero_attn=False,
31
+ kdim=None,
32
+ vdim=None,
33
+ batch_first=False,
34
+ linear1_cls=Linear,
35
+ linear2_cls=Linear,
36
+ device=None,
37
+ dtype=None,
38
+ ) -> None:
39
+ factory_kwargs = {"device": device, "dtype": dtype}
40
+ super(MultiheadAttention, self).__init__()
41
+ self.embed_dim = embed_dim
42
+ self.kdim = kdim if kdim is not None else embed_dim
43
+ self.vdim = vdim if vdim is not None else embed_dim
44
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
45
+
46
+ self.num_heads = num_heads
47
+ self.dropout = dropout
48
+ self.batch_first = batch_first
49
+ self.head_dim = embed_dim // num_heads
50
+ assert (
51
+ self.head_dim * num_heads == self.embed_dim
52
+ ), "embed_dim must be divisible by num_heads"
53
+
54
+ if add_bias_kv:
55
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
56
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
57
+ else:
58
+ self.bias_k = self.bias_v = None
59
+
60
+ if linear1_cls == Linear:
61
+ if not self._qkv_same_embed_dim:
62
+ self.q_proj_weight = Parameter(
63
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
64
+ )
65
+ self.k_proj_weight = Parameter(
66
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
67
+ )
68
+ self.v_proj_weight = Parameter(
69
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
70
+ )
71
+ self.register_parameter("in_proj_weight", None)
72
+ else:
73
+ self.in_proj_weight = Parameter(
74
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
75
+ )
76
+ self.register_parameter("q_proj_weight", None)
77
+ self.register_parameter("k_proj_weight", None)
78
+ self.register_parameter("v_proj_weight", None)
79
+
80
+ if bias:
81
+ self.in_proj_bias = Parameter(
82
+ torch.empty(3 * embed_dim, **factory_kwargs)
83
+ )
84
+ else:
85
+ self.register_parameter("in_proj_bias", None)
86
+ self.out_proj = NonDynamicallyQuantizableLinear(
87
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
88
+ )
89
+
90
+ self._reset_parameters()
91
+ else:
92
+ if not self._qkv_same_embed_dim:
93
+ raise NotImplementedError
94
+ else:
95
+ self.in_proj_linear = linear1_cls(
96
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
97
+ )
98
+ self.in_proj_weight = self.in_proj_linear.weight
99
+
100
+ self.register_parameter("q_proj_weight", None)
101
+ self.register_parameter("k_proj_weight", None)
102
+ self.register_parameter("v_proj_weight", None)
103
+
104
+ if bias:
105
+ self.in_proj_bias = self.in_proj_linear.bias
106
+ else:
107
+ self.register_parameter("in_proj_bias", None)
108
+
109
+ self.out_proj = linear2_cls(
110
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
111
+ )
112
+
113
+ if self.bias_k is not None:
114
+ xavier_normal_(self.bias_k)
115
+ if self.bias_v is not None:
116
+ xavier_normal_(self.bias_v)
117
+
118
+ self.add_zero_attn = add_zero_attn
119
+
120
+ def _reset_parameters(self):
121
+ if self._qkv_same_embed_dim:
122
+ xavier_uniform_(self.in_proj_weight)
123
+ else:
124
+ xavier_uniform_(self.q_proj_weight)
125
+ xavier_uniform_(self.k_proj_weight)
126
+ xavier_uniform_(self.v_proj_weight)
127
+
128
+ if self.in_proj_bias is not None:
129
+ constant_(self.in_proj_bias, 0.0)
130
+ constant_(self.out_proj.bias, 0.0)
131
+
132
+ if self.bias_k is not None:
133
+ xavier_normal_(self.bias_k)
134
+ if self.bias_v is not None:
135
+ xavier_normal_(self.bias_v)
136
+
137
+ def __setstate__(self, state):
138
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
139
+ if "_qkv_same_embed_dim" not in state:
140
+ state["_qkv_same_embed_dim"] = True
141
+
142
+ super(MultiheadAttention, self).__setstate__(state)
143
+
144
+ def forward(
145
+ self,
146
+ query: Tensor,
147
+ key: Tensor,
148
+ value: Tensor,
149
+ key_padding_mask: Optional[Tensor] = None,
150
+ need_weights: bool = True,
151
+ attn_mask: Optional[Tensor] = None,
152
+ average_attn_weights: bool = True,
153
+ cache=None,
154
+ ) -> Tuple[Tensor, Optional[Tensor]]:
155
+ any_nested = query.is_nested or key.is_nested or value.is_nested
156
+ query = key = value = query.transpose(1, 0)
157
+ attn_output = multi_head_attention_forward_patched(
158
+ query,
159
+ key,
160
+ value,
161
+ self.embed_dim,
162
+ self.num_heads,
163
+ self.in_proj_weight,
164
+ self.in_proj_bias,
165
+ self.bias_k,
166
+ self.bias_v,
167
+ self.add_zero_attn,
168
+ self.dropout,
169
+ self.out_proj.weight,
170
+ self.out_proj.bias,
171
+ training=self.training,
172
+ key_padding_mask=key_padding_mask,
173
+ need_weights=need_weights,
174
+ attn_mask=attn_mask,
175
+ average_attn_weights=average_attn_weights,
176
+ cache=cache,
177
+ )
178
+ return attn_output.transpose(1, 0)
AR/modules/embedding_onnx.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+ self.reverse = False
50
+ self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
51
+
52
+ def extend_pe(self, x):
53
+ position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
54
+ scpe = (position * self.div_term).unsqueeze(0)
55
+ pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
56
+ pe = pe.contiguous().view(1, -1, self.embedding_dim)
57
+ return pe
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ pe = self.extend_pe(x)
61
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
62
+ output = output * self.x_scale + self.alpha * pe
63
+ return self.dropout(output)
AR/modules/patched_mha_with_cache_onnx.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+
9
+ def multi_head_attention_forward_patched(
10
+ query,
11
+ key,
12
+ value,
13
+ embed_dim_to_check: int,
14
+ num_heads: int,
15
+ in_proj_weight,
16
+ in_proj_bias: Optional[Tensor],
17
+ bias_k: Optional[Tensor],
18
+ bias_v: Optional[Tensor],
19
+ add_zero_attn: bool,
20
+ dropout_p: float,
21
+ out_proj_weight: Tensor,
22
+ out_proj_bias: Optional[Tensor],
23
+ training: bool = True,
24
+ key_padding_mask: Optional[Tensor] = None,
25
+ need_weights: bool = True,
26
+ attn_mask: Optional[Tensor] = None,
27
+ use_separate_proj_weight: bool = False,
28
+ q_proj_weight: Optional[Tensor] = None,
29
+ k_proj_weight: Optional[Tensor] = None,
30
+ v_proj_weight: Optional[Tensor] = None,
31
+ static_k: Optional[Tensor] = None,
32
+ static_v: Optional[Tensor] = None,
33
+ average_attn_weights: bool = True,
34
+ is_causal: bool = False,
35
+ cache=None,
36
+ ) -> Tuple[Tensor, Optional[Tensor]]:
37
+
38
+ # set up shape vars
39
+ _, _, embed_dim = query.shape
40
+ attn_mask = _canonical_mask(
41
+ mask=attn_mask,
42
+ mask_name="attn_mask",
43
+ other_type=None,
44
+ other_name="",
45
+ target_type=query.dtype,
46
+ check_other=False,
47
+ )
48
+ head_dim = embed_dim // num_heads
49
+
50
+ proj_qkv = linear(query, in_proj_weight, in_proj_bias)
51
+ proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
52
+ q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
53
+
54
+ if cache["first_infer"] == 1:
55
+ cache["k"][cache["stage"]] = k
56
+ cache["v"][cache["stage"]] = v
57
+ else:
58
+ cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
59
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
60
+ k = cache["k"][cache["stage"]]
61
+ v = cache["v"][cache["stage"]]
62
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
63
+
64
+ attn_mask = _canonical_mask(
65
+ mask=attn_mask,
66
+ mask_name="attn_mask",
67
+ other_type=None,
68
+ other_name="",
69
+ target_type=q.dtype,
70
+ check_other=False,
71
+ )
72
+ attn_mask = attn_mask.unsqueeze(0)
73
+
74
+ q = q.view(-1, num_heads, head_dim).transpose(0, 1)
75
+ k = k.view(-1, num_heads, head_dim).transpose(0, 1)
76
+ v = v.view(-1, num_heads, head_dim).transpose(0, 1)
77
+
78
+ dropout_p = 0.0
79
+ attn_mask = attn_mask.unsqueeze(0)
80
+ q = q.view(num_heads, -1, head_dim).unsqueeze(0)
81
+ k = k.view(num_heads, -1, head_dim).unsqueeze(0)
82
+ v = v.view(num_heads, -1, head_dim).unsqueeze(0)
83
+ attn_output = scaled_dot_product_attention(
84
+ q, k, v, attn_mask, dropout_p, is_causal
85
+ )
86
+ attn_output = (
87
+ attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
88
+ )
89
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
90
+ attn_output = attn_output.view(-1, 1, attn_output.size(1))
91
+
92
+ return attn_output
AR/modules/transformer_onnx.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation_onnx import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(
46
+ torch.empty(self.normalized_shape, **factory_kwargs)
47
+ )
48
+ self.bias = nn.Parameter(
49
+ torch.empty(self.normalized_shape, **factory_kwargs)
50
+ )
51
+ else:
52
+ self.register_parameter("weight", None)
53
+ self.register_parameter("bias", None)
54
+
55
+ self.reset_parameters()
56
+
57
+ def reset_parameters(self) -> None:
58
+ if self.elementwise_affine:
59
+ nn.init.ones_(self.weight)
60
+ nn.init.zeros_(self.bias)
61
+
62
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
63
+ if isinstance(input, tuple):
64
+ input, embedding = input
65
+ return (
66
+ F.layer_norm(
67
+ input,
68
+ self.normalized_shape,
69
+ self.weight,
70
+ self.bias,
71
+ self.eps,
72
+ ),
73
+ embedding,
74
+ )
75
+
76
+ assert embedding is None
77
+ return F.layer_norm(
78
+ input, self.normalized_shape, self.weight, self.bias, self.eps
79
+ )
80
+
81
+ def extra_repr(self) -> str:
82
+ return (
83
+ "{normalized_shape}, eps={eps}, "
84
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
85
+ )
86
+
87
+
88
+ class IdentityNorm(nn.Module):
89
+ def __init__(
90
+ self,
91
+ d_model: int,
92
+ eps: float = 1e-5,
93
+ device=None,
94
+ dtype=None,
95
+ ) -> None:
96
+ super(IdentityNorm, self).__init__()
97
+
98
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
99
+ if isinstance(input, tuple):
100
+ return input
101
+
102
+ assert embedding is None
103
+ return input
104
+
105
+
106
+ class TransformerEncoder(nn.Module):
107
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
108
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
109
+
110
+ Args:
111
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
112
+ num_layers: the number of sub-encoder-layers in the encoder (required).
113
+ norm: the layer normalization component (optional).
114
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
115
+ (and convert back on output). This will improve the overall performance of
116
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
117
+
118
+ Examples::
119
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
120
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
121
+ >>> src = torch.rand(10, 32, 512)
122
+ >>> out = transformer_encoder(src)
123
+ """
124
+ __constants__ = ["norm"]
125
+
126
+ def __init__(self, encoder_layer, num_layers, norm=None):
127
+ super(TransformerEncoder, self).__init__()
128
+ self.layers = _get_clones(encoder_layer, num_layers)
129
+ self.num_layers = num_layers
130
+ self.norm = norm
131
+
132
+ def forward(
133
+ self,
134
+ src: Tensor,
135
+ mask: Optional[Tensor] = None,
136
+ src_key_padding_mask: Optional[Tensor] = None,
137
+ return_layer_states: bool = False,
138
+ cache=None,
139
+ ) -> Tensor:
140
+ output = src
141
+ for mod in self.layers:
142
+ output = mod(
143
+ output,
144
+ src_mask=mask,
145
+ src_key_padding_mask=src_key_padding_mask,
146
+ cache=cache,
147
+ )
148
+
149
+ if self.norm is not None:
150
+ output = self.norm(output)
151
+
152
+ return output
153
+
154
+
155
+ class TransformerEncoderLayer(nn.Module):
156
+ __constants__ = ["batch_first", "norm_first"]
157
+ def __init__(
158
+ self,
159
+ d_model: int,
160
+ nhead: int,
161
+ dim_feedforward: int = 2048,
162
+ dropout: float = 0.1,
163
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
164
+ batch_first: bool = False,
165
+ norm_first: bool = False,
166
+ device=None,
167
+ dtype=None,
168
+ linear1_self_attention_cls: nn.Module = nn.Linear,
169
+ linear2_self_attention_cls: nn.Module = nn.Linear,
170
+ linear1_feedforward_cls: nn.Module = nn.Linear,
171
+ linear2_feedforward_cls: nn.Module = nn.Linear,
172
+ layer_norm_cls: nn.Module = LayerNorm,
173
+ layer_norm_eps: float = 1e-5,
174
+ adaptive_layer_norm=False,
175
+ ) -> None:
176
+ factory_kwargs = {"device": device, "dtype": dtype}
177
+ super(TransformerEncoderLayer, self).__init__()
178
+ self.self_attn = MultiheadAttention(
179
+ d_model, # 512 16
180
+ nhead,
181
+ dropout=dropout,
182
+ batch_first=batch_first,
183
+ linear1_cls=linear1_self_attention_cls,
184
+ linear2_cls=linear2_self_attention_cls,
185
+ **factory_kwargs,
186
+ )
187
+ self.linear1 = linear1_feedforward_cls(
188
+ d_model, dim_feedforward, **factory_kwargs
189
+ )
190
+ self.dropout = nn.Dropout(dropout)
191
+ self.linear2 = linear2_feedforward_cls(
192
+ dim_feedforward, d_model, **factory_kwargs
193
+ )
194
+ self.norm_first = norm_first
195
+ self.dropout1 = nn.Dropout(dropout)
196
+ self.dropout2 = nn.Dropout(dropout)
197
+ if isinstance(activation, str):
198
+ activation = _get_activation_fn(activation)
199
+ elif isinstance(activation, partial):
200
+ activation = activation(d_model)
201
+ elif activation == BalancedDoubleSwish:
202
+ activation = BalancedDoubleSwish(d_model)
203
+ self.activation = activation
204
+
205
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
206
+ if layer_norm_cls == IdentityNorm:
207
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
208
+ else:
209
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
210
+
211
+ if adaptive_layer_norm:
212
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
213
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
214
+ else:
215
+ self.norm1 = norm1
216
+ self.norm2 = norm2
217
+
218
+ def __setstate__(self, state):
219
+ super(TransformerEncoderLayer, self).__setstate__(state)
220
+ if not hasattr(self, "activation"):
221
+ self.activation = F.relu
222
+
223
+ def forward(
224
+ self,
225
+ src: Tensor,
226
+ src_mask: Optional[Tensor] = None,
227
+ src_key_padding_mask: Optional[Tensor] = None,
228
+ cache=None,
229
+ ) -> Tensor:
230
+ x = src
231
+ stage_embedding = None
232
+ x = self.norm1(
233
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
234
+ stage_embedding,
235
+ )
236
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
237
+
238
+ return x
239
+
240
+ def _sa_block(
241
+ self,
242
+ x: Tensor,
243
+ attn_mask: Optional[Tensor],
244
+ key_padding_mask: Optional[Tensor],
245
+ cache=None,
246
+ ) -> Tensor:
247
+ x = self.self_attn(
248
+ x,
249
+ x,
250
+ x,
251
+ attn_mask=attn_mask,
252
+ key_padding_mask=key_padding_mask,
253
+ need_weights=False,
254
+ cache=cache,
255
+ )
256
+ return self.dropout1(x)
257
+
258
+ def _ff_block(self, x: Tensor) -> Tensor:
259
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
260
+ return self.dropout2(x)
261
+
262
+
263
+ class AdaptiveLayerNorm(nn.Module):
264
+ r"""Adaptive Layer Normalization"""
265
+
266
+ def __init__(self, d_model, norm) -> None:
267
+ super(AdaptiveLayerNorm, self).__init__()
268
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
269
+ self.norm = norm
270
+ self.d_model = d_model
271
+ self.eps = self.norm.eps
272
+
273
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
274
+ if isinstance(input, tuple):
275
+ input, embedding = input
276
+ weight, bias = torch.split(
277
+ self.project_layer(embedding),
278
+ split_size_or_sections=self.d_model,
279
+ dim=-1,
280
+ )
281
+ return (weight * self.norm(input) + bias, embedding)
282
+
283
+ weight, bias = torch.split(
284
+ self.project_layer(embedding),
285
+ split_size_or_sections=self.d_model,
286
+ dim=-1,
287
+ )
288
+ return weight * self.norm(input) + bias
289
+
290
+
291
+ def _get_clones(module, N):
292
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
app.py CHANGED
@@ -1,10 +1,33 @@
1
- import os,re
2
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- gpt_path = os.environ.get(
5
- "gpt_path", "models/Taffy/Taffy-e5.ckpt"
6
- )
7
- sovits_path = os.environ.get("sovits_path", "models/Taffy/Taffy_e20_s1020.pth")
 
 
 
 
 
 
8
  cnhubert_base_path = os.environ.get(
9
  "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
10
  )
@@ -13,6 +36,8 @@ bert_path = os.environ.get(
13
  )
14
  infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
15
  infer_ttswebui = int(infer_ttswebui)
 
 
16
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
17
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
18
  is_half = eval(os.environ.get("is_half", "True"))
@@ -22,10 +47,6 @@ import numpy as np
22
  import librosa,torch
23
  from feature_extractor import cnhubert
24
  cnhubert.cnhubert_base_path=cnhubert_base_path
25
- import ssl
26
- ssl._create_default_https_context = ssl._create_unverified_context
27
- import nltk
28
- nltk.download('cmudict')
29
 
30
  from module.models import SynthesizerTrn
31
  from AR.models.t2s_lightning_module import Text2SemanticLightningModule
@@ -34,12 +55,17 @@ from text.cleaner import clean_text
34
  from time import time as ttime
35
  from module.mel_processing import spectrogram_torch
36
  from my_utils import load_audio
 
 
37
 
38
- device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
- is_half = eval(
41
- os.environ.get("is_half", "True" if torch.cuda.is_available() else "False")
42
- )
 
 
 
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
45
  bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
@@ -48,13 +74,11 @@ if is_half == True:
48
  else:
49
  bert_model = bert_model.to(device)
50
 
51
-
52
- # bert_model=bert_model.to(device)
53
  def get_bert_feature(text, word2ph):
54
  with torch.no_grad():
55
  inputs = tokenizer(text, return_tensors="pt")
56
  for i in inputs:
57
- inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
58
  res = bert_model(**inputs, output_hidden_states=True)
59
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
60
  assert len(word2ph) == len(text)
@@ -63,15 +87,8 @@ def get_bert_feature(text, word2ph):
63
  repeat_feature = res[i].repeat(word2ph[i], 1)
64
  phone_level_feature.append(repeat_feature)
65
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
66
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
67
  return phone_level_feature.T
68
 
69
-
70
- n_semantic = 1024
71
-
72
- dict_s2=torch.load(sovits_path,map_location="cpu")
73
- hps=dict_s2["config"]
74
-
75
  class DictToAttrRecursive(dict):
76
  def __init__(self, input_dict):
77
  super().__init__(input_dict)
@@ -100,11 +117,6 @@ class DictToAttrRecursive(dict):
100
  raise AttributeError(f"Attribute {item} not found")
101
 
102
 
103
- hps = DictToAttrRecursive(hps)
104
-
105
- hps.model.semantic_frame_rate = "25hz"
106
- dict_s1 = torch.load(gpt_path, map_location="cpu")
107
- config = dict_s1["config"]
108
  ssl_model = cnhubert.get_model()
109
  if is_half == True:
110
  ssl_model = ssl_model.half().to(device)
@@ -123,13 +135,15 @@ def change_sovits_weights(sovits_path):
123
  n_speakers=hps.data.n_speakers,
124
  **hps.model
125
  )
126
- del vq_model.enc_q
 
127
  if is_half == True:
128
  vq_model = vq_model.half().to(device)
129
  else:
130
  vq_model = vq_model.to(device)
131
  vq_model.eval()
132
  print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
 
133
  change_sovits_weights(sovits_path)
134
 
135
  def change_gpt_weights(gpt_path):
@@ -146,9 +160,9 @@ def change_gpt_weights(gpt_path):
146
  t2s_model.eval()
147
  total = sum([param.nelement() for param in t2s_model.parameters()])
148
  print("Number of parameter: %.2fM" % (total / 1e6))
 
149
  change_gpt_weights(gpt_path)
150
 
151
-
152
  def get_spepc(hps, filename):
153
  audio = load_audio(filename, int(hps.data.sampling_rate))
154
  audio = torch.FloatTensor(audio)
@@ -165,14 +179,91 @@ def get_spepc(hps, filename):
165
  return spec
166
 
167
 
168
- dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
- def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language):
172
- ref_wav_path = text_to_audio_mappings.get(selected_text, "")
173
- if not ref_wav_path:
174
- print("Audio file not found for the selected text.")
175
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  t0 = ttime()
177
  prompt_text = prompt_text.strip("\n")
178
  prompt_language, text = prompt_language, text.strip("\n")
@@ -201,28 +292,38 @@ def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language
201
  t1 = ttime()
202
  prompt_language = dict_language[prompt_language]
203
  text_language = dict_language[text_language]
204
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
205
- phones1 = cleaned_text_to_sequence(phones1)
206
- texts = text.split("\n")
207
- audio_opt = []
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  for text in texts:
210
  # 解决输入目标文本的空行导致报错的问题
211
  if (len(text.strip()) == 0):
212
  continue
213
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
214
- phones2 = cleaned_text_to_sequence(phones2)
215
- if prompt_language == "zh":
216
- bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
217
  else:
218
- bert1 = torch.zeros(
219
- (1024, len(phones1)),
220
- dtype=torch.float16 if is_half == True else torch.float32,
221
- ).to(device)
222
- if text_language == "zh":
223
- bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
224
  else:
225
- bert2 = torch.zeros((1024, len(phones2))).to(bert1)
 
226
  bert = torch.cat([bert1, bert2], 1)
227
 
228
  all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
@@ -345,85 +446,96 @@ def cut2(inp):
345
  def cut3(inp):
346
  inp = inp.strip("\n")
347
  return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
348
-
349
- def scan_audio_files(folder_path):
350
- """ 扫描指定文件夹获取音频文件列表 """
351
- return [f for f in os.listdir(folder_path) if f.endswith('.wav')]
352
-
353
- def load_audio_text_mappings(folder_path, list_file_name):
354
- text_to_audio_mappings = {}
355
- audio_to_text_mappings = {}
356
- with open(os.path.join(folder_path, list_file_name), 'r', encoding='utf-8') as file:
357
- for line in file:
358
- parts = line.strip().split('|')
359
- if len(parts) >= 4:
360
- audio_file_name = parts[0]
361
- text = parts[3]
362
- audio_file_path = os.path.join(folder_path, audio_file_name)
363
- text_to_audio_mappings[text] = audio_file_path
364
- audio_to_text_mappings[audio_file_path] = text
365
- return text_to_audio_mappings, audio_to_text_mappings
366
-
367
- audio_folder_path = 'audio/Taffy'
368
- text_to_audio_mappings, audio_to_text_mappings = load_audio_text_mappings(audio_folder_path, 'Taffy.list')
 
 
 
 
 
 
 
 
 
369
 
370
  with gr.Blocks(title="GPT-SoVITS WebUI") as app:
371
- gr.Markdown(value="""
372
- # <center>【AI塔菲】在线语音生成(GPT-SoVITS)\n
373
-
374
- ### <center>模型作者:Xz乔希 https://space.bilibili.com/5859321\n
375
- ### <center>数据集下载:https://huggingface.co/datasets/XzJosh/audiodataset\n
376
- ### <center>声音归属:永雏塔菲 https://space.bilibili.com/1265680561\n
377
- ### <center>GPT-SoVITS项目:https://github.com/RVC-Boss/GPT-SoVITS\n
378
- ### <center>使用本模型请严格遵守法律法规!发布二创作品请标注本项目作者及链接、作品使用GPT-SoVITS AI生成!\n
379
- ### <center>⚠️在线端不稳定且生成速度较慢,强烈建议下载模型本地推理!\n
380
- """)
381
- # with gr.Tabs():
382
- # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
383
  with gr.Group():
384
- gr.Markdown(value="*参考音频选择(必选)")
385
  with gr.Row():
386
- audio_select = gr.Dropdown(label="选择参考音频(不建议选较长的)", choices=list(text_to_audio_mappings.keys()))
387
- ref_audio = gr.Audio(label="参考音频试听")
388
- ref_text = gr.Textbox(label="参考音频文本")
389
-
390
- # 定义更新参考文本的函数
391
- def update_ref_text_and_audio(selected_text):
392
- audio_path = text_to_audio_mappings.get(selected_text, "")
393
- return selected_text, audio_path
394
-
395
- # 绑定下拉菜单的变化到更新函数
396
- audio_select.change(update_ref_text_and_audio, [audio_select], [ref_text, ref_audio])
397
-
398
- # 其他 Gradio 组件和功能
399
- prompt_language = gr.Dropdown(
400
- label="参考音频语种", choices=["中文", "英文", "日文"], value="中文"
401
- )
402
- gr.Markdown(value="*请填写需要合成的目标文本")
403
  with gr.Row():
404
- text = gr.Textbox(label="需要合成的文本", value="")
 
 
 
 
 
 
 
405
  text_language = gr.Dropdown(
406
- label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
 
 
 
 
 
 
407
  )
408
- inference_button = gr.Button("合成语音", variant="primary")
409
- output = gr.Audio(label="输出的语音")
 
410
  inference_button.click(
411
  get_tts_wav,
412
- [audio_select, ref_text, prompt_language, text, text_language],
413
  [output],
414
  )
415
 
416
-
417
- gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
418
- with gr.Row():
419
- text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
420
- button1 = gr.Button("凑五句一切", variant="primary")
421
- button2 = gr.Button("凑50字一切", variant="primary")
422
- button3 = gr.Button("按中文句号。切", variant="primary")
423
- text_opt = gr.Textbox(label="切分后文本", value="")
424
- button1.click(cut1, [text_inp], [text_opt])
425
- button2.click(cut2, [text_inp], [text_opt])
426
- button3.click(cut3, [text_inp], [text_opt])
427
-
428
- app.queue(max_size=10)
429
- app.launch(inbrowser=True)
 
 
 
 
 
 
 
 
1
+ import os,re,logging
2
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
3
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
4
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
5
+ logging.getLogger("httpx").setLevel(logging.ERROR)
6
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
7
+
8
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
9
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
10
+ import pdb
11
+
12
+ if os.path.exists("./gweight.txt"):
13
+ with open("./gweight.txt", 'r',encoding="utf-8") as file:
14
+ gweight_data = file.read()
15
+ gpt_path = os.environ.get(
16
+ "gpt_path", gweight_data)
17
+ else:
18
+ gpt_path = os.environ.get(
19
+ "gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
20
 
21
+ if os.path.exists("./sweight.txt"):
22
+ with open("./sweight.txt", 'r',encoding="utf-8") as file:
23
+ sweight_data = file.read()
24
+ sovits_path = os.environ.get("sovits_path", sweight_data)
25
+ else:
26
+ sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
27
+ # gpt_path = os.environ.get(
28
+ # "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
29
+ # )
30
+ # sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
31
  cnhubert_base_path = os.environ.get(
32
  "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
33
  )
 
36
  )
37
  infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
38
  infer_ttswebui = int(infer_ttswebui)
39
+ is_share = os.environ.get("is_share", "False")
40
+ is_share=eval(is_share)
41
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
42
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
43
  is_half = eval(os.environ.get("is_half", "True"))
 
47
  import librosa,torch
48
  from feature_extractor import cnhubert
49
  cnhubert.cnhubert_base_path=cnhubert_base_path
 
 
 
 
50
 
51
  from module.models import SynthesizerTrn
52
  from AR.models.t2s_lightning_module import Text2SemanticLightningModule
 
55
  from time import time as ttime
56
  from module.mel_processing import spectrogram_torch
57
  from my_utils import load_audio
58
+ from tools.i18n.i18n import I18nAuto
59
+ i18n = I18nAuto()
60
 
61
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
62
 
63
+ if torch.cuda.is_available():
64
+ device = "cuda"
65
+ elif torch.backends.mps.is_available():
66
+ device = "mps"
67
+ else:
68
+ device = "cpu"
69
 
70
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
71
  bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
 
74
  else:
75
  bert_model = bert_model.to(device)
76
 
 
 
77
  def get_bert_feature(text, word2ph):
78
  with torch.no_grad():
79
  inputs = tokenizer(text, return_tensors="pt")
80
  for i in inputs:
81
+ inputs[i] = inputs[i].to(device)
82
  res = bert_model(**inputs, output_hidden_states=True)
83
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
84
  assert len(word2ph) == len(text)
 
87
  repeat_feature = res[i].repeat(word2ph[i], 1)
88
  phone_level_feature.append(repeat_feature)
89
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
90
  return phone_level_feature.T
91
 
 
 
 
 
 
 
92
  class DictToAttrRecursive(dict):
93
  def __init__(self, input_dict):
94
  super().__init__(input_dict)
 
117
  raise AttributeError(f"Attribute {item} not found")
118
 
119
 
 
 
 
 
 
120
  ssl_model = cnhubert.get_model()
121
  if is_half == True:
122
  ssl_model = ssl_model.half().to(device)
 
135
  n_speakers=hps.data.n_speakers,
136
  **hps.model
137
  )
138
+ if("pretrained"not in sovits_path):
139
+ del vq_model.enc_q
140
  if is_half == True:
141
  vq_model = vq_model.half().to(device)
142
  else:
143
  vq_model = vq_model.to(device)
144
  vq_model.eval()
145
  print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
146
+ with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
147
  change_sovits_weights(sovits_path)
148
 
149
  def change_gpt_weights(gpt_path):
 
160
  t2s_model.eval()
161
  total = sum([param.nelement() for param in t2s_model.parameters()])
162
  print("Number of parameter: %.2fM" % (total / 1e6))
163
+ with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
164
  change_gpt_weights(gpt_path)
165
 
 
166
  def get_spepc(hps, filename):
167
  audio = load_audio(filename, int(hps.data.sampling_rate))
168
  audio = torch.FloatTensor(audio)
 
179
  return spec
180
 
181
 
182
+ dict_language={
183
+ i18n("中文"):"zh",
184
+ i18n("英文"):"en",
185
+ i18n("日文"):"ja"
186
+ }
187
+
188
+
189
+ def splite_en_inf(sentence, language):
190
+ pattern = re.compile(r'[a-zA-Z. ]+')
191
+ textlist = []
192
+ langlist = []
193
+ pos = 0
194
+ for match in pattern.finditer(sentence):
195
+ start, end = match.span()
196
+ if start > pos:
197
+ textlist.append(sentence[pos:start])
198
+ langlist.append(language)
199
+ textlist.append(sentence[start:end])
200
+ langlist.append("en")
201
+ pos = end
202
+ if pos < len(sentence):
203
+ textlist.append(sentence[pos:])
204
+ langlist.append(language)
205
+
206
+ return textlist, langlist
207
 
208
 
209
+ def clean_text_inf(text, language):
210
+ phones, word2ph, norm_text = clean_text(text, language)
211
+ phones = cleaned_text_to_sequence(phones)
212
+
213
+ return phones, word2ph, norm_text
214
+
215
+
216
+ def get_bert_inf(phones, word2ph, norm_text, language):
217
+ if language == "zh":
218
+ bert = get_bert_feature(norm_text, word2ph).to(device)
219
+ else:
220
+ bert = torch.zeros(
221
+ (1024, len(phones)),
222
+ dtype=torch.float16 if is_half == True else torch.float32,
223
+ ).to(device)
224
+
225
+ return bert
226
+
227
+
228
+ def nonen_clean_text_inf(text, language):
229
+ textlist, langlist = splite_en_inf(text, language)
230
+ phones_list = []
231
+ word2ph_list = []
232
+ norm_text_list = []
233
+ for i in range(len(textlist)):
234
+ lang = langlist[i]
235
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
236
+ phones_list.append(phones)
237
+ if lang == "en" or "ja":
238
+ pass
239
+ else:
240
+ word2ph_list.append(word2ph)
241
+ norm_text_list.append(norm_text)
242
+ print(word2ph_list)
243
+ phones = sum(phones_list, [])
244
+ word2ph = sum(word2ph_list, [])
245
+ norm_text = ' '.join(norm_text_list)
246
+
247
+ return phones, word2ph, norm_text
248
+
249
+
250
+ def nonen_get_bert_inf(text, language):
251
+ textlist, langlist = splite_en_inf(text, language)
252
+ print(textlist)
253
+ print(langlist)
254
+ bert_list = []
255
+ for i in range(len(textlist)):
256
+ text = textlist[i]
257
+ lang = langlist[i]
258
+ phones, word2ph, norm_text = clean_text_inf(text, lang)
259
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
260
+ bert_list.append(bert)
261
+ bert = torch.cat(bert_list, dim=1)
262
+
263
+ return bert
264
+
265
+ #i18n("不切"),i18n("凑五句一切"),i18n("凑50字一切"),i18n("按中文句号。切"),i18n("按英文句号.切")
266
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut=i18n("不切")):
267
  t0 = ttime()
268
  prompt_text = prompt_text.strip("\n")
269
  prompt_language, text = prompt_language, text.strip("\n")
 
292
  t1 = ttime()
293
  prompt_language = dict_language[prompt_language]
294
  text_language = dict_language[text_language]
 
 
 
 
295
 
296
+ if prompt_language == "en":
297
+ phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
298
+ else:
299
+ phones1, word2ph1, norm_text1 = nonen_clean_text_inf(prompt_text, prompt_language)
300
+ if(how_to_cut==i18n("凑五句一切")):text=cut1(text)
301
+ elif(how_to_cut==i18n("凑50字一切")):text=cut2(text)
302
+ elif(how_to_cut==i18n("按中文句号。切")):text=cut3(text)
303
+ elif(how_to_cut==i18n("按英文句号.切")):text=cut4(text)
304
+ text = text.replace("\n\n","\n").replace("\n\n","\n").replace("\n\n","\n")
305
+ if(text[-1]not in splits):text+="。"if text_language!="en"else "."
306
+ texts=text.split("\n")
307
+ audio_opt = []
308
+ if prompt_language == "en":
309
+ bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
310
+ else:
311
+ bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
312
+
313
  for text in texts:
314
  # 解决输入目标文本的空行导致报错的问题
315
  if (len(text.strip()) == 0):
316
  continue
317
+ if text_language == "en":
318
+ phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
 
 
319
  else:
320
+ phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language)
321
+
322
+ if text_language == "en":
323
+ bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language)
 
 
324
  else:
325
+ bert2 = nonen_get_bert_inf(text, text_language)
326
+
327
  bert = torch.cat([bert1, bert2], 1)
328
 
329
  all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
 
446
  def cut3(inp):
447
  inp = inp.strip("\n")
448
  return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
449
+ def cut4(inp):
450
+ inp = inp.strip("\n")
451
+ return "\n".join(["%s." % item for item in inp.strip(".").split(".")])
452
+
453
+ def custom_sort_key(s):
454
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
455
+ parts = re.split('(\d+)', s)
456
+ # 将数字部分转换为整数,非数字部分保持不变
457
+ parts = [int(part) if part.isdigit() else part for part in parts]
458
+ return parts
459
+
460
+ def change_choices():
461
+ SoVITS_names, GPT_names = get_weights_names()
462
+ return {"choices": sorted(SoVITS_names,key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names,key=custom_sort_key), "__type__": "update"}
463
+
464
+ pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"
465
+ pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
466
+ SoVITS_weight_root="SoVITS_weights"
467
+ GPT_weight_root="GPT_weights"
468
+ os.makedirs(SoVITS_weight_root,exist_ok=True)
469
+ os.makedirs(GPT_weight_root,exist_ok=True)
470
+ def get_weights_names():
471
+ SoVITS_names = [pretrained_sovits_name]
472
+ for name in os.listdir(SoVITS_weight_root):
473
+ if name.endswith(".pth"):SoVITS_names.append("%s/%s"%(SoVITS_weight_root,name))
474
+ GPT_names = [pretrained_gpt_name]
475
+ for name in os.listdir(GPT_weight_root):
476
+ if name.endswith(".ckpt"): GPT_names.append("%s/%s"%(GPT_weight_root,name))
477
+ return SoVITS_names,GPT_names
478
+ SoVITS_names,GPT_names = get_weights_names()
479
 
480
  with gr.Blocks(title="GPT-SoVITS WebUI") as app:
481
+ gr.Markdown(
482
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
483
+ )
 
 
 
 
 
 
 
 
 
484
  with gr.Group():
485
+ gr.Markdown(value=i18n("模型切换"))
486
  with gr.Row():
487
+ GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path,interactive=True)
488
+ SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path,interactive=True)
489
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
490
+ refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
491
+ SoVITS_dropdown.change(change_sovits_weights,[SoVITS_dropdown],[])
492
+ GPT_dropdown.change(change_gpt_weights,[GPT_dropdown],[])
493
+ gr.Markdown(value=i18n("*请上传并填写参考信息"))
 
 
 
 
 
 
 
 
 
 
494
  with gr.Row():
495
+ inp_ref = gr.Audio(label=i18n("请上传参考音频"), type="filepath")
496
+ prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
497
+ prompt_language = gr.Dropdown(
498
+ label=i18n("参考音频的语种"),choices=[i18n("中文"),i18n("英文"),i18n("日文")],value=i18n("中文")
499
+ )
500
+ gr.Markdown(value=i18n("*请填写需要合成的目标文本。中英混合选中文,日英混合选日文,中日混合暂不支持,非目标语言文本自动遗弃。"))
501
+ with gr.Row():
502
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="")
503
  text_language = gr.Dropdown(
504
+ label=i18n("需要合成的语种"),choices=[i18n("中文"),i18n("英文"),i18n("日文")],value=i18n("中文")
505
+ )
506
+ how_to_cut = gr.Radio(
507
+ label=i18n("怎么切"),
508
+ choices=[i18n("不切"),i18n("凑五句一切"),i18n("凑50字一切"),i18n("按中文句号。切"),i18n("按英文句号.切"),],
509
+ value=i18n("凑50字一切"),
510
+ interactive=True,
511
  )
512
+ inference_button = gr.Button(i18n("合成语音"), variant="primary")
513
+ output = gr.Audio(label=i18n("输出的语音"))
514
+
515
  inference_button.click(
516
  get_tts_wav,
517
+ [inp_ref, prompt_text, prompt_language, text, text_language,how_to_cut],
518
  [output],
519
  )
520
 
521
+ gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
522
+ with gr.Row():
523
+ text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"),value="")
524
+ button1 = gr.Button(i18n("凑五句一切"), variant="primary")
525
+ button2 = gr.Button(i18n("凑50字一切"), variant="primary")
526
+ button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
527
+ button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
528
+ text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
529
+ button1.click(cut1, [text_inp], [text_opt])
530
+ button2.click(cut2, [text_inp], [text_opt])
531
+ button3.click(cut3, [text_inp], [text_opt])
532
+ button4.click(cut4, [text_inp], [text_opt])
533
+ gr.Markdown(value=i18n("后续将支持混合语种编码文本输入。"))
534
+
535
+ app.queue(concurrency_count=511, max_size=1022).launch(
536
+ server_name="0.0.0.0",
537
+ inbrowser=True,
538
+ share=is_share,
539
+ server_port=infer_ttswebui,
540
+ quiet=True,
541
+ )
module/attentions_onnx.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from module import commons
7
+ from module.modules import LayerNorm
8
+
9
+
10
+ class LayerNorm(nn.Module):
11
+ def __init__(self, channels, eps=1e-5):
12
+ super().__init__()
13
+ self.channels = channels
14
+ self.eps = eps
15
+
16
+ self.gamma = nn.Parameter(torch.ones(channels))
17
+ self.beta = nn.Parameter(torch.zeros(channels))
18
+
19
+ def forward(self, x):
20
+ x = x.transpose(1, -1)
21
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
22
+ return x.transpose(1, -1)
23
+
24
+
25
+ @torch.jit.script
26
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
27
+ n_channels_int = n_channels[0]
28
+ in_act = input_a + input_b
29
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
30
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
31
+ acts = t_act * s_act
32
+ return acts
33
+
34
+
35
+ class Encoder(nn.Module):
36
+ def __init__(
37
+ self,
38
+ hidden_channels,
39
+ filter_channels,
40
+ n_heads,
41
+ n_layers,
42
+ kernel_size=1,
43
+ p_dropout=0.0,
44
+ window_size=4,
45
+ isflow=True,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+ self.hidden_channels = hidden_channels
50
+ self.filter_channels = filter_channels
51
+ self.n_heads = n_heads
52
+ self.n_layers = n_layers
53
+ self.kernel_size = kernel_size
54
+ self.p_dropout = p_dropout
55
+ self.window_size = window_size
56
+ # if isflow:
57
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
58
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
59
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
60
+ # self.gin_channels = 256
61
+ self.cond_layer_idx = self.n_layers
62
+ if "gin_channels" in kwargs:
63
+ self.gin_channels = kwargs["gin_channels"]
64
+ if self.gin_channels != 0:
65
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
66
+ # vits2 says 3rd block, so idx is 2 by default
67
+ self.cond_layer_idx = (
68
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
69
+ )
70
+ logging.debug(self.gin_channels, self.cond_layer_idx)
71
+ assert (
72
+ self.cond_layer_idx < self.n_layers
73
+ ), "cond_layer_idx should be less than n_layers"
74
+ self.drop = nn.Dropout(p_dropout)
75
+ self.attn_layers = nn.ModuleList()
76
+ self.norm_layers_1 = nn.ModuleList()
77
+ self.ffn_layers = nn.ModuleList()
78
+ self.norm_layers_2 = nn.ModuleList()
79
+ for i in range(self.n_layers):
80
+ self.attn_layers.append(
81
+ MultiHeadAttention(
82
+ hidden_channels,
83
+ hidden_channels,
84
+ n_heads,
85
+ p_dropout=p_dropout,
86
+ window_size=window_size,
87
+ )
88
+ )
89
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
90
+ self.ffn_layers.append(
91
+ FFN(
92
+ hidden_channels,
93
+ hidden_channels,
94
+ filter_channels,
95
+ kernel_size,
96
+ p_dropout=p_dropout,
97
+ )
98
+ )
99
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
100
+
101
+ def forward(self, x, x_mask, g=None):
102
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
103
+ x = x * x_mask
104
+ for i in range(self.n_layers):
105
+ if i == self.cond_layer_idx and g is not None:
106
+ g = self.spk_emb_linear(g.transpose(1, 2))
107
+ g = g.transpose(1, 2)
108
+ x = x + g
109
+ x = x * x_mask
110
+ y = self.attn_layers[i](x, x, attn_mask)
111
+ y = self.drop(y)
112
+ x = self.norm_layers_1[i](x + y)
113
+
114
+ y = self.ffn_layers[i](x, x_mask)
115
+ y = self.drop(y)
116
+ x = self.norm_layers_2[i](x + y)
117
+ x = x * x_mask
118
+ return x
119
+
120
+
121
+ class MultiHeadAttention(nn.Module):
122
+ def __init__(
123
+ self,
124
+ channels,
125
+ out_channels,
126
+ n_heads,
127
+ p_dropout=0.0,
128
+ window_size=None,
129
+ heads_share=True,
130
+ block_length=None,
131
+ proximal_bias=False,
132
+ proximal_init=False,
133
+ ):
134
+ super().__init__()
135
+ assert channels % n_heads == 0
136
+
137
+ self.channels = channels
138
+ self.out_channels = out_channels
139
+ self.n_heads = n_heads
140
+ self.p_dropout = p_dropout
141
+ self.window_size = window_size
142
+ self.heads_share = heads_share
143
+ self.block_length = block_length
144
+ self.proximal_bias = proximal_bias
145
+ self.proximal_init = proximal_init
146
+ self.attn = None
147
+
148
+ self.k_channels = channels // n_heads
149
+ self.conv_q = nn.Conv1d(channels, channels, 1)
150
+ self.conv_k = nn.Conv1d(channels, channels, 1)
151
+ self.conv_v = nn.Conv1d(channels, channels, 1)
152
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
153
+ self.drop = nn.Dropout(p_dropout)
154
+
155
+ if window_size is not None:
156
+ n_heads_rel = 1 if heads_share else n_heads
157
+ rel_stddev = self.k_channels**-0.5
158
+ self.emb_rel_k = nn.Parameter(
159
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
160
+ * rel_stddev
161
+ )
162
+ self.emb_rel_v = nn.Parameter(
163
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
164
+ * rel_stddev
165
+ )
166
+
167
+ nn.init.xavier_uniform_(self.conv_q.weight)
168
+ nn.init.xavier_uniform_(self.conv_k.weight)
169
+ nn.init.xavier_uniform_(self.conv_v.weight)
170
+ if proximal_init:
171
+ with torch.no_grad():
172
+ self.conv_k.weight.copy_(self.conv_q.weight)
173
+ self.conv_k.bias.copy_(self.conv_q.bias)
174
+
175
+ def forward(self, x, c, attn_mask=None):
176
+ q = self.conv_q(x)
177
+ k = self.conv_k(c)
178
+ v = self.conv_v(c)
179
+
180
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
181
+
182
+ x = self.conv_o(x)
183
+ return x
184
+
185
+ def attention(self, query, key, value, mask=None):
186
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
187
+ b, d, t_s, _ = (*key.size(), query.size(2))
188
+ query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
189
+ key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
190
+ value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
191
+
192
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
193
+ if self.window_size is not None:
194
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
195
+ rel_logits = self._matmul_with_relative_keys(
196
+ query / math.sqrt(self.k_channels), key_relative_embeddings
197
+ )
198
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
199
+ scores = scores + scores_local
200
+ if mask is not None:
201
+ scores = scores.masked_fill(mask == 0, -1e4)
202
+ if self.block_length is not None:
203
+ block_mask = (
204
+ torch.ones_like(scores)
205
+ .triu(-self.block_length)
206
+ .tril(self.block_length)
207
+ )
208
+ scores = scores.masked_fill(block_mask == 0, -1e4)
209
+ p_attn = F.softmax(scores, dim=-1)
210
+ p_attn = self.drop(p_attn)
211
+ output = torch.matmul(p_attn, value)
212
+ if self.window_size is not None:
213
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
214
+ value_relative_embeddings = self._get_relative_embeddings(
215
+ self.emb_rel_v, t_s
216
+ )
217
+ output = output + self._matmul_with_relative_values(
218
+ relative_weights, value_relative_embeddings
219
+ )
220
+ output = (
221
+ output.transpose(2, 3).contiguous().view(b, d, -1)
222
+ )
223
+ return output, p_attn
224
+
225
+ def _matmul_with_relative_values(self, x, y):
226
+ """
227
+ x: [b, h, l, m]
228
+ y: [h or 1, m, d]
229
+ ret: [b, h, l, d]
230
+ """
231
+ ret = torch.matmul(x, y.unsqueeze(0))
232
+ return ret
233
+
234
+ def _matmul_with_relative_keys(self, x, y):
235
+ """
236
+ x: [b, h, l, d]
237
+ y: [h or 1, m, d]
238
+ ret: [b, h, l, m]
239
+ """
240
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
241
+ return ret
242
+
243
+ def _get_relative_embeddings(self, relative_embeddings, length):
244
+ max_relative_position = 2 * self.window_size + 1
245
+ # Pad first before slice to avoid using cond ops.
246
+ pad_length = max(length - (self.window_size + 1), 0)
247
+ slice_start_position = max((self.window_size + 1) - length, 0)
248
+ slice_end_position = slice_start_position + 2 * length - 1
249
+ if pad_length > 0:
250
+ padded_relative_embeddings = F.pad(
251
+ relative_embeddings,
252
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
253
+ )
254
+ else:
255
+ padded_relative_embeddings = relative_embeddings
256
+ used_relative_embeddings = padded_relative_embeddings[
257
+ :, slice_start_position:slice_end_position
258
+ ]
259
+ return used_relative_embeddings
260
+
261
+ def _relative_position_to_absolute_position(self, x):
262
+ """
263
+ x: [b, h, l, 2*l-1]
264
+ ret: [b, h, l, l]
265
+ """
266
+ batch, heads, length, _ = x.size()
267
+ # Concat columns of pad to shift from relative to absolute indexing.
268
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
269
+
270
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
271
+ x_flat = x.view([batch, heads, length * 2 * length])
272
+ x_flat = F.pad(
273
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
274
+ )
275
+
276
+ # Reshape and slice out the padded elements.
277
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
278
+ :, :, :length, length - 1 :
279
+ ]
280
+ return x_final
281
+
282
+ def _absolute_position_to_relative_position(self, x):
283
+ """
284
+ x: [b, h, l, l]
285
+ ret: [b, h, l, 2*l-1]
286
+ """
287
+ batch, heads, length, _ = x.size()
288
+ # padd along column
289
+ x = F.pad(
290
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
291
+ )
292
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
293
+ # add 0's in the beginning that will skew the elements after reshape
294
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
295
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
296
+ return x_final
297
+
298
+ def _attention_bias_proximal(self, length):
299
+ """Bias for self-attention to encourage attention to close positions.
300
+ Args:
301
+ length: an integer scalar.
302
+ Returns:
303
+ a Tensor with shape [1, 1, length, length]
304
+ """
305
+ r = torch.arange(length, dtype=torch.float32)
306
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
307
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
308
+
309
+
310
+ class FFN(nn.Module):
311
+ def __init__(
312
+ self,
313
+ in_channels,
314
+ out_channels,
315
+ filter_channels,
316
+ kernel_size,
317
+ p_dropout=0.0,
318
+ activation=None,
319
+ causal=False,
320
+ ):
321
+ super().__init__()
322
+ self.in_channels = in_channels
323
+ self.out_channels = out_channels
324
+ self.filter_channels = filter_channels
325
+ self.kernel_size = kernel_size
326
+ self.p_dropout = p_dropout
327
+ self.activation = activation
328
+ self.causal = causal
329
+
330
+ if causal:
331
+ self.padding = self._causal_padding
332
+ else:
333
+ self.padding = self._same_padding
334
+
335
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
336
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
337
+ self.drop = nn.Dropout(p_dropout)
338
+
339
+ def forward(self, x, x_mask):
340
+ x = self.conv_1(self.padding(x * x_mask))
341
+ if self.activation == "gelu":
342
+ x = x * torch.sigmoid(1.702 * x)
343
+ else:
344
+ x = torch.relu(x)
345
+ x = self.drop(x)
346
+ x = self.conv_2(self.padding(x * x_mask))
347
+ return x * x_mask
348
+
349
+ def _causal_padding(self, x):
350
+ if self.kernel_size == 1:
351
+ return x
352
+ pad_l = self.kernel_size - 1
353
+ pad_r = 0
354
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
355
+ x = F.pad(x, commons.convert_pad_shape(padding))
356
+ return x
357
+
358
+ def _same_padding(self, x):
359
+ if self.kernel_size == 1:
360
+ return x
361
+ pad_l = (self.kernel_size - 1) // 2
362
+ pad_r = self.kernel_size // 2
363
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
364
+ x = F.pad(x, commons.convert_pad_shape(padding))
365
+ return x
module/models_onnx.py ADDED
@@ -0,0 +1,920 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from module import commons
8
+ from module import modules
9
+ from module import attentions_onnx as attentions
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from module.commons import init_weights, get_padding
14
+ from module.mrte_model import MRTE
15
+ from module.quantize import ResidualVectorQuantizer
16
+ from text import symbols
17
+ from torch.cuda.amp import autocast
18
+
19
+
20
+ class StochasticDurationPredictor(nn.Module):
21
+ def __init__(
22
+ self,
23
+ in_channels,
24
+ filter_channels,
25
+ kernel_size,
26
+ p_dropout,
27
+ n_flows=4,
28
+ gin_channels=0,
29
+ ):
30
+ super().__init__()
31
+ filter_channels = in_channels # it needs to be removed from future version.
32
+ self.in_channels = in_channels
33
+ self.filter_channels = filter_channels
34
+ self.kernel_size = kernel_size
35
+ self.p_dropout = p_dropout
36
+ self.n_flows = n_flows
37
+ self.gin_channels = gin_channels
38
+
39
+ self.log_flow = modules.Log()
40
+ self.flows = nn.ModuleList()
41
+ self.flows.append(modules.ElementwiseAffine(2))
42
+ for i in range(n_flows):
43
+ self.flows.append(
44
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
45
+ )
46
+ self.flows.append(modules.Flip())
47
+
48
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
49
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
50
+ self.post_convs = modules.DDSConv(
51
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
52
+ )
53
+ self.post_flows = nn.ModuleList()
54
+ self.post_flows.append(modules.ElementwiseAffine(2))
55
+ for i in range(4):
56
+ self.post_flows.append(
57
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
58
+ )
59
+ self.post_flows.append(modules.Flip())
60
+
61
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
62
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
63
+ self.convs = modules.DDSConv(
64
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
65
+ )
66
+ if gin_channels != 0:
67
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
68
+
69
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
70
+ x = torch.detach(x)
71
+ x = self.pre(x)
72
+ if g is not None:
73
+ g = torch.detach(g)
74
+ x = x + self.cond(g)
75
+ x = self.convs(x, x_mask)
76
+ x = self.proj(x) * x_mask
77
+
78
+ if not reverse:
79
+ flows = self.flows
80
+ assert w is not None
81
+
82
+ logdet_tot_q = 0
83
+ h_w = self.post_pre(w)
84
+ h_w = self.post_convs(h_w, x_mask)
85
+ h_w = self.post_proj(h_w) * x_mask
86
+ e_q = (
87
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
88
+ * x_mask
89
+ )
90
+ z_q = e_q
91
+ for flow in self.post_flows:
92
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
93
+ logdet_tot_q += logdet_q
94
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
95
+ u = torch.sigmoid(z_u) * x_mask
96
+ z0 = (w - u) * x_mask
97
+ logdet_tot_q += torch.sum(
98
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
99
+ )
100
+ logq = (
101
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
102
+ - logdet_tot_q
103
+ )
104
+
105
+ logdet_tot = 0
106
+ z0, logdet = self.log_flow(z0, x_mask)
107
+ logdet_tot += logdet
108
+ z = torch.cat([z0, z1], 1)
109
+ for flow in flows:
110
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
111
+ logdet_tot = logdet_tot + logdet
112
+ nll = (
113
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
114
+ - logdet_tot
115
+ )
116
+ return nll + logq # [b]
117
+ else:
118
+ flows = list(reversed(self.flows))
119
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
120
+ z = (
121
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
122
+ * noise_scale
123
+ )
124
+ for flow in flows:
125
+ z = flow(z, x_mask, g=x, reverse=reverse)
126
+ z0, z1 = torch.split(z, [1, 1], 1)
127
+ logw = z0
128
+ return logw
129
+
130
+
131
+ class DurationPredictor(nn.Module):
132
+ def __init__(
133
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
134
+ ):
135
+ super().__init__()
136
+
137
+ self.in_channels = in_channels
138
+ self.filter_channels = filter_channels
139
+ self.kernel_size = kernel_size
140
+ self.p_dropout = p_dropout
141
+ self.gin_channels = gin_channels
142
+
143
+ self.drop = nn.Dropout(p_dropout)
144
+ self.conv_1 = nn.Conv1d(
145
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
146
+ )
147
+ self.norm_1 = modules.LayerNorm(filter_channels)
148
+ self.conv_2 = nn.Conv1d(
149
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
150
+ )
151
+ self.norm_2 = modules.LayerNorm(filter_channels)
152
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
153
+
154
+ if gin_channels != 0:
155
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
156
+
157
+ def forward(self, x, x_mask, g=None):
158
+ x = torch.detach(x)
159
+ if g is not None:
160
+ g = torch.detach(g)
161
+ x = x + self.cond(g)
162
+ x = self.conv_1(x * x_mask)
163
+ x = torch.relu(x)
164
+ x = self.norm_1(x)
165
+ x = self.drop(x)
166
+ x = self.conv_2(x * x_mask)
167
+ x = torch.relu(x)
168
+ x = self.norm_2(x)
169
+ x = self.drop(x)
170
+ x = self.proj(x * x_mask)
171
+ return x * x_mask
172
+
173
+
174
+ class TextEncoder(nn.Module):
175
+ def __init__(
176
+ self,
177
+ out_channels,
178
+ hidden_channels,
179
+ filter_channels,
180
+ n_heads,
181
+ n_layers,
182
+ kernel_size,
183
+ p_dropout,
184
+ latent_channels=192,
185
+ ):
186
+ super().__init__()
187
+ self.out_channels = out_channels
188
+ self.hidden_channels = hidden_channels
189
+ self.filter_channels = filter_channels
190
+ self.n_heads = n_heads
191
+ self.n_layers = n_layers
192
+ self.kernel_size = kernel_size
193
+ self.p_dropout = p_dropout
194
+ self.latent_channels = latent_channels
195
+
196
+ self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
197
+
198
+ self.encoder_ssl = attentions.Encoder(
199
+ hidden_channels,
200
+ filter_channels,
201
+ n_heads,
202
+ n_layers // 2,
203
+ kernel_size,
204
+ p_dropout,
205
+ )
206
+
207
+ self.encoder_text = attentions.Encoder(
208
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
209
+ )
210
+ self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
211
+
212
+ self.mrte = MRTE()
213
+
214
+ self.encoder2 = attentions.Encoder(
215
+ hidden_channels,
216
+ filter_channels,
217
+ n_heads,
218
+ n_layers // 2,
219
+ kernel_size,
220
+ p_dropout,
221
+ )
222
+
223
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
224
+
225
+ def forward(self, y, text, ge):
226
+ y_mask = torch.ones_like(y[:1,:1,:])
227
+
228
+ y = self.ssl_proj(y * y_mask) * y_mask
229
+ y = self.encoder_ssl(y * y_mask, y_mask)
230
+
231
+ text_mask = torch.ones_like(text).to(y.dtype).unsqueeze(0)
232
+
233
+ text = self.text_embedding(text).transpose(1, 2)
234
+ text = self.encoder_text(text * text_mask, text_mask)
235
+ y = self.mrte(y, y_mask, text, text_mask, ge)
236
+
237
+ y = self.encoder2(y * y_mask, y_mask)
238
+
239
+ stats = self.proj(y) * y_mask
240
+ m, logs = torch.split(stats, self.out_channels, dim=1)
241
+ return y, m, logs, y_mask
242
+
243
+ def extract_latent(self, x):
244
+ x = self.ssl_proj(x)
245
+ quantized, codes, commit_loss, quantized_list = self.quantizer(x)
246
+ return codes.transpose(0, 1)
247
+
248
+ def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
249
+ quantized = self.quantizer.decode(codes)
250
+
251
+ y = self.vq_proj(quantized) * y_mask
252
+ y = self.encoder_ssl(y * y_mask, y_mask)
253
+
254
+ y = self.mrte(y, y_mask, refer, refer_mask, ge)
255
+
256
+ y = self.encoder2(y * y_mask, y_mask)
257
+
258
+ stats = self.proj(y) * y_mask
259
+ m, logs = torch.split(stats, self.out_channels, dim=1)
260
+ return y, m, logs, y_mask, quantized
261
+
262
+
263
+ class ResidualCouplingBlock(nn.Module):
264
+ def __init__(
265
+ self,
266
+ channels,
267
+ hidden_channels,
268
+ kernel_size,
269
+ dilation_rate,
270
+ n_layers,
271
+ n_flows=4,
272
+ gin_channels=0,
273
+ ):
274
+ super().__init__()
275
+ self.channels = channels
276
+ self.hidden_channels = hidden_channels
277
+ self.kernel_size = kernel_size
278
+ self.dilation_rate = dilation_rate
279
+ self.n_layers = n_layers
280
+ self.n_flows = n_flows
281
+ self.gin_channels = gin_channels
282
+
283
+ self.flows = nn.ModuleList()
284
+ for i in range(n_flows):
285
+ self.flows.append(
286
+ modules.ResidualCouplingLayer(
287
+ channels,
288
+ hidden_channels,
289
+ kernel_size,
290
+ dilation_rate,
291
+ n_layers,
292
+ gin_channels=gin_channels,
293
+ mean_only=True,
294
+ )
295
+ )
296
+ self.flows.append(modules.Flip())
297
+
298
+ def forward(self, x, x_mask, g=None, reverse=False):
299
+ if not reverse:
300
+ for flow in self.flows:
301
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
302
+ else:
303
+ for flow in reversed(self.flows):
304
+ x = flow(x, x_mask, g=g, reverse=reverse)
305
+ return x
306
+
307
+
308
+ class PosteriorEncoder(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channels,
312
+ out_channels,
313
+ hidden_channels,
314
+ kernel_size,
315
+ dilation_rate,
316
+ n_layers,
317
+ gin_channels=0,
318
+ ):
319
+ super().__init__()
320
+ self.in_channels = in_channels
321
+ self.out_channels = out_channels
322
+ self.hidden_channels = hidden_channels
323
+ self.kernel_size = kernel_size
324
+ self.dilation_rate = dilation_rate
325
+ self.n_layers = n_layers
326
+ self.gin_channels = gin_channels
327
+
328
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
329
+ self.enc = modules.WN(
330
+ hidden_channels,
331
+ kernel_size,
332
+ dilation_rate,
333
+ n_layers,
334
+ gin_channels=gin_channels,
335
+ )
336
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
337
+
338
+ def forward(self, x, x_lengths, g=None):
339
+ if g != None:
340
+ g = g.detach()
341
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
342
+ x.dtype
343
+ )
344
+ x = self.pre(x) * x_mask
345
+ x = self.enc(x, x_mask, g=g)
346
+ stats = self.proj(x) * x_mask
347
+ m, logs = torch.split(stats, self.out_channels, dim=1)
348
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
349
+ return z, m, logs, x_mask
350
+
351
+
352
+ class WNEncoder(nn.Module):
353
+ def __init__(
354
+ self,
355
+ in_channels,
356
+ out_channels,
357
+ hidden_channels,
358
+ kernel_size,
359
+ dilation_rate,
360
+ n_layers,
361
+ gin_channels=0,
362
+ ):
363
+ super().__init__()
364
+ self.in_channels = in_channels
365
+ self.out_channels = out_channels
366
+ self.hidden_channels = hidden_channels
367
+ self.kernel_size = kernel_size
368
+ self.dilation_rate = dilation_rate
369
+ self.n_layers = n_layers
370
+ self.gin_channels = gin_channels
371
+
372
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
373
+ self.enc = modules.WN(
374
+ hidden_channels,
375
+ kernel_size,
376
+ dilation_rate,
377
+ n_layers,
378
+ gin_channels=gin_channels,
379
+ )
380
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
381
+ self.norm = modules.LayerNorm(out_channels)
382
+
383
+ def forward(self, x, x_lengths, g=None):
384
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
385
+ x.dtype
386
+ )
387
+ x = self.pre(x) * x_mask
388
+ x = self.enc(x, x_mask, g=g)
389
+ out = self.proj(x) * x_mask
390
+ out = self.norm(out)
391
+ return out
392
+
393
+
394
+ class Generator(torch.nn.Module):
395
+ def __init__(
396
+ self,
397
+ initial_channel,
398
+ resblock,
399
+ resblock_kernel_sizes,
400
+ resblock_dilation_sizes,
401
+ upsample_rates,
402
+ upsample_initial_channel,
403
+ upsample_kernel_sizes,
404
+ gin_channels=0,
405
+ ):
406
+ super(Generator, self).__init__()
407
+ self.num_kernels = len(resblock_kernel_sizes)
408
+ self.num_upsamples = len(upsample_rates)
409
+ self.conv_pre = Conv1d(
410
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
411
+ )
412
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
413
+
414
+ self.ups = nn.ModuleList()
415
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
416
+ self.ups.append(
417
+ weight_norm(
418
+ ConvTranspose1d(
419
+ upsample_initial_channel // (2**i),
420
+ upsample_initial_channel // (2 ** (i + 1)),
421
+ k,
422
+ u,
423
+ padding=(k - u) // 2,
424
+ )
425
+ )
426
+ )
427
+
428
+ self.resblocks = nn.ModuleList()
429
+ for i in range(len(self.ups)):
430
+ ch = upsample_initial_channel // (2 ** (i + 1))
431
+ for j, (k, d) in enumerate(
432
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
433
+ ):
434
+ self.resblocks.append(resblock(ch, k, d))
435
+
436
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
437
+ self.ups.apply(init_weights)
438
+
439
+ if gin_channels != 0:
440
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
441
+
442
+ def forward(self, x, g=None):
443
+ x = self.conv_pre(x)
444
+ if g is not None:
445
+ x = x + self.cond(g)
446
+
447
+ for i in range(self.num_upsamples):
448
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
449
+ x = self.ups[i](x)
450
+ xs = None
451
+ for j in range(self.num_kernels):
452
+ if xs is None:
453
+ xs = self.resblocks[i * self.num_kernels + j](x)
454
+ else:
455
+ xs += self.resblocks[i * self.num_kernels + j](x)
456
+ x = xs / self.num_kernels
457
+ x = F.leaky_relu(x)
458
+ x = self.conv_post(x)
459
+ x = torch.tanh(x)
460
+
461
+ return x
462
+
463
+ def remove_weight_norm(self):
464
+ print("Removing weight norm...")
465
+ for l in self.ups:
466
+ remove_weight_norm(l)
467
+ for l in self.resblocks:
468
+ l.remove_weight_norm()
469
+
470
+
471
+ class DiscriminatorP(torch.nn.Module):
472
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
473
+ super(DiscriminatorP, self).__init__()
474
+ self.period = period
475
+ self.use_spectral_norm = use_spectral_norm
476
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
477
+ self.convs = nn.ModuleList(
478
+ [
479
+ norm_f(
480
+ Conv2d(
481
+ 1,
482
+ 32,
483
+ (kernel_size, 1),
484
+ (stride, 1),
485
+ padding=(get_padding(kernel_size, 1), 0),
486
+ )
487
+ ),
488
+ norm_f(
489
+ Conv2d(
490
+ 32,
491
+ 128,
492
+ (kernel_size, 1),
493
+ (stride, 1),
494
+ padding=(get_padding(kernel_size, 1), 0),
495
+ )
496
+ ),
497
+ norm_f(
498
+ Conv2d(
499
+ 128,
500
+ 512,
501
+ (kernel_size, 1),
502
+ (stride, 1),
503
+ padding=(get_padding(kernel_size, 1), 0),
504
+ )
505
+ ),
506
+ norm_f(
507
+ Conv2d(
508
+ 512,
509
+ 1024,
510
+ (kernel_size, 1),
511
+ (stride, 1),
512
+ padding=(get_padding(kernel_size, 1), 0),
513
+ )
514
+ ),
515
+ norm_f(
516
+ Conv2d(
517
+ 1024,
518
+ 1024,
519
+ (kernel_size, 1),
520
+ 1,
521
+ padding=(get_padding(kernel_size, 1), 0),
522
+ )
523
+ ),
524
+ ]
525
+ )
526
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
527
+
528
+ def forward(self, x):
529
+ fmap = []
530
+
531
+ # 1d to 2d
532
+ b, c, t = x.shape
533
+ if t % self.period != 0: # pad first
534
+ n_pad = self.period - (t % self.period)
535
+ x = F.pad(x, (0, n_pad), "reflect")
536
+ t = t + n_pad
537
+ x = x.view(b, c, t // self.period, self.period)
538
+
539
+ for l in self.convs:
540
+ x = l(x)
541
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
542
+ fmap.append(x)
543
+ x = self.conv_post(x)
544
+ fmap.append(x)
545
+ x = torch.flatten(x, 1, -1)
546
+
547
+ return x, fmap
548
+
549
+
550
+ class DiscriminatorS(torch.nn.Module):
551
+ def __init__(self, use_spectral_norm=False):
552
+ super(DiscriminatorS, self).__init__()
553
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
554
+ self.convs = nn.ModuleList(
555
+ [
556
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
557
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
558
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
559
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
560
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
561
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
562
+ ]
563
+ )
564
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
565
+
566
+ def forward(self, x):
567
+ fmap = []
568
+
569
+ for l in self.convs:
570
+ x = l(x)
571
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
572
+ fmap.append(x)
573
+ x = self.conv_post(x)
574
+ fmap.append(x)
575
+ x = torch.flatten(x, 1, -1)
576
+
577
+ return x, fmap
578
+
579
+
580
+ class MultiPeriodDiscriminator(torch.nn.Module):
581
+ def __init__(self, use_spectral_norm=False):
582
+ super(MultiPeriodDiscriminator, self).__init__()
583
+ periods = [2, 3, 5, 7, 11]
584
+
585
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
586
+ discs = discs + [
587
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
588
+ ]
589
+ self.discriminators = nn.ModuleList(discs)
590
+
591
+ def forward(self, y, y_hat):
592
+ y_d_rs = []
593
+ y_d_gs = []
594
+ fmap_rs = []
595
+ fmap_gs = []
596
+ for i, d in enumerate(self.discriminators):
597
+ y_d_r, fmap_r = d(y)
598
+ y_d_g, fmap_g = d(y_hat)
599
+ y_d_rs.append(y_d_r)
600
+ y_d_gs.append(y_d_g)
601
+ fmap_rs.append(fmap_r)
602
+ fmap_gs.append(fmap_g)
603
+
604
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
605
+
606
+
607
+ class ReferenceEncoder(nn.Module):
608
+ """
609
+ inputs --- [N, Ty/r, n_mels*r] mels
610
+ outputs --- [N, ref_enc_gru_size]
611
+ """
612
+
613
+ def __init__(self, spec_channels, gin_channels=0):
614
+ super().__init__()
615
+ self.spec_channels = spec_channels
616
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
617
+ K = len(ref_enc_filters)
618
+ filters = [1] + ref_enc_filters
619
+ convs = [
620
+ weight_norm(
621
+ nn.Conv2d(
622
+ in_channels=filters[i],
623
+ out_channels=filters[i + 1],
624
+ kernel_size=(3, 3),
625
+ stride=(2, 2),
626
+ padding=(1, 1),
627
+ )
628
+ )
629
+ for i in range(K)
630
+ ]
631
+ self.convs = nn.ModuleList(convs)
632
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
633
+
634
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
635
+ self.gru = nn.GRU(
636
+ input_size=ref_enc_filters[-1] * out_channels,
637
+ hidden_size=256 // 2,
638
+ batch_first=True,
639
+ )
640
+ self.proj = nn.Linear(128, gin_channels)
641
+
642
+ def forward(self, inputs):
643
+ N = inputs.size(0)
644
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
645
+ for conv in self.convs:
646
+ out = conv(out)
647
+ # out = wn(out)
648
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
649
+
650
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
651
+ T = out.size(1)
652
+ N = out.size(0)
653
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
654
+
655
+ self.gru.flatten_parameters()
656
+ memory, out = self.gru(out) # out --- [1, N, 128]
657
+
658
+ return self.proj(out.squeeze(0)).unsqueeze(-1)
659
+
660
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
661
+ for i in range(n_convs):
662
+ L = (L - kernel_size + 2 * pad) // stride + 1
663
+ return L
664
+
665
+
666
+ class Quantizer_module(torch.nn.Module):
667
+ def __init__(self, n_e, e_dim):
668
+ super(Quantizer_module, self).__init__()
669
+ self.embedding = nn.Embedding(n_e, e_dim)
670
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
671
+
672
+ def forward(self, x):
673
+ d = (
674
+ torch.sum(x**2, 1, keepdim=True)
675
+ + torch.sum(self.embedding.weight**2, 1)
676
+ - 2 * torch.matmul(x, self.embedding.weight.T)
677
+ )
678
+ min_indicies = torch.argmin(d, 1)
679
+ z_q = self.embedding(min_indicies)
680
+ return z_q, min_indicies
681
+
682
+
683
+ class Quantizer(torch.nn.Module):
684
+ def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
685
+ super(Quantizer, self).__init__()
686
+ assert embed_dim % n_code_groups == 0
687
+ self.quantizer_modules = nn.ModuleList(
688
+ [
689
+ Quantizer_module(n_codes, embed_dim // n_code_groups)
690
+ for _ in range(n_code_groups)
691
+ ]
692
+ )
693
+ self.n_code_groups = n_code_groups
694
+ self.embed_dim = embed_dim
695
+
696
+ def forward(self, xin):
697
+ # B, C, T
698
+ B, C, T = xin.shape
699
+ xin = xin.transpose(1, 2)
700
+ x = xin.reshape(-1, self.embed_dim)
701
+ x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
702
+ min_indicies = []
703
+ z_q = []
704
+ for _x, m in zip(x, self.quantizer_modules):
705
+ _z_q, _min_indicies = m(_x)
706
+ z_q.append(_z_q)
707
+ min_indicies.append(_min_indicies) # B * T,
708
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
709
+ loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
710
+ (z_q - xin.detach()) ** 2
711
+ )
712
+ z_q = xin + (z_q - xin).detach()
713
+ z_q = z_q.transpose(1, 2)
714
+ codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
715
+ return z_q, loss, codes.transpose(1, 2)
716
+
717
+ def embed(self, x):
718
+ # idx: N, 4, T
719
+ x = x.transpose(1, 2)
720
+ x = torch.split(x, 1, 2)
721
+ ret = []
722
+ for q, embed in zip(x, self.quantizer_modules):
723
+ q = embed.embedding(q.squeeze(-1))
724
+ ret.append(q)
725
+ ret = torch.cat(ret, -1)
726
+ return ret.transpose(1, 2) # N, C, T
727
+
728
+
729
+ class CodePredictor(nn.Module):
730
+ def __init__(
731
+ self,
732
+ hidden_channels,
733
+ filter_channels,
734
+ n_heads,
735
+ n_layers,
736
+ kernel_size,
737
+ p_dropout,
738
+ n_q=8,
739
+ dims=1024,
740
+ ssl_dim=768,
741
+ ):
742
+ super().__init__()
743
+ self.hidden_channels = hidden_channels
744
+ self.filter_channels = filter_channels
745
+ self.n_heads = n_heads
746
+ self.n_layers = n_layers
747
+ self.kernel_size = kernel_size
748
+ self.p_dropout = p_dropout
749
+
750
+ self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
751
+ self.ref_enc = modules.MelStyleEncoder(
752
+ ssl_dim, style_vector_dim=hidden_channels
753
+ )
754
+
755
+ self.encoder = attentions.Encoder(
756
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
757
+ )
758
+
759
+ self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
760
+ self.n_q = n_q
761
+ self.dims = dims
762
+
763
+ def forward(self, x, x_mask, refer, codes, infer=False):
764
+ x = x.detach()
765
+ x = self.vq_proj(x * x_mask) * x_mask
766
+ g = self.ref_enc(refer, x_mask)
767
+ x = x + g
768
+ x = self.encoder(x * x_mask, x_mask)
769
+ x = self.out_proj(x * x_mask) * x_mask
770
+ logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
771
+ 2, 3
772
+ )
773
+ target = codes[1:].transpose(0, 1)
774
+ if not infer:
775
+ logits = logits.reshape(-1, self.dims)
776
+ target = target.reshape(-1)
777
+ loss = torch.nn.functional.cross_entropy(logits, target)
778
+ return loss
779
+ else:
780
+ _, top10_preds = torch.topk(logits, 10, dim=-1)
781
+ correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
782
+ top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
783
+
784
+ print("Top-10 Accuracy:", top3_acc, "%")
785
+
786
+ pred_codes = torch.argmax(logits, dim=-1)
787
+ acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
788
+ print("Top-1 Accuracy:", acc, "%")
789
+
790
+ return pred_codes.transpose(0, 1)
791
+
792
+
793
+ class SynthesizerTrn(nn.Module):
794
+ """
795
+ Synthesizer for Training
796
+ """
797
+
798
+ def __init__(
799
+ self,
800
+ spec_channels,
801
+ segment_size,
802
+ inter_channels,
803
+ hidden_channels,
804
+ filter_channels,
805
+ n_heads,
806
+ n_layers,
807
+ kernel_size,
808
+ p_dropout,
809
+ resblock,
810
+ resblock_kernel_sizes,
811
+ resblock_dilation_sizes,
812
+ upsample_rates,
813
+ upsample_initial_channel,
814
+ upsample_kernel_sizes,
815
+ n_speakers=0,
816
+ gin_channels=0,
817
+ use_sdp=True,
818
+ semantic_frame_rate=None,
819
+ freeze_quantizer=None,
820
+ **kwargs
821
+ ):
822
+ super().__init__()
823
+ self.spec_channels = spec_channels
824
+ self.inter_channels = inter_channels
825
+ self.hidden_channels = hidden_channels
826
+ self.filter_channels = filter_channels
827
+ self.n_heads = n_heads
828
+ self.n_layers = n_layers
829
+ self.kernel_size = kernel_size
830
+ self.p_dropout = p_dropout
831
+ self.resblock = resblock
832
+ self.resblock_kernel_sizes = resblock_kernel_sizes
833
+ self.resblock_dilation_sizes = resblock_dilation_sizes
834
+ self.upsample_rates = upsample_rates
835
+ self.upsample_initial_channel = upsample_initial_channel
836
+ self.upsample_kernel_sizes = upsample_kernel_sizes
837
+ self.segment_size = segment_size
838
+ self.n_speakers = n_speakers
839
+ self.gin_channels = gin_channels
840
+
841
+ self.use_sdp = use_sdp
842
+ self.enc_p = TextEncoder(
843
+ inter_channels,
844
+ hidden_channels,
845
+ filter_channels,
846
+ n_heads,
847
+ n_layers,
848
+ kernel_size,
849
+ p_dropout,
850
+ )
851
+ self.dec = Generator(
852
+ inter_channels,
853
+ resblock,
854
+ resblock_kernel_sizes,
855
+ resblock_dilation_sizes,
856
+ upsample_rates,
857
+ upsample_initial_channel,
858
+ upsample_kernel_sizes,
859
+ gin_channels=gin_channels,
860
+ )
861
+ self.enc_q = PosteriorEncoder(
862
+ spec_channels,
863
+ inter_channels,
864
+ hidden_channels,
865
+ 5,
866
+ 1,
867
+ 16,
868
+ gin_channels=gin_channels,
869
+ )
870
+ self.flow = ResidualCouplingBlock(
871
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
872
+ )
873
+
874
+ self.ref_enc = modules.MelStyleEncoder(
875
+ spec_channels, style_vector_dim=gin_channels
876
+ )
877
+
878
+ ssl_dim = 768
879
+ self.ssl_dim = ssl_dim
880
+ assert semantic_frame_rate in ["25hz", "50hz"]
881
+ self.semantic_frame_rate = semantic_frame_rate
882
+ if semantic_frame_rate == "25hz":
883
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
884
+ else:
885
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
886
+
887
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
888
+ if freeze_quantizer:
889
+ self.ssl_proj.requires_grad_(False)
890
+ self.quantizer.requires_grad_(False)
891
+ # self.enc_p.text_embedding.requires_grad_(False)
892
+ # self.enc_p.encoder_text.requires_grad_(False)
893
+ # self.enc_p.mrte.requires_grad_(False)
894
+
895
+ def forward(self, codes, text, refer):
896
+ refer_mask = torch.ones_like(refer[:1,:1,:])
897
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
898
+
899
+ y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
900
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
901
+
902
+ quantized = self.quantizer.decode(codes)
903
+ if self.semantic_frame_rate == "25hz":
904
+ dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
905
+ quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
906
+
907
+ x, m_p, logs_p, y_mask = self.enc_p(
908
+ quantized, text, ge
909
+ )
910
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
911
+
912
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
913
+
914
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
915
+ return o
916
+
917
+ def extract_latent(self, x):
918
+ ssl = self.ssl_proj(x)
919
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
920
+ return codes.transpose(0, 1)
onnx_export.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from module.models_onnx import SynthesizerTrn, symbols
2
+ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+ from feature_extractor import cnhubert
7
+ cnhubert_base_path = "pretrained_models/chinese-hubert-base"
8
+ cnhubert.cnhubert_base_path=cnhubert_base_path
9
+ ssl_model = cnhubert.get_model()
10
+ from text import cleaned_text_to_sequence
11
+ import soundfile
12
+ from my_utils import load_audio
13
+ import os
14
+ import json
15
+
16
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
17
+ hann_window = torch.hann_window(win_size).to(
18
+ dtype=y.dtype, device=y.device
19
+ )
20
+ y = torch.nn.functional.pad(
21
+ y.unsqueeze(1),
22
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
23
+ mode="reflect",
24
+ )
25
+ y = y.squeeze(1)
26
+ spec = torch.stft(
27
+ y,
28
+ n_fft,
29
+ hop_length=hop_size,
30
+ win_length=win_size,
31
+ window=hann_window,
32
+ center=center,
33
+ pad_mode="reflect",
34
+ normalized=False,
35
+ onesided=True,
36
+ return_complex=False,
37
+ )
38
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
39
+ return spec
40
+
41
+
42
+ class DictToAttrRecursive(dict):
43
+ def __init__(self, input_dict):
44
+ super().__init__(input_dict)
45
+ for key, value in input_dict.items():
46
+ if isinstance(value, dict):
47
+ value = DictToAttrRecursive(value)
48
+ self[key] = value
49
+ setattr(self, key, value)
50
+
51
+ def __getattr__(self, item):
52
+ try:
53
+ return self[item]
54
+ except KeyError:
55
+ raise AttributeError(f"Attribute {item} not found")
56
+
57
+ def __setattr__(self, key, value):
58
+ if isinstance(value, dict):
59
+ value = DictToAttrRecursive(value)
60
+ super(DictToAttrRecursive, self).__setitem__(key, value)
61
+ super().__setattr__(key, value)
62
+
63
+ def __delattr__(self, item):
64
+ try:
65
+ del self[item]
66
+ except KeyError:
67
+ raise AttributeError(f"Attribute {item} not found")
68
+
69
+
70
+ class T2SEncoder(nn.Module):
71
+ def __init__(self, t2s, vits):
72
+ super().__init__()
73
+ self.encoder = t2s.onnx_encoder
74
+ self.vits = vits
75
+
76
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
77
+ codes = self.vits.extract_latent(ssl_content)
78
+ prompt_semantic = codes[0, 0]
79
+ bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
80
+ all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
81
+ bert = bert.unsqueeze(0)
82
+ prompt = prompt_semantic.unsqueeze(0)
83
+ return self.encoder(all_phoneme_ids, bert), prompt
84
+
85
+
86
+ class T2SModel(nn.Module):
87
+ def __init__(self, t2s_path, vits_model):
88
+ super().__init__()
89
+ dict_s1 = torch.load(t2s_path, map_location="cpu")
90
+ self.config = dict_s1["config"]
91
+ self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
92
+ self.t2s_model.load_state_dict(dict_s1["weight"])
93
+ self.t2s_model.eval()
94
+ self.vits_model = vits_model.vq_model
95
+ self.hz = 50
96
+ self.max_sec = self.config["data"]["max_sec"]
97
+ self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
98
+ self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
99
+ self.t2s_model = self.t2s_model.model
100
+ self.t2s_model.init_onnx()
101
+ self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
102
+ self.first_stage_decoder = self.t2s_model.first_stage_decoder
103
+ self.stage_decoder = self.t2s_model.stage_decoder
104
+ #self.t2s_model = torch.jit.script(self.t2s_model)
105
+
106
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
107
+ early_stop_num = self.t2s_model.early_stop_num
108
+
109
+ #[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
110
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
111
+
112
+ prefix_len = prompts.shape[1]
113
+
114
+ #[1,N,512] [1,N]
115
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
116
+
117
+ stop = False
118
+ for idx in range(1, 1500):
119
+ #[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
120
+ enco = self.stage_decoder(y, k, v, y_emb, x_example)
121
+ y, k, v, y_emb, logits, samples = enco
122
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
123
+ stop = True
124
+ if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
125
+ stop = True
126
+ if stop:
127
+ break
128
+ y[0, -1] = 0
129
+
130
+ return y[:, -idx:].unsqueeze(0)
131
+
132
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
133
+ #self.onnx_encoder = torch.jit.script(self.onnx_encoder)
134
+ if dynamo:
135
+ export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
136
+ onnx_encoder_export_output = torch.onnx.dynamo_export(
137
+ self.onnx_encoder,
138
+ (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
139
+ export_options=export_options
140
+ )
141
+ onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
142
+ return
143
+ torch.onnx.export(
144
+ self.onnx_encoder,
145
+ (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
146
+ f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
147
+ input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
148
+ output_names=["x", "prompts"],
149
+ dynamic_axes={
150
+ "ref_seq": [1],
151
+ "text_seq": [1],
152
+ "ref_bert": [0],
153
+ "text_bert": [0],
154
+ "ssl_content": [2],
155
+ },
156
+ opset_version=16
157
+ )
158
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
159
+ torch.exp
160
+ torch.onnx.export(
161
+ self.first_stage_decoder,
162
+ (x, prompts),
163
+ f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
164
+ input_names=["x", "prompts"],
165
+ output_names=["y", "k", "v", "y_emb", "x_example"],
166
+ dynamic_axes={
167
+ "x": [1],
168
+ "prompts": [1],
169
+ },
170
+ verbose=True,
171
+ opset_version=16
172
+ )
173
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
174
+
175
+ torch.onnx.export(
176
+ self.stage_decoder,
177
+ (y, k, v, y_emb, x_example),
178
+ f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
179
+ input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
180
+ output_names=["y", "k", "v", "y_emb", "logits", "samples"],
181
+ dynamic_axes={
182
+ "iy": [1],
183
+ "ik": [1],
184
+ "iv": [1],
185
+ "iy_emb": [1],
186
+ "ix_example": [1],
187
+ },
188
+ verbose=True,
189
+ opset_version=16
190
+ )
191
+
192
+
193
+ class VitsModel(nn.Module):
194
+ def __init__(self, vits_path):
195
+ super().__init__()
196
+ dict_s2 = torch.load(vits_path,map_location="cpu")
197
+ self.hps = dict_s2["config"]
198
+ self.hps = DictToAttrRecursive(self.hps)
199
+ self.hps.model.semantic_frame_rate = "25hz"
200
+ self.vq_model = SynthesizerTrn(
201
+ self.hps.data.filter_length // 2 + 1,
202
+ self.hps.train.segment_size // self.hps.data.hop_length,
203
+ n_speakers=self.hps.data.n_speakers,
204
+ **self.hps.model
205
+ )
206
+ self.vq_model.eval()
207
+ self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
208
+
209
+ def forward(self, text_seq, pred_semantic, ref_audio):
210
+ refer = spectrogram_torch(
211
+ ref_audio,
212
+ self.hps.data.filter_length,
213
+ self.hps.data.sampling_rate,
214
+ self.hps.data.hop_length,
215
+ self.hps.data.win_length,
216
+ center=False
217
+ )
218
+ return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
219
+
220
+
221
+ class GptSoVits(nn.Module):
222
+ def __init__(self, vits, t2s):
223
+ super().__init__()
224
+ self.vits = vits
225
+ self.t2s = t2s
226
+
227
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
228
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
229
+ return self.vits(text_seq, pred_semantic, ref_audio)
230
+
231
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
232
+ self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
233
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
234
+ torch.onnx.export(
235
+ self.vits,
236
+ (text_seq, pred_semantic, ref_audio),
237
+ f"onnx/{project_name}/{project_name}_vits.onnx",
238
+ input_names=["text_seq", "pred_semantic", "ref_audio"],
239
+ output_names=["audio"],
240
+ dynamic_axes={
241
+ "text_seq": [1],
242
+ "pred_semantic": [2],
243
+ "ref_audio": [1],
244
+ },
245
+ opset_version=17
246
+ )
247
+
248
+
249
+ class SSLModel(nn.Module):
250
+ def __init__(self):
251
+ super().__init__()
252
+ self.ssl = ssl_model
253
+
254
+ def forward(self, ref_audio_16k):
255
+ return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
256
+
257
+
258
+ def export(vits_path, gpt_path, project_name):
259
+ vits = VitsModel(vits_path)
260
+ gpt = T2SModel(gpt_path, vits)
261
+ gpt_sovits = GptSoVits(vits, gpt)
262
+ ssl = SSLModel()
263
+ ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
264
+ text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
265
+ ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
266
+ text_bert = torch.randn((text_seq.shape[1], 1024)).float()
267
+ ref_audio = torch.randn((1, 48000 * 5)).float()
268
+ # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
269
+ ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
270
+ ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
271
+
272
+ try:
273
+ os.mkdir(f"onnx/{project_name}")
274
+ except:
275
+ pass
276
+
277
+ ssl_content = ssl(ref_audio_16k).float()
278
+
279
+ a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
280
+
281
+ # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
282
+
283
+ gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
284
+
285
+ MoeVSConf = {
286
+ "Folder" : f"{project_name}",
287
+ "Name" : f"{project_name}",
288
+ "Type" : "GPT-SoVits",
289
+ "Rate" : vits.hps.data.sampling_rate,
290
+ "NumLayers": gpt.t2s_model.num_layers,
291
+ "EmbeddingDim": gpt.t2s_model.embedding_dim,
292
+ "Dict": "BasicDict",
293
+ "BertPath": "chinese-roberta-wwm-ext-large",
294
+ "Symbol": symbols,
295
+ "AddBlank": False
296
+ }
297
+
298
+ MoeVSConfJson = json.dumps(MoeVSConf)
299
+ with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
300
+ json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
301
+
302
+
303
+ if __name__ == "__main__":
304
+ try:
305
+ os.mkdir("onnx")
306
+ except:
307
+ pass
308
+
309
+ gpt_path = "pt_model/koharu-e20.ckpt"
310
+ vits_path = "pt_model/koharu_e20_s4960.pth"
311
+ exp_path = "koharu"
312
+ export(vits_path, gpt_path, exp_path)
313
+
314
+ # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
prepare_datasets/1-get-text.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+
5
+ inp_text = os.environ.get("inp_text")
6
+ inp_wav_dir = os.environ.get("inp_wav_dir")
7
+ exp_name = os.environ.get("exp_name")
8
+ i_part = os.environ.get("i_part")
9
+ all_parts = os.environ.get("all_parts")
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
11
+ opt_dir = os.environ.get("opt_dir")
12
+ bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
13
+ is_half = eval(os.environ.get("is_half", "True"))
14
+ import sys, numpy as np, traceback, pdb
15
+ import os.path
16
+ from glob import glob
17
+ from tqdm import tqdm
18
+ from text.cleaner import clean_text
19
+ import torch
20
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
21
+ import numpy as np
22
+
23
+ # inp_text=sys.argv[1]
24
+ # inp_wav_dir=sys.argv[2]
25
+ # exp_name=sys.argv[3]
26
+ # i_part=sys.argv[4]
27
+ # all_parts=sys.argv[5]
28
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
29
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
30
+ # bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
31
+
32
+ from time import time as ttime
33
+ import shutil
34
+
35
+
36
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
37
+ dir = os.path.dirname(path)
38
+ name = os.path.basename(path)
39
+ tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
40
+ torch.save(fea, tmp_path)
41
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
42
+
43
+
44
+
45
+ txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
46
+ if os.path.exists(txt_path) == False:
47
+ bert_dir = "%s/3-bert" % (opt_dir)
48
+ os.makedirs(opt_dir, exist_ok=True)
49
+ os.makedirs(bert_dir, exist_ok=True)
50
+ if torch.cuda.is_available():
51
+ device = "cuda:0"
52
+ elif torch.backends.mps.is_available():
53
+ device = "mps"
54
+ else:
55
+ device = "cpu"
56
+ tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
57
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
58
+ if is_half == True:
59
+ bert_model = bert_model.half().to(device)
60
+ else:
61
+ bert_model = bert_model.to(device)
62
+
63
+ def get_bert_feature(text, word2ph):
64
+ with torch.no_grad():
65
+ inputs = tokenizer(text, return_tensors="pt")
66
+ for i in inputs:
67
+ inputs[i] = inputs[i].to(device)
68
+ res = bert_model(**inputs, output_hidden_states=True)
69
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
70
+
71
+ assert len(word2ph) == len(text)
72
+ phone_level_feature = []
73
+ for i in range(len(word2ph)):
74
+ repeat_feature = res[i].repeat(word2ph[i], 1)
75
+ phone_level_feature.append(repeat_feature)
76
+
77
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
78
+
79
+ return phone_level_feature.T
80
+
81
+ def process(data, res):
82
+ for name, text, lan in data:
83
+ try:
84
+ name = os.path.basename(name)
85
+ phones, word2ph, norm_text = clean_text(
86
+ text.replace("%", "-").replace("¥", ","), lan
87
+ )
88
+ path_bert = "%s/%s.pt" % (bert_dir, name)
89
+ if os.path.exists(path_bert) == False and lan == "zh":
90
+ bert_feature = get_bert_feature(norm_text, word2ph)
91
+ assert bert_feature.shape[-1] == len(phones)
92
+ # torch.save(bert_feature, path_bert)
93
+ my_save(bert_feature, path_bert)
94
+ phones = " ".join(phones)
95
+ # res.append([name,phones])
96
+ res.append([name, phones, word2ph, norm_text])
97
+ except:
98
+ print(name, text, traceback.format_exc())
99
+
100
+ todo = []
101
+ res = []
102
+ with open(inp_text, "r", encoding="utf8") as f:
103
+ lines = f.read().strip("\n").split("\n")
104
+
105
+ language_v1_to_language_v2 = {
106
+ "ZH": "zh",
107
+ "zh": "zh",
108
+ "JP": "ja",
109
+ "jp": "ja",
110
+ "JA": "ja",
111
+ "ja": "ja",
112
+ "EN": "en",
113
+ "en": "en",
114
+ "En": "en",
115
+ }
116
+ for line in lines[int(i_part) :: int(all_parts)]:
117
+ try:
118
+ wav_name, spk_name, language, text = line.split("|")
119
+ # todo.append([name,text,"zh"])
120
+ todo.append(
121
+ [wav_name, text, language_v1_to_language_v2.get(language, language)]
122
+ )
123
+ except:
124
+ print(line, traceback.format_exc())
125
+
126
+ process(todo, res)
127
+ opt = []
128
+ for name, phones, word2ph, norm_text in res:
129
+ opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
130
+ with open(txt_path, "w", encoding="utf8") as f:
131
+ f.write("\n".join(opt) + "\n")
prepare_datasets/2-get-hubert-wav32k.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import sys,os
4
+ inp_text= os.environ.get("inp_text")
5
+ inp_wav_dir= os.environ.get("inp_wav_dir")
6
+ exp_name= os.environ.get("exp_name")
7
+ i_part= os.environ.get("i_part")
8
+ all_parts= os.environ.get("all_parts")
9
+ os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
10
+ from feature_extractor import cnhubert
11
+ opt_dir= os.environ.get("opt_dir")
12
+ cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
13
+ is_half=eval(os.environ.get("is_half","True"))
14
+
15
+ import pdb,traceback,numpy as np,logging
16
+ from scipy.io import wavfile
17
+ import librosa,torch
18
+ now_dir = os.getcwd()
19
+ sys.path.append(now_dir)
20
+ from my_utils import load_audio
21
+
22
+ # from config import cnhubert_base_path
23
+ # cnhubert.cnhubert_base_path=cnhubert_base_path
24
+ # inp_text=sys.argv[1]
25
+ # inp_wav_dir=sys.argv[2]
26
+ # exp_name=sys.argv[3]
27
+ # i_part=sys.argv[4]
28
+ # all_parts=sys.argv[5]
29
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
30
+ # cnhubert.cnhubert_base_path=sys.argv[7]
31
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
32
+
33
+ from time import time as ttime
34
+ import shutil
35
+ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
36
+ dir=os.path.dirname(path)
37
+ name=os.path.basename(path)
38
+ tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
39
+ torch.save(fea,tmp_path)
40
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
41
+
42
+ hubert_dir="%s/4-cnhubert"%(opt_dir)
43
+ wav32dir="%s/5-wav32k"%(opt_dir)
44
+ os.makedirs(opt_dir,exist_ok=True)
45
+ os.makedirs(hubert_dir,exist_ok=True)
46
+ os.makedirs(wav32dir,exist_ok=True)
47
+
48
+ maxx=0.95
49
+ alpha=0.5
50
+ if torch.cuda.is_available():
51
+ device = "cuda:0"
52
+ elif torch.backends.mps.is_available():
53
+ device = "mps"
54
+ else:
55
+ device = "cpu"
56
+ model=cnhubert.get_model()
57
+ # is_half=False
58
+ if(is_half==True):
59
+ model=model.half().to(device)
60
+ else:
61
+ model = model.to(device)
62
+
63
+ nan_fails=[]
64
+ def name2go(wav_name):
65
+ hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
66
+ if(os.path.exists(hubert_path)):return
67
+ wav_path="%s/%s"%(inp_wav_dir,wav_name)
68
+ tmp_audio = load_audio(wav_path, 32000)
69
+ tmp_max = np.abs(tmp_audio).max()
70
+ if tmp_max > 2.2:
71
+ print("%s-filtered" % (wav_name, tmp_max))
72
+ return
73
+ tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
74
+ tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
75
+ tmp_audio = librosa.resample(
76
+ tmp_audio32b, orig_sr=32000, target_sr=16000
77
+ )#不是重采样问题
78
+ tensor_wav16 = torch.from_numpy(tmp_audio)
79
+ if (is_half == True):
80
+ tensor_wav16=tensor_wav16.half().to(device)
81
+ else:
82
+ tensor_wav16 = tensor_wav16.to(device)
83
+ ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
84
+ if np.isnan(ssl.detach().numpy()).sum()!= 0:
85
+ nan_fails.append(wav_name)
86
+ print("nan filtered:%s"%wav_name)
87
+ return
88
+ wavfile.write(
89
+ "%s/%s"%(wav32dir,wav_name),
90
+ 32000,
91
+ tmp_audio32.astype("int16"),
92
+ )
93
+ my_save(ssl,hubert_path )
94
+
95
+ with open(inp_text,"r",encoding="utf8")as f:
96
+ lines=f.read().strip("\n").split("\n")
97
+
98
+ for line in lines[int(i_part)::int(all_parts)]:
99
+ try:
100
+ # wav_name,text=line.split("\t")
101
+ wav_name, spk_name, language, text = line.split("|")
102
+ wav_name=os.path.basename(wav_name)
103
+ name2go(wav_name)
104
+ except:
105
+ print(line,traceback.format_exc())
106
+
107
+ if(len(nan_fails)>0 and is_half==True):
108
+ is_half=False
109
+ model=model.float()
110
+ for wav_name in nan_fails:
111
+ try:
112
+ name2go(wav_name)
113
+ except:
114
+ print(wav_name,traceback.format_exc())
prepare_datasets/3-get-semantic.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ inp_text = os.environ.get("inp_text")
4
+ exp_name = os.environ.get("exp_name")
5
+ i_part = os.environ.get("i_part")
6
+ all_parts = os.environ.get("all_parts")
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
8
+ opt_dir = os.environ.get("opt_dir")
9
+ pretrained_s2G = os.environ.get("pretrained_s2G")
10
+ s2config_path = os.environ.get("s2config_path")
11
+ is_half = eval(os.environ.get("is_half", "True"))
12
+ import math, traceback
13
+ import multiprocessing
14
+ import sys, pdb
15
+
16
+ now_dir = os.getcwd()
17
+ sys.path.append(now_dir)
18
+ from random import shuffle
19
+ import torch.multiprocessing as mp
20
+ from glob import glob
21
+ from tqdm import tqdm
22
+ import logging, librosa, utils, torch
23
+ from module.models import SynthesizerTrn
24
+
25
+ logging.getLogger("numba").setLevel(logging.WARNING)
26
+ # from config import pretrained_s2G
27
+
28
+ # inp_text=sys.argv[1]
29
+ # exp_name=sys.argv[2]
30
+ # i_part=sys.argv[3]
31
+ # all_parts=sys.argv[4]
32
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
33
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
34
+
35
+
36
+ hubert_dir = "%s/4-cnhubert" % (opt_dir)
37
+ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
38
+ if os.path.exists(semantic_path) == False:
39
+ os.makedirs(opt_dir, exist_ok=True)
40
+
41
+ if torch.cuda.is_available():
42
+ device = "cuda"
43
+ elif torch.backends.mps.is_available():
44
+ device = "mps"
45
+ else:
46
+ device = "cpu"
47
+ hps = utils.get_hparams_from_file(s2config_path)
48
+ vq_model = SynthesizerTrn(
49
+ hps.data.filter_length // 2 + 1,
50
+ hps.train.segment_size // hps.data.hop_length,
51
+ n_speakers=hps.data.n_speakers,
52
+ **hps.model
53
+ )
54
+ if is_half == True:
55
+ vq_model = vq_model.half().to(device)
56
+ else:
57
+ vq_model = vq_model.to(device)
58
+ vq_model.eval()
59
+ # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
60
+ # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
61
+ print(
62
+ vq_model.load_state_dict(
63
+ torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
64
+ )
65
+ )
66
+
67
+ def name2go(wav_name, lines):
68
+ hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
69
+ if os.path.exists(hubert_path) == False:
70
+ return
71
+ ssl_content = torch.load(hubert_path, map_location="cpu")
72
+ if is_half == True:
73
+ ssl_content = ssl_content.half().to(device)
74
+ else:
75
+ ssl_content = ssl_content.to(device)
76
+ codes = vq_model.extract_latent(ssl_content)
77
+ semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
78
+ lines.append("%s\t%s" % (wav_name, semantic))
79
+
80
+ with open(inp_text, "r", encoding="utf8") as f:
81
+ lines = f.read().strip("\n").split("\n")
82
+
83
+ lines1 = []
84
+ for line in lines[int(i_part) :: int(all_parts)]:
85
+ # print(line)
86
+ try:
87
+ # wav_name,text=line.split("\t")
88
+ wav_name, spk_name, language, text = line.split("|")
89
+ wav_name = os.path.basename(wav_name)
90
+ # name2go(name,lines1)
91
+ name2go(wav_name, lines1)
92
+ except:
93
+ print(line, traceback.format_exc())
94
+ with open(semantic_path, "w", encoding="utf8") as f:
95
+ f.write("\n".join(lines1))
process_ckpt.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ from tools.i18n.i18n import I18nAuto
6
+
7
+ i18n = I18nAuto()
8
+
9
+
10
+ def savee(ckpt, name, epoch, steps, hps):
11
+ try:
12
+ opt = OrderedDict()
13
+ opt["weight"] = {}
14
+ for key in ckpt.keys():
15
+ if "enc_q" in key:
16
+ continue
17
+ opt["weight"][key] = ckpt[key].half()
18
+ opt["config"] = hps
19
+ opt["info"] = "%sepoch_%siteration" % (epoch, steps)
20
+ torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
21
+ return "Success."
22
+ except:
23
+ return traceback.format_exc()
text/tone_sandhi.py CHANGED
@@ -455,6 +455,35 @@ class ToneSandhi:
455
  "电子",
456
  "人人",
457
  "虎虎",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  }
459
  self.punc = ":,;。?!“”‘’':,;.?!"
460
 
 
455
  "电子",
456
  "人人",
457
  "虎虎",
458
+ "幺幺",
459
+ "干嘛",
460
+ "学子",
461
+ "哈哈",
462
+ "数数",
463
+ "袅袅",
464
+ "局地",
465
+ "以下",
466
+ "娃哈哈",
467
+ "花花草草",
468
+ "留得",
469
+ "耕地",
470
+ "想想",
471
+ "熙熙",
472
+ "攘攘",
473
+ "卵子",
474
+ "死死",
475
+ "冉冉",
476
+ "恳恳",
477
+ "佼佼",
478
+ "吵吵",
479
+ "打打",
480
+ "考考",
481
+ "整整",
482
+ "莘莘",
483
+ "落地",
484
+ "算子",
485
+ "家家户户",
486
+ "青青",
487
  }
488
  self.punc = ":,;。?!“”‘’':,;.?!"
489