Kit-Lemonfoot commited on
Commit
d681f3e
1 Parent(s): 12f6da1

Upload 92 files

Browse files
GPT_SoVITS/AR/models/t2s_model.py CHANGED
@@ -1,5 +1,9 @@
1
  # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
  # reference: https://github.com/lifeiteng/vall-e
 
 
 
 
3
  import torch
4
  from tqdm import tqdm
5
 
@@ -35,6 +39,155 @@ default_config = {
35
  }
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class Text2SemanticDecoder(nn.Module):
39
  def __init__(self, config, norm_first=False, top_k=3):
40
  super(Text2SemanticDecoder, self).__init__()
@@ -89,6 +242,37 @@ class Text2SemanticDecoder(nn.Module):
89
  ignore_index=self.EOS,
90
  )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
93
  x = self.ar_text_embedding(x)
94
  x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -116,7 +300,7 @@ class Text2SemanticDecoder(nn.Module):
116
  (0, y_len),
117
  value=True,
118
  )
119
-
120
  y_attn_mask = F.pad(
121
  torch.triu(
122
  torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@@ -212,6 +396,7 @@ class Text2SemanticDecoder(nn.Module):
212
  (0, y_len),
213
  value=True,
214
  )
 
215
  y_attn_mask = F.pad(
216
  torch.triu(
217
  torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@@ -277,7 +462,7 @@ class Text2SemanticDecoder(nn.Module):
277
  value=True,
278
  )
279
  y_attn_mask = F.pad(
280
- torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
281
  (x_len, 0),
282
  value=False,
283
  )
@@ -321,16 +506,226 @@ class Text2SemanticDecoder(nn.Module):
321
  # 错位
322
  return targets[:, :-1], targets[:, 1:]
323
 
324
- def infer_panel(
325
  self,
326
- x, #####全部文本token
327
- x_lens,
328
- prompts, ####参考音频token
329
- bert_feature,
330
  top_k: int = -100,
331
  top_p: int = 100,
332
  early_stop_num: int = -1,
333
  temperature: float = 1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  ):
335
  x = self.ar_text_embedding(x)
336
  x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -343,17 +738,9 @@ class Text2SemanticDecoder(nn.Module):
343
  x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
344
  stop = False
345
  # print(1111111,self.num_layers)
346
- cache = {
347
- "all_stage": self.num_layers,
348
- "k": [None] * self.num_layers, ###根据配置自己手写
349
- "v": [None] * self.num_layers,
350
- # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
351
- "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
352
- # "logits":None,###原版就已经只对结尾求再拼接了,不用管
353
- # "xy_dec":None,###不需要,本来只需要最后一个做logits
354
- "first_infer": 1,
355
- "stage": 0,
356
- }
357
  ################### first step ##########################
358
  if y is not None:
359
  y_emb = self.ar_audio_embedding(y)
@@ -361,7 +748,6 @@ class Text2SemanticDecoder(nn.Module):
361
  prefix_len = y.shape[1]
362
  y_pos = self.ar_audio_position(y_emb)
363
  xy_pos = torch.concat([x, y_pos], dim=1)
364
- cache["y_emb"] = y_emb
365
  ref_free = False
366
  else:
367
  y_emb = None
@@ -372,77 +758,58 @@ class Text2SemanticDecoder(nn.Module):
372
  y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
373
  ref_free = True
374
 
 
 
375
  x_attn_mask_pad = F.pad(
376
- x_attn_mask,
377
- (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
378
- value=True,
379
- )
380
  y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
381
  torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
382
  (x_len, 0),
383
  value=False,
384
  )
385
- xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
386
- x.device
387
- )
388
-
389
-
390
  for idx in tqdm(range(1500)):
391
-
392
- xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
 
 
 
393
  logits = self.ar_predict_layer(
394
  xy_dec[:, -1]
395
- ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
396
- # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
397
- if(idx==0):###第一次跑不能EOS否则没有了
398
- logits = logits[:, :-1] ###刨除1024终止符号的概率
 
 
399
  samples = sample(
400
- logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
401
- )[0].unsqueeze(0)
402
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
403
- # print(samples.shape)#[1,1]#第一个1是bs
404
- y = torch.concat([y, samples], dim=1)
405
 
406
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
407
  print("use early stop num:", early_stop_num)
408
  stop = True
409
 
410
  if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
411
- # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
412
  stop = True
413
  if stop:
414
- # if prompts.shape[1] == y.shape[1]:
415
- # y = torch.concat([y, torch.zeros_like(samples)], dim=1)
416
- # print("bad zero prediction")
417
  if y.shape[1]==0:
418
  y = torch.concat([y, torch.zeros_like(samples)], dim=1)
419
  print("bad zero prediction")
420
  print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
421
  break
422
-
423
  ####################### update next step ###################################
424
- cache["first_infer"] = 0
425
- if cache["y_emb"] is not None:
426
- y_emb = torch.cat(
427
- [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
428
- )
429
- cache["y_emb"] = y_emb
430
- y_pos = self.ar_audio_position(y_emb)
431
- xy_pos = y_pos[:, -1:]
432
- else:
433
- y_emb = self.ar_audio_embedding(y[:, -1:])
434
- cache["y_emb"] = y_emb
435
- y_pos = self.ar_audio_position(y_emb)
436
- xy_pos = y_pos
437
- y_len = y_pos.shape[1]
438
-
439
- ###最右边一列(是错的)
440
- # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
441
- # xy_attn_mask[:,-1]=False
442
- ###最下面一行(是对的)
443
- xy_attn_mask = torch.zeros(
444
- (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
445
- )
446
  if ref_free:
447
  return y[:, :-1], 0
448
- return y[:, :-1], idx-1
 
1
  # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
  # reference: https://github.com/lifeiteng/vall-e
3
+ import os, sys
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+ from typing import List
7
  import torch
8
  from tqdm import tqdm
9
 
 
39
  }
40
 
41
 
42
+ @torch.jit.script
43
+ class T2SMLP:
44
+ def __init__(self, w1, b1, w2, b2):
45
+ self.w1 = w1
46
+ self.b1 = b1
47
+ self.w2 = w2
48
+ self.b2 = b2
49
+
50
+ def forward(self, x):
51
+ x = F.relu(F.linear(x, self.w1, self.b1))
52
+ x = F.linear(x, self.w2, self.b2)
53
+ return x
54
+
55
+
56
+ @torch.jit.script
57
+ class T2SBlock:
58
+ def __init__(
59
+ self,
60
+ num_heads,
61
+ hidden_dim: int,
62
+ mlp: T2SMLP,
63
+ qkv_w,
64
+ qkv_b,
65
+ out_w,
66
+ out_b,
67
+ norm_w1,
68
+ norm_b1,
69
+ norm_eps1,
70
+ norm_w2,
71
+ norm_b2,
72
+ norm_eps2,
73
+ ):
74
+ self.num_heads = num_heads
75
+ self.mlp = mlp
76
+ self.hidden_dim: int = hidden_dim
77
+ self.qkv_w = qkv_w
78
+ self.qkv_b = qkv_b
79
+ self.out_w = out_w
80
+ self.out_b = out_b
81
+ self.norm_w1 = norm_w1
82
+ self.norm_b1 = norm_b1
83
+ self.norm_eps1 = norm_eps1
84
+ self.norm_w2 = norm_w2
85
+ self.norm_b2 = norm_b2
86
+ self.norm_eps2 = norm_eps2
87
+
88
+ @torch.jit.ignore
89
+ def to_mask(self, x, padding_mask):
90
+ return x*padding_mask if padding_mask is not None else x
91
+
92
+ def process_prompt(self, x, attn_mask : torch.Tensor, padding_mask:torch.Tensor=None):
93
+
94
+
95
+ q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
96
+
97
+ batch_size = q.shape[0]
98
+ q_len = q.shape[1]
99
+ kv_len = k.shape[1]
100
+
101
+ q = self.to_mask(q, padding_mask)
102
+ k_cache = self.to_mask(k, padding_mask)
103
+ v_cache = self.to_mask(v, padding_mask)
104
+
105
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
106
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
107
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
108
+
109
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
110
+
111
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
112
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
113
+ attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
114
+
115
+ x = self.to_mask(x + attn, padding_mask)
116
+ x = F.layer_norm(
117
+ x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
118
+ )
119
+ x = self.to_mask(x + self.mlp.forward(self.to_mask(x, padding_mask)), padding_mask)
120
+ x = F.layer_norm(
121
+ x,
122
+ [self.hidden_dim],
123
+ self.norm_w2,
124
+ self.norm_b2,
125
+ self.norm_eps2,
126
+ )
127
+ return x, k_cache, v_cache
128
+
129
+ def decode_next_token(self, x, k_cache, v_cache):
130
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
131
+
132
+ k_cache = torch.cat([k_cache, k], dim=1)
133
+ v_cache = torch.cat([v_cache, v], dim=1)
134
+
135
+ batch_size = q.shape[0]
136
+ q_len = q.shape[1]
137
+ kv_len = k_cache.shape[1]
138
+
139
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
140
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
141
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
142
+
143
+
144
+ attn = F.scaled_dot_product_attention(q, k, v)
145
+
146
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
147
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
148
+ attn = F.linear(attn, self.out_w, self.out_b)
149
+
150
+ x = x + attn
151
+ x = F.layer_norm(
152
+ x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
153
+ )
154
+ x = x + self.mlp.forward(x)
155
+ x = F.layer_norm(
156
+ x,
157
+ [self.hidden_dim],
158
+ self.norm_w2,
159
+ self.norm_b2,
160
+ self.norm_eps2,
161
+ )
162
+ return x, k_cache, v_cache
163
+
164
+
165
+ @torch.jit.script
166
+ class T2STransformer:
167
+ def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
168
+ self.num_blocks : int = num_blocks
169
+ self.blocks = blocks
170
+
171
+ def process_prompt(
172
+ self, x, attn_mask : torch.Tensor,
173
+ padding_mask : torch.Tensor=None,
174
+ ):
175
+ k_cache : List[torch.Tensor] = []
176
+ v_cache : List[torch.Tensor] = []
177
+ for i in range(self.num_blocks):
178
+ x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
179
+ k_cache.append(k_cache_)
180
+ v_cache.append(v_cache_)
181
+ return x, k_cache, v_cache
182
+
183
+ def decode_next_token(
184
+ self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
185
+ ):
186
+ for i in range(self.num_blocks):
187
+ x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
188
+ return x, k_cache, v_cache
189
+
190
+
191
  class Text2SemanticDecoder(nn.Module):
192
  def __init__(self, config, norm_first=False, top_k=3):
193
  super(Text2SemanticDecoder, self).__init__()
 
242
  ignore_index=self.EOS,
243
  )
244
 
245
+ blocks = []
246
+
247
+ for i in range(self.num_layers):
248
+ layer = self.h.layers[i]
249
+ t2smlp = T2SMLP(
250
+ layer.linear1.weight,
251
+ layer.linear1.bias,
252
+ layer.linear2.weight,
253
+ layer.linear2.bias
254
+ )
255
+
256
+ block = T2SBlock(
257
+ self.num_head,
258
+ self.model_dim,
259
+ t2smlp,
260
+ layer.self_attn.in_proj_weight,
261
+ layer.self_attn.in_proj_bias,
262
+ layer.self_attn.out_proj.weight,
263
+ layer.self_attn.out_proj.bias,
264
+ layer.norm1.weight,
265
+ layer.norm1.bias,
266
+ layer.norm1.eps,
267
+ layer.norm2.weight,
268
+ layer.norm2.bias,
269
+ layer.norm2.eps
270
+ )
271
+
272
+ blocks.append(block)
273
+
274
+ self.t2s_transformer = T2STransformer(self.num_layers, blocks)
275
+
276
  def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
277
  x = self.ar_text_embedding(x)
278
  x = x + self.bert_proj(bert_feature.transpose(1, 2))
 
300
  (0, y_len),
301
  value=True,
302
  )
303
+ # x_attn_mask[:, x_len]=False
304
  y_attn_mask = F.pad(
305
  torch.triu(
306
  torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
 
396
  (0, y_len),
397
  value=True,
398
  )
399
+ # x_attn_mask[:, x_len]=False
400
  y_attn_mask = F.pad(
401
  torch.triu(
402
  torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
 
462
  value=True,
463
  )
464
  y_attn_mask = F.pad(
465
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),
466
  (x_len, 0),
467
  value=False,
468
  )
 
506
  # 错位
507
  return targets[:, :-1], targets[:, 1:]
508
 
509
+ def infer_panel_batch_infer_with_flash_attn(
510
  self,
511
+ x:torch.LongTensor, #####全部文本token
512
+ x_lens:torch.LongTensor,
513
+ prompts:torch.LongTensor, ####参考音频token
514
+ bert_feature:torch.LongTensor,
515
  top_k: int = -100,
516
  top_p: int = 100,
517
  early_stop_num: int = -1,
518
  temperature: float = 1.0,
519
+ repetition_penalty: float = 1.35,
520
+ **kwargs,
521
+ ):
522
+ # # fp16 会对结果产生影响(和没pad相比)
523
+ # bert_feature_dtype = bert_feature[0].dtype
524
+ # if not hasattr(self.bert_proj, "dtype"):
525
+ # self.bert_proj.dtype = torch.float32
526
+ # self.bert_proj=self.bert_proj.float()
527
+
528
+ ## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果。)
529
+ ## pad之后再进行Linear会有误差(和没pad相比),就离谱。。。
530
+ max_len = kwargs.get("max_len",x_lens.max())
531
+ # for x_item, bert_item in zip(x, bert_feature):
532
+ # max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
533
+ x_list = [self.ar_text_embedding(item) for item in x]
534
+ x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
535
+ x = torch.stack(x_list, dim=0)
536
+
537
+ bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
538
+ bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
539
+ bert_feature = torch.stack(bert_features_list, dim=0)
540
+
541
+
542
+ # bert_feature = self.bert_proj(bert_feature.transpose(1, 2).float()).to(dtype=bert_feature_dtype)
543
+ # x = self.ar_text_embedding(x)
544
+ x = x + bert_feature
545
+ x = self.ar_text_position(x)
546
+
547
+ # AR Decoder
548
+ y = prompts
549
+
550
+ x_len = x.shape[1]
551
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
552
+ stop = False
553
+
554
+ k_cache = None
555
+ v_cache = None
556
+ ################### first step ##########################
557
+ if y is not None:
558
+ y_emb = self.ar_audio_embedding(y)
559
+ y_len = y_emb.shape[1]
560
+ prefix_len = y.shape[1]
561
+ y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
562
+ y_pos = self.ar_audio_position(y_emb)
563
+ xy_pos = torch.concat([x, y_pos], dim=1)
564
+ ref_free = False
565
+ else:
566
+ y_emb = None
567
+ y_len = 0
568
+ prefix_len = 0
569
+ y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
570
+ y_pos = None
571
+ xy_pos = x
572
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
573
+ ref_free = True
574
+
575
+
576
+ ##### create mask #####
577
+ bsz = x.shape[0]
578
+ src_len = x_len + y_len
579
+ y_paddind_mask = make_pad_mask(y_lens, y_len)
580
+ x_paddind_mask = make_pad_mask(x_lens, max_len)
581
+
582
+ # (bsz, x_len + y_len)
583
+ xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
584
+
585
+ x_mask = F.pad(
586
+ x_attn_mask,
587
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
588
+ value=True,
589
+ )
590
+ y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
591
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
592
+ (x_len, 0),
593
+ value=False,
594
+ )
595
+
596
+ xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
597
+ # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
598
+ _xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
599
+ xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
600
+ xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
601
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
602
+ xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
603
+
604
+ xy_padding_mask = ~xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
605
+ xy_padding_mask = xy_padding_mask.to(dtype=x.dtype)
606
+
607
+ ###### decode #####
608
+ y_list = [None]*y.shape[0]
609
+ batch_idx_map = list(range(y.shape[0]))
610
+ idx_list = [None]*y.shape[0]
611
+ for idx in tqdm(range(1500)):
612
+ if idx == 0:
613
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask)
614
+ else:
615
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
616
+
617
+ logits = self.ar_predict_layer(
618
+ xy_dec[:, -1]
619
+ )
620
+
621
+ if idx == 0:
622
+ xy_attn_mask = None
623
+ logits = logits[:, :-1]
624
+
625
+ samples = sample(
626
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
627
+ )[0]
628
+
629
+ y = torch.concat([y, samples], dim=1)
630
+
631
+ ####### 移除batch中已经生成完毕的序列,进一步优化计算量
632
+ reserved_idx_of_batch_for_y = None
633
+ if (self.EOS in samples[:, 0]) or \
634
+ (self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS,则停止
635
+ l = samples[:, 0]==self.EOS
636
+ removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
637
+ reserved_idx_of_batch_for_y = torch.where(l==False)[0]
638
+ # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
639
+ for i in removed_idx_of_batch_for_y:
640
+ batch_index = batch_idx_map[i]
641
+ idx_list[batch_index] = idx - 1
642
+ y_list[batch_index] = y[i, :-1]
643
+
644
+ batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
645
+
646
+ # 只保留batch中未生成完毕的序列
647
+ if reserved_idx_of_batch_for_y is not None:
648
+ # index = torch.LongTensor(batch_idx_map).to(y.device)
649
+ y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
650
+ if k_cache is not None :
651
+ for i in range(len(k_cache)):
652
+ k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
653
+ v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
654
+
655
+
656
+ if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
657
+ print("use early stop num:", early_stop_num)
658
+ stop = True
659
+ for i, batch_index in enumerate(batch_idx_map):
660
+ batch_index = batch_idx_map[i]
661
+ idx_list[batch_index] = idx
662
+ y_list[batch_index] = y[i, :-1]
663
+
664
+ if not (None in idx_list):
665
+ stop = True
666
+
667
+ if stop:
668
+ if y.shape[1]==0:
669
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
670
+ print("bad zero prediction")
671
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
672
+ break
673
+
674
+ ####################### update next step ###################################
675
+ y_emb = self.ar_audio_embedding(y[:, -1:])
676
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
677
+
678
+ if (None in idx_list):
679
+ for i in range(x.shape[0]):
680
+ if idx_list[i] is None:
681
+ idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
682
+
683
+ if ref_free:
684
+ return y_list, [0]*x.shape[0]
685
+ return y_list, idx_list
686
+
687
+ def infer_panel_0307(self,
688
+ x:List[torch.LongTensor], #####全部文本token
689
+ x_lens:torch.LongTensor,
690
+ prompts:torch.LongTensor, ####参考音频token
691
+ bert_feature:torch.LongTensor,
692
+ top_k: int = -100,
693
+ top_p: int = 100,
694
+ early_stop_num: int = -1,
695
+ temperature: float = 1.0,
696
+ repetition_penalty: float = 1.35,
697
+ **kwargs
698
+ ):
699
+ y_list = []
700
+ idx_list = []
701
+ for i in range(len(x)):
702
+ y, idx = self.infer_panel_with_flash_attn_only(x[i].unsqueeze(0),
703
+ x_lens[i],
704
+ prompts[i].unsqueeze(0),
705
+ bert_feature[i].unsqueeze(0),
706
+ top_k,
707
+ top_p,
708
+ early_stop_num,
709
+ temperature,
710
+ repetition_penalty,
711
+ **kwargs)
712
+ y_list.append(y[0])
713
+ idx_list.append(idx)
714
+
715
+ return y_list, idx_list
716
+
717
+ def infer_panel_with_flash_attn_only(
718
+ self,
719
+ x:torch.LongTensor, #####全部文本token
720
+ x_lens:torch.LongTensor,
721
+ prompts:torch.LongTensor, ####参考音频token
722
+ bert_feature:torch.LongTensor,
723
+ top_k: int = -100,
724
+ top_p: int = 100,
725
+ early_stop_num: int = -1,
726
+ temperature: float = 1.0,
727
+ repetition_penalty: float = 1.35,
728
+ **kwargs
729
  ):
730
  x = self.ar_text_embedding(x)
731
  x = x + self.bert_proj(bert_feature.transpose(1, 2))
 
738
  x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
739
  stop = False
740
  # print(1111111,self.num_layers)
741
+
742
+ k_cache = None
743
+ v_cache = None
 
 
 
 
 
 
 
 
744
  ################### first step ##########################
745
  if y is not None:
746
  y_emb = self.ar_audio_embedding(y)
 
748
  prefix_len = y.shape[1]
749
  y_pos = self.ar_audio_position(y_emb)
750
  xy_pos = torch.concat([x, y_pos], dim=1)
 
751
  ref_free = False
752
  else:
753
  y_emb = None
 
758
  y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
759
  ref_free = True
760
 
761
+ bsz = x.shape[0]
762
+ src_len = x_len + y_len
763
  x_attn_mask_pad = F.pad(
764
+ x_attn_mask,
765
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
766
+ value=True,
767
+ )
768
  y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
769
  torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
770
  (x_len, 0),
771
  value=False,
772
  )
773
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).unsqueeze(0).expand(bsz*self.num_head, -1, -1).view(bsz, self.num_head, src_len, src_len).to(x.device)
774
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
775
+ xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
 
 
776
  for idx in tqdm(range(1500)):
777
+ if xy_attn_mask is not None:
778
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
779
+ else:
780
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
781
+
782
  logits = self.ar_predict_layer(
783
  xy_dec[:, -1]
784
+ )
785
+
786
+ if idx == 0:
787
+ xy_attn_mask = None
788
+ logits = logits[:, :-1]
789
+
790
  samples = sample(
791
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
792
+ )[0]
793
+
794
+ y = torch.concat([y, samples], dim=1)
 
795
 
796
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
797
  print("use early stop num:", early_stop_num)
798
  stop = True
799
 
800
  if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
 
801
  stop = True
802
  if stop:
 
 
 
803
  if y.shape[1]==0:
804
  y = torch.concat([y, torch.zeros_like(samples)], dim=1)
805
  print("bad zero prediction")
806
  print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
807
  break
808
+
809
  ####################### update next step ###################################
810
+ y_emb = self.ar_audio_embedding(y[:, -1:])
811
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
812
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  if ref_free:
814
  return y[:, :-1], 0
815
+ return y[:, :-1], idx - 1
GPT_SoVITS/AR/models/t2s_model_batch_only.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.models.utils import make_pad_mask
6
+ from AR.models.utils import (
7
+ topk_sampling,
8
+ sample,
9
+ logits_to_probs,
10
+ multinomial_sample_one_no_sync,
11
+ dpo_loss,
12
+ make_reject_y,
13
+ get_batch_logps
14
+ )
15
+ from AR.modules.embedding import SinePositionalEmbedding
16
+ from AR.modules.embedding import TokenEmbedding
17
+ from AR.modules.transformer import LayerNorm
18
+ from AR.modules.transformer import TransformerEncoder
19
+ from AR.modules.transformer import TransformerEncoderLayer
20
+ from torch import nn
21
+ from torch.nn import functional as F
22
+ from torchmetrics.classification import MulticlassAccuracy
23
+
24
+ default_config = {
25
+ "embedding_dim": 512,
26
+ "hidden_dim": 512,
27
+ "num_head": 8,
28
+ "num_layers": 12,
29
+ "num_codebook": 8,
30
+ "p_dropout": 0.0,
31
+ "vocab_size": 1024 + 1,
32
+ "phoneme_vocab_size": 512,
33
+ "EOS": 1024,
34
+ }
35
+
36
+
37
+ class Text2SemanticDecoder(nn.Module):
38
+ def __init__(self, config, norm_first=False, top_k=3):
39
+ super(Text2SemanticDecoder, self).__init__()
40
+ self.model_dim = config["model"]["hidden_dim"]
41
+ self.embedding_dim = config["model"]["embedding_dim"]
42
+ self.num_head = config["model"]["head"]
43
+ self.num_layers = config["model"]["n_layer"]
44
+ self.norm_first = norm_first
45
+ self.vocab_size = config["model"]["vocab_size"]
46
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
47
+ self.p_dropout = config["model"]["dropout"]
48
+ self.EOS = config["model"]["EOS"]
49
+ self.norm_first = norm_first
50
+ assert self.EOS == self.vocab_size - 1
51
+ # should be same as num of kmeans bin
52
+ # assert self.EOS == 1024
53
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
54
+ self.ar_text_embedding = TokenEmbedding(
55
+ self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
56
+ )
57
+ self.ar_text_position = SinePositionalEmbedding(
58
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
59
+ )
60
+ self.ar_audio_embedding = TokenEmbedding(
61
+ self.embedding_dim, self.vocab_size, self.p_dropout
62
+ )
63
+ self.ar_audio_position = SinePositionalEmbedding(
64
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
65
+ )
66
+
67
+ self.h = TransformerEncoder(
68
+ TransformerEncoderLayer(
69
+ d_model=self.model_dim,
70
+ nhead=self.num_head,
71
+ dim_feedforward=self.model_dim * 4,
72
+ dropout=0.1,
73
+ batch_first=True,
74
+ norm_first=norm_first,
75
+ ),
76
+ num_layers=self.num_layers,
77
+ norm=LayerNorm(self.model_dim) if norm_first else None,
78
+ )
79
+
80
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
81
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
82
+
83
+ self.ar_accuracy_metric = MulticlassAccuracy(
84
+ self.vocab_size,
85
+ top_k=top_k,
86
+ average="micro",
87
+ multidim_average="global",
88
+ ignore_index=self.EOS,
89
+ )
90
+
91
+ def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
92
+ x = self.ar_text_embedding(x)
93
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
94
+ x = self.ar_text_position(x)
95
+ x_mask = make_pad_mask(x_lens)
96
+
97
+ y_mask = make_pad_mask(y_lens)
98
+ y_mask_int = y_mask.type(torch.int64)
99
+ codes = y.type(torch.int64) * (1 - y_mask_int)
100
+
101
+ # Training
102
+ # AR Decoder
103
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
104
+ x_len = x_lens.max()
105
+ y_len = y_lens.max()
106
+ y_emb = self.ar_audio_embedding(y)
107
+ y_pos = self.ar_audio_position(y_emb)
108
+
109
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
110
+
111
+ ar_xy_padding_mask = xy_padding_mask
112
+
113
+ x_attn_mask = F.pad(
114
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
115
+ (0, y_len),
116
+ value=True,
117
+ )
118
+
119
+ y_attn_mask = F.pad(
120
+ torch.triu(
121
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
122
+ diagonal=1,
123
+ ),
124
+ (x_len, 0),
125
+ value=False,
126
+ )
127
+
128
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
129
+ bsz, src_len = x.shape[0], x_len + y_len
130
+ _xy_padding_mask = (
131
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
132
+ .expand(-1, self.num_head, -1, -1)
133
+ .reshape(bsz * self.num_head, 1, src_len)
134
+ )
135
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
136
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
137
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
138
+ xy_attn_mask = new_attn_mask
139
+ # x 和完整的 y 一次性输入模型
140
+ xy_pos = torch.concat([x, y_pos], dim=1)
141
+
142
+ return xy_pos, xy_attn_mask, targets
143
+
144
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
145
+ """
146
+ x: phoneme_ids
147
+ y: semantic_ids
148
+ """
149
+
150
+ reject_y, reject_y_lens = make_reject_y(y, y_lens)
151
+
152
+ xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
153
+
154
+ xy_dec, _ = self.h(
155
+ (xy_pos, None),
156
+ mask=xy_attn_mask,
157
+ )
158
+ x_len = x_lens.max()
159
+ logits = self.ar_predict_layer(xy_dec[:, x_len:])
160
+
161
+ ###### DPO #############
162
+ reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
163
+
164
+ reject_xy_dec, _ = self.h(
165
+ (reject_xy_pos, None),
166
+ mask=reject_xy_attn_mask,
167
+ )
168
+ x_len = x_lens.max()
169
+ reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
170
+
171
+ # loss
172
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
173
+
174
+ loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
175
+ acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
176
+
177
+ A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
178
+ loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
179
+
180
+ loss = loss_1 + loss_2
181
+
182
+ return loss, acc
183
+
184
+ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
185
+ """
186
+ x: phoneme_ids
187
+ y: semantic_ids
188
+ """
189
+ x = self.ar_text_embedding(x)
190
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
191
+ x = self.ar_text_position(x)
192
+ x_mask = make_pad_mask(x_lens)
193
+
194
+ y_mask = make_pad_mask(y_lens)
195
+ y_mask_int = y_mask.type(torch.int64)
196
+ codes = y.type(torch.int64) * (1 - y_mask_int)
197
+
198
+ # Training
199
+ # AR Decoder
200
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
201
+ x_len = x_lens.max()
202
+ y_len = y_lens.max()
203
+ y_emb = self.ar_audio_embedding(y)
204
+ y_pos = self.ar_audio_position(y_emb)
205
+
206
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
207
+ ar_xy_padding_mask = xy_padding_mask
208
+
209
+ x_attn_mask = F.pad(
210
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
211
+ (0, y_len),
212
+ value=True,
213
+ )
214
+ y_attn_mask = F.pad(
215
+ torch.triu(
216
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
217
+ diagonal=1,
218
+ ),
219
+ (x_len, 0),
220
+ value=False,
221
+ )
222
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
223
+ bsz, src_len = x.shape[0], x_len + y_len
224
+ _xy_padding_mask = (
225
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
226
+ .expand(-1, self.num_head, -1, -1)
227
+ .reshape(bsz * self.num_head, 1, src_len)
228
+ )
229
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
230
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
231
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
232
+ xy_attn_mask = new_attn_mask
233
+ # x 和完整的 y 一次性输入模型
234
+ xy_pos = torch.concat([x, y_pos], dim=1)
235
+ xy_dec, _ = self.h(
236
+ (xy_pos, None),
237
+ mask=xy_attn_mask,
238
+ )
239
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
240
+ # loss
241
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
242
+ loss = F.cross_entropy(logits, targets, reduction="sum")
243
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
244
+ return loss, acc
245
+
246
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
247
+ def infer(
248
+ self,
249
+ x,
250
+ x_lens,
251
+ prompts,
252
+ bert_feature,
253
+ top_k: int = -100,
254
+ early_stop_num: int = -1,
255
+ temperature: float = 1.0,
256
+ ):
257
+ x = self.ar_text_embedding(x)
258
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
259
+ x = self.ar_text_position(x)
260
+
261
+ # AR Decoder
262
+ y = prompts
263
+ prefix_len = y.shape[1]
264
+ x_len = x.shape[1]
265
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
266
+ stop = False
267
+ for _ in tqdm(range(1500)):
268
+ y_emb = self.ar_audio_embedding(y)
269
+ y_pos = self.ar_audio_position(y_emb)
270
+ # x 和逐渐增长的 y 一起输入给模型
271
+ xy_pos = torch.concat([x, y_pos], dim=1)
272
+ y_len = y.shape[1]
273
+ x_attn_mask_pad = F.pad(
274
+ x_attn_mask,
275
+ (0, y_len),
276
+ value=True,
277
+ )
278
+ y_attn_mask = F.pad(
279
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
280
+ (x_len, 0),
281
+ value=False,
282
+ )
283
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
284
+ y.device
285
+ )
286
+
287
+ xy_dec, _ = self.h(
288
+ (xy_pos, None),
289
+ mask=xy_attn_mask,
290
+ )
291
+ logits = self.ar_predict_layer(xy_dec[:, -1])
292
+ samples = topk_sampling(
293
+ logits, top_k=top_k, top_p=1.0, temperature=temperature
294
+ )
295
+
296
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
297
+ print("use early stop num:", early_stop_num)
298
+ stop = True
299
+
300
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
301
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
302
+ stop = True
303
+ if stop:
304
+ if prompts.shape[1] == y.shape[1]:
305
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
306
+ print("bad zero prediction")
307
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
308
+ break
309
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
310
+ # print(samples.shape)#[1,1]#第一个1是bs
311
+ # import os
312
+ # os._exit(2333)
313
+ y = torch.concat([y, samples], dim=1)
314
+ return y
315
+
316
+ def pad_y_eos(self, y, y_mask_int, eos_id):
317
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
318
+ y_mask_int, (0, 1), value=1
319
+ )
320
+ # 错位
321
+ return targets[:, :-1], targets[:, 1:]
322
+
323
+ def infer_panel(
324
+ self,
325
+ x, #####全部文本token
326
+ x_lens,
327
+ prompts, ####参考音频token
328
+ bert_feature,
329
+ top_k: int = -100,
330
+ top_p: int = 100,
331
+ early_stop_num: int = -1,
332
+ temperature: float = 1.0,
333
+ ):
334
+ x = self.ar_text_embedding(x)
335
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
336
+ x = self.ar_text_position(x)
337
+
338
+ # AR Decoder
339
+ y = prompts
340
+
341
+ x_len = x.shape[1]
342
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
343
+ stop = False
344
+ # print(1111111,self.num_layers)
345
+ cache = {
346
+ "all_stage": self.num_layers,
347
+ "k": [None] * self.num_layers, ###根据配置自己手写
348
+ "v": [None] * self.num_layers,
349
+ # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
350
+ "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
351
+ # "logits":None,###原版就已经只对结尾求再拼接了,不用管
352
+ # "xy_dec":None,###不需要,本来只需要最后一个做logits
353
+ "first_infer": 1,
354
+ "stage": 0,
355
+ }
356
+ ################### first step ##########################
357
+ if y is not None:
358
+ y_emb = self.ar_audio_embedding(y)
359
+ y_len = y_emb.shape[1]
360
+ prefix_len = y.shape[1]
361
+ y_pos = self.ar_audio_position(y_emb)
362
+ xy_pos = torch.concat([x, y_pos], dim=1)
363
+ cache["y_emb"] = y_emb
364
+ ref_free = False
365
+ else:
366
+ y_emb = None
367
+ y_len = 0
368
+ prefix_len = 0
369
+ y_pos = None
370
+ xy_pos = x
371
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
372
+ ref_free = True
373
+
374
+ x_attn_mask_pad = F.pad(
375
+ x_attn_mask,
376
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
377
+ value=True,
378
+ )
379
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
380
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
381
+ (x_len, 0),
382
+ value=False,
383
+ )
384
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
385
+ x.device
386
+ )
387
+
388
+ y_list = [None]*y.shape[0]
389
+ batch_idx_map = list(range(y.shape[0]))
390
+ idx_list = [None]*y.shape[0]
391
+ for idx in tqdm(range(1500)):
392
+
393
+ xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
394
+ logits = self.ar_predict_layer(
395
+ xy_dec[:, -1]
396
+ ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
397
+ # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
398
+ if(idx==0):###第一次跑不能EOS否则没有了
399
+ logits = logits[:, :-1] ###刨除1024终止符号的概率
400
+ samples = sample(
401
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
402
+ )[0]
403
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
404
+ # print(samples.shape)#[1,1]#第一个1是bs
405
+ y = torch.concat([y, samples], dim=1)
406
+
407
+ # 移除已经生成完毕的序列
408
+ reserved_idx_of_batch_for_y = None
409
+ if (self.EOS in torch.argmax(logits, dim=-1)) or \
410
+ (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
411
+ l = samples[:, 0]==self.EOS
412
+ removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
413
+ reserved_idx_of_batch_for_y = torch.where(l==False)[0]
414
+ # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
415
+ for i in removed_idx_of_batch_for_y:
416
+ batch_index = batch_idx_map[i]
417
+ idx_list[batch_index] = idx - 1
418
+ y_list[batch_index] = y[i, :-1]
419
+
420
+ batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
421
+
422
+ # 只保留未生成完毕的序列
423
+ if reserved_idx_of_batch_for_y is not None:
424
+ # index = torch.LongTensor(batch_idx_map).to(y.device)
425
+ y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
426
+ if cache["y_emb"] is not None:
427
+ cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
428
+ if cache["k"] is not None:
429
+ for i in range(self.num_layers):
430
+ # 因为kv转置了,所以batch dim是1
431
+ cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
432
+ cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
433
+
434
+
435
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
436
+ print("use early stop num:", early_stop_num)
437
+ stop = True
438
+
439
+ if not (None in idx_list):
440
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
441
+ stop = True
442
+ if stop:
443
+ # if prompts.shape[1] == y.shape[1]:
444
+ # y = torch.concat([y, torch.zeros_like(samples)], dim=1)
445
+ # print("bad zero prediction")
446
+ if y.shape[1]==0:
447
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
448
+ print("bad zero prediction")
449
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
450
+ break
451
+
452
+ ####################### update next step ###################################
453
+ cache["first_infer"] = 0
454
+ if cache["y_emb"] is not None:
455
+ y_emb = torch.cat(
456
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
457
+ )
458
+ cache["y_emb"] = y_emb
459
+ y_pos = self.ar_audio_position(y_emb)
460
+ xy_pos = y_pos[:, -1:]
461
+ else:
462
+ y_emb = self.ar_audio_embedding(y[:, -1:])
463
+ cache["y_emb"] = y_emb
464
+ y_pos = self.ar_audio_position(y_emb)
465
+ xy_pos = y_pos
466
+ y_len = y_pos.shape[1]
467
+
468
+ ###最右边一列(是错的)
469
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
470
+ # xy_attn_mask[:,-1]=False
471
+ ###最下面一行(是对的)
472
+ xy_attn_mask = torch.zeros(
473
+ (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
474
+ )
475
+
476
+ if (None in idx_list):
477
+ for i in range(x.shape[0]):
478
+ if idx_list[i] is None:
479
+ idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
480
+
481
+ if ref_free:
482
+ return y_list, [0]*x.shape[0]
483
+ return y_list, idx_list
GPT_SoVITS/AR/models/utils.py CHANGED
@@ -115,17 +115,17 @@ def logits_to_probs(
115
  top_p: Optional[int] = None,
116
  repetition_penalty: float = 1.0,
117
  ):
118
- if previous_tokens is not None:
119
- previous_tokens = previous_tokens.squeeze()
120
  # print(logits.shape,previous_tokens.shape)
121
  # pdb.set_trace()
122
  if previous_tokens is not None and repetition_penalty != 1.0:
123
  previous_tokens = previous_tokens.long()
124
- score = torch.gather(logits, dim=0, index=previous_tokens)
125
  score = torch.where(
126
  score < 0, score * repetition_penalty, score / repetition_penalty
127
  )
128
- logits.scatter_(dim=0, index=previous_tokens, src=score)
129
 
130
  if top_p is not None and top_p < 1.0:
131
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@@ -133,9 +133,9 @@ def logits_to_probs(
133
  torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
  )
135
  sorted_indices_to_remove = cum_probs > top_p
136
- sorted_indices_to_remove[0] = False # keep at least one option
137
  indices_to_remove = sorted_indices_to_remove.scatter(
138
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
139
  )
140
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
 
@@ -143,7 +143,7 @@ def logits_to_probs(
143
 
144
  if top_k is not None:
145
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
- pivot = v.select(-1, -1).unsqueeze(-1)
147
  logits = torch.where(logits < pivot, -float("Inf"), logits)
148
 
149
  probs = torch.nn.functional.softmax(logits, dim=-1)
 
115
  top_p: Optional[int] = None,
116
  repetition_penalty: float = 1.0,
117
  ):
118
+ # if previous_tokens is not None:
119
+ # previous_tokens = previous_tokens.squeeze()
120
  # print(logits.shape,previous_tokens.shape)
121
  # pdb.set_trace()
122
  if previous_tokens is not None and repetition_penalty != 1.0:
123
  previous_tokens = previous_tokens.long()
124
+ score = torch.gather(logits, dim=1, index=previous_tokens)
125
  score = torch.where(
126
  score < 0, score * repetition_penalty, score / repetition_penalty
127
  )
128
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
129
 
130
  if top_p is not None and top_p < 1.0:
131
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
 
133
  torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
  )
135
  sorted_indices_to_remove = cum_probs > top_p
136
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
137
  indices_to_remove = sorted_indices_to_remove.scatter(
138
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
139
  )
140
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
 
 
143
 
144
  if top_k is not None:
145
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
+ pivot = v[: , -1].unsqueeze(-1)
147
  logits = torch.where(logits < pivot, -float("Inf"), logits)
148
 
149
  probs = torch.nn.functional.softmax(logits, dim=-1)
GPT_SoVITS/AR/modules/patched_mha_with_cache.py CHANGED
@@ -1,465 +1,465 @@
1
- from torch.nn.functional import *
2
- from torch.nn.functional import (
3
- _mha_shape_check,
4
- _canonical_mask,
5
- _none_or_dtype,
6
- _in_projection_packed,
7
- )
8
- from torch.nn import functional as F
9
- import torch
10
- # Tensor = torch.Tensor
11
- # from typing import Callable, List, Optional, Tuple, Union
12
-
13
-
14
- def multi_head_attention_forward_patched(
15
- query: Tensor,
16
- key: Tensor,
17
- value: Tensor,
18
- embed_dim_to_check: int,
19
- num_heads: int,
20
- in_proj_weight: Optional[Tensor],
21
- in_proj_bias: Optional[Tensor],
22
- bias_k: Optional[Tensor],
23
- bias_v: Optional[Tensor],
24
- add_zero_attn: bool,
25
- dropout_p: float,
26
- out_proj_weight: Tensor,
27
- out_proj_bias: Optional[Tensor],
28
- training: bool = True,
29
- key_padding_mask: Optional[Tensor] = None,
30
- need_weights: bool = True,
31
- attn_mask: Optional[Tensor] = None,
32
- use_separate_proj_weight: bool = False,
33
- q_proj_weight: Optional[Tensor] = None,
34
- k_proj_weight: Optional[Tensor] = None,
35
- v_proj_weight: Optional[Tensor] = None,
36
- static_k: Optional[Tensor] = None,
37
- static_v: Optional[Tensor] = None,
38
- average_attn_weights: bool = True,
39
- is_causal: bool = False,
40
- cache=None,
41
- ) -> Tuple[Tensor, Optional[Tensor]]:
42
- r"""
43
- Args:
44
- query, key, value: map a query and a set of key-value pairs to an output.
45
- See "Attention Is All You Need" for more details.
46
- embed_dim_to_check: total dimension of the model.
47
- num_heads: parallel attention heads.
48
- in_proj_weight, in_proj_bias: input projection weight and bias.
49
- bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
50
- add_zero_attn: add a new batch of zeros to the key and
51
- value sequences at dim=1.
52
- dropout_p: probability of an element to be zeroed.
53
- out_proj_weight, out_proj_bias: the output projection weight and bias.
54
- training: apply dropout if is ``True``.
55
- key_padding_mask: if provided, specified padding elements in the key will
56
- be ignored by the attention. This is an binary mask. When the value is True,
57
- the corresponding value on the attention layer will be filled with -inf.
58
- need_weights: output attn_output_weights.
59
- Default: `True`
60
- Note: `needs_weight` defaults to `True`, but should be set to `False`
61
- For best performance when attention weights are not nedeeded.
62
- *Setting needs_weights to `True`
63
- leads to a significant performance degradation.*
64
- attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
65
- the batches while a 3D mask allows to specify a different mask for the entries of each batch.
66
- is_causal: If specified, applies a causal mask as attention mask, and ignores
67
- attn_mask for computing scaled dot product attention.
68
- Default: ``False``.
69
- .. warning::
70
- is_causal is provides a hint that the attn_mask is the
71
- causal mask.Providing incorrect hints can result in
72
- incorrect execution, including forward and backward
73
- compatibility.
74
- use_separate_proj_weight: the function accept the proj. weights for query, key,
75
- and value in different forms. If false, in_proj_weight will be used, which is
76
- a combination of q_proj_weight, k_proj_weight, v_proj_weight.
77
- q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
78
- static_k, static_v: static key and value used for attention operators.
79
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
80
- Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
81
- when ``need_weights=True.``. Default: True
82
-
83
-
84
- Shape:
85
- Inputs:
86
- - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
87
- the embedding dimension.
88
- - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
89
- the embedding dimension.
90
- - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
91
- the embedding dimension.
92
- - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
93
- If a FloatTensor is provided, it will be directly added to the value.
94
- If a BoolTensor is provided, the positions with the
95
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
96
- - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
97
- 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
98
- S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
99
- positions. If a BoolTensor is provided, positions with ``True``
100
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
101
- is provided, it will be added to the attention weight.
102
- - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
103
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
104
- - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
105
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
106
-
107
- Outputs:
108
- - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
109
- E is the embedding dimension.
110
- - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
111
- attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
112
- :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
113
- :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
114
- head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
115
- """
116
- tens_ops = (
117
- query,
118
- key,
119
- value,
120
- in_proj_weight,
121
- in_proj_bias,
122
- bias_k,
123
- bias_v,
124
- out_proj_weight,
125
- out_proj_bias,
126
- )
127
- if has_torch_function(tens_ops):
128
- return handle_torch_function(
129
- multi_head_attention_forward,
130
- tens_ops,
131
- query,
132
- key,
133
- value,
134
- embed_dim_to_check,
135
- num_heads,
136
- in_proj_weight,
137
- in_proj_bias,
138
- bias_k,
139
- bias_v,
140
- add_zero_attn,
141
- dropout_p,
142
- out_proj_weight,
143
- out_proj_bias,
144
- training=training,
145
- key_padding_mask=key_padding_mask,
146
- need_weights=need_weights,
147
- attn_mask=attn_mask,
148
- is_causal=is_causal,
149
- use_separate_proj_weight=use_separate_proj_weight,
150
- q_proj_weight=q_proj_weight,
151
- k_proj_weight=k_proj_weight,
152
- v_proj_weight=v_proj_weight,
153
- static_k=static_k,
154
- static_v=static_v,
155
- average_attn_weights=average_attn_weights,
156
- cache=cache,
157
- )
158
-
159
- is_batched = _mha_shape_check(
160
- query, key, value, key_padding_mask, attn_mask, num_heads
161
- )
162
-
163
- # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
164
- # is batched, run the computation and before returning squeeze the
165
- # batch dimension so that the output doesn't carry this temporary batch dimension.
166
- if not is_batched:
167
- # unsqueeze if the input is unbatched
168
- query = query.unsqueeze(1)
169
- key = key.unsqueeze(1)
170
- value = value.unsqueeze(1)
171
- if key_padding_mask is not None:
172
- key_padding_mask = key_padding_mask.unsqueeze(0)
173
-
174
- # set up shape vars
175
- tgt_len, bsz, embed_dim = query.shape
176
- src_len, _, _ = key.shape
177
-
178
- key_padding_mask = _canonical_mask(
179
- mask=key_padding_mask,
180
- mask_name="key_padding_mask",
181
- other_type=_none_or_dtype(attn_mask),
182
- other_name="attn_mask",
183
- target_type=query.dtype,
184
- )
185
-
186
- if is_causal and attn_mask is None:
187
- raise RuntimeError(
188
- "Need attn_mask if specifying the is_causal hint. "
189
- "You may use the Transformer module method "
190
- "`generate_square_subsequent_mask` to create this mask."
191
- )
192
-
193
- if is_causal and key_padding_mask is None and not need_weights:
194
- # when we have a kpm or need weights, we need attn_mask
195
- # Otherwise, we use the is_causal hint go as is_causal
196
- # indicator to SDPA.
197
- attn_mask = None
198
- else:
199
- attn_mask = _canonical_mask(
200
- mask=attn_mask,
201
- mask_name="attn_mask",
202
- other_type=None,
203
- other_name="",
204
- target_type=query.dtype,
205
- check_other=False,
206
- )
207
-
208
- if key_padding_mask is not None:
209
- # We have the attn_mask, and use that to merge kpm into it.
210
- # Turn off use of is_causal hint, as the merged mask is no
211
- # longer causal.
212
- is_causal = False
213
-
214
- assert (
215
- embed_dim == embed_dim_to_check
216
- ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
217
- if isinstance(embed_dim, torch.Tensor):
218
- # embed_dim can be a tensor when JIT tracing
219
- head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
220
- else:
221
- head_dim = embed_dim // num_heads
222
- assert (
223
- head_dim * num_heads == embed_dim
224
- ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
225
- if use_separate_proj_weight:
226
- # allow MHA to have different embedding dimensions when separate projection weights are used
227
- assert (
228
- key.shape[:2] == value.shape[:2]
229
- ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
230
- else:
231
- assert (
232
- key.shape == value.shape
233
- ), f"key shape {key.shape} does not match value shape {value.shape}"
234
-
235
- #
236
- # compute in-projection
237
- #
238
- if not use_separate_proj_weight:
239
- assert (
240
- in_proj_weight is not None
241
- ), "use_separate_proj_weight is False but in_proj_weight is None"
242
- q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
243
- else:
244
- assert (
245
- q_proj_weight is not None
246
- ), "use_separate_proj_weight is True but q_proj_weight is None"
247
- assert (
248
- k_proj_weight is not None
249
- ), "use_separate_proj_weight is True but k_proj_weight is None"
250
- assert (
251
- v_proj_weight is not None
252
- ), "use_separate_proj_weight is True but v_proj_weight is None"
253
- if in_proj_bias is None:
254
- b_q = b_k = b_v = None
255
- else:
256
- b_q, b_k, b_v = in_proj_bias.chunk(3)
257
- q, k, v = _in_projection(
258
- query,
259
- key,
260
- value,
261
- q_proj_weight,
262
- k_proj_weight,
263
- v_proj_weight,
264
- b_q,
265
- b_k,
266
- b_v,
267
- )
268
- if cache != None:
269
- if cache["first_infer"] == 1:
270
- cache["k"][cache["stage"]] = k
271
- # print(0,cache["k"].shape)
272
- cache["v"][cache["stage"]] = v
273
- else: ###12个layer每个都要留自己的cache_kv
274
- # print(1,cache["k"].shape)
275
- cache["k"][cache["stage"]] = torch.cat(
276
- [cache["k"][cache["stage"]], k], 0
277
- ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
278
- cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
279
- # print(2, cache["k"].shape)
280
- src_len = cache["k"][cache["stage"]].shape[0]
281
- k = cache["k"][cache["stage"]]
282
- v = cache["v"][cache["stage"]]
283
- # if attn_mask is not None:
284
- # attn_mask=attn_mask[-1:,]
285
- # print(attn_mask.shape,attn_mask)
286
- cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
287
- # print(2333,cache)
288
- # prep attention mask
289
-
290
- attn_mask = _canonical_mask(
291
- mask=attn_mask,
292
- mask_name="attn_mask",
293
- other_type=None,
294
- other_name="",
295
- target_type=q.dtype,
296
- check_other=False,
297
- )
298
-
299
- if attn_mask is not None:
300
- # ensure attn_mask's dim is 3
301
- if attn_mask.dim() == 2:
302
- correct_2d_size = (tgt_len, src_len)
303
- if attn_mask.shape != correct_2d_size:
304
- raise RuntimeError(
305
- f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
306
- )
307
- attn_mask = attn_mask.unsqueeze(0)
308
- elif attn_mask.dim() == 3:
309
- correct_3d_size = (bsz * num_heads, tgt_len, src_len)
310
- if attn_mask.shape != correct_3d_size:
311
- raise RuntimeError(
312
- f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
313
- )
314
- else:
315
- raise RuntimeError(
316
- f"attn_mask's dimension {attn_mask.dim()} is not supported"
317
- )
318
-
319
- # add bias along batch dimension (currently second)
320
- if bias_k is not None and bias_v is not None:
321
- assert static_k is None, "bias cannot be added to static key."
322
- assert static_v is None, "bias cannot be added to static value."
323
- k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
324
- v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
325
- if attn_mask is not None:
326
- attn_mask = pad(attn_mask, (0, 1))
327
- if key_padding_mask is not None:
328
- key_padding_mask = pad(key_padding_mask, (0, 1))
329
- else:
330
- assert bias_k is None
331
- assert bias_v is None
332
-
333
- #
334
- # reshape q, k, v for multihead attention and make em batch first
335
- #
336
- q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
337
- if static_k is None:
338
- k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
339
- else:
340
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
341
- assert (
342
- static_k.size(0) == bsz * num_heads
343
- ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
344
- assert (
345
- static_k.size(2) == head_dim
346
- ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
347
- k = static_k
348
- if static_v is None:
349
- v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
350
- else:
351
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
352
- assert (
353
- static_v.size(0) == bsz * num_heads
354
- ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
355
- assert (
356
- static_v.size(2) == head_dim
357
- ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
358
- v = static_v
359
-
360
- # add zero attention along batch dimension (now first)
361
- if add_zero_attn:
362
- zero_attn_shape = (bsz * num_heads, 1, head_dim)
363
- k = torch.cat(
364
- [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
365
- )
366
- v = torch.cat(
367
- [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
368
- )
369
- if attn_mask is not None:
370
- attn_mask = pad(attn_mask, (0, 1))
371
- if key_padding_mask is not None:
372
- key_padding_mask = pad(key_padding_mask, (0, 1))
373
-
374
- # update source sequence length after adjustments
375
- src_len = k.size(1)
376
-
377
- # merge key padding and attention masks
378
- if key_padding_mask is not None:
379
- assert key_padding_mask.shape == (
380
- bsz,
381
- src_len,
382
- ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
383
- key_padding_mask = (
384
- key_padding_mask.view(bsz, 1, 1, src_len)
385
- .expand(-1, num_heads, -1, -1)
386
- .reshape(bsz * num_heads, 1, src_len)
387
- )
388
- if attn_mask is None:
389
- attn_mask = key_padding_mask
390
- else:
391
- attn_mask = attn_mask + key_padding_mask
392
-
393
- # adjust dropout probability
394
- if not training:
395
- dropout_p = 0.0
396
-
397
- #
398
- # (deep breath) calculate attention and out projection
399
- #
400
-
401
- if need_weights:
402
- B, Nt, E = q.shape
403
- q_scaled = q / math.sqrt(E)
404
-
405
- assert not (
406
- is_causal and attn_mask is None
407
- ), "FIXME: is_causal not implemented for need_weights"
408
-
409
- if attn_mask is not None:
410
- attn_output_weights = torch.baddbmm(
411
- attn_mask, q_scaled, k.transpose(-2, -1)
412
- )
413
- else:
414
- attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
415
- attn_output_weights = softmax(attn_output_weights, dim=-1)
416
- if dropout_p > 0.0:
417
- attn_output_weights = dropout(attn_output_weights, p=dropout_p)
418
-
419
- attn_output = torch.bmm(attn_output_weights, v)
420
-
421
- attn_output = (
422
- attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
423
- )
424
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
425
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
426
-
427
- # optionally average attention weights over heads
428
- attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
429
- if average_attn_weights:
430
- attn_output_weights = attn_output_weights.mean(dim=1)
431
-
432
- if not is_batched:
433
- # squeeze the output if input was unbatched
434
- attn_output = attn_output.squeeze(1)
435
- attn_output_weights = attn_output_weights.squeeze(0)
436
- return attn_output, attn_output_weights
437
- else:
438
- # attn_mask can be either (L,S) or (N*num_heads, L, S)
439
- # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
440
- # in order to match the input for SDPA of (N, num_heads, L, S)
441
- if attn_mask is not None:
442
- if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
443
- attn_mask = attn_mask.unsqueeze(0)
444
- else:
445
- attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
446
-
447
- q = q.view(bsz, num_heads, tgt_len, head_dim)
448
- k = k.view(bsz, num_heads, src_len, head_dim)
449
- v = v.view(bsz, num_heads, src_len, head_dim)
450
-
451
- # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
452
- attn_output = scaled_dot_product_attention(
453
- q, k, v, attn_mask, dropout_p, is_causal
454
- )
455
-
456
- attn_output = (
457
- attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
458
- )
459
-
460
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
461
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
462
- if not is_batched:
463
- # squeeze the output if input was unbatched
464
- attn_output = attn_output.squeeze(1)
465
- return attn_output, None
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+ from torch.nn import functional as F
9
+ import torch
10
+ # Tensor = torch.Tensor
11
+ # from typing import Callable, List, Optional, Tuple, Union
12
+
13
+
14
+ def multi_head_attention_forward_patched(
15
+ query: Tensor,
16
+ key: Tensor,
17
+ value: Tensor,
18
+ embed_dim_to_check: int,
19
+ num_heads: int,
20
+ in_proj_weight: Optional[Tensor],
21
+ in_proj_bias: Optional[Tensor],
22
+ bias_k: Optional[Tensor],
23
+ bias_v: Optional[Tensor],
24
+ add_zero_attn: bool,
25
+ dropout_p: float,
26
+ out_proj_weight: Tensor,
27
+ out_proj_bias: Optional[Tensor],
28
+ training: bool = True,
29
+ key_padding_mask: Optional[Tensor] = None,
30
+ need_weights: bool = True,
31
+ attn_mask: Optional[Tensor] = None,
32
+ use_separate_proj_weight: bool = False,
33
+ q_proj_weight: Optional[Tensor] = None,
34
+ k_proj_weight: Optional[Tensor] = None,
35
+ v_proj_weight: Optional[Tensor] = None,
36
+ static_k: Optional[Tensor] = None,
37
+ static_v: Optional[Tensor] = None,
38
+ average_attn_weights: bool = True,
39
+ is_causal: bool = False,
40
+ cache=None,
41
+ ) -> Tuple[Tensor, Optional[Tensor]]:
42
+ r"""
43
+ Args:
44
+ query, key, value: map a query and a set of key-value pairs to an output.
45
+ See "Attention Is All You Need" for more details.
46
+ embed_dim_to_check: total dimension of the model.
47
+ num_heads: parallel attention heads.
48
+ in_proj_weight, in_proj_bias: input projection weight and bias.
49
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
50
+ add_zero_attn: add a new batch of zeros to the key and
51
+ value sequences at dim=1.
52
+ dropout_p: probability of an element to be zeroed.
53
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
54
+ training: apply dropout if is ``True``.
55
+ key_padding_mask: if provided, specified padding elements in the key will
56
+ be ignored by the attention. This is an binary mask. When the value is True,
57
+ the corresponding value on the attention layer will be filled with -inf.
58
+ need_weights: output attn_output_weights.
59
+ Default: `True`
60
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
61
+ For best performance when attention weights are not nedeeded.
62
+ *Setting needs_weights to `True`
63
+ leads to a significant performance degradation.*
64
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
65
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
66
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
67
+ attn_mask for computing scaled dot product attention.
68
+ Default: ``False``.
69
+ .. warning::
70
+ is_causal is provides a hint that the attn_mask is the
71
+ causal mask.Providing incorrect hints can result in
72
+ incorrect execution, including forward and backward
73
+ compatibility.
74
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
75
+ and value in different forms. If false, in_proj_weight will be used, which is
76
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
77
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
78
+ static_k, static_v: static key and value used for attention operators.
79
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
80
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
81
+ when ``need_weights=True.``. Default: True
82
+
83
+
84
+ Shape:
85
+ Inputs:
86
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
87
+ the embedding dimension.
88
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
89
+ the embedding dimension.
90
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
91
+ the embedding dimension.
92
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
93
+ If a FloatTensor is provided, it will be directly added to the value.
94
+ If a BoolTensor is provided, the positions with the
95
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
96
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
97
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
98
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
99
+ positions. If a BoolTensor is provided, positions with ``True``
100
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
101
+ is provided, it will be added to the attention weight.
102
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
103
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
104
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
105
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
106
+
107
+ Outputs:
108
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
109
+ E is the embedding dimension.
110
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
111
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
112
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
113
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
114
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
115
+ """
116
+ tens_ops = (
117
+ query,
118
+ key,
119
+ value,
120
+ in_proj_weight,
121
+ in_proj_bias,
122
+ bias_k,
123
+ bias_v,
124
+ out_proj_weight,
125
+ out_proj_bias,
126
+ )
127
+ if has_torch_function(tens_ops):
128
+ return handle_torch_function(
129
+ multi_head_attention_forward,
130
+ tens_ops,
131
+ query,
132
+ key,
133
+ value,
134
+ embed_dim_to_check,
135
+ num_heads,
136
+ in_proj_weight,
137
+ in_proj_bias,
138
+ bias_k,
139
+ bias_v,
140
+ add_zero_attn,
141
+ dropout_p,
142
+ out_proj_weight,
143
+ out_proj_bias,
144
+ training=training,
145
+ key_padding_mask=key_padding_mask,
146
+ need_weights=need_weights,
147
+ attn_mask=attn_mask,
148
+ is_causal=is_causal,
149
+ use_separate_proj_weight=use_separate_proj_weight,
150
+ q_proj_weight=q_proj_weight,
151
+ k_proj_weight=k_proj_weight,
152
+ v_proj_weight=v_proj_weight,
153
+ static_k=static_k,
154
+ static_v=static_v,
155
+ average_attn_weights=average_attn_weights,
156
+ cache=cache,
157
+ )
158
+
159
+ is_batched = _mha_shape_check(
160
+ query, key, value, key_padding_mask, attn_mask, num_heads
161
+ )
162
+
163
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
164
+ # is batched, run the computation and before returning squeeze the
165
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
166
+ if not is_batched:
167
+ # unsqueeze if the input is unbatched
168
+ query = query.unsqueeze(1)
169
+ key = key.unsqueeze(1)
170
+ value = value.unsqueeze(1)
171
+ if key_padding_mask is not None:
172
+ key_padding_mask = key_padding_mask.unsqueeze(0)
173
+
174
+ # set up shape vars
175
+ tgt_len, bsz, embed_dim = query.shape
176
+ src_len, _, _ = key.shape
177
+
178
+ key_padding_mask = _canonical_mask(
179
+ mask=key_padding_mask,
180
+ mask_name="key_padding_mask",
181
+ other_type=_none_or_dtype(attn_mask),
182
+ other_name="attn_mask",
183
+ target_type=query.dtype,
184
+ )
185
+
186
+ if is_causal and attn_mask is None:
187
+ raise RuntimeError(
188
+ "Need attn_mask if specifying the is_causal hint. "
189
+ "You may use the Transformer module method "
190
+ "`generate_square_subsequent_mask` to create this mask."
191
+ )
192
+
193
+ if is_causal and key_padding_mask is None and not need_weights:
194
+ # when we have a kpm or need weights, we need attn_mask
195
+ # Otherwise, we use the is_causal hint go as is_causal
196
+ # indicator to SDPA.
197
+ attn_mask = None
198
+ else:
199
+ attn_mask = _canonical_mask(
200
+ mask=attn_mask,
201
+ mask_name="attn_mask",
202
+ other_type=None,
203
+ other_name="",
204
+ target_type=query.dtype,
205
+ check_other=False,
206
+ )
207
+
208
+ if key_padding_mask is not None:
209
+ # We have the attn_mask, and use that to merge kpm into it.
210
+ # Turn off use of is_causal hint, as the merged mask is no
211
+ # longer causal.
212
+ is_causal = False
213
+
214
+ assert (
215
+ embed_dim == embed_dim_to_check
216
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
217
+ if isinstance(embed_dim, torch.Tensor):
218
+ # embed_dim can be a tensor when JIT tracing
219
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
220
+ else:
221
+ head_dim = embed_dim // num_heads
222
+ assert (
223
+ head_dim * num_heads == embed_dim
224
+ ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
225
+ if use_separate_proj_weight:
226
+ # allow MHA to have different embedding dimensions when separate projection weights are used
227
+ assert (
228
+ key.shape[:2] == value.shape[:2]
229
+ ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
230
+ else:
231
+ assert (
232
+ key.shape == value.shape
233
+ ), f"key shape {key.shape} does not match value shape {value.shape}"
234
+
235
+ #
236
+ # compute in-projection
237
+ #
238
+ if not use_separate_proj_weight:
239
+ assert (
240
+ in_proj_weight is not None
241
+ ), "use_separate_proj_weight is False but in_proj_weight is None"
242
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
243
+ else:
244
+ assert (
245
+ q_proj_weight is not None
246
+ ), "use_separate_proj_weight is True but q_proj_weight is None"
247
+ assert (
248
+ k_proj_weight is not None
249
+ ), "use_separate_proj_weight is True but k_proj_weight is None"
250
+ assert (
251
+ v_proj_weight is not None
252
+ ), "use_separate_proj_weight is True but v_proj_weight is None"
253
+ if in_proj_bias is None:
254
+ b_q = b_k = b_v = None
255
+ else:
256
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
257
+ q, k, v = _in_projection(
258
+ query,
259
+ key,
260
+ value,
261
+ q_proj_weight,
262
+ k_proj_weight,
263
+ v_proj_weight,
264
+ b_q,
265
+ b_k,
266
+ b_v,
267
+ )
268
+ if cache != None:
269
+ if cache["first_infer"] == 1:
270
+ cache["k"][cache["stage"]] = k
271
+ # print(0,cache["k"].shape)
272
+ cache["v"][cache["stage"]] = v
273
+ else: ###12个layer每个都要留自己的cache_kv
274
+ # print(1,cache["k"].shape)
275
+ cache["k"][cache["stage"]] = torch.cat(
276
+ [cache["k"][cache["stage"]], k], 0
277
+ ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
278
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
279
+ # print(2, cache["k"].shape)
280
+ src_len = cache["k"][cache["stage"]].shape[0]
281
+ k = cache["k"][cache["stage"]]
282
+ v = cache["v"][cache["stage"]]
283
+ # if attn_mask is not None:
284
+ # attn_mask=attn_mask[-1:,]
285
+ # print(attn_mask.shape,attn_mask)
286
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
287
+ # print(2333,cache)
288
+ # prep attention mask
289
+
290
+ attn_mask = _canonical_mask(
291
+ mask=attn_mask,
292
+ mask_name="attn_mask",
293
+ other_type=None,
294
+ other_name="",
295
+ target_type=q.dtype,
296
+ check_other=False,
297
+ )
298
+
299
+ if attn_mask is not None:
300
+ # ensure attn_mask's dim is 3
301
+ if attn_mask.dim() == 2:
302
+ correct_2d_size = (tgt_len, src_len)
303
+ if attn_mask.shape != correct_2d_size:
304
+ raise RuntimeError(
305
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
306
+ )
307
+ attn_mask = attn_mask.unsqueeze(0)
308
+ elif attn_mask.dim() == 3:
309
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
310
+ if attn_mask.shape != correct_3d_size:
311
+ raise RuntimeError(
312
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
313
+ )
314
+ else:
315
+ raise RuntimeError(
316
+ f"attn_mask's dimension {attn_mask.dim()} is not supported"
317
+ )
318
+
319
+ # add bias along batch dimension (currently second)
320
+ if bias_k is not None and bias_v is not None:
321
+ assert static_k is None, "bias cannot be added to static key."
322
+ assert static_v is None, "bias cannot be added to static value."
323
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
324
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
325
+ if attn_mask is not None:
326
+ attn_mask = pad(attn_mask, (0, 1))
327
+ if key_padding_mask is not None:
328
+ key_padding_mask = pad(key_padding_mask, (0, 1))
329
+ else:
330
+ assert bias_k is None
331
+ assert bias_v is None
332
+
333
+ #
334
+ # reshape q, k, v for multihead attention and make em batch first
335
+ #
336
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
337
+ if static_k is None:
338
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
339
+ else:
340
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
341
+ assert (
342
+ static_k.size(0) == bsz * num_heads
343
+ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
344
+ assert (
345
+ static_k.size(2) == head_dim
346
+ ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
347
+ k = static_k
348
+ if static_v is None:
349
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
350
+ else:
351
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
352
+ assert (
353
+ static_v.size(0) == bsz * num_heads
354
+ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
355
+ assert (
356
+ static_v.size(2) == head_dim
357
+ ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
358
+ v = static_v
359
+
360
+ # add zero attention along batch dimension (now first)
361
+ if add_zero_attn:
362
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
363
+ k = torch.cat(
364
+ [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
365
+ )
366
+ v = torch.cat(
367
+ [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
368
+ )
369
+ if attn_mask is not None:
370
+ attn_mask = pad(attn_mask, (0, 1))
371
+ if key_padding_mask is not None:
372
+ key_padding_mask = pad(key_padding_mask, (0, 1))
373
+
374
+ # update source sequence length after adjustments
375
+ src_len = k.size(1)
376
+
377
+ # merge key padding and attention masks
378
+ if key_padding_mask is not None:
379
+ assert key_padding_mask.shape == (
380
+ bsz,
381
+ src_len,
382
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
383
+ key_padding_mask = (
384
+ key_padding_mask.view(bsz, 1, 1, src_len)
385
+ .expand(-1, num_heads, -1, -1)
386
+ .reshape(bsz * num_heads, 1, src_len)
387
+ )
388
+ if attn_mask is None:
389
+ attn_mask = key_padding_mask
390
+ else:
391
+ attn_mask = attn_mask + key_padding_mask
392
+
393
+ # adjust dropout probability
394
+ if not training:
395
+ dropout_p = 0.0
396
+
397
+ #
398
+ # (deep breath) calculate attention and out projection
399
+ #
400
+
401
+ if need_weights:
402
+ B, Nt, E = q.shape
403
+ q_scaled = q / math.sqrt(E)
404
+
405
+ assert not (
406
+ is_causal and attn_mask is None
407
+ ), "FIXME: is_causal not implemented for need_weights"
408
+
409
+ if attn_mask is not None:
410
+ attn_output_weights = torch.baddbmm(
411
+ attn_mask, q_scaled, k.transpose(-2, -1)
412
+ )
413
+ else:
414
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
415
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
416
+ if dropout_p > 0.0:
417
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
418
+
419
+ attn_output = torch.bmm(attn_output_weights, v)
420
+
421
+ attn_output = (
422
+ attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
423
+ )
424
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
425
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
426
+
427
+ # optionally average attention weights over heads
428
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
429
+ if average_attn_weights:
430
+ attn_output_weights = attn_output_weights.mean(dim=1)
431
+
432
+ if not is_batched:
433
+ # squeeze the output if input was unbatched
434
+ attn_output = attn_output.squeeze(1)
435
+ attn_output_weights = attn_output_weights.squeeze(0)
436
+ return attn_output, attn_output_weights
437
+ else:
438
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
439
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
440
+ # in order to match the input for SDPA of (N, num_heads, L, S)
441
+ if attn_mask is not None:
442
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
443
+ attn_mask = attn_mask.unsqueeze(0)
444
+ else:
445
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
446
+
447
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
448
+ k = k.view(bsz, num_heads, src_len, head_dim)
449
+ v = v.view(bsz, num_heads, src_len, head_dim)
450
+
451
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
452
+ attn_output = scaled_dot_product_attention(
453
+ q, k, v, attn_mask, dropout_p, is_causal
454
+ )
455
+
456
+ attn_output = (
457
+ attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
458
+ )
459
+
460
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
461
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
462
+ if not is_batched:
463
+ # squeeze the output if input was unbatched
464
+ attn_output = attn_output.squeeze(1)
465
+ return attn_output, None
GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py CHANGED
@@ -1,92 +1,92 @@
1
- from torch.nn.functional import *
2
- from torch.nn.functional import (
3
- _mha_shape_check,
4
- _canonical_mask,
5
- _none_or_dtype,
6
- _in_projection_packed,
7
- )
8
-
9
- def multi_head_attention_forward_patched(
10
- query,
11
- key,
12
- value,
13
- embed_dim_to_check: int,
14
- num_heads: int,
15
- in_proj_weight,
16
- in_proj_bias: Optional[Tensor],
17
- bias_k: Optional[Tensor],
18
- bias_v: Optional[Tensor],
19
- add_zero_attn: bool,
20
- dropout_p: float,
21
- out_proj_weight: Tensor,
22
- out_proj_bias: Optional[Tensor],
23
- training: bool = True,
24
- key_padding_mask: Optional[Tensor] = None,
25
- need_weights: bool = True,
26
- attn_mask: Optional[Tensor] = None,
27
- use_separate_proj_weight: bool = False,
28
- q_proj_weight: Optional[Tensor] = None,
29
- k_proj_weight: Optional[Tensor] = None,
30
- v_proj_weight: Optional[Tensor] = None,
31
- static_k: Optional[Tensor] = None,
32
- static_v: Optional[Tensor] = None,
33
- average_attn_weights: bool = True,
34
- is_causal: bool = False,
35
- cache=None,
36
- ) -> Tuple[Tensor, Optional[Tensor]]:
37
-
38
- # set up shape vars
39
- _, _, embed_dim = query.shape
40
- attn_mask = _canonical_mask(
41
- mask=attn_mask,
42
- mask_name="attn_mask",
43
- other_type=None,
44
- other_name="",
45
- target_type=query.dtype,
46
- check_other=False,
47
- )
48
- head_dim = embed_dim // num_heads
49
-
50
- proj_qkv = linear(query, in_proj_weight, in_proj_bias)
51
- proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
52
- q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
53
-
54
- if cache["first_infer"] == 1:
55
- cache["k"][cache["stage"]] = k
56
- cache["v"][cache["stage"]] = v
57
- else:
58
- cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
59
- cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
60
- k = cache["k"][cache["stage"]]
61
- v = cache["v"][cache["stage"]]
62
- cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
63
-
64
- attn_mask = _canonical_mask(
65
- mask=attn_mask,
66
- mask_name="attn_mask",
67
- other_type=None,
68
- other_name="",
69
- target_type=q.dtype,
70
- check_other=False,
71
- )
72
- attn_mask = attn_mask.unsqueeze(0)
73
-
74
- q = q.view(-1, num_heads, head_dim).transpose(0, 1)
75
- k = k.view(-1, num_heads, head_dim).transpose(0, 1)
76
- v = v.view(-1, num_heads, head_dim).transpose(0, 1)
77
-
78
- dropout_p = 0.0
79
- attn_mask = attn_mask.unsqueeze(0)
80
- q = q.view(num_heads, -1, head_dim).unsqueeze(0)
81
- k = k.view(num_heads, -1, head_dim).unsqueeze(0)
82
- v = v.view(num_heads, -1, head_dim).unsqueeze(0)
83
- attn_output = scaled_dot_product_attention(
84
- q, k, v, attn_mask, dropout_p, is_causal
85
- )
86
- attn_output = (
87
- attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
88
- )
89
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
90
- attn_output = attn_output.view(-1, 1, attn_output.size(1))
91
-
92
- return attn_output
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+
9
+ def multi_head_attention_forward_patched(
10
+ query,
11
+ key,
12
+ value,
13
+ embed_dim_to_check: int,
14
+ num_heads: int,
15
+ in_proj_weight,
16
+ in_proj_bias: Optional[Tensor],
17
+ bias_k: Optional[Tensor],
18
+ bias_v: Optional[Tensor],
19
+ add_zero_attn: bool,
20
+ dropout_p: float,
21
+ out_proj_weight: Tensor,
22
+ out_proj_bias: Optional[Tensor],
23
+ training: bool = True,
24
+ key_padding_mask: Optional[Tensor] = None,
25
+ need_weights: bool = True,
26
+ attn_mask: Optional[Tensor] = None,
27
+ use_separate_proj_weight: bool = False,
28
+ q_proj_weight: Optional[Tensor] = None,
29
+ k_proj_weight: Optional[Tensor] = None,
30
+ v_proj_weight: Optional[Tensor] = None,
31
+ static_k: Optional[Tensor] = None,
32
+ static_v: Optional[Tensor] = None,
33
+ average_attn_weights: bool = True,
34
+ is_causal: bool = False,
35
+ cache=None,
36
+ ) -> Tuple[Tensor, Optional[Tensor]]:
37
+
38
+ # set up shape vars
39
+ _, _, embed_dim = query.shape
40
+ attn_mask = _canonical_mask(
41
+ mask=attn_mask,
42
+ mask_name="attn_mask",
43
+ other_type=None,
44
+ other_name="",
45
+ target_type=query.dtype,
46
+ check_other=False,
47
+ )
48
+ head_dim = embed_dim // num_heads
49
+
50
+ proj_qkv = linear(query, in_proj_weight, in_proj_bias)
51
+ proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
52
+ q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
53
+
54
+ if cache["first_infer"] == 1:
55
+ cache["k"][cache["stage"]] = k
56
+ cache["v"][cache["stage"]] = v
57
+ else:
58
+ cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
59
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
60
+ k = cache["k"][cache["stage"]]
61
+ v = cache["v"][cache["stage"]]
62
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
63
+
64
+ attn_mask = _canonical_mask(
65
+ mask=attn_mask,
66
+ mask_name="attn_mask",
67
+ other_type=None,
68
+ other_name="",
69
+ target_type=q.dtype,
70
+ check_other=False,
71
+ )
72
+ attn_mask = attn_mask.unsqueeze(0)
73
+
74
+ q = q.view(-1, num_heads, head_dim).transpose(0, 1)
75
+ k = k.view(-1, num_heads, head_dim).transpose(0, 1)
76
+ v = v.view(-1, num_heads, head_dim).transpose(0, 1)
77
+
78
+ dropout_p = 0.0
79
+ attn_mask = attn_mask.unsqueeze(0)
80
+ q = q.view(num_heads, -1, head_dim).unsqueeze(0)
81
+ k = k.view(num_heads, -1, head_dim).unsqueeze(0)
82
+ v = v.view(num_heads, -1, head_dim).unsqueeze(0)
83
+ attn_output = scaled_dot_product_attention(
84
+ q, k, v, attn_mask, dropout_p, is_causal
85
+ )
86
+ attn_output = (
87
+ attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
88
+ )
89
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
90
+ attn_output = attn_output.view(-1, 1, attn_output.size(1))
91
+
92
+ return attn_output
GPT_SoVITS/TTS_infer_pack/TTS.py ADDED
@@ -0,0 +1,932 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import math
3
+ import os, sys
4
+ import random
5
+ import traceback
6
+
7
+ from tqdm import tqdm
8
+ now_dir = os.getcwd()
9
+ sys.path.append(now_dir)
10
+ import ffmpeg
11
+ import os
12
+ from typing import Generator, List, Union
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import yaml
17
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
18
+
19
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
20
+ from feature_extractor.cnhubert import CNHubert
21
+ from module.models import SynthesizerTrn
22
+ import librosa
23
+ from time import time as ttime
24
+ from tools.i18n.i18n import I18nAuto
25
+ from my_utils import load_audio
26
+ from module.mel_processing import spectrogram_torch
27
+ from TTS_infer_pack.text_segmentation_method import splits
28
+ from TTS_infer_pack.TextPreprocessor import TextPreprocessor
29
+ i18n = I18nAuto()
30
+
31
+ # configs/tts_infer.yaml
32
+ """
33
+ default:
34
+ device: cpu
35
+ is_half: false
36
+ bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
37
+ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
38
+ t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
39
+ vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
40
+
41
+ custom:
42
+ device: cuda
43
+ is_half: true
44
+ bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
45
+ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
46
+ t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
47
+ vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
48
+
49
+
50
+ """
51
+
52
+ def set_seed(seed:int):
53
+ seed = int(seed)
54
+ seed = seed if seed != -1 else random.randrange(1 << 32)
55
+ print(f"Set seed to {seed}")
56
+ os.environ['PYTHONHASHSEED'] = str(seed)
57
+ random.seed(seed)
58
+ np.random.seed(seed)
59
+ torch.manual_seed(seed)
60
+ try:
61
+ if torch.cuda.is_available():
62
+ torch.cuda.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
+ # torch.backends.cudnn.deterministic = True
65
+ # torch.backends.cudnn.benchmark = False
66
+ # torch.backends.cudnn.enabled = True
67
+ # 开启后会影响精度
68
+ torch.backends.cuda.matmul.allow_tf32 = False
69
+ torch.backends.cudnn.allow_tf32 = False
70
+ except:
71
+ pass
72
+ return seed
73
+
74
+ class TTS_Config:
75
+ default_configs={
76
+ "device": "cpu",
77
+ "is_half": False,
78
+ "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
79
+ "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
80
+ "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
81
+ "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
82
+ }
83
+ configs:dict = None
84
+ def __init__(self, configs: Union[dict, str]=None):
85
+
86
+ # 设置默认配置文件路径
87
+ configs_base_path:str = "GPT_SoVITS/configs/"
88
+ os.makedirs(configs_base_path, exist_ok=True)
89
+ self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
90
+
91
+ if configs in ["", None]:
92
+ if not os.path.exists(self.configs_path):
93
+ self.save_configs()
94
+ print(f"Create default config file at {self.configs_path}")
95
+ configs:dict = {"default": deepcopy(self.default_configs)}
96
+
97
+ if isinstance(configs, str):
98
+ self.configs_path = configs
99
+ configs:dict = self._load_configs(self.configs_path)
100
+
101
+ assert isinstance(configs, dict)
102
+ default_configs:dict = configs.get("default", None)
103
+ if default_configs is not None:
104
+ self.default_configs = default_configs
105
+
106
+ self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
107
+
108
+
109
+ self.device = self.configs.get("device", torch.device("cpu"))
110
+ self.is_half = self.configs.get("is_half", False)
111
+ self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
112
+ self.vits_weights_path = self.configs.get("vits_weights_path", None)
113
+ self.bert_base_path = self.configs.get("bert_base_path", None)
114
+ self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
115
+
116
+
117
+ if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
118
+ self.t2s_weights_path = self.default_configs['t2s_weights_path']
119
+ print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
120
+ if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
121
+ self.vits_weights_path = self.default_configs['vits_weights_path']
122
+ print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
123
+ if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
124
+ self.bert_base_path = self.default_configs['bert_base_path']
125
+ print(f"fall back to default bert_base_path: {self.bert_base_path}")
126
+ if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
127
+ self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
128
+ print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
129
+ self.update_configs()
130
+
131
+
132
+ self.max_sec = None
133
+ self.hz:int = 50
134
+ self.semantic_frame_rate:str = "25hz"
135
+ self.segment_size:int = 20480
136
+ self.filter_length:int = 2048
137
+ self.sampling_rate:int = 32000
138
+ self.hop_length:int = 640
139
+ self.win_length:int = 2048
140
+ self.n_speakers:int = 300
141
+
142
+ self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
143
+
144
+
145
+ def _load_configs(self, configs_path: str)->dict:
146
+ with open(configs_path, 'r') as f:
147
+ configs = yaml.load(f, Loader=yaml.FullLoader)
148
+
149
+ return configs
150
+
151
+ def save_configs(self, configs_path:str=None)->None:
152
+ configs={
153
+ "default":self.default_configs,
154
+ }
155
+ if self.configs is not None:
156
+ configs["custom"] = self.update_configs()
157
+
158
+ if configs_path is None:
159
+ configs_path = self.configs_path
160
+ with open(configs_path, 'w') as f:
161
+ yaml.dump(configs, f)
162
+
163
+ def update_configs(self):
164
+ self.config = {
165
+ "device" : str(self.device),
166
+ "is_half" : self.is_half,
167
+ "t2s_weights_path" : self.t2s_weights_path,
168
+ "vits_weights_path" : self.vits_weights_path,
169
+ "bert_base_path" : self.bert_base_path,
170
+ "cnhuhbert_base_path": self.cnhuhbert_base_path,
171
+ }
172
+ return self.config
173
+
174
+ def __str__(self):
175
+ self.configs = self.update_configs()
176
+ string = "TTS Config".center(100, '-') + '\n'
177
+ for k, v in self.configs.items():
178
+ string += f"{str(k).ljust(20)}: {str(v)}\n"
179
+ string += "-" * 100 + '\n'
180
+ return string
181
+
182
+ def __repr__(self):
183
+ return self.__str__()
184
+
185
+
186
+ class TTS:
187
+ def __init__(self, configs: Union[dict, str, TTS_Config]):
188
+ if isinstance(configs, TTS_Config):
189
+ self.configs = configs
190
+ else:
191
+ self.configs:TTS_Config = TTS_Config(configs)
192
+
193
+ self.t2s_model:Text2SemanticLightningModule = None
194
+ self.vits_model:SynthesizerTrn = None
195
+ self.bert_tokenizer:AutoTokenizer = None
196
+ self.bert_model:AutoModelForMaskedLM = None
197
+ self.cnhuhbert_model:CNHubert = None
198
+
199
+ self._init_models()
200
+
201
+ self.text_preprocessor:TextPreprocessor = \
202
+ TextPreprocessor(self.bert_model,
203
+ self.bert_tokenizer,
204
+ self.configs.device)
205
+
206
+
207
+ self.prompt_cache:dict = {
208
+ "ref_audio_path" : None,
209
+ "prompt_semantic": None,
210
+ "refer_spec" : None,
211
+ "prompt_text" : None,
212
+ "prompt_lang" : None,
213
+ "phones" : None,
214
+ "bert_features" : None,
215
+ "norm_text" : None,
216
+ }
217
+
218
+
219
+ self.stop_flag:bool = False
220
+ self.precision:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
221
+
222
+ def _init_models(self,):
223
+ self.init_t2s_weights(self.configs.t2s_weights_path)
224
+ self.init_vits_weights(self.configs.vits_weights_path)
225
+ self.init_bert_weights(self.configs.bert_base_path)
226
+ self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
227
+ # self.enable_half_precision(self.configs.is_half)
228
+
229
+
230
+
231
+ def init_cnhuhbert_weights(self, base_path: str):
232
+ print(f"Loading CNHuBERT weights from {base_path}")
233
+ self.cnhuhbert_model = CNHubert(base_path)
234
+ self.cnhuhbert_model=self.cnhuhbert_model.eval()
235
+ self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
236
+ if self.configs.is_half and str(self.configs.device)!="cpu":
237
+ self.cnhuhbert_model = self.cnhuhbert_model.half()
238
+
239
+
240
+
241
+ def init_bert_weights(self, base_path: str):
242
+ #print(f"Loading BERT weights from {base_path}")
243
+ self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
244
+ self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
245
+ self.bert_model=self.bert_model.eval()
246
+ self.bert_model = self.bert_model.to(self.configs.device)
247
+ if self.configs.is_half and str(self.configs.device)!="cpu":
248
+ self.bert_model = self.bert_model.half()
249
+
250
+ def init_vits_weights(self, weights_path: str):
251
+ #print(f"Loading VITS weights from {weights_path}")
252
+ self.configs.vits_weights_path = weights_path
253
+ self.configs.save_configs()
254
+ dict_s2 = torch.load(weights_path, map_location=self.configs.device)
255
+ hps = dict_s2["config"]
256
+ self.configs.filter_length = hps["data"]["filter_length"]
257
+ self.configs.segment_size = hps["train"]["segment_size"]
258
+ self.configs.sampling_rate = hps["data"]["sampling_rate"]
259
+ self.configs.hop_length = hps["data"]["hop_length"]
260
+ self.configs.win_length = hps["data"]["win_length"]
261
+ self.configs.n_speakers = hps["data"]["n_speakers"]
262
+ self.configs.semantic_frame_rate = "25hz"
263
+ kwargs = hps["model"]
264
+ vits_model = SynthesizerTrn(
265
+ self.configs.filter_length // 2 + 1,
266
+ self.configs.segment_size // self.configs.hop_length,
267
+ n_speakers=self.configs.n_speakers,
268
+ **kwargs
269
+ )
270
+ # if ("pretrained" not in weights_path):
271
+ if hasattr(vits_model, "enc_q"):
272
+ del vits_model.enc_q
273
+
274
+ vits_model = vits_model.to(self.configs.device)
275
+ vits_model = vits_model.eval()
276
+ vits_model.load_state_dict(dict_s2["weight"], strict=False)
277
+ self.vits_model = vits_model
278
+ if self.configs.is_half and str(self.configs.device)!="cpu":
279
+ self.vits_model = self.vits_model.half()
280
+
281
+
282
+ def init_t2s_weights(self, weights_path: str):
283
+ print(f"Loading Text2Semantic weights from {weights_path}")
284
+ self.configs.t2s_weights_path = weights_path
285
+ self.configs.save_configs()
286
+ self.configs.hz = 50
287
+ dict_s1 = torch.load(weights_path, map_location=self.configs.device)
288
+ config = dict_s1["config"]
289
+ self.configs.max_sec = config["data"]["max_sec"]
290
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
291
+ t2s_model.load_state_dict(dict_s1["weight"])
292
+ t2s_model = t2s_model.to(self.configs.device)
293
+ t2s_model = t2s_model.eval()
294
+ self.t2s_model = t2s_model
295
+ if self.configs.is_half and str(self.configs.device)!="cpu":
296
+ self.t2s_model = self.t2s_model.half()
297
+
298
+ def enable_half_precision(self, enable: bool = True):
299
+ '''
300
+ To enable half precision for the TTS model.
301
+ Args:
302
+ enable: bool, whether to enable half precision.
303
+
304
+ '''
305
+ if str(self.configs.device) == "cpu" and enable:
306
+ print("Half precision is not supported on CPU.")
307
+ return
308
+
309
+ self.configs.is_half = enable
310
+ self.precision = torch.float16 if enable else torch.float32
311
+ self.configs.save_configs()
312
+ if enable:
313
+ if self.t2s_model is not None:
314
+ self.t2s_model =self.t2s_model.half()
315
+ if self.vits_model is not None:
316
+ self.vits_model = self.vits_model.half()
317
+ if self.bert_model is not None:
318
+ self.bert_model =self.bert_model.half()
319
+ if self.cnhuhbert_model is not None:
320
+ self.cnhuhbert_model = self.cnhuhbert_model.half()
321
+ else:
322
+ if self.t2s_model is not None:
323
+ self.t2s_model = self.t2s_model.float()
324
+ if self.vits_model is not None:
325
+ self.vits_model = self.vits_model.float()
326
+ if self.bert_model is not None:
327
+ self.bert_model = self.bert_model.float()
328
+ if self.cnhuhbert_model is not None:
329
+ self.cnhuhbert_model = self.cnhuhbert_model.float()
330
+
331
+ def set_device(self, device: torch.device):
332
+ '''
333
+ To set the device for all models.
334
+ Args:
335
+ device: torch.device, the device to use for all models.
336
+ '''
337
+ self.configs.device = device
338
+ self.configs.save_configs()
339
+ if self.t2s_model is not None:
340
+ self.t2s_model = self.t2s_model.to(device)
341
+ if self.vits_model is not None:
342
+ self.vits_model = self.vits_model.to(device)
343
+ if self.bert_model is not None:
344
+ self.bert_model = self.bert_model.to(device)
345
+ if self.cnhuhbert_model is not None:
346
+ self.cnhuhbert_model = self.cnhuhbert_model.to(device)
347
+
348
+ def set_ref_audio(self, ref_audio_path:str):
349
+ '''
350
+ To set the reference audio for the TTS model,
351
+ including the prompt_semantic and refer_spepc.
352
+ Args:
353
+ ref_audio_path: str, the path of the reference audio.
354
+ '''
355
+ self._set_prompt_semantic(ref_audio_path)
356
+ self._set_ref_spec(ref_audio_path)
357
+
358
+ def _set_ref_spec(self, ref_audio_path):
359
+ audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
360
+ audio = torch.FloatTensor(audio)
361
+ audio_norm = audio
362
+ audio_norm = audio_norm.unsqueeze(0)
363
+ spec = spectrogram_torch(
364
+ audio_norm,
365
+ self.configs.filter_length,
366
+ self.configs.sampling_rate,
367
+ self.configs.hop_length,
368
+ self.configs.win_length,
369
+ center=False,
370
+ )
371
+ spec = spec.to(self.configs.device)
372
+ if self.configs.is_half:
373
+ spec = spec.half()
374
+ # self.refer_spec = spec
375
+ self.prompt_cache["refer_spec"] = spec
376
+
377
+
378
+ def _set_prompt_semantic(self, ref_wav_path:str):
379
+ zero_wav = np.zeros(
380
+ int(self.configs.sampling_rate * 0.3),
381
+ dtype=np.float16 if self.configs.is_half else np.float32,
382
+ )
383
+ with torch.no_grad():
384
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
385
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
386
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
387
+ wav16k = torch.from_numpy(wav16k)
388
+ zero_wav_torch = torch.from_numpy(zero_wav)
389
+ wav16k = wav16k.to(self.configs.device)
390
+ zero_wav_torch = zero_wav_torch.to(self.configs.device)
391
+ if self.configs.is_half:
392
+ wav16k = wav16k.half()
393
+ zero_wav_torch = zero_wav_torch.half()
394
+
395
+ wav16k = torch.cat([wav16k, zero_wav_torch])
396
+ hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))[
397
+ "last_hidden_state"
398
+ ].transpose(
399
+ 1, 2
400
+ ) # .float()
401
+ codes = self.vits_model.extract_latent(hubert_feature)
402
+
403
+ prompt_semantic = codes[0, 0].to(self.configs.device)
404
+ self.prompt_cache["prompt_semantic"] = prompt_semantic
405
+
406
+ def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
407
+ seq = sequences[0]
408
+ ndim = seq.dim()
409
+ if axis < 0:
410
+ axis += ndim
411
+ dtype:torch.dtype = seq.dtype
412
+ pad_value = torch.tensor(pad_value, dtype=dtype)
413
+ seq_lengths = [seq.shape[axis] for seq in sequences]
414
+ if max_length is None:
415
+ max_length = max(seq_lengths)
416
+ else:
417
+ max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
418
+
419
+ padded_sequences = []
420
+ for seq, length in zip(sequences, seq_lengths):
421
+ padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
422
+ padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value)
423
+ padded_sequences.append(padded_seq)
424
+ batch = torch.stack(padded_sequences)
425
+ return batch
426
+
427
+ def to_batch(self, data:list,
428
+ prompt_data:dict=None,
429
+ batch_size:int=5,
430
+ threshold:float=0.75,
431
+ split_bucket:bool=True,
432
+ device:torch.device=torch.device("cpu"),
433
+ precision:torch.dtype=torch.float32,
434
+ ):
435
+ _data:list = []
436
+ index_and_len_list = []
437
+ for idx, item in enumerate(data):
438
+ norm_text_len = len(item["norm_text"])
439
+ index_and_len_list.append([idx, norm_text_len])
440
+
441
+ batch_index_list = []
442
+ if split_bucket:
443
+ index_and_len_list.sort(key=lambda x: x[1])
444
+ index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
445
+
446
+ batch_index_list_len = 0
447
+ pos = 0
448
+ while pos <index_and_len_list.shape[0]:
449
+ # batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
450
+ pos_end = min(pos+batch_size,index_and_len_list.shape[0])
451
+ while pos < pos_end:
452
+ batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
453
+ score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
454
+ if (score>=threshold) or (pos_end-pos==1):
455
+ batch_index=index_and_len_list[pos:pos_end, 0].tolist()
456
+ batch_index_list_len += len(batch_index)
457
+ batch_index_list.append(batch_index)
458
+ pos = pos_end
459
+ break
460
+ pos_end=pos_end-1
461
+
462
+ assert batch_index_list_len == len(data)
463
+
464
+ else:
465
+ for i in range(len(data)):
466
+ if i%batch_size == 0:
467
+ batch_index_list.append([])
468
+ batch_index_list[-1].append(i)
469
+
470
+
471
+ for batch_idx, index_list in enumerate(batch_index_list):
472
+ item_list = [data[idx] for idx in index_list]
473
+ phones_list = []
474
+ phones_len_list = []
475
+ # bert_features_list = []
476
+ all_phones_list = []
477
+ all_phones_len_list = []
478
+ all_bert_features_list = []
479
+ norm_text_batch = []
480
+ bert_max_len = 0
481
+ phones_max_len = 0
482
+ for item in item_list:
483
+ if prompt_data is not None:
484
+ all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
485
+ .to(dtype=precision, device=device)
486
+ all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
487
+ phones = torch.LongTensor(item["phones"]).to(device)
488
+ # norm_text = prompt_data["norm_text"]+item["norm_text"]
489
+ else:
490
+ all_bert_features = item["bert_features"]\
491
+ .to(dtype=precision, device=device)
492
+ phones = torch.LongTensor(item["phones"]).to(device)
493
+ all_phones = phones
494
+ # norm_text = item["norm_text"]
495
+
496
+ bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
497
+ phones_max_len = max(phones_max_len, phones.shape[-1])
498
+
499
+ phones_list.append(phones)
500
+ phones_len_list.append(phones.shape[-1])
501
+ all_phones_list.append(all_phones)
502
+ all_phones_len_list.append(all_phones.shape[-1])
503
+ all_bert_features_list.append(all_bert_features)
504
+ norm_text_batch.append(item["norm_text"])
505
+
506
+ phones_batch = phones_list
507
+ all_phones_batch = all_phones_list
508
+ all_bert_features_batch = all_bert_features_list
509
+
510
+
511
+ max_len = max(bert_max_len, phones_max_len)
512
+ # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
513
+ #### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
514
+ # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
515
+ # all_bert_features_batch = all_bert_features_list
516
+ # all_bert_features_batch = torch.zeros((len(all_bert_features_list), 1024, max_len), dtype=precision, device=device)
517
+ # for idx, item in enumerate(all_bert_features_list):
518
+ # all_bert_features_batch[idx, :, : item.shape[-1]] = item
519
+
520
+ # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
521
+ # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
522
+ # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
523
+ # all_phones_batch = torch.stack(all_phones_list, dim=0)
524
+
525
+ # all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
526
+ # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
527
+ # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
528
+
529
+ batch = {
530
+ "phones": phones_batch,
531
+ "phones_len": torch.LongTensor(phones_len_list).to(device),
532
+ "all_phones": all_phones_batch,
533
+ "all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
534
+ "all_bert_features": all_bert_features_batch,
535
+ "norm_text": norm_text_batch,
536
+ "max_len": max_len,
537
+ }
538
+ _data.append(batch)
539
+
540
+ return _data, batch_index_list
541
+
542
+ def recovery_order(self, data:list, batch_index_list:list)->list:
543
+ '''
544
+ Recovery the order of the audio according to the batch_index_list.
545
+
546
+ Args:
547
+ data (List[list(np.ndarray)]): the out of order audio .
548
+ batch_index_list (List[list[int]]): the batch index list.
549
+
550
+ Returns:
551
+ list (List[np.ndarray]): the data in the original order.
552
+ '''
553
+ length = len(sum(batch_index_list, []))
554
+ _data = [None]*length
555
+ for i, index_list in enumerate(batch_index_list):
556
+ for j, index in enumerate(index_list):
557
+ _data[index] = data[i][j]
558
+ return _data
559
+
560
+ def stop(self,):
561
+ '''
562
+ Stop the inference process.
563
+ '''
564
+ self.stop_flag = True
565
+
566
+ @torch.no_grad()
567
+ def run(self, inputs:dict):
568
+ """
569
+ Text to speech inference.
570
+
571
+ Args:
572
+ inputs (dict):
573
+ {
574
+ "text": "", # str.(required) text to be synthesized
575
+ "text_lang: "", # str.(required) language of the text to be synthesized
576
+ "ref_audio_path": "", # str.(required) reference audio path
577
+ "prompt_text": "", # str.(optional) prompt text for the reference audio
578
+ "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
579
+ "top_k": 5, # int. top k sampling
580
+ "top_p": 1, # float. top p sampling
581
+ "temperature": 1, # float. temperature for sampling
582
+ "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
583
+ "batch_size": 1, # int. batch size for inference
584
+ "batch_threshold": 0.75, # float. threshold for batch splitting.
585
+ "split_bucket: True, # bool. whether to split the batch into multiple buckets.
586
+ "return_fragment": False, # bool. step by step return the audio fragment.
587
+ "speed_factor":1.0, # float. control the speed of the synthesized audio.
588
+ "fragment_interval":0.3, # float. to control the interval of the audio fragment.
589
+ "seed": -1, # int. random seed for reproducibility.
590
+ "parallel_infer": True, # bool. whether to use parallel inference.
591
+ "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
592
+ }
593
+ returns:
594
+ tuple[int, np.ndarray]: sampling rate and audio data.
595
+ """
596
+ ########## variables initialization ###########
597
+ self.stop_flag:bool = False
598
+ text:str = inputs.get("text", "")
599
+ text_lang:str = inputs.get("text_lang", "")
600
+ ref_audio_path:str = inputs.get("ref_audio_path", "")
601
+ prompt_text:str = inputs.get("prompt_text", "")
602
+ prompt_lang:str = inputs.get("prompt_lang", "")
603
+ top_k:int = inputs.get("top_k", 5)
604
+ top_p:float = inputs.get("top_p", 1)
605
+ temperature:float = inputs.get("temperature", 1)
606
+ text_split_method:str = inputs.get("text_split_method", "cut0")
607
+ batch_size = inputs.get("batch_size", 1)
608
+ batch_threshold = inputs.get("batch_threshold", 0.75)
609
+ speed_factor = inputs.get("speed_factor", 1.0)
610
+ split_bucket = inputs.get("split_bucket", True)
611
+ return_fragment = inputs.get("return_fragment", False)
612
+ fragment_interval = inputs.get("fragment_interval", 0.3)
613
+ seed = inputs.get("seed", -1)
614
+ seed = -1 if seed in ["", None] else seed
615
+ actual_seed = set_seed(seed)
616
+ parallel_infer = inputs.get("parallel_infer", True)
617
+ repetition_penalty = inputs.get("repetition_penalty", 1.35)
618
+
619
+ if parallel_infer:
620
+ #print(i18n("并行推理模式已开启"))
621
+ self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn
622
+ else:
623
+ #print(i18n("并行推理模式已关闭"))
624
+ self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307
625
+
626
+ if return_fragment:
627
+ #print(i18n("分段返回模式已开启"))
628
+ if split_bucket:
629
+ split_bucket = False
630
+ #print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
631
+
632
+ #if split_bucket:
633
+ # print(i18n("分桶处理模式已开启"))
634
+
635
+ if fragment_interval<0.01:
636
+ fragment_interval = 0.01
637
+ #print(i18n("分段间隔过小,已自动设置为0.01"))
638
+
639
+ no_prompt_text = False
640
+ if prompt_text in [None, ""]:
641
+ no_prompt_text = True
642
+
643
+ assert text_lang in self.configs.languages
644
+ if not no_prompt_text:
645
+ assert prompt_lang in self.configs.languages
646
+
647
+ if ref_audio_path in [None, ""] and \
648
+ ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)):
649
+ raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
650
+
651
+ ###### setting reference audio and prompt text preprocessing ########
652
+ t0 = ttime()
653
+ if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
654
+ self.set_ref_audio(ref_audio_path)
655
+
656
+ if not no_prompt_text:
657
+ prompt_text = prompt_text.strip("\n")
658
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "."
659
+ #print(i18n("实际输入的参考文本:"), prompt_text)
660
+ if self.prompt_cache["prompt_text"] != prompt_text:
661
+ self.prompt_cache["prompt_text"] = prompt_text
662
+ self.prompt_cache["prompt_lang"] = prompt_lang
663
+ phones, bert_features, norm_text = \
664
+ self.text_preprocessor.segment_and_extract_feature_for_text(
665
+ prompt_text,
666
+ prompt_lang)
667
+ self.prompt_cache["phones"] = phones
668
+ self.prompt_cache["bert_features"] = bert_features
669
+ self.prompt_cache["norm_text"] = norm_text
670
+
671
+ ###### text preprocessing ########
672
+ t1 = ttime()
673
+ data:list = None
674
+ if not return_fragment:
675
+ data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
676
+ if len(data) == 0:
677
+ yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
678
+ dtype=np.int16)
679
+ return
680
+
681
+ batch_index_list:list = None
682
+ data, batch_index_list = self.to_batch(data,
683
+ prompt_data=self.prompt_cache if not no_prompt_text else None,
684
+ batch_size=batch_size,
685
+ threshold=batch_threshold,
686
+ split_bucket=split_bucket,
687
+ device=self.configs.device,
688
+ precision=self.precision
689
+ )
690
+ else:
691
+ #print(i18n("############ 切分文本 ############"))
692
+ texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
693
+ data = []
694
+ for i in range(len(texts)):
695
+ if i%batch_size == 0:
696
+ data.append([])
697
+ data[-1].append(texts[i])
698
+
699
+ def make_batch(batch_texts):
700
+ batch_data = []
701
+ #print(i18n("############ 提取文本Bert特征 ############"))
702
+ for text in tqdm(batch_texts):
703
+ phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang)
704
+ if phones is None:
705
+ continue
706
+ res={
707
+ "phones": phones,
708
+ "bert_features": bert_features,
709
+ "norm_text": norm_text,
710
+ }
711
+ batch_data.append(res)
712
+ if len(batch_data) == 0:
713
+ return None
714
+ batch, _ = self.to_batch(batch_data,
715
+ prompt_data=self.prompt_cache if not no_prompt_text else None,
716
+ batch_size=batch_size,
717
+ threshold=batch_threshold,
718
+ split_bucket=False,
719
+ device=self.configs.device,
720
+ precision=self.precision
721
+ )
722
+ return batch[0]
723
+
724
+
725
+ t2 = ttime()
726
+ try:
727
+ #print("############ 推理 ############")
728
+ ###### inference ######
729
+ t_34 = 0.0
730
+ t_45 = 0.0
731
+ audio = []
732
+ for item in data:
733
+ t3 = ttime()
734
+ if return_fragment:
735
+ item = make_batch(item)
736
+ if item is None:
737
+ continue
738
+
739
+ batch_phones:List[torch.LongTensor] = item["phones"]
740
+ # batch_phones:torch.LongTensor = item["phones"]
741
+ batch_phones_len:torch.LongTensor = item["phones_len"]
742
+ all_phoneme_ids:torch.LongTensor = item["all_phones"]
743
+ all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
744
+ all_bert_features:torch.LongTensor = item["all_bert_features"]
745
+ norm_text:str = item["norm_text"]
746
+ max_len = item["max_len"]
747
+
748
+ #print(i18n("前端处理后的文本(每句):"), norm_text)
749
+ if no_prompt_text :
750
+ prompt = None
751
+ else:
752
+ prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
753
+
754
+
755
+ pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
756
+ all_phoneme_ids,
757
+ all_phoneme_lens,
758
+ prompt,
759
+ all_bert_features,
760
+ # prompt_phone_len=ph_offset,
761
+ top_k=top_k,
762
+ top_p=top_p,
763
+ temperature=temperature,
764
+ early_stop_num=self.configs.hz * self.configs.max_sec,
765
+ max_len=max_len,
766
+ repetition_penalty=repetition_penalty,
767
+ )
768
+ t4 = ttime()
769
+ t_34 += t4 - t3
770
+
771
+ refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\
772
+ .to(dtype=self.precision, device=self.configs.device)
773
+
774
+ batch_audio_fragment = []
775
+
776
+ # 这里要记得加 torch.no_grad() 不然速度慢一大截
777
+ # with torch.no_grad():
778
+
779
+ # ## vits并行推理 method 1
780
+ # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
781
+ # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
782
+ # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
783
+ # max_len = 0
784
+ # for i in range(0, len(batch_phones)):
785
+ # max_len = max(max_len, batch_phones[i].shape[-1])
786
+ # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
787
+ # batch_phones = batch_phones.to(self.configs.device)
788
+ # batch_audio_fragment = (self.vits_model.batched_decode(
789
+ # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
790
+ # ))
791
+
792
+ # ## vits并行推理 method 2
793
+ pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
794
+ upsample_rate = math.prod(self.vits_model.upsample_rates)
795
+ audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
796
+ audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
797
+ all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
798
+ _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
799
+ _batch_audio_fragment = (self.vits_model.decode(
800
+ all_pred_semantic, _batch_phones, refer_audio_spec
801
+ ).detach()[0, 0, :])
802
+ audio_frag_end_idx.insert(0, 0)
803
+ batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
804
+
805
+ # ## vits串行推理
806
+ # for i, idx in enumerate(idx_list):
807
+ # phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
808
+ # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
809
+ # audio_fragment =(self.vits_model.decode(
810
+ # _pred_semantic, phones, refer_audio_spec
811
+ # ).detach()[0, 0, :])
812
+ # batch_audio_fragment.append(
813
+ # audio_fragment
814
+ # ) ###试试重建不带上prompt部分
815
+
816
+ t5 = ttime()
817
+ t_45 += t5 - t4
818
+ if return_fragment:
819
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
820
+ yield self.audio_postprocess([batch_audio_fragment],
821
+ self.configs.sampling_rate,
822
+ None,
823
+ speed_factor,
824
+ False,
825
+ fragment_interval
826
+ )
827
+ else:
828
+ audio.append(batch_audio_fragment)
829
+
830
+ if self.stop_flag:
831
+ yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
832
+ dtype=np.int16)
833
+ return
834
+
835
+ if not return_fragment:
836
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
837
+ yield self.audio_postprocess(audio,
838
+ self.configs.sampling_rate,
839
+ batch_index_list,
840
+ speed_factor,
841
+ split_bucket,
842
+ fragment_interval
843
+ )
844
+
845
+ except Exception as e:
846
+ traceback.print_exc()
847
+ # 必须返回一个空音频, 否则会导致显存不释放。
848
+ yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
849
+ dtype=np.int16)
850
+ # 重置模型, 否则会导致显存释放不完全。
851
+ del self.t2s_model
852
+ del self.vits_model
853
+ self.t2s_model = None
854
+ self.vits_model = None
855
+ self.init_t2s_weights(self.configs.t2s_weights_path)
856
+ self.init_vits_weights(self.configs.vits_weights_path)
857
+ raise e
858
+ finally:
859
+ self.empty_cache()
860
+
861
+ def empty_cache(self):
862
+ try:
863
+ if "cuda" in str(self.configs.device):
864
+ torch.cuda.empty_cache()
865
+ elif str(self.configs.device) == "mps":
866
+ torch.mps.empty_cache()
867
+ except:
868
+ pass
869
+
870
+ def audio_postprocess(self,
871
+ audio:List[torch.Tensor],
872
+ sr:int,
873
+ batch_index_list:list=None,
874
+ speed_factor:float=1.0,
875
+ split_bucket:bool=True,
876
+ fragment_interval:float=0.3
877
+ )->tuple[int, np.ndarray]:
878
+ zero_wav = torch.zeros(
879
+ int(self.configs.sampling_rate * fragment_interval),
880
+ dtype=self.precision,
881
+ device=self.configs.device
882
+ )
883
+
884
+ for i, batch in enumerate(audio):
885
+ for j, audio_fragment in enumerate(batch):
886
+ max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
887
+ if max_audio>1: audio_fragment/=max_audio
888
+ audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
889
+ audio[i][j] = audio_fragment.cpu().numpy()
890
+
891
+
892
+ if split_bucket:
893
+ audio = self.recovery_order(audio, batch_index_list)
894
+ else:
895
+ # audio = [item for batch in audio for item in batch]
896
+ audio = sum(audio, [])
897
+
898
+
899
+ audio = np.concatenate(audio, 0)
900
+ audio = (audio * 32768).astype(np.int16)
901
+
902
+ try:
903
+ if speed_factor != 1.0:
904
+ audio = speed_change(audio, speed=speed_factor, sr=int(sr))
905
+ except Exception as e:
906
+ print(f"Failed to change speed of audio: \n{e}")
907
+
908
+ return sr, audio
909
+
910
+
911
+
912
+
913
+ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
914
+ # 将 NumPy 数组转换为原始 PCM 流
915
+ raw_audio = input_audio.astype(np.int16).tobytes()
916
+
917
+ # 设置 ffmpeg 输入流
918
+ input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
919
+
920
+ # 变速处理
921
+ output_stream = input_stream.filter('atempo', speed)
922
+
923
+ # 输出流到管道
924
+ out, _ = (
925
+ output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
926
+ .run(input=raw_audio, capture_stdout=True, capture_stderr=True)
927
+ )
928
+
929
+ # 将管道输出解码为 NumPy 数组
930
+ processed_audio = np.frombuffer(out, np.int16)
931
+
932
+ return processed_audio
GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os, sys
3
+
4
+ from tqdm import tqdm
5
+ now_dir = os.getcwd()
6
+ sys.path.append(now_dir)
7
+
8
+ import re
9
+ import torch
10
+ import LangSegment
11
+ from typing import Dict, List, Tuple
12
+ from text.cleaner import clean_text
13
+ from text import cleaned_text_to_sequence
14
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
15
+ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
16
+
17
+ from tools.i18n.i18n import I18nAuto
18
+
19
+ i18n = I18nAuto()
20
+
21
+ def get_first(text:str) -> str:
22
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
23
+ text = re.split(pattern, text)[0].strip()
24
+ return text
25
+
26
+ def merge_short_text_in_array(texts:str, threshold:int) -> list:
27
+ if (len(texts)) < 2:
28
+ return texts
29
+ result = []
30
+ text = ""
31
+ for ele in texts:
32
+ text += ele
33
+ if len(text) >= threshold:
34
+ result.append(text)
35
+ text = ""
36
+ if (len(text) > 0):
37
+ if len(result) == 0:
38
+ result.append(text)
39
+ else:
40
+ result[len(result) - 1] += text
41
+ return result
42
+
43
+
44
+
45
+
46
+
47
+
48
+ class TextPreprocessor:
49
+ def __init__(self, bert_model:AutoModelForMaskedLM,
50
+ tokenizer:AutoTokenizer, device:torch.device):
51
+ self.bert_model = bert_model
52
+ self.tokenizer = tokenizer
53
+ self.device = device
54
+
55
+ def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
56
+ print(i18n("############ 切分文本 ############"))
57
+ texts = self.pre_seg_text(text, lang, text_split_method)
58
+ result = []
59
+ print(i18n("############ 提取文本Bert特征 ############"))
60
+ for text in tqdm(texts):
61
+ phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
62
+ if phones is None:
63
+ continue
64
+ res={
65
+ "phones": phones,
66
+ "bert_features": bert_features,
67
+ "norm_text": norm_text,
68
+ }
69
+ result.append(res)
70
+ return result
71
+
72
+ def pre_seg_text(self, text:str, lang:str, text_split_method:str):
73
+ text = text.strip("\n")
74
+ if (text[0] not in splits and len(get_first(text)) < 4):
75
+ text = "。" + text if lang != "en" else "." + text
76
+ print(i18n("实际输入的目标文本:"))
77
+ print(text)
78
+
79
+ seg_method = get_seg_method(text_split_method)
80
+ text = seg_method(text)
81
+
82
+ while "\n\n" in text:
83
+ text = text.replace("\n\n", "\n")
84
+
85
+ _texts = text.split("\n")
86
+ _texts = merge_short_text_in_array(_texts, 5)
87
+ texts = []
88
+
89
+
90
+ for text in _texts:
91
+ # 解决输入目标文本的空行导致报错的问题
92
+ if (len(text.strip()) == 0):
93
+ continue
94
+ if (text[-1] not in splits): text += "。" if lang != "en" else "."
95
+
96
+ # 解决句子过长导致Bert报错的问题
97
+ if (len(text) > 510):
98
+ texts.extend(split_big_text(text))
99
+ else:
100
+ texts.append(text)
101
+
102
+ print(i18n("实际输入的目标文本(切句后):"))
103
+ print(texts)
104
+ return texts
105
+
106
+ def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
107
+ textlist, langlist = self.seg_text(texts, language)
108
+ if len(textlist) == 0:
109
+ return None, None, None
110
+
111
+ phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
112
+ return phones, bert_features, norm_text
113
+
114
+
115
+ def seg_text(self, text:str, language:str)->Tuple[list, list]:
116
+
117
+ textlist=[]
118
+ langlist=[]
119
+ if language in ["auto", "zh", "ja"]:
120
+ LangSegment.setfilters(["zh","ja","en","ko"])
121
+ for tmp in LangSegment.getTexts(text):
122
+ if tmp["text"] == "":
123
+ continue
124
+ if tmp["lang"] == "ko":
125
+ langlist.append("zh")
126
+ elif tmp["lang"] == "en":
127
+ langlist.append("en")
128
+ else:
129
+ # 因无法区别中日文汉字,以用户输入为准
130
+ langlist.append(language if language!="auto" else tmp["lang"])
131
+ textlist.append(tmp["text"])
132
+ elif language == "en":
133
+ LangSegment.setfilters(["en"])
134
+ formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
135
+ while " " in formattext:
136
+ formattext = formattext.replace(" ", " ")
137
+ if formattext != "":
138
+ textlist.append(formattext)
139
+ langlist.append("en")
140
+
141
+ elif language in ["all_zh","all_ja"]:
142
+
143
+ formattext = text
144
+ while " " in formattext:
145
+ formattext = formattext.replace(" ", " ")
146
+ language = language.replace("all_","")
147
+ if text == "":
148
+ return [],[]
149
+ textlist.append(formattext)
150
+ langlist.append(language)
151
+
152
+ else:
153
+ raise ValueError(f"language {language} not supported")
154
+
155
+ return textlist, langlist
156
+
157
+
158
+ def extract_bert_feature(self, textlist:list, langlist:list):
159
+ phones_list = []
160
+ bert_feature_list = []
161
+ norm_text_list = []
162
+ for i in range(len(textlist)):
163
+ lang = langlist[i]
164
+ phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang)
165
+ _bert_feature = self.get_bert_inf(phones, word2ph, norm_text, lang)
166
+ # phones_list.append(phones)
167
+ phones_list.extend(phones)
168
+ norm_text_list.append(norm_text)
169
+ bert_feature_list.append(_bert_feature)
170
+ bert_feature = torch.cat(bert_feature_list, dim=1)
171
+ # phones = sum(phones_list, [])
172
+ norm_text = ''.join(norm_text_list)
173
+ return phones_list, bert_feature, norm_text
174
+
175
+
176
+ def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
177
+ with torch.no_grad():
178
+ inputs = self.tokenizer(text, return_tensors="pt")
179
+ for i in inputs:
180
+ inputs[i] = inputs[i].to(self.device)
181
+ res = self.bert_model(**inputs, output_hidden_states=True)
182
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
183
+ assert len(word2ph) == len(text)
184
+ phone_level_feature = []
185
+ for i in range(len(word2ph)):
186
+ repeat_feature = res[i].repeat(word2ph[i], 1)
187
+ phone_level_feature.append(repeat_feature)
188
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
189
+ return phone_level_feature.T
190
+
191
+ def clean_text_inf(self, text:str, language:str):
192
+ phones, word2ph, norm_text = clean_text(text, language)
193
+ phones = cleaned_text_to_sequence(phones)
194
+ return phones, word2ph, norm_text
195
+
196
+ def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
197
+ language=language.replace("all_","")
198
+ if language == "zh":
199
+ feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
200
+ else:
201
+ feature = torch.zeros(
202
+ (1024, len(phones)),
203
+ dtype=torch.float32,
204
+ ).to(self.device)
205
+
206
+ return feature
207
+
208
+
209
+
210
+
GPT_SoVITS/TTS_infer_pack/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import TTS, text_segmentation_method
GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+ import re
6
+ from typing import Callable
7
+ from tools.i18n.i18n import I18nAuto
8
+
9
+ i18n = I18nAuto()
10
+
11
+ METHODS = dict()
12
+
13
+ def get_method(name:str)->Callable:
14
+ method = METHODS.get(name, None)
15
+ if method is None:
16
+ raise ValueError(f"Method {name} not found")
17
+ return method
18
+
19
+ def get_method_names()->list:
20
+ return list(METHODS.keys())
21
+
22
+ def register_method(name):
23
+ def decorator(func):
24
+ METHODS[name] = func
25
+ return func
26
+ return decorator
27
+
28
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
29
+
30
+ def split_big_text(text, max_len=510):
31
+ # 定义全角和半角标点符号
32
+ punctuation = "".join(splits)
33
+
34
+ # 切割文本
35
+ segments = re.split('([' + punctuation + '])', text)
36
+
37
+ # 初始化结果列表和当前片段
38
+ result = []
39
+ current_segment = ''
40
+
41
+ for segment in segments:
42
+ # 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
43
+ if len(current_segment + segment) > max_len:
44
+ result.append(current_segment)
45
+ current_segment = segment
46
+ else:
47
+ current_segment += segment
48
+
49
+ # 将最后一个片段加入结果列表
50
+ if current_segment:
51
+ result.append(current_segment)
52
+
53
+ return result
54
+
55
+
56
+
57
+ def split(todo_text):
58
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
59
+ if todo_text[-1] not in splits:
60
+ todo_text += "。"
61
+ i_split_head = i_split_tail = 0
62
+ len_text = len(todo_text)
63
+ todo_texts = []
64
+ while 1:
65
+ if i_split_head >= len_text:
66
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
67
+ if todo_text[i_split_head] in splits:
68
+ i_split_head += 1
69
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
70
+ i_split_tail = i_split_head
71
+ else:
72
+ i_split_head += 1
73
+ return todo_texts
74
+
75
+
76
+ # 不切
77
+ @register_method("cut0")
78
+ def cut0(inp):
79
+ return inp
80
+
81
+
82
+ # 凑四句一切
83
+ @register_method("cut1")
84
+ def cut1(inp):
85
+ inp = inp.strip("\n")
86
+ inps = split(inp)
87
+ split_idx = list(range(0, len(inps), 4))
88
+ # split_idx[-1] = None
89
+ split_idx.append(None)
90
+ if len(split_idx) > 1:
91
+ opts = []
92
+ for idx in range(len(split_idx) - 1):
93
+ opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
94
+ else:
95
+ opts = [inp]
96
+ return "\n".join(opts)
97
+
98
+ # 凑50字一切
99
+ @register_method("cut2")
100
+ def cut2(inp):
101
+ inp = inp.strip("\n")
102
+ inps = split(inp)
103
+ if len(inps) < 2:
104
+ return inp
105
+ opts = []
106
+ summ = 0
107
+ tmp_str = ""
108
+ for i in range(len(inps)):
109
+ summ += len(inps[i])
110
+ tmp_str += inps[i]
111
+ if summ > 50:
112
+ summ = 0
113
+ opts.append(tmp_str)
114
+ tmp_str = ""
115
+ if tmp_str != "":
116
+ opts.append(tmp_str)
117
+ # print(opts)
118
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
119
+ opts[-2] = opts[-2] + opts[-1]
120
+ opts = opts[:-1]
121
+ return "\n".join(opts)
122
+
123
+ # 按中文句号。切
124
+ @register_method("cut3")
125
+ def cut3(inp):
126
+ inp = inp.strip("\n")
127
+ return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
128
+
129
+ #按英文句号.切
130
+ @register_method("cut4")
131
+ def cut4(inp):
132
+ inp = inp.strip("\n")
133
+ return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
134
+
135
+ # 按标点符号切
136
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
137
+ @register_method("cut5")
138
+ def cut5(inp):
139
+ # if not re.search(r'[^\w\s]', inp[-1]):
140
+ # inp += '。'
141
+ inp = inp.strip("\n")
142
+ # punds = r'[,.;?!、,。?!;:…]'
143
+ punds = r'[,.;?!、,。?!;::…]'
144
+ items = re.split(f'({punds})', inp)
145
+ mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
146
+ # 在句子不存在符号或句尾无符号的时候保证文本完整
147
+ if len(items)%2 == 1:
148
+ mergeitems.append(items[-1])
149
+ opt = "\n".join(mergeitems)
150
+ return opt
151
+
152
+
153
+
154
+ if __name__ == '__main__':
155
+ method = get_method("cut5")
156
+ print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
157
+
GPT_SoVITS/configs/tts_infer.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom:
2
+ bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
3
+ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
4
+ device: cuda
5
+ is_half: true
6
+ t2s_weights_path: GPT_SoVITS/GPT_weights//ShioriNovella_GPT.ckpt
7
+ vits_weights_path: GPT_SoVITS/SoVITS_weights//ShioriNovella_SoVITS.pth
8
+ default:
9
+ bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
10
+ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
11
+ device: cpu
12
+ is_half: false
13
+ t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
14
+ vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
GPT_SoVITS/feature_extractor/cnhubert.py CHANGED
@@ -20,13 +20,16 @@ cnhubert_base_path = None
20
 
21
 
22
  class CNHubert(nn.Module):
23
- def __init__(self):
24
  super().__init__()
25
- self.model = HubertModel.from_pretrained(cnhubert_base_path)
 
 
26
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
27
- cnhubert_base_path
28
  )
29
 
 
30
  def forward(self, x):
31
  input_values = self.feature_extractor(
32
  x, return_tensors="pt", sampling_rate=16000
 
20
 
21
 
22
  class CNHubert(nn.Module):
23
+ def __init__(self, base_path:str=None):
24
  super().__init__()
25
+ if base_path is None:
26
+ base_path = cnhubert_base_path
27
+ self.model = HubertModel.from_pretrained(base_path)
28
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
29
+ base_path
30
  )
31
 
32
+
33
  def forward(self, x):
34
  input_values = self.feature_extractor(
35
  x, return_tensors="pt", sampling_rate=16000
GPT_SoVITS/inference_gui.py CHANGED
@@ -7,7 +7,7 @@ import soundfile as sf
7
  from tools.i18n.i18n import I18nAuto
8
  i18n = I18nAuto()
9
 
10
- from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
11
 
12
 
13
  class GPTSoVITSGUI(QMainWindow):
 
7
  from tools.i18n.i18n import I18nAuto
8
  i18n = I18nAuto()
9
 
10
+ from GPT_SoVITS.inference_webui_old import change_gpt_weights, change_sovits_weights, get_tts_wav
11
 
12
 
13
  class GPTSoVITSGUI(QMainWindow):
GPT_SoVITS/inference_webui.py CHANGED
@@ -1,688 +1,269 @@
1
- # Based on GPT-SoVITS-emo by kevinwang676
2
- # I fucking hate this thing. Why does every GPT-SoVITS space have to suck balls?
3
-
4
- import os
5
- import torch
6
- from openvoice import se_extractor
7
- from openvoice.api import BaseSpeakerTTS, ToneColorConverter
8
-
9
- if torch.cuda.is_available():
10
- device = "cuda"
11
- else:
12
- device = "cpu"
13
-
14
- ckpt_base = 'checkpoints/base_speakers/EN'
15
- ckpt_converter = 'checkpoints/converter'
16
- base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)
17
- base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')
18
-
19
- tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
20
- tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
21
-
22
- #source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)
23
- #source_se_style = torch.load(f'{ckpt_base}/en_style_se.pth').to(device)
24
-
25
- def vc_en(audio_ref, style_mode):
26
- text = "We have always tried to be at the intersection of technology and liberal arts, to be able to get the best of both, to make extremely advanced products from a technology point of view."
27
- if style_mode=="default":
28
- source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)
29
- reference_speaker = audio_ref
30
- target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
31
- save_path = "output.wav"
32
-
33
- # Run the base speaker tts
34
- src_path = "tmp.wav"
35
- base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)
36
-
37
- # Run the tone color converter
38
- encode_message = "@MyShell"
39
- tone_color_converter.convert(
40
- audio_src_path=src_path,
41
- src_se=source_se,
42
- tgt_se=target_se,
43
- output_path=save_path,
44
- message=encode_message)
45
-
46
- else:
47
- source_se = torch.load(f'{ckpt_base}/en_style_se.pth').to(device)
48
- reference_speaker = audio_ref
49
- target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
50
-
51
- save_path = "output.wav"
52
-
53
- # Run the base speaker tts
54
- src_path = "tmp.wav"
55
- base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=1.0)
56
-
57
- # Run the tone color converter
58
- encode_message = "@MyShell"
59
- tone_color_converter.convert(
60
- audio_src_path=src_path,
61
- src_se=source_se,
62
- tgt_se=target_se,
63
- output_path=save_path,
64
- message=encode_message)
65
-
66
- return "output.wav"
67
-
68
- # End
69
-
70
- import re, logging
71
- import LangSegment
72
- logging.getLogger("markdown_it").setLevel(logging.ERROR)
73
- logging.getLogger("urllib3").setLevel(logging.ERROR)
74
- logging.getLogger("httpcore").setLevel(logging.ERROR)
75
- logging.getLogger("httpx").setLevel(logging.ERROR)
76
- logging.getLogger("asyncio").setLevel(logging.ERROR)
77
- logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
78
- logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
79
- import pdb
80
- import json
81
-
82
- cnhubert_base_path = os.environ.get(
83
- "cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
84
- )
85
- bert_path = os.environ.get(
86
- "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
87
- )
88
- infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
89
- infer_ttswebui = int(infer_ttswebui)
90
- is_share = os.environ.get("is_share", "False")
91
- is_share = eval(is_share)
92
- if "_CUDA_VISIBLE_DEVICES" in os.environ:
93
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
94
- is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
95
- import gradio as gr
96
- from transformers import AutoModelForMaskedLM, AutoTokenizer
97
- import numpy as np
98
- import librosa
99
- from feature_extractor import cnhubert
100
-
101
- cnhubert.cnhubert_base_path = cnhubert_base_path
102
-
103
- from module.models import SynthesizerTrn
104
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
105
- from text import cleaned_text_to_sequence
106
- from text.cleaner import clean_text
107
- from time import time as ttime
108
- from module.mel_processing import spectrogram_torch
109
- from my_utils import load_audio
110
-
111
- # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
112
-
113
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
114
- bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
115
- if is_half == True:
116
- bert_model = bert_model.half().to(device)
117
- else:
118
- bert_model = bert_model.to(device)
119
-
120
-
121
- def get_bert_feature(text, word2ph):
122
- with torch.no_grad():
123
- inputs = tokenizer(text, return_tensors="pt")
124
- for i in inputs:
125
- inputs[i] = inputs[i].to(device)
126
- res = bert_model(**inputs, output_hidden_states=True)
127
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
128
- assert len(word2ph) == len(text)
129
- phone_level_feature = []
130
- for i in range(len(word2ph)):
131
- repeat_feature = res[i].repeat(word2ph[i], 1)
132
- phone_level_feature.append(repeat_feature)
133
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
134
- return phone_level_feature.T
135
-
136
-
137
- class DictToAttrRecursive(dict):
138
- def __init__(self, input_dict):
139
- super().__init__(input_dict)
140
- for key, value in input_dict.items():
141
- if isinstance(value, dict):
142
- value = DictToAttrRecursive(value)
143
- self[key] = value
144
- setattr(self, key, value)
145
-
146
- def __getattr__(self, item):
147
- try:
148
- return self[item]
149
- except KeyError:
150
- raise AttributeError(f"Attribute {item} not found")
151
-
152
- def __setattr__(self, key, value):
153
- if isinstance(value, dict):
154
- value = DictToAttrRecursive(value)
155
- super(DictToAttrRecursive, self).__setitem__(key, value)
156
- super().__setattr__(key, value)
157
-
158
- def __delattr__(self, item):
159
- try:
160
- del self[item]
161
- except KeyError:
162
- raise AttributeError(f"Attribute {item} not found")
163
-
164
-
165
- ssl_model = cnhubert.get_model()
166
- if is_half == True:
167
- ssl_model = ssl_model.half().to(device)
168
- else:
169
- ssl_model = ssl_model.to(device)
170
-
171
- clm = ""
172
-
173
- def change_sovits_weights(sovits_path):
174
- global vq_model, hps
175
- dict_s2 = torch.load(sovits_path, map_location="cpu")
176
- hps = dict_s2["config"]
177
- hps = DictToAttrRecursive(hps)
178
- hps.model.semantic_frame_rate = "25hz"
179
- vq_model = SynthesizerTrn(
180
- hps.data.filter_length // 2 + 1,
181
- hps.train.segment_size // hps.data.hop_length,
182
- n_speakers=hps.data.n_speakers,
183
- **hps.model
184
- )
185
- if ("pretrained" not in sovits_path):
186
- del vq_model.enc_q
187
- if is_half == True:
188
- vq_model = vq_model.half().to(device)
189
- else:
190
- vq_model = vq_model.to(device)
191
- vq_model.eval()
192
- print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
193
- #with open("./sweight.txt", "w", encoding="utf-8") as f:
194
- # f.write(sovits_path)
195
-
196
-
197
- #change_sovits_weights(sovits_path)
198
-
199
-
200
- def change_gpt_weights(gpt_path):
201
- global hz, max_sec, t2s_model, config
202
- hz = 50
203
- dict_s1 = torch.load(gpt_path, map_location="cpu")
204
- config = dict_s1["config"]
205
- max_sec = config["data"]["max_sec"]
206
- t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
207
- t2s_model.load_state_dict(dict_s1["weight"])
208
- if is_half == True:
209
- t2s_model = t2s_model.half()
210
- t2s_model = t2s_model.to(device)
211
- t2s_model.eval()
212
- total = sum([param.nelement() for param in t2s_model.parameters()])
213
- #print("Number of parameter: %.2fM" % (total / 1e6))
214
- #with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
215
-
216
-
217
- #change_gpt_weights(gpt_path)
218
-
219
-
220
- def get_spepc(hps, filename):
221
- audio = load_audio(filename, int(hps.data.sampling_rate))
222
- audio = torch.FloatTensor(audio)
223
- audio_norm = audio
224
- audio_norm = audio_norm.unsqueeze(0)
225
- spec = spectrogram_torch(
226
- audio_norm,
227
- hps.data.filter_length,
228
- hps.data.sampling_rate,
229
- hps.data.hop_length,
230
- hps.data.win_length,
231
- center=False,
232
- )
233
- return spec
234
-
235
-
236
- dict_language = {
237
- "ZH": "all_zh",#全部按中文识别
238
- "EN": "en",#全部按英文识别#######不变
239
- "JP": "all_ja",#全部按日文识别
240
- "ZH/EN": "zh",#按中英混合识别####不变
241
- "JP/EN": "ja",#按日英混合识别####不变
242
- "Automatic": "auto",#多语种启动切分识别语种
243
- }
244
-
245
-
246
- def clean_text_inf(text, language):
247
- phones, word2ph, norm_text = clean_text(text, language)
248
- phones = cleaned_text_to_sequence(phones)
249
- return phones, word2ph, norm_text
250
-
251
- dtype=torch.float16 if is_half == True else torch.float32
252
- def get_bert_inf(phones, word2ph, norm_text, language):
253
- language=language.replace("all_","")
254
- if language == "zh":
255
- bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
256
- else:
257
- bert = torch.zeros(
258
- (1024, len(phones)),
259
- dtype=torch.float16 if is_half == True else torch.float32,
260
- ).to(device)
261
-
262
- return bert
263
-
264
-
265
- splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
266
-
267
-
268
- def get_first(text):
269
- pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
270
- text = re.split(pattern, text)[0].strip()
271
- return text
272
-
273
-
274
- def get_phones_and_bert(text,language):
275
- if language in {"en","all_zh","all_ja"}:
276
- language = language.replace("all_","")
277
- if language == "en":
278
- LangSegment.setfilters(["en"])
279
- formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
280
- else:
281
- # 因无法区别中日文汉字,以用户输入为准
282
- formattext = text
283
- while " " in formattext:
284
- formattext = formattext.replace(" ", " ")
285
- phones, word2ph, norm_text = clean_text_inf(formattext, language)
286
- if language == "zh":
287
- bert = get_bert_feature(norm_text, word2ph).to(device)
288
- else:
289
- bert = torch.zeros(
290
- (1024, len(phones)),
291
- dtype=torch.float16 if is_half == True else torch.float32,
292
- ).to(device)
293
- elif language in {"zh", "ja","auto"}:
294
- textlist=[]
295
- langlist=[]
296
- LangSegment.setfilters(["zh","ja","en","ko"])
297
- if language == "auto":
298
- for tmp in LangSegment.getTexts(text):
299
- if tmp["lang"] == "ko":
300
- langlist.append("zh")
301
- textlist.append(tmp["text"])
302
- else:
303
- langlist.append(tmp["lang"])
304
- textlist.append(tmp["text"])
305
- else:
306
- for tmp in LangSegment.getTexts(text):
307
- if tmp["lang"] == "en":
308
- langlist.append(tmp["lang"])
309
- else:
310
- # 因无法区别中日文汉字,以用户输入为准
311
- langlist.append(language)
312
- textlist.append(tmp["text"])
313
- print(textlist)
314
- print(langlist)
315
- phones_list = []
316
- bert_list = []
317
- norm_text_list = []
318
- for i in range(len(textlist)):
319
- lang = langlist[i]
320
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
321
- bert = get_bert_inf(phones, word2ph, norm_text, lang)
322
- phones_list.append(phones)
323
- norm_text_list.append(norm_text)
324
- bert_list.append(bert)
325
- bert = torch.cat(bert_list, dim=1)
326
- phones = sum(phones_list, [])
327
- norm_text = ''.join(norm_text_list)
328
-
329
- return phones,bert.to(dtype),norm_text
330
-
331
-
332
- def merge_short_text_in_array(texts, threshold):
333
- if (len(texts)) < 2:
334
- return texts
335
- result = []
336
- text = ""
337
- for ele in texts:
338
- text += ele
339
- if len(text) >= threshold:
340
- result.append(text)
341
- text = ""
342
- if (len(text) > 0):
343
- if len(result) == 0:
344
- result.append(text)
345
- else:
346
- result[len(result) - 1] += text
347
- return result
348
-
349
- def get_tts_wav(name, gptmp, svmp, sty, ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut="None", top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
350
-
351
- global clm
352
- if(not ref_wav_path):
353
- ref_wav_path=f"referenceaudio/{name}/"+referencedata[name][0][sty]
354
- prompt_text=referencedata[name][1][sty]
355
- if clm!=name:
356
- print(f"Switching to model {name}")
357
- clm=name
358
- change_gpt_weights(gptmp)
359
- change_sovits_weights(svmp)
360
-
361
- if prompt_text is None or len(prompt_text) == 0:
362
- ref_free = True
363
- t0 = ttime()
364
- prompt_language = dict_language[prompt_language]
365
- text_language = dict_language[text_language]
366
- if not ref_free:
367
- prompt_text = prompt_text.strip("\n")
368
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
369
- text = text.strip("\n")
370
- if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
371
-
372
- print("Input text:", text)
373
- zero_wav = np.zeros(
374
- int(hps.data.sampling_rate * 0.3),
375
- dtype=np.float16 if is_half == True else np.float32,
376
- )
377
- with torch.no_grad():
378
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
379
- if (wav16k.shape[0] > 240000 or wav16k.shape[0] < 48000):
380
- raise OSError("Reference audio too long!!")
381
- wav16k = torch.from_numpy(wav16k)
382
- zero_wav_torch = torch.from_numpy(zero_wav)
383
- if is_half == True:
384
- wav16k = wav16k.half().to(device)
385
- zero_wav_torch = zero_wav_torch.half().to(device)
386
- else:
387
- wav16k = wav16k.to(device)
388
- zero_wav_torch = zero_wav_torch.to(device)
389
- wav16k = torch.cat([wav16k, zero_wav_torch])
390
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
391
- "last_hidden_state"
392
- ].transpose(
393
- 1, 2
394
- ) # .float()
395
- codes = vq_model.extract_latent(ssl_content)
396
-
397
- prompt_semantic = codes[0, 0]
398
- t1 = ttime()
399
-
400
- if (how_to_cut == "4 Sentences"):
401
- text = cut1(text)
402
- elif (how_to_cut == "50 Characters"):
403
- text = cut2(text)
404
- elif (how_to_cut == "Chinese/Japanese Punctuation"):
405
- text = cut3(text)
406
- elif (how_to_cut == "EN Punctuation"):
407
- text = cut4(text)
408
- elif (how_to_cut == "All Punctuation"):
409
- text = cut5(text)
410
- while "\n\n" in text:
411
- text = text.replace("\n\n", "\n")
412
- texts = text.split("\n")
413
- texts = merge_short_text_in_array(texts, 5)
414
- audio_opt = []
415
- if not ref_free:
416
- phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
417
-
418
- for text in texts:
419
- # 解决输入目标文本的空行导致报错的问题
420
- if (len(text.strip()) == 0):
421
- continue
422
- if (text[-1] not in splits): text += "。" if text_language != "en" else "."
423
- phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
424
- if not ref_free:
425
- bert = torch.cat([bert1, bert2], 1)
426
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
427
- else:
428
- bert = bert2
429
- all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
430
-
431
- bert = bert.to(device).unsqueeze(0)
432
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
433
- prompt = prompt_semantic.unsqueeze(0).to(device)
434
- t2 = ttime()
435
- with torch.no_grad():
436
- # pred_semantic = t2s_model.model.infer(
437
- pred_semantic, idx = t2s_model.model.infer_panel(
438
- all_phoneme_ids,
439
- all_phoneme_len,
440
- None if ref_free else prompt,
441
- bert,
442
- # prompt_phone_len=ph_offset,
443
- top_k=top_k,
444
- top_p=top_p,
445
- temperature=temperature,
446
- early_stop_num=hz * max_sec,
447
- )
448
- t3 = ttime()
449
- # print(pred_semantic.shape,idx)
450
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(
451
- 0
452
- ) # .unsqueeze(0)#mq要多unsqueeze一次
453
- refer = get_spepc(hps, ref_wav_path) # .to(device)
454
- if is_half == True:
455
- refer = refer.half().to(device)
456
- else:
457
- refer = refer.to(device)
458
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
459
- audio = (
460
- vq_model.decode(
461
- pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
462
- )
463
- .detach()
464
- .cpu()
465
- .numpy()[0, 0]
466
- ) ###试试重建不带上prompt部分
467
- max_audio=np.abs(audio).max()#简单防止16bit爆音
468
- if max_audio>1:audio/=max_audio
469
- audio_opt.append(audio)
470
- audio_opt.append(zero_wav)
471
- t4 = ttime()
472
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
473
- yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
474
- np.int16
475
- )
476
-
477
-
478
- def split(todo_text):
479
- todo_text = todo_text.replace("……", "。").replace("——", ",")
480
- if todo_text[-1] not in splits:
481
- todo_text += "。"
482
- i_split_head = i_split_tail = 0
483
- len_text = len(todo_text)
484
- todo_texts = []
485
- while 1:
486
- if i_split_head >= len_text:
487
- break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
488
- if todo_text[i_split_head] in splits:
489
- i_split_head += 1
490
- todo_texts.append(todo_text[i_split_tail:i_split_head])
491
- i_split_tail = i_split_head
492
- else:
493
- i_split_head += 1
494
- return todo_texts
495
-
496
-
497
- def cut1(inp):
498
- inp = inp.strip("\n")
499
- inps = split(inp)
500
- split_idx = list(range(0, len(inps), 4))
501
- split_idx[-1] = None
502
- if len(split_idx) > 1:
503
- opts = []
504
- for idx in range(len(split_idx) - 1):
505
- opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
506
- else:
507
- opts = [inp]
508
- return "\n".join(opts)
509
-
510
-
511
- def cut2(inp):
512
- inp = inp.strip("\n")
513
- inps = split(inp)
514
- if len(inps) < 2:
515
- return inp
516
- opts = []
517
- summ = 0
518
- tmp_str = ""
519
- for i in range(len(inps)):
520
- summ += len(inps[i])
521
- tmp_str += inps[i]
522
- if summ > 50:
523
- summ = 0
524
- opts.append(tmp_str)
525
- tmp_str = ""
526
- if tmp_str != "":
527
- opts.append(tmp_str)
528
- # print(opts)
529
- if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
530
- opts[-2] = opts[-2] + opts[-1]
531
- opts = opts[:-1]
532
- return "\n".join(opts)
533
-
534
-
535
- def cut3(inp):
536
- inp = inp.strip("\n")
537
- return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
538
-
539
-
540
- def cut4(inp):
541
- inp = inp.strip("\n")
542
- return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
543
-
544
-
545
- # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
546
- def cut5(inp):
547
- # if not re.search(r'[^\w\s]', inp[-1]):
548
- # inp += '。'
549
- inp = inp.strip("\n")
550
- punds = r'[,.;?!、,。?!;:…]'
551
- items = re.split(f'({punds})', inp)
552
- mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
553
- # 在句子不存在符号或句尾无符号的时候保证文本完整
554
- if len(items)%2 == 1:
555
- mergeitems.append(items[-1])
556
- opt = "\n".join(mergeitems)
557
- return opt
558
-
559
-
560
- def custom_sort_key(s):
561
- # 使用正则表达式提取字符串中的数字部分和非数字部分
562
- parts = re.split('(\d+)', s)
563
- # 将数字部分转换为整数,非数字部分保持不变
564
- parts = [int(part) if part.isdigit() else part for part in parts]
565
- return parts
566
-
567
-
568
- def change_choices():
569
- SoVITS_names, GPT_names = get_weights_names()
570
- return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
571
-
572
-
573
- pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
574
- pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
575
- SoVITS_weight_root = "GPT_SoVITS/SoVITS_weights"
576
- GPT_weight_root = "GPT_SoVITS/GPT_weights"
577
- #os.makedirs(SoVITS_weight_root, exist_ok=True)
578
- #os.makedirs(GPT_weight_root, exist_ok=True)
579
-
580
-
581
- def get_weights_names():
582
- SoVITS_names = [pretrained_sovits_name]
583
- for name in os.listdir(SoVITS_weight_root):
584
- if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name))
585
- GPT_names = [pretrained_gpt_name]
586
- for name in os.listdir(GPT_weight_root):
587
- if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name))
588
- return SoVITS_names, GPT_names
589
-
590
- def load_models():
591
- print("Loading models...")
592
- voices=[]
593
- ustyles={}
594
- with open("voicelist.json", "r", encoding="utf-8") as f:
595
- voc_info = json.load(f)
596
- for name, info in voc_info.items():
597
- if not info['enable']:
598
- continue
599
- title= info['title']
600
- gptmodelpath= "%s/%s" % (GPT_weight_root, info['gpt_model_path'])
601
- sovitsmodelpath= "%s/%s" % (SoVITS_weight_root, info['sovits_model_path'])
602
- author= info['modelauthor']
603
- image = info['cover']
604
- styles = info['styles']
605
- styletrans = info['styletrans']
606
- st=[styles, styletrans]
607
- voices.append((name, title, gptmodelpath, sovitsmodelpath, author, image))
608
- ustyles[name]=st
609
- print(f"Indexed model {title}")
610
- return voices, ustyles
611
-
612
- modeldata, referencedata = load_models()
613
-
614
- SoVITS_names, GPT_names = get_weights_names()
615
-
616
- #print(os.getcwd())
617
- #for r, _, f in os.walk(os.getcwd()):
618
- # for n in f:
619
- # print(os.path.join(r, n))
620
-
621
- #Gradio preload
622
- text = gr.TextArea(label="Input Text", value="Hello there! This is test audio of a new text to speech tool.")
623
- text_language = gr.Dropdown(label="Language", choices=["EN", "JP", "ZH", "ZH/EN", "JP/EN", "Automatic"], value="EN")
624
- how_to_cut = gr.Dropdown(label="Slicing Method",
625
- choices=["None", "4 Sentences", "50 Characters", "ZH/JP Punctuation", "EN Punctuation", "All Punctuation" ],
626
- value="4 Sentences",
627
- interactive=True,
628
- )
629
- top_k = gr.Slider(minimum=1,maximum=100,step=1,label="top_k",value=5,interactive=True)
630
- top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label="top_p",value=1,interactive=True)
631
- temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label="temperature",value=1,interactive=True)
632
-
633
- #Main gradio
634
- with gr.Blocks(title="Lemonfoot GPT-SoVITS") as app:
635
- gr.Markdown(
636
- "# Lemonfoot GPT-SoVITS 🚀🍋\n"
637
- "### Space by Kit Lemonfoot / Noel Shirogane's High Flying Birds\n"
638
- "Based on code originally by RVC_Boss and kevinwang676\n\n"
639
- "Do no evil.\n\n"
640
- "**NOTE:** *This is more or less a test Space*. HuggingFace Spaces are not capable of running GPT-SoVITS efficiently; a single generation may take upwards of an hour to infer one sentence. "
641
- "If you wish to use these models for legitimate generation, it is recommended to [download the models individually](https://huggingface.co/Kit-Lemonfoot/kitlemonfoot_gptsovits_models) and run GPT-SoVITS locally."
642
- )
643
- for (name, title, gptmodelpath, sovitsmodelpath, author, image) in modeldata:
644
- with gr.TabItem(name):
645
- with gr.Row():
646
- with gr.Column():
647
- n = gr.Textbox(value=name, visible=False, interactive=False)
648
- gptmp = gr.Textbox(value=gptmodelpath, visible=False, interactive=False)
649
- svmp = gr.Textbox(value=sovitsmodelpath, visible=False, interactive=False)
650
- gr.Markdown(f"**{title}**\n\n Dataset author: {author}")
651
- gr.Image(f"images/{image}", label=None, show_label=False, width=300, show_download_button=False, container=False, show_share_button=False)
652
- with gr.Column():
653
- with gr.TabItem("Style using a preset"):
654
- sty = gr.Dropdown(
655
- label="Current style",
656
- choices=referencedata[name][0].keys(),
657
- value="Neutral",
658
- interactive=True
659
- )
660
- with gr.TabItem("Style using a different audio"):
661
- with gr.Column():
662
- ref_audio_path = gr.Audio(label="Reference Audio", type="filepath")
663
- ref_text_free = gr.Checkbox(label="Enables no text-reference mode.", value=False, interactive=True)
664
- prompt_text = gr.Textbox(label="Reference Audio Text", interactive=True)
665
- prompt_language = gr.Textbox(value="EN", visible=False, interactive=False)
666
- with gr.Column():
667
- inference_button = gr.Button("Synthesize", variant="primary")
668
- output = gr.Audio(label="Output")
669
-
670
- inference_button.click(
671
- get_tts_wav,
672
- inputs=[n, gptmp, svmp, sty, ref_audio_path, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
673
- outputs=[output]
674
- )
675
-
676
- #bottom info
677
- with gr.Row():
678
- with gr.Column():
679
- text.render()
680
- text_language.render()
681
- how_to_cut.render()
682
- with gr.Column():
683
- gr.Markdown("### GPT Sampling Parameters")
684
- top_k.render()
685
- top_p.render()
686
- temperature.render()
687
-
688
- app.queue().launch()
 
1
+ # Based on GPT-SoVITS-fast-inference by ChasonJiang
2
+
3
+ import random
4
+ import os
5
+ import torch
6
+
7
+ if torch.cuda.is_available():
8
+ device = "cuda"
9
+ else:
10
+ device = "cpu"
11
+
12
+ import re, logging
13
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
14
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
15
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
16
+ logging.getLogger("httpx").setLevel(logging.ERROR)
17
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
18
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
19
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
20
+ import pdb
21
+ import json
22
+
23
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
24
+ infer_ttswebui = int(infer_ttswebui)
25
+ is_share = os.environ.get("is_share", "False")
26
+ is_share = eval(is_share)
27
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
28
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
29
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
30
+ gpt_path=None
31
+ sovits_path=None
32
+ #gpt_path = os.environ.get("gpt_path", None)
33
+ #sovits_path = os.environ.get("sovits_path", None)
34
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
35
+ bert_path = os.environ.get("bert_path", None)
36
+
37
+ import gradio as gr
38
+ from TTS_infer_pack.TTS import TTS, TTS_Config
39
+ from TTS_infer_pack.text_segmentation_method import get_method
40
+
41
+
42
+ dict_language = {
43
+ "ZH": "all_zh",#全部按中文识别
44
+ "EN": "en",#全部按英文识别#######不变
45
+ "JP": "all_ja",#全部按日文识别
46
+ "ZH/EN": "zh",#按中英混合识别####不变
47
+ "JP/EN": "ja",#按日英混合识别####不变
48
+ "Automatic": "auto",#多语种启动切分识别语种
49
+ }
50
+
51
+ cut_method = {
52
+ "None":"cut0",
53
+ "4 Sentences": "cut1",
54
+ "50 Characters": "cut2",
55
+ "ZH/JP Punctuation": "cut3",
56
+ "EN Punctuation": "cut4",
57
+ "All Punctuation": "cut5",
58
+ }
59
+
60
+ tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
61
+ tts_config.device = device
62
+ tts_config.is_half = is_half
63
+ if gpt_path is not None:
64
+ tts_config.t2s_weights_path = gpt_path
65
+ if sovits_path is not None:
66
+ tts_config.vits_weights_path = sovits_path
67
+ if cnhubert_base_path is not None:
68
+ tts_config.cnhuhbert_base_path = cnhubert_base_path
69
+ if bert_path is not None:
70
+ tts_config.bert_base_path = bert_path
71
+
72
+ print(tts_config)
73
+ tts_pipeline = TTS(tts_config)
74
+ gpt_path = tts_config.t2s_weights_path
75
+ sovits_path = tts_config.vits_weights_path
76
+
77
+ clm= ""
78
+
79
+ def inference(name, gptmp, svmp, sty, text, text_lang,
80
+ ref_audio_path, prompt_text,
81
+ prompt_lang, top_k,
82
+ top_p, temperature,
83
+ text_split_method, batch_size,
84
+ speed_factor, ref_text_free,
85
+ split_bucket,fragment_interval,
86
+ seed, keep_random, parallel_infer,
87
+ repetition_penalty
88
+ ):
89
+
90
+ global clm
91
+ #Live switching
92
+ if(not ref_audio_path):
93
+ ref_audio_path=f"referenceaudio/{name}/"+referencedata[name][0][sty]
94
+ prompt_text=referencedata[name][1][sty]
95
+ if clm!=name:
96
+ print(f"Switching to model {name}")
97
+ clm=name
98
+ print(os.getcwd())
99
+ tts_pipeline.init_t2s_weights(gptmp)
100
+ tts_pipeline.init_vits_weights(svmp)
101
+
102
+
103
+ seed = -1 if keep_random else seed
104
+ actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
105
+ inputs={
106
+ "text": text,
107
+ "text_lang": dict_language[text_lang],
108
+ "ref_audio_path": ref_audio_path,
109
+ "prompt_text": prompt_text if not ref_text_free else "",
110
+ "prompt_lang": dict_language[prompt_lang],
111
+ "top_k": top_k,
112
+ "top_p": top_p,
113
+ "temperature": temperature,
114
+ "text_split_method": cut_method[text_split_method],
115
+ "batch_size":int(batch_size),
116
+ "speed_factor":float(speed_factor),
117
+ "split_bucket":split_bucket,
118
+ "return_fragment":False,
119
+ "fragment_interval":fragment_interval,
120
+ "seed":actual_seed,
121
+ "parallel_infer": parallel_infer,
122
+ "repetition_penalty": repetition_penalty,
123
+ }
124
+ for item in tts_pipeline.run(inputs):
125
+ yield item, actual_seed
126
+
127
+ def custom_sort_key(s):
128
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
129
+ parts = re.split('(\d+)', s)
130
+ # 将数字部分转换为整数,非数字部分保持不变
131
+ parts = [int(part) if part.isdigit() else part for part in parts]
132
+ return parts
133
+
134
+
135
+ def change_choices():
136
+ SoVITS_names, GPT_names = get_weights_names()
137
+ return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
138
+
139
+
140
+ pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
141
+ pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
142
+ SoVITS_weight_root = "GPT_SoVITS/SoVITS_weights/"
143
+ GPT_weight_root = "GPT_SoVITS/GPT_weights/"
144
+ #os.makedirs(SoVITS_weight_root, exist_ok=True)
145
+ #os.makedirs(GPT_weight_root, exist_ok=True)
146
+
147
+ def get_weights_names():
148
+ SoVITS_names = [pretrained_sovits_name]
149
+ for name in os.listdir(SoVITS_weight_root):
150
+ if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name))
151
+ GPT_names = [pretrained_gpt_name]
152
+ for name in os.listdir(GPT_weight_root):
153
+ if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name))
154
+ return SoVITS_names, GPT_names
155
+
156
+ def load_models():
157
+ print("Loading models...")
158
+ voices=[]
159
+ ustyles={}
160
+ with open("voicelist.json", "r", encoding="utf-8") as f:
161
+ voc_info = json.load(f)
162
+ for name, info in voc_info.items():
163
+ if not info['enable']:
164
+ continue
165
+ title= info['title']
166
+ #gptmodelpath= info['gpt_model_path']
167
+ #sovitsmodelpath= info['sovits_model_path']
168
+ gptmodelpath= "%s/%s" % (GPT_weight_root, info['gpt_model_path'])
169
+ sovitsmodelpath= "%s/%s" % (SoVITS_weight_root, info['sovits_model_path'])
170
+ author= info['modelauthor']
171
+ image = info['cover']
172
+ styles = info['styles']
173
+ styletrans = info['styletrans']
174
+ st=[styles, styletrans]
175
+ voices.append((name, title, gptmodelpath, sovitsmodelpath, author, image))
176
+ ustyles[name]=st
177
+ print(f"Indexed model {title}")
178
+ return voices, ustyles
179
+
180
+ modeldata, referencedata = load_models()
181
+
182
+ #print(os.getcwd())
183
+ #for r, _, f in os.walk(os.getcwd()):
184
+ # for n in f:
185
+ # print(os.path.join(r, n))
186
+
187
+ #Gradio preload
188
+ text = gr.TextArea(label="Input Text", value="Hello there! This is test audio of a new text to speech tool.")
189
+ text_language = gr.Dropdown(label="Language", choices=["EN", "JP", "ZH", "ZH/EN", "JP/EN", "Automatic"], value="EN")
190
+ how_to_cut = gr.Dropdown(label="Slicing Method",
191
+ choices=["None", "4 Sentences", "50 Characters", "ZH/JP Punctuation", "EN Punctuation", "All Punctuation" ],
192
+ value="4 Sentences",
193
+ interactive=True,
194
+ )
195
+ top_k = gr.Slider(minimum=1,maximum=100,step=1,label="Top_k",value=5,interactive=True)
196
+ top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label="Top_p",value=1,interactive=True)
197
+ temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label="Temperature",value=1,interactive=True)
198
+ batch_size = gr.Slider(minimum=1,maximum=200,step=1,label="Batch Size",value=20,interactive=True)
199
+ fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label="Fragment Interval",value=0.3,interactive=True)
200
+ speed_factor = gr.Slider(minimum=0.50,maximum=2,step=0.05,label="Speed Factor",value=1.0,interactive=True)
201
+ repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label="Repetition Penalty",value=1.35,interactive=True)
202
+ parallel_infer = gr.Checkbox(label="Parallel Infer", value=True, interactive=True, show_label=True)
203
+ split_bucket = gr.Checkbox(label="Split Bucket", value=True, interactive=True, show_label=True)
204
+ seed = gr.Number(label="Random Seed",value=-1, interactive=True, show_label=True)
205
+ keep_random = gr.Checkbox(label="Use Randomized Seed", value=True, interactive=True, show_label=True)
206
+
207
+ #Main gradio
208
+ with gr.Blocks(title="Lemonfoot GPT-SoVITS") as app:
209
+ gr.Markdown(
210
+ "# Lemonfoot GPT-SoVITS 🚀🍋\n"
211
+ "### Space by Kit Lemonfoot / Noel Shirogane's High Flying Birds\n"
212
+ "Based on code originally by RVC_Boss and ChasonJiang\n\n"
213
+ "Do no evil.\n\n"
214
+ )
215
+ for (name, title, gptmodelpath, sovitsmodelpath, author, image) in modeldata:
216
+ with gr.TabItem(name):
217
+ with gr.Row():
218
+ with gr.Column():
219
+ n = gr.Textbox(value=name, visible=False, interactive=False)
220
+ gptmp = gr.Textbox(value=gptmodelpath, visible=False, interactive=False)
221
+ svmp = gr.Textbox(value=sovitsmodelpath, visible=False, interactive=False)
222
+ gr.Markdown(f"**{title}**\n\n Dataset author: {author}")
223
+ gr.Image(f"images/{image}", label=None, show_label=False, width=300, show_download_button=False, container=False, show_share_button=False)
224
+ with gr.Column():
225
+ with gr.TabItem("Style using a preset"):
226
+ sty = gr.Dropdown(
227
+ label="Current style",
228
+ choices=referencedata[name][0].keys(),
229
+ value="Neutral",
230
+ interactive=True
231
+ )
232
+ with gr.TabItem("Style using a different audio"):
233
+ with gr.Column():
234
+ ref_audio_path = gr.Audio(label="Reference Audio", type="filepath")
235
+ ref_text_free = gr.Checkbox(label="Enables no text-reference mode.", value=False, interactive=True)
236
+ prompt_text = gr.Textbox(label="Reference Audio Text", interactive=True)
237
+ prompt_language = gr.Textbox(value="EN", visible=False, interactive=False)
238
+ with gr.Column():
239
+ inference_button = gr.Button("Synthesize", variant="primary")
240
+ output = gr.Audio(label="Output")
241
+
242
+ inference_button.click(
243
+ inference,
244
+ inputs=[n, gptmp, svmp, sty, text, text_language, ref_audio_path, prompt_text, prompt_language, top_k, top_p, temperature, how_to_cut, batch_size, speed_factor, ref_text_free, split_bucket, fragment_interval, seed, keep_random, parallel_infer, repetition_penalty],
245
+ outputs=[output, seed]
246
+ )
247
+
248
+ #bottom info
249
+ with gr.Row():
250
+ with gr.Column():
251
+ text.render()
252
+ text_language.render()
253
+ how_to_cut.render()
254
+ with gr.Column():
255
+ temperature.render()
256
+ speed_factor.render()
257
+ with gr.Accordion("Advanced Inference Parameters", open=False):
258
+ top_k.render()
259
+ top_p.render()
260
+ batch_size.render()
261
+ fragment_interval.render()
262
+ repetition_penalty.render()
263
+ parallel_infer.render()
264
+ split_bucket.render()
265
+ seed.render()
266
+ keep_random.render()
267
+
268
+
269
+ app.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT_SoVITS/module/models.py CHANGED
@@ -1,5 +1,6 @@
1
  import copy
2
  import math
 
3
  import torch
4
  from torch import nn
5
  from torch.nn import functional as F
@@ -986,6 +987,55 @@ class SynthesizerTrn(nn.Module):
986
 
987
  o = self.dec((z * y_mask)[:, :, :], g=ge)
988
  return o
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
989
 
990
  def extract_latent(self, x):
991
  ssl = self.ssl_proj(x)
 
1
  import copy
2
  import math
3
+ from typing import List
4
  import torch
5
  from torch import nn
6
  from torch.nn import functional as F
 
987
 
988
  o = self.dec((z * y_mask)[:, :, :], g=ge)
989
  return o
990
+
991
+
992
+ @torch.no_grad()
993
+ def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
994
+ ge = None
995
+ if refer is not None:
996
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
997
+ refer_mask = torch.unsqueeze(
998
+ commons.sequence_mask(refer_lengths, refer.size(2)), 1
999
+ ).to(refer.dtype)
1000
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
1001
+
1002
+ # y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
1003
+ # codes.dtype
1004
+ # )
1005
+ y_lengths = (y_lengths * 2).long().to(codes.device)
1006
+ text_lengths = text_lengths.long().to(text.device)
1007
+ # y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
1008
+ # text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
1009
+
1010
+ # 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
1011
+ quantized = self.quantizer.decode(codes)
1012
+ if self.semantic_frame_rate == "25hz":
1013
+ quantized = F.interpolate(
1014
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
1015
+ )
1016
+
1017
+ x, m_p, logs_p, y_mask = self.enc_p(
1018
+ quantized, y_lengths, text, text_lengths, ge
1019
+ )
1020
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1021
+
1022
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
1023
+ z_masked = (z * y_mask)[:, :, :]
1024
+
1025
+ # 串行。把padding部分去掉再decode
1026
+ o_list:List[torch.Tensor] = []
1027
+ for i in range(z_masked.shape[0]):
1028
+ z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
1029
+ o = self.dec(z_slice, g=ge)[0, 0, :].detach()
1030
+ o_list.append(o)
1031
+
1032
+ # 并行(会有问题)。先decode,再把padding的部分去掉
1033
+ # o = self.dec(z_masked, g=ge)
1034
+ # upsample_rate = int(math.prod(self.upsample_rates))
1035
+ # o_lengths = y_lengths*upsample_rate
1036
+ # o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
1037
+
1038
+ return o_list
1039
 
1040
  def extract_latent(self, x):
1041
  ssl = self.ssl_proj(x)
GPT_SoVITS/my_utils.py CHANGED
@@ -1,21 +1,21 @@
1
- import ffmpeg
2
- import numpy as np
3
-
4
-
5
- def load_audio(file, sr):
6
- try:
7
- # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
8
- # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
9
- # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
10
- file = (
11
- file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
12
- ) # 防止小白拷路径头尾带了空格和"和回车
13
- out, _ = (
14
- ffmpeg.input(file, threads=0)
15
- .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
16
- .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
17
- )
18
- except Exception as e:
19
- raise RuntimeError(f"Failed to load audio: {e}")
20
-
21
- return np.frombuffer(out, np.float32).flatten()
 
1
+ import ffmpeg
2
+ import numpy as np
3
+
4
+
5
+ def load_audio(file, sr):
6
+ try:
7
+ # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
8
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
9
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
10
+ file = (
11
+ file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
12
+ ) # 防止小白拷路径头尾带了空格和"和回车
13
+ out, _ = (
14
+ ffmpeg.input(file, threads=0)
15
+ .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
16
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
17
+ )
18
+ except Exception as e:
19
+ raise RuntimeError(f"Failed to load audio: {e}")
20
+
21
+ return np.frombuffer(out, np.float32).flatten()
GPT_SoVITS/onnx_export.py CHANGED
@@ -1,334 +1,334 @@
1
- from module.models_onnx import SynthesizerTrn, symbols
2
- from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
3
- import torch
4
- import torchaudio
5
- from torch import nn
6
- from feature_extractor import cnhubert
7
- cnhubert_base_path = "pretrained_models/chinese-hubert-base"
8
- cnhubert.cnhubert_base_path=cnhubert_base_path
9
- ssl_model = cnhubert.get_model()
10
- from text import cleaned_text_to_sequence
11
- import soundfile
12
- from my_utils import load_audio
13
- import os
14
- import json
15
-
16
- def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
17
- hann_window = torch.hann_window(win_size).to(
18
- dtype=y.dtype, device=y.device
19
- )
20
- y = torch.nn.functional.pad(
21
- y.unsqueeze(1),
22
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
23
- mode="reflect",
24
- )
25
- y = y.squeeze(1)
26
- spec = torch.stft(
27
- y,
28
- n_fft,
29
- hop_length=hop_size,
30
- win_length=win_size,
31
- window=hann_window,
32
- center=center,
33
- pad_mode="reflect",
34
- normalized=False,
35
- onesided=True,
36
- return_complex=False,
37
- )
38
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
39
- return spec
40
-
41
-
42
- class DictToAttrRecursive(dict):
43
- def __init__(self, input_dict):
44
- super().__init__(input_dict)
45
- for key, value in input_dict.items():
46
- if isinstance(value, dict):
47
- value = DictToAttrRecursive(value)
48
- self[key] = value
49
- setattr(self, key, value)
50
-
51
- def __getattr__(self, item):
52
- try:
53
- return self[item]
54
- except KeyError:
55
- raise AttributeError(f"Attribute {item} not found")
56
-
57
- def __setattr__(self, key, value):
58
- if isinstance(value, dict):
59
- value = DictToAttrRecursive(value)
60
- super(DictToAttrRecursive, self).__setitem__(key, value)
61
- super().__setattr__(key, value)
62
-
63
- def __delattr__(self, item):
64
- try:
65
- del self[item]
66
- except KeyError:
67
- raise AttributeError(f"Attribute {item} not found")
68
-
69
-
70
- class T2SEncoder(nn.Module):
71
- def __init__(self, t2s, vits):
72
- super().__init__()
73
- self.encoder = t2s.onnx_encoder
74
- self.vits = vits
75
-
76
- def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
77
- codes = self.vits.extract_latent(ssl_content)
78
- prompt_semantic = codes[0, 0]
79
- bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
80
- all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
81
- bert = bert.unsqueeze(0)
82
- prompt = prompt_semantic.unsqueeze(0)
83
- return self.encoder(all_phoneme_ids, bert), prompt
84
-
85
-
86
- class T2SModel(nn.Module):
87
- def __init__(self, t2s_path, vits_model):
88
- super().__init__()
89
- dict_s1 = torch.load(t2s_path, map_location="cpu")
90
- self.config = dict_s1["config"]
91
- self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
92
- self.t2s_model.load_state_dict(dict_s1["weight"])
93
- self.t2s_model.eval()
94
- self.vits_model = vits_model.vq_model
95
- self.hz = 50
96
- self.max_sec = self.config["data"]["max_sec"]
97
- self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
98
- self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
99
- self.t2s_model = self.t2s_model.model
100
- self.t2s_model.init_onnx()
101
- self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
102
- self.first_stage_decoder = self.t2s_model.first_stage_decoder
103
- self.stage_decoder = self.t2s_model.stage_decoder
104
- #self.t2s_model = torch.jit.script(self.t2s_model)
105
-
106
- def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
107
- early_stop_num = self.t2s_model.early_stop_num
108
-
109
- #[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
110
- x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
111
-
112
- prefix_len = prompts.shape[1]
113
-
114
- #[1,N,512] [1,N]
115
- y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
116
-
117
- stop = False
118
- for idx in range(1, 1500):
119
- #[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
120
- enco = self.stage_decoder(y, k, v, y_emb, x_example)
121
- y, k, v, y_emb, logits, samples = enco
122
- if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
123
- stop = True
124
- if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
125
- stop = True
126
- if stop:
127
- break
128
- y[0, -1] = 0
129
-
130
- return y[:, -idx:].unsqueeze(0)
131
-
132
- def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
133
- #self.onnx_encoder = torch.jit.script(self.onnx_encoder)
134
- if dynamo:
135
- export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
136
- onnx_encoder_export_output = torch.onnx.dynamo_export(
137
- self.onnx_encoder,
138
- (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
139
- export_options=export_options
140
- )
141
- onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
142
- return
143
-
144
- torch.onnx.export(
145
- self.onnx_encoder,
146
- (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
147
- f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
148
- input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
149
- output_names=["x", "prompts"],
150
- dynamic_axes={
151
- "ref_seq": {1 : "ref_length"},
152
- "text_seq": {1 : "text_length"},
153
- "ref_bert": {0 : "ref_length"},
154
- "text_bert": {0 : "text_length"},
155
- "ssl_content": {2 : "ssl_length"},
156
- },
157
- opset_version=16
158
- )
159
- x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
160
-
161
- torch.onnx.export(
162
- self.first_stage_decoder,
163
- (x, prompts),
164
- f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
165
- input_names=["x", "prompts"],
166
- output_names=["y", "k", "v", "y_emb", "x_example"],
167
- dynamic_axes={
168
- "x": {1 : "x_length"},
169
- "prompts": {1 : "prompts_length"},
170
- },
171
- verbose=False,
172
- opset_version=16
173
- )
174
- y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
175
-
176
- torch.onnx.export(
177
- self.stage_decoder,
178
- (y, k, v, y_emb, x_example),
179
- f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
180
- input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
181
- output_names=["y", "k", "v", "y_emb", "logits", "samples"],
182
- dynamic_axes={
183
- "iy": {1 : "iy_length"},
184
- "ik": {1 : "ik_length"},
185
- "iv": {1 : "iv_length"},
186
- "iy_emb": {1 : "iy_emb_length"},
187
- "ix_example": {1 : "ix_example_length"},
188
- },
189
- verbose=False,
190
- opset_version=16
191
- )
192
-
193
-
194
- class VitsModel(nn.Module):
195
- def __init__(self, vits_path):
196
- super().__init__()
197
- dict_s2 = torch.load(vits_path,map_location="cpu")
198
- self.hps = dict_s2["config"]
199
- self.hps = DictToAttrRecursive(self.hps)
200
- self.hps.model.semantic_frame_rate = "25hz"
201
- self.vq_model = SynthesizerTrn(
202
- self.hps.data.filter_length // 2 + 1,
203
- self.hps.train.segment_size // self.hps.data.hop_length,
204
- n_speakers=self.hps.data.n_speakers,
205
- **self.hps.model
206
- )
207
- self.vq_model.eval()
208
- self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
209
-
210
- def forward(self, text_seq, pred_semantic, ref_audio):
211
- refer = spectrogram_torch(
212
- ref_audio,
213
- self.hps.data.filter_length,
214
- self.hps.data.sampling_rate,
215
- self.hps.data.hop_length,
216
- self.hps.data.win_length,
217
- center=False
218
- )
219
- return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
220
-
221
-
222
- class GptSoVits(nn.Module):
223
- def __init__(self, vits, t2s):
224
- super().__init__()
225
- self.vits = vits
226
- self.t2s = t2s
227
-
228
- def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
229
- pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
230
- audio = self.vits(text_seq, pred_semantic, ref_audio)
231
- if debug:
232
- import onnxruntime
233
- sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
234
- audio1 = sess.run(None, {
235
- "text_seq" : text_seq.detach().cpu().numpy(),
236
- "pred_semantic" : pred_semantic.detach().cpu().numpy(),
237
- "ref_audio" : ref_audio.detach().cpu().numpy()
238
- })
239
- return audio, audio1
240
- return audio
241
-
242
- def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
243
- self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
244
- pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
245
- torch.onnx.export(
246
- self.vits,
247
- (text_seq, pred_semantic, ref_audio),
248
- f"onnx/{project_name}/{project_name}_vits.onnx",
249
- input_names=["text_seq", "pred_semantic", "ref_audio"],
250
- output_names=["audio"],
251
- dynamic_axes={
252
- "text_seq": {1 : "text_length"},
253
- "pred_semantic": {2 : "pred_length"},
254
- "ref_audio": {1 : "audio_length"},
255
- },
256
- opset_version=17,
257
- verbose=False
258
- )
259
-
260
-
261
- class SSLModel(nn.Module):
262
- def __init__(self):
263
- super().__init__()
264
- self.ssl = ssl_model
265
-
266
- def forward(self, ref_audio_16k):
267
- return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
268
-
269
-
270
- def export(vits_path, gpt_path, project_name):
271
- vits = VitsModel(vits_path)
272
- gpt = T2SModel(gpt_path, vits)
273
- gpt_sovits = GptSoVits(vits, gpt)
274
- ssl = SSLModel()
275
- ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
276
- text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
277
- ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
278
- text_bert = torch.randn((text_seq.shape[1], 1024)).float()
279
- ref_audio = torch.randn((1, 48000 * 5)).float()
280
- # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
281
- ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
282
- ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
283
-
284
- try:
285
- os.mkdir(f"onnx/{project_name}")
286
- except:
287
- pass
288
-
289
- ssl_content = ssl(ref_audio_16k).float()
290
-
291
- debug = False
292
-
293
- if debug:
294
- a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
295
- soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
296
- soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
297
- return
298
-
299
- a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
300
-
301
- soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
302
-
303
- gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
304
-
305
- MoeVSConf = {
306
- "Folder" : f"{project_name}",
307
- "Name" : f"{project_name}",
308
- "Type" : "GPT-SoVits",
309
- "Rate" : vits.hps.data.sampling_rate,
310
- "NumLayers": gpt.t2s_model.num_layers,
311
- "EmbeddingDim": gpt.t2s_model.embedding_dim,
312
- "Dict": "BasicDict",
313
- "BertPath": "chinese-roberta-wwm-ext-large",
314
- "Symbol": symbols,
315
- "AddBlank": False
316
- }
317
-
318
- MoeVSConfJson = json.dumps(MoeVSConf)
319
- with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
320
- json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
321
-
322
-
323
- if __name__ == "__main__":
324
- try:
325
- os.mkdir("onnx")
326
- except:
327
- pass
328
-
329
- gpt_path = "GPT_weights/nahida-e25.ckpt"
330
- vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
331
- exp_path = "nahida"
332
- export(vits_path, gpt_path, exp_path)
333
-
334
  # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
 
1
+ from module.models_onnx import SynthesizerTrn, symbols
2
+ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+ from feature_extractor import cnhubert
7
+ cnhubert_base_path = "pretrained_models/chinese-hubert-base"
8
+ cnhubert.cnhubert_base_path=cnhubert_base_path
9
+ ssl_model = cnhubert.get_model()
10
+ from text import cleaned_text_to_sequence
11
+ import soundfile
12
+ from my_utils import load_audio
13
+ import os
14
+ import json
15
+
16
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
17
+ hann_window = torch.hann_window(win_size).to(
18
+ dtype=y.dtype, device=y.device
19
+ )
20
+ y = torch.nn.functional.pad(
21
+ y.unsqueeze(1),
22
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
23
+ mode="reflect",
24
+ )
25
+ y = y.squeeze(1)
26
+ spec = torch.stft(
27
+ y,
28
+ n_fft,
29
+ hop_length=hop_size,
30
+ win_length=win_size,
31
+ window=hann_window,
32
+ center=center,
33
+ pad_mode="reflect",
34
+ normalized=False,
35
+ onesided=True,
36
+ return_complex=False,
37
+ )
38
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
39
+ return spec
40
+
41
+
42
+ class DictToAttrRecursive(dict):
43
+ def __init__(self, input_dict):
44
+ super().__init__(input_dict)
45
+ for key, value in input_dict.items():
46
+ if isinstance(value, dict):
47
+ value = DictToAttrRecursive(value)
48
+ self[key] = value
49
+ setattr(self, key, value)
50
+
51
+ def __getattr__(self, item):
52
+ try:
53
+ return self[item]
54
+ except KeyError:
55
+ raise AttributeError(f"Attribute {item} not found")
56
+
57
+ def __setattr__(self, key, value):
58
+ if isinstance(value, dict):
59
+ value = DictToAttrRecursive(value)
60
+ super(DictToAttrRecursive, self).__setitem__(key, value)
61
+ super().__setattr__(key, value)
62
+
63
+ def __delattr__(self, item):
64
+ try:
65
+ del self[item]
66
+ except KeyError:
67
+ raise AttributeError(f"Attribute {item} not found")
68
+
69
+
70
+ class T2SEncoder(nn.Module):
71
+ def __init__(self, t2s, vits):
72
+ super().__init__()
73
+ self.encoder = t2s.onnx_encoder
74
+ self.vits = vits
75
+
76
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
77
+ codes = self.vits.extract_latent(ssl_content)
78
+ prompt_semantic = codes[0, 0]
79
+ bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
80
+ all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
81
+ bert = bert.unsqueeze(0)
82
+ prompt = prompt_semantic.unsqueeze(0)
83
+ return self.encoder(all_phoneme_ids, bert), prompt
84
+
85
+
86
+ class T2SModel(nn.Module):
87
+ def __init__(self, t2s_path, vits_model):
88
+ super().__init__()
89
+ dict_s1 = torch.load(t2s_path, map_location="cpu")
90
+ self.config = dict_s1["config"]
91
+ self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
92
+ self.t2s_model.load_state_dict(dict_s1["weight"])
93
+ self.t2s_model.eval()
94
+ self.vits_model = vits_model.vq_model
95
+ self.hz = 50
96
+ self.max_sec = self.config["data"]["max_sec"]
97
+ self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
98
+ self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
99
+ self.t2s_model = self.t2s_model.model
100
+ self.t2s_model.init_onnx()
101
+ self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
102
+ self.first_stage_decoder = self.t2s_model.first_stage_decoder
103
+ self.stage_decoder = self.t2s_model.stage_decoder
104
+ #self.t2s_model = torch.jit.script(self.t2s_model)
105
+
106
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
107
+ early_stop_num = self.t2s_model.early_stop_num
108
+
109
+ #[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
110
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
111
+
112
+ prefix_len = prompts.shape[1]
113
+
114
+ #[1,N,512] [1,N]
115
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
116
+
117
+ stop = False
118
+ for idx in range(1, 1500):
119
+ #[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
120
+ enco = self.stage_decoder(y, k, v, y_emb, x_example)
121
+ y, k, v, y_emb, logits, samples = enco
122
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
123
+ stop = True
124
+ if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
125
+ stop = True
126
+ if stop:
127
+ break
128
+ y[0, -1] = 0
129
+
130
+ return y[:, -idx:].unsqueeze(0)
131
+
132
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
133
+ #self.onnx_encoder = torch.jit.script(self.onnx_encoder)
134
+ if dynamo:
135
+ export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
136
+ onnx_encoder_export_output = torch.onnx.dynamo_export(
137
+ self.onnx_encoder,
138
+ (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
139
+ export_options=export_options
140
+ )
141
+ onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
142
+ return
143
+
144
+ torch.onnx.export(
145
+ self.onnx_encoder,
146
+ (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
147
+ f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
148
+ input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
149
+ output_names=["x", "prompts"],
150
+ dynamic_axes={
151
+ "ref_seq": {1 : "ref_length"},
152
+ "text_seq": {1 : "text_length"},
153
+ "ref_bert": {0 : "ref_length"},
154
+ "text_bert": {0 : "text_length"},
155
+ "ssl_content": {2 : "ssl_length"},
156
+ },
157
+ opset_version=16
158
+ )
159
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
160
+
161
+ torch.onnx.export(
162
+ self.first_stage_decoder,
163
+ (x, prompts),
164
+ f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
165
+ input_names=["x", "prompts"],
166
+ output_names=["y", "k", "v", "y_emb", "x_example"],
167
+ dynamic_axes={
168
+ "x": {1 : "x_length"},
169
+ "prompts": {1 : "prompts_length"},
170
+ },
171
+ verbose=False,
172
+ opset_version=16
173
+ )
174
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
175
+
176
+ torch.onnx.export(
177
+ self.stage_decoder,
178
+ (y, k, v, y_emb, x_example),
179
+ f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
180
+ input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
181
+ output_names=["y", "k", "v", "y_emb", "logits", "samples"],
182
+ dynamic_axes={
183
+ "iy": {1 : "iy_length"},
184
+ "ik": {1 : "ik_length"},
185
+ "iv": {1 : "iv_length"},
186
+ "iy_emb": {1 : "iy_emb_length"},
187
+ "ix_example": {1 : "ix_example_length"},
188
+ },
189
+ verbose=False,
190
+ opset_version=16
191
+ )
192
+
193
+
194
+ class VitsModel(nn.Module):
195
+ def __init__(self, vits_path):
196
+ super().__init__()
197
+ dict_s2 = torch.load(vits_path,map_location="cpu")
198
+ self.hps = dict_s2["config"]
199
+ self.hps = DictToAttrRecursive(self.hps)
200
+ self.hps.model.semantic_frame_rate = "25hz"
201
+ self.vq_model = SynthesizerTrn(
202
+ self.hps.data.filter_length // 2 + 1,
203
+ self.hps.train.segment_size // self.hps.data.hop_length,
204
+ n_speakers=self.hps.data.n_speakers,
205
+ **self.hps.model
206
+ )
207
+ self.vq_model.eval()
208
+ self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
209
+
210
+ def forward(self, text_seq, pred_semantic, ref_audio):
211
+ refer = spectrogram_torch(
212
+ ref_audio,
213
+ self.hps.data.filter_length,
214
+ self.hps.data.sampling_rate,
215
+ self.hps.data.hop_length,
216
+ self.hps.data.win_length,
217
+ center=False
218
+ )
219
+ return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
220
+
221
+
222
+ class GptSoVits(nn.Module):
223
+ def __init__(self, vits, t2s):
224
+ super().__init__()
225
+ self.vits = vits
226
+ self.t2s = t2s
227
+
228
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
229
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
230
+ audio = self.vits(text_seq, pred_semantic, ref_audio)
231
+ if debug:
232
+ import onnxruntime
233
+ sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
234
+ audio1 = sess.run(None, {
235
+ "text_seq" : text_seq.detach().cpu().numpy(),
236
+ "pred_semantic" : pred_semantic.detach().cpu().numpy(),
237
+ "ref_audio" : ref_audio.detach().cpu().numpy()
238
+ })
239
+ return audio, audio1
240
+ return audio
241
+
242
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
243
+ self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
244
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
245
+ torch.onnx.export(
246
+ self.vits,
247
+ (text_seq, pred_semantic, ref_audio),
248
+ f"onnx/{project_name}/{project_name}_vits.onnx",
249
+ input_names=["text_seq", "pred_semantic", "ref_audio"],
250
+ output_names=["audio"],
251
+ dynamic_axes={
252
+ "text_seq": {1 : "text_length"},
253
+ "pred_semantic": {2 : "pred_length"},
254
+ "ref_audio": {1 : "audio_length"},
255
+ },
256
+ opset_version=17,
257
+ verbose=False
258
+ )
259
+
260
+
261
+ class SSLModel(nn.Module):
262
+ def __init__(self):
263
+ super().__init__()
264
+ self.ssl = ssl_model
265
+
266
+ def forward(self, ref_audio_16k):
267
+ return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
268
+
269
+
270
+ def export(vits_path, gpt_path, project_name):
271
+ vits = VitsModel(vits_path)
272
+ gpt = T2SModel(gpt_path, vits)
273
+ gpt_sovits = GptSoVits(vits, gpt)
274
+ ssl = SSLModel()
275
+ ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
276
+ text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
277
+ ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
278
+ text_bert = torch.randn((text_seq.shape[1], 1024)).float()
279
+ ref_audio = torch.randn((1, 48000 * 5)).float()
280
+ # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
281
+ ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
282
+ ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
283
+
284
+ try:
285
+ os.mkdir(f"onnx/{project_name}")
286
+ except:
287
+ pass
288
+
289
+ ssl_content = ssl(ref_audio_16k).float()
290
+
291
+ debug = False
292
+
293
+ if debug:
294
+ a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
295
+ soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
296
+ soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
297
+ return
298
+
299
+ a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
300
+
301
+ soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
302
+
303
+ gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
304
+
305
+ MoeVSConf = {
306
+ "Folder" : f"{project_name}",
307
+ "Name" : f"{project_name}",
308
+ "Type" : "GPT-SoVits",
309
+ "Rate" : vits.hps.data.sampling_rate,
310
+ "NumLayers": gpt.t2s_model.num_layers,
311
+ "EmbeddingDim": gpt.t2s_model.embedding_dim,
312
+ "Dict": "BasicDict",
313
+ "BertPath": "chinese-roberta-wwm-ext-large",
314
+ "Symbol": symbols,
315
+ "AddBlank": False
316
+ }
317
+
318
+ MoeVSConfJson = json.dumps(MoeVSConf)
319
+ with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
320
+ json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
321
+
322
+
323
+ if __name__ == "__main__":
324
+ try:
325
+ os.mkdir("onnx")
326
+ except:
327
+ pass
328
+
329
+ gpt_path = "GPT_weights/nahida-e25.ckpt"
330
+ vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
331
+ exp_path = "nahida"
332
+ export(vits_path, gpt_path, exp_path)
333
+
334
  # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
GPT_SoVITS/prepare_datasets/1-get-text.py CHANGED
@@ -1,131 +1,131 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import os
4
-
5
- inp_text = os.environ.get("inp_text")
6
- inp_wav_dir = os.environ.get("inp_wav_dir")
7
- exp_name = os.environ.get("exp_name")
8
- i_part = os.environ.get("i_part")
9
- all_parts = os.environ.get("all_parts")
10
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
11
- opt_dir = os.environ.get("opt_dir")
12
- bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
13
- is_half = eval(os.environ.get("is_half", "True"))
14
- import sys, numpy as np, traceback, pdb
15
- import os.path
16
- from glob import glob
17
- from tqdm import tqdm
18
- from text.cleaner import clean_text
19
- import torch
20
- from transformers import AutoModelForMaskedLM, AutoTokenizer
21
- import numpy as np
22
-
23
- # inp_text=sys.argv[1]
24
- # inp_wav_dir=sys.argv[2]
25
- # exp_name=sys.argv[3]
26
- # i_part=sys.argv[4]
27
- # all_parts=sys.argv[5]
28
- # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
29
- # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
30
- # bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
31
-
32
- from time import time as ttime
33
- import shutil
34
-
35
-
36
- def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
37
- dir=os.path.dirname(path)
38
- name=os.path.basename(path)
39
- # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
40
- tmp_path="%s%s.pth"%(ttime(),i_part)
41
- torch.save(fea,tmp_path)
42
- shutil.move(tmp_path,"%s/%s"%(dir,name))
43
-
44
-
45
- txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
46
- if os.path.exists(txt_path) == False:
47
- bert_dir = "%s/3-bert" % (opt_dir)
48
- os.makedirs(opt_dir, exist_ok=True)
49
- os.makedirs(bert_dir, exist_ok=True)
50
- if torch.cuda.is_available():
51
- device = "cuda:0"
52
- # elif torch.backends.mps.is_available():
53
- # device = "mps"
54
- else:
55
- device = "cpu"
56
- tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
57
- bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
58
- if is_half == True:
59
- bert_model = bert_model.half().to(device)
60
- else:
61
- bert_model = bert_model.to(device)
62
-
63
- def get_bert_feature(text, word2ph):
64
- with torch.no_grad():
65
- inputs = tokenizer(text, return_tensors="pt")
66
- for i in inputs:
67
- inputs[i] = inputs[i].to(device)
68
- res = bert_model(**inputs, output_hidden_states=True)
69
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
70
-
71
- assert len(word2ph) == len(text)
72
- phone_level_feature = []
73
- for i in range(len(word2ph)):
74
- repeat_feature = res[i].repeat(word2ph[i], 1)
75
- phone_level_feature.append(repeat_feature)
76
-
77
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
78
-
79
- return phone_level_feature.T
80
-
81
- def process(data, res):
82
- for name, text, lan in data:
83
- try:
84
- name = os.path.basename(name)
85
- phones, word2ph, norm_text = clean_text(
86
- text.replace("%", "-").replace("¥", ","), lan
87
- )
88
- path_bert = "%s/%s.pt" % (bert_dir, name)
89
- if os.path.exists(path_bert) == False and lan == "zh":
90
- bert_feature = get_bert_feature(norm_text, word2ph)
91
- assert bert_feature.shape[-1] == len(phones)
92
- # torch.save(bert_feature, path_bert)
93
- my_save(bert_feature, path_bert)
94
- phones = " ".join(phones)
95
- # res.append([name,phones])
96
- res.append([name, phones, word2ph, norm_text])
97
- except:
98
- print(name, text, traceback.format_exc())
99
-
100
- todo = []
101
- res = []
102
- with open(inp_text, "r", encoding="utf8") as f:
103
- lines = f.read().strip("\n").split("\n")
104
-
105
- language_v1_to_language_v2 = {
106
- "ZH": "zh",
107
- "zh": "zh",
108
- "JP": "ja",
109
- "jp": "ja",
110
- "JA": "ja",
111
- "ja": "ja",
112
- "EN": "en",
113
- "en": "en",
114
- "En": "en",
115
- }
116
- for line in lines[int(i_part) :: int(all_parts)]:
117
- try:
118
- wav_name, spk_name, language, text = line.split("|")
119
- # todo.append([name,text,"zh"])
120
- todo.append(
121
- [wav_name, text, language_v1_to_language_v2.get(language, language)]
122
- )
123
- except:
124
- print(line, traceback.format_exc())
125
-
126
- process(todo, res)
127
- opt = []
128
- for name, phones, word2ph, norm_text in res:
129
- opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
130
- with open(txt_path, "w", encoding="utf8") as f:
131
- f.write("\n".join(opt) + "\n")
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+
5
+ inp_text = os.environ.get("inp_text")
6
+ inp_wav_dir = os.environ.get("inp_wav_dir")
7
+ exp_name = os.environ.get("exp_name")
8
+ i_part = os.environ.get("i_part")
9
+ all_parts = os.environ.get("all_parts")
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
11
+ opt_dir = os.environ.get("opt_dir")
12
+ bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
13
+ is_half = eval(os.environ.get("is_half", "True"))
14
+ import sys, numpy as np, traceback, pdb
15
+ import os.path
16
+ from glob import glob
17
+ from tqdm import tqdm
18
+ from text.cleaner import clean_text
19
+ import torch
20
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
21
+ import numpy as np
22
+
23
+ # inp_text=sys.argv[1]
24
+ # inp_wav_dir=sys.argv[2]
25
+ # exp_name=sys.argv[3]
26
+ # i_part=sys.argv[4]
27
+ # all_parts=sys.argv[5]
28
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
29
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
30
+ # bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
31
+
32
+ from time import time as ttime
33
+ import shutil
34
+
35
+
36
+ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
37
+ dir=os.path.dirname(path)
38
+ name=os.path.basename(path)
39
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
40
+ tmp_path="%s%s.pth"%(ttime(),i_part)
41
+ torch.save(fea,tmp_path)
42
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
43
+
44
+
45
+ txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
46
+ if os.path.exists(txt_path) == False:
47
+ bert_dir = "%s/3-bert" % (opt_dir)
48
+ os.makedirs(opt_dir, exist_ok=True)
49
+ os.makedirs(bert_dir, exist_ok=True)
50
+ if torch.cuda.is_available():
51
+ device = "cuda:0"
52
+ # elif torch.backends.mps.is_available():
53
+ # device = "mps"
54
+ else:
55
+ device = "cpu"
56
+ tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
57
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
58
+ if is_half == True:
59
+ bert_model = bert_model.half().to(device)
60
+ else:
61
+ bert_model = bert_model.to(device)
62
+
63
+ def get_bert_feature(text, word2ph):
64
+ with torch.no_grad():
65
+ inputs = tokenizer(text, return_tensors="pt")
66
+ for i in inputs:
67
+ inputs[i] = inputs[i].to(device)
68
+ res = bert_model(**inputs, output_hidden_states=True)
69
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
70
+
71
+ assert len(word2ph) == len(text)
72
+ phone_level_feature = []
73
+ for i in range(len(word2ph)):
74
+ repeat_feature = res[i].repeat(word2ph[i], 1)
75
+ phone_level_feature.append(repeat_feature)
76
+
77
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
78
+
79
+ return phone_level_feature.T
80
+
81
+ def process(data, res):
82
+ for name, text, lan in data:
83
+ try:
84
+ name = os.path.basename(name)
85
+ phones, word2ph, norm_text = clean_text(
86
+ text.replace("%", "-").replace("¥", ","), lan
87
+ )
88
+ path_bert = "%s/%s.pt" % (bert_dir, name)
89
+ if os.path.exists(path_bert) == False and lan == "zh":
90
+ bert_feature = get_bert_feature(norm_text, word2ph)
91
+ assert bert_feature.shape[-1] == len(phones)
92
+ # torch.save(bert_feature, path_bert)
93
+ my_save(bert_feature, path_bert)
94
+ phones = " ".join(phones)
95
+ # res.append([name,phones])
96
+ res.append([name, phones, word2ph, norm_text])
97
+ except:
98
+ print(name, text, traceback.format_exc())
99
+
100
+ todo = []
101
+ res = []
102
+ with open(inp_text, "r", encoding="utf8") as f:
103
+ lines = f.read().strip("\n").split("\n")
104
+
105
+ language_v1_to_language_v2 = {
106
+ "ZH": "zh",
107
+ "zh": "zh",
108
+ "JP": "ja",
109
+ "jp": "ja",
110
+ "JA": "ja",
111
+ "ja": "ja",
112
+ "EN": "en",
113
+ "en": "en",
114
+ "En": "en",
115
+ }
116
+ for line in lines[int(i_part) :: int(all_parts)]:
117
+ try:
118
+ wav_name, spk_name, language, text = line.split("|")
119
+ # todo.append([name,text,"zh"])
120
+ todo.append(
121
+ [wav_name, text, language_v1_to_language_v2.get(language, language)]
122
+ )
123
+ except:
124
+ print(line, traceback.format_exc())
125
+
126
+ process(todo, res)
127
+ opt = []
128
+ for name, phones, word2ph, norm_text in res:
129
+ opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
130
+ with open(txt_path, "w", encoding="utf8") as f:
131
+ f.write("\n".join(opt) + "\n")
GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py CHANGED
@@ -1,120 +1,120 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import sys,os
4
- inp_text= os.environ.get("inp_text")
5
- inp_wav_dir= os.environ.get("inp_wav_dir")
6
- exp_name= os.environ.get("exp_name")
7
- i_part= os.environ.get("i_part")
8
- all_parts= os.environ.get("all_parts")
9
- os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
10
- from feature_extractor import cnhubert
11
- opt_dir= os.environ.get("opt_dir")
12
- cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
13
- is_half=eval(os.environ.get("is_half","True"))
14
-
15
- import pdb,traceback,numpy as np,logging
16
- from scipy.io import wavfile
17
- import librosa,torch
18
- now_dir = os.getcwd()
19
- sys.path.append(now_dir)
20
- from my_utils import load_audio
21
-
22
- # from config import cnhubert_base_path
23
- # cnhubert.cnhubert_base_path=cnhubert_base_path
24
- # inp_text=sys.argv[1]
25
- # inp_wav_dir=sys.argv[2]
26
- # exp_name=sys.argv[3]
27
- # i_part=sys.argv[4]
28
- # all_parts=sys.argv[5]
29
- # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
30
- # cnhubert.cnhubert_base_path=sys.argv[7]
31
- # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
32
-
33
- from time import time as ttime
34
- import shutil
35
- def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
36
- dir=os.path.dirname(path)
37
- name=os.path.basename(path)
38
- # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
39
- tmp_path="%s%s.pth"%(ttime(),i_part)
40
- torch.save(fea,tmp_path)
41
- shutil.move(tmp_path,"%s/%s"%(dir,name))
42
-
43
- hubert_dir="%s/4-cnhubert"%(opt_dir)
44
- wav32dir="%s/5-wav32k"%(opt_dir)
45
- os.makedirs(opt_dir,exist_ok=True)
46
- os.makedirs(hubert_dir,exist_ok=True)
47
- os.makedirs(wav32dir,exist_ok=True)
48
-
49
- maxx=0.95
50
- alpha=0.5
51
- if torch.cuda.is_available():
52
- device = "cuda:0"
53
- # elif torch.backends.mps.is_available():
54
- # device = "mps"
55
- else:
56
- device = "cpu"
57
- model=cnhubert.get_model()
58
- # is_half=False
59
- if(is_half==True):
60
- model=model.half().to(device)
61
- else:
62
- model = model.to(device)
63
-
64
- nan_fails=[]
65
- def name2go(wav_name,wav_path):
66
- hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
67
- if(os.path.exists(hubert_path)):return
68
- tmp_audio = load_audio(wav_path, 32000)
69
- tmp_max = np.abs(tmp_audio).max()
70
- if tmp_max > 2.2:
71
- print("%s-filtered,%s" % (wav_name, tmp_max))
72
- return
73
- tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
74
- tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
75
- tmp_audio = librosa.resample(
76
- tmp_audio32b, orig_sr=32000, target_sr=16000
77
- )#不是重采样问题
78
- tensor_wav16 = torch.from_numpy(tmp_audio)
79
- if (is_half == True):
80
- tensor_wav16=tensor_wav16.half().to(device)
81
- else:
82
- tensor_wav16 = tensor_wav16.to(device)
83
- ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
84
- if np.isnan(ssl.detach().numpy()).sum()!= 0:
85
- nan_fails.append(wav_name)
86
- print("nan filtered:%s"%wav_name)
87
- return
88
- wavfile.write(
89
- "%s/%s"%(wav32dir,wav_name),
90
- 32000,
91
- tmp_audio32.astype("int16"),
92
- )
93
- my_save(ssl,hubert_path )
94
-
95
- with open(inp_text,"r",encoding="utf8")as f:
96
- lines=f.read().strip("\n").split("\n")
97
-
98
- for line in lines[int(i_part)::int(all_parts)]:
99
- try:
100
- # wav_name,text=line.split("\t")
101
- wav_name, spk_name, language, text = line.split("|")
102
- if (inp_wav_dir != "" and inp_wav_dir != None):
103
- wav_name = os.path.basename(wav_name)
104
- wav_path = "%s/%s"%(inp_wav_dir, wav_name)
105
-
106
- else:
107
- wav_path=wav_name
108
- wav_name = os.path.basename(wav_name)
109
- name2go(wav_name,wav_path)
110
- except:
111
- print(line,traceback.format_exc())
112
-
113
- if(len(nan_fails)>0 and is_half==True):
114
- is_half=False
115
- model=model.float()
116
- for wav_name in nan_fails:
117
- try:
118
- name2go(wav_name)
119
- except:
120
- print(wav_name,traceback.format_exc())
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import sys,os
4
+ inp_text= os.environ.get("inp_text")
5
+ inp_wav_dir= os.environ.get("inp_wav_dir")
6
+ exp_name= os.environ.get("exp_name")
7
+ i_part= os.environ.get("i_part")
8
+ all_parts= os.environ.get("all_parts")
9
+ os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
10
+ from feature_extractor import cnhubert
11
+ opt_dir= os.environ.get("opt_dir")
12
+ cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
13
+ is_half=eval(os.environ.get("is_half","True"))
14
+
15
+ import pdb,traceback,numpy as np,logging
16
+ from scipy.io import wavfile
17
+ import librosa,torch
18
+ now_dir = os.getcwd()
19
+ sys.path.append(now_dir)
20
+ from my_utils import load_audio
21
+
22
+ # from config import cnhubert_base_path
23
+ # cnhubert.cnhubert_base_path=cnhubert_base_path
24
+ # inp_text=sys.argv[1]
25
+ # inp_wav_dir=sys.argv[2]
26
+ # exp_name=sys.argv[3]
27
+ # i_part=sys.argv[4]
28
+ # all_parts=sys.argv[5]
29
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
30
+ # cnhubert.cnhubert_base_path=sys.argv[7]
31
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
32
+
33
+ from time import time as ttime
34
+ import shutil
35
+ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
36
+ dir=os.path.dirname(path)
37
+ name=os.path.basename(path)
38
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
39
+ tmp_path="%s%s.pth"%(ttime(),i_part)
40
+ torch.save(fea,tmp_path)
41
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
42
+
43
+ hubert_dir="%s/4-cnhubert"%(opt_dir)
44
+ wav32dir="%s/5-wav32k"%(opt_dir)
45
+ os.makedirs(opt_dir,exist_ok=True)
46
+ os.makedirs(hubert_dir,exist_ok=True)
47
+ os.makedirs(wav32dir,exist_ok=True)
48
+
49
+ maxx=0.95
50
+ alpha=0.5
51
+ if torch.cuda.is_available():
52
+ device = "cuda:0"
53
+ # elif torch.backends.mps.is_available():
54
+ # device = "mps"
55
+ else:
56
+ device = "cpu"
57
+ model=cnhubert.get_model()
58
+ # is_half=False
59
+ if(is_half==True):
60
+ model=model.half().to(device)
61
+ else:
62
+ model = model.to(device)
63
+
64
+ nan_fails=[]
65
+ def name2go(wav_name,wav_path):
66
+ hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
67
+ if(os.path.exists(hubert_path)):return
68
+ tmp_audio = load_audio(wav_path, 32000)
69
+ tmp_max = np.abs(tmp_audio).max()
70
+ if tmp_max > 2.2:
71
+ print("%s-filtered,%s" % (wav_name, tmp_max))
72
+ return
73
+ tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
74
+ tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
75
+ tmp_audio = librosa.resample(
76
+ tmp_audio32b, orig_sr=32000, target_sr=16000
77
+ )#不是重采样问题
78
+ tensor_wav16 = torch.from_numpy(tmp_audio)
79
+ if (is_half == True):
80
+ tensor_wav16=tensor_wav16.half().to(device)
81
+ else:
82
+ tensor_wav16 = tensor_wav16.to(device)
83
+ ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
84
+ if np.isnan(ssl.detach().numpy()).sum()!= 0:
85
+ nan_fails.append(wav_name)
86
+ print("nan filtered:%s"%wav_name)
87
+ return
88
+ wavfile.write(
89
+ "%s/%s"%(wav32dir,wav_name),
90
+ 32000,
91
+ tmp_audio32.astype("int16"),
92
+ )
93
+ my_save(ssl,hubert_path )
94
+
95
+ with open(inp_text,"r",encoding="utf8")as f:
96
+ lines=f.read().strip("\n").split("\n")
97
+
98
+ for line in lines[int(i_part)::int(all_parts)]:
99
+ try:
100
+ # wav_name,text=line.split("\t")
101
+ wav_name, spk_name, language, text = line.split("|")
102
+ if (inp_wav_dir != "" and inp_wav_dir != None):
103
+ wav_name = os.path.basename(wav_name)
104
+ wav_path = "%s/%s"%(inp_wav_dir, wav_name)
105
+
106
+ else:
107
+ wav_path=wav_name
108
+ wav_name = os.path.basename(wav_name)
109
+ name2go(wav_name,wav_path)
110
+ except:
111
+ print(line,traceback.format_exc())
112
+
113
+ if(len(nan_fails)>0 and is_half==True):
114
+ is_half=False
115
+ model=model.float()
116
+ for wav_name in nan_fails:
117
+ try:
118
+ name2go(wav_name)
119
+ except:
120
+ print(wav_name,traceback.format_exc())
GPT_SoVITS/prepare_datasets/3-get-semantic.py CHANGED
@@ -1,95 +1,95 @@
1
- import os
2
-
3
- inp_text = os.environ.get("inp_text")
4
- exp_name = os.environ.get("exp_name")
5
- i_part = os.environ.get("i_part")
6
- all_parts = os.environ.get("all_parts")
7
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
8
- opt_dir = os.environ.get("opt_dir")
9
- pretrained_s2G = os.environ.get("pretrained_s2G")
10
- s2config_path = os.environ.get("s2config_path")
11
- is_half = eval(os.environ.get("is_half", "True"))
12
- import math, traceback
13
- import multiprocessing
14
- import sys, pdb
15
-
16
- now_dir = os.getcwd()
17
- sys.path.append(now_dir)
18
- from random import shuffle
19
- import torch.multiprocessing as mp
20
- from glob import glob
21
- from tqdm import tqdm
22
- import logging, librosa, utils, torch
23
- from module.models import SynthesizerTrn
24
-
25
- logging.getLogger("numba").setLevel(logging.WARNING)
26
- # from config import pretrained_s2G
27
-
28
- # inp_text=sys.argv[1]
29
- # exp_name=sys.argv[2]
30
- # i_part=sys.argv[3]
31
- # all_parts=sys.argv[4]
32
- # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
33
- # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
34
-
35
-
36
- hubert_dir = "%s/4-cnhubert" % (opt_dir)
37
- semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
38
- if os.path.exists(semantic_path) == False:
39
- os.makedirs(opt_dir, exist_ok=True)
40
-
41
- if torch.cuda.is_available():
42
- device = "cuda"
43
- # elif torch.backends.mps.is_available():
44
- # device = "mps"
45
- else:
46
- device = "cpu"
47
- hps = utils.get_hparams_from_file(s2config_path)
48
- vq_model = SynthesizerTrn(
49
- hps.data.filter_length // 2 + 1,
50
- hps.train.segment_size // hps.data.hop_length,
51
- n_speakers=hps.data.n_speakers,
52
- **hps.model
53
- )
54
- if is_half == True:
55
- vq_model = vq_model.half().to(device)
56
- else:
57
- vq_model = vq_model.to(device)
58
- vq_model.eval()
59
- # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
60
- # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
61
- print(
62
- vq_model.load_state_dict(
63
- torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
64
- )
65
- )
66
-
67
- def name2go(wav_name, lines):
68
- hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
69
- if os.path.exists(hubert_path) == False:
70
- return
71
- ssl_content = torch.load(hubert_path, map_location="cpu")
72
- if is_half == True:
73
- ssl_content = ssl_content.half().to(device)
74
- else:
75
- ssl_content = ssl_content.to(device)
76
- codes = vq_model.extract_latent(ssl_content)
77
- semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
78
- lines.append("%s\t%s" % (wav_name, semantic))
79
-
80
- with open(inp_text, "r", encoding="utf8") as f:
81
- lines = f.read().strip("\n").split("\n")
82
-
83
- lines1 = []
84
- for line in lines[int(i_part) :: int(all_parts)]:
85
- # print(line)
86
- try:
87
- # wav_name,text=line.split("\t")
88
- wav_name, spk_name, language, text = line.split("|")
89
- wav_name = os.path.basename(wav_name)
90
- # name2go(name,lines1)
91
- name2go(wav_name, lines1)
92
- except:
93
- print(line, traceback.format_exc())
94
- with open(semantic_path, "w", encoding="utf8") as f:
95
- f.write("\n".join(lines1))
 
1
+ import os
2
+
3
+ inp_text = os.environ.get("inp_text")
4
+ exp_name = os.environ.get("exp_name")
5
+ i_part = os.environ.get("i_part")
6
+ all_parts = os.environ.get("all_parts")
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
8
+ opt_dir = os.environ.get("opt_dir")
9
+ pretrained_s2G = os.environ.get("pretrained_s2G")
10
+ s2config_path = os.environ.get("s2config_path")
11
+ is_half = eval(os.environ.get("is_half", "True"))
12
+ import math, traceback
13
+ import multiprocessing
14
+ import sys, pdb
15
+
16
+ now_dir = os.getcwd()
17
+ sys.path.append(now_dir)
18
+ from random import shuffle
19
+ import torch.multiprocessing as mp
20
+ from glob import glob
21
+ from tqdm import tqdm
22
+ import logging, librosa, utils, torch
23
+ from module.models import SynthesizerTrn
24
+
25
+ logging.getLogger("numba").setLevel(logging.WARNING)
26
+ # from config import pretrained_s2G
27
+
28
+ # inp_text=sys.argv[1]
29
+ # exp_name=sys.argv[2]
30
+ # i_part=sys.argv[3]
31
+ # all_parts=sys.argv[4]
32
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
33
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
34
+
35
+
36
+ hubert_dir = "%s/4-cnhubert" % (opt_dir)
37
+ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
38
+ if os.path.exists(semantic_path) == False:
39
+ os.makedirs(opt_dir, exist_ok=True)
40
+
41
+ if torch.cuda.is_available():
42
+ device = "cuda"
43
+ # elif torch.backends.mps.is_available():
44
+ # device = "mps"
45
+ else:
46
+ device = "cpu"
47
+ hps = utils.get_hparams_from_file(s2config_path)
48
+ vq_model = SynthesizerTrn(
49
+ hps.data.filter_length // 2 + 1,
50
+ hps.train.segment_size // hps.data.hop_length,
51
+ n_speakers=hps.data.n_speakers,
52
+ **hps.model
53
+ )
54
+ if is_half == True:
55
+ vq_model = vq_model.half().to(device)
56
+ else:
57
+ vq_model = vq_model.to(device)
58
+ vq_model.eval()
59
+ # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
60
+ # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
61
+ print(
62
+ vq_model.load_state_dict(
63
+ torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
64
+ )
65
+ )
66
+
67
+ def name2go(wav_name, lines):
68
+ hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
69
+ if os.path.exists(hubert_path) == False:
70
+ return
71
+ ssl_content = torch.load(hubert_path, map_location="cpu")
72
+ if is_half == True:
73
+ ssl_content = ssl_content.half().to(device)
74
+ else:
75
+ ssl_content = ssl_content.to(device)
76
+ codes = vq_model.extract_latent(ssl_content)
77
+ semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
78
+ lines.append("%s\t%s" % (wav_name, semantic))
79
+
80
+ with open(inp_text, "r", encoding="utf8") as f:
81
+ lines = f.read().strip("\n").split("\n")
82
+
83
+ lines1 = []
84
+ for line in lines[int(i_part) :: int(all_parts)]:
85
+ # print(line)
86
+ try:
87
+ # wav_name,text=line.split("\t")
88
+ wav_name, spk_name, language, text = line.split("|")
89
+ wav_name = os.path.basename(wav_name)
90
+ # name2go(name,lines1)
91
+ name2go(wav_name, lines1)
92
+ except:
93
+ print(line, traceback.format_exc())
94
+ with open(semantic_path, "w", encoding="utf8") as f:
95
+ f.write("\n".join(lines1))
GPT_SoVITS/process_ckpt.py CHANGED
@@ -1,31 +1,31 @@
1
- import traceback
2
- from collections import OrderedDict
3
- from time import time as ttime
4
- import shutil,os
5
- import torch
6
- from tools.i18n.i18n import I18nAuto
7
-
8
- i18n = I18nAuto()
9
-
10
- def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
11
- dir=os.path.dirname(path)
12
- name=os.path.basename(path)
13
- tmp_path="%s.pth"%(ttime())
14
- torch.save(fea,tmp_path)
15
- shutil.move(tmp_path,"%s/%s"%(dir,name))
16
-
17
- def savee(ckpt, name, epoch, steps, hps):
18
- try:
19
- opt = OrderedDict()
20
- opt["weight"] = {}
21
- for key in ckpt.keys():
22
- if "enc_q" in key:
23
- continue
24
- opt["weight"][key] = ckpt[key].half()
25
- opt["config"] = hps
26
- opt["info"] = "%sepoch_%siteration" % (epoch, steps)
27
- # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
28
- my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
29
- return "Success."
30
- except:
31
- return traceback.format_exc()
 
1
+ import traceback
2
+ from collections import OrderedDict
3
+ from time import time as ttime
4
+ import shutil,os
5
+ import torch
6
+ from tools.i18n.i18n import I18nAuto
7
+
8
+ i18n = I18nAuto()
9
+
10
+ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
11
+ dir=os.path.dirname(path)
12
+ name=os.path.basename(path)
13
+ tmp_path="%s.pth"%(ttime())
14
+ torch.save(fea,tmp_path)
15
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
16
+
17
+ def savee(ckpt, name, epoch, steps, hps):
18
+ try:
19
+ opt = OrderedDict()
20
+ opt["weight"] = {}
21
+ for key in ckpt.keys():
22
+ if "enc_q" in key:
23
+ continue
24
+ opt["weight"][key] = ckpt[key].half()
25
+ opt["config"] = hps
26
+ opt["info"] = "%sepoch_%siteration" % (epoch, steps)
27
+ # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
28
+ my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
29
+ return "Success."
30
+ except:
31
+ return traceback.format_exc()
GPT_SoVITS/text/zh_normalization/num.py CHANGED
@@ -106,6 +106,29 @@ def replace_default_num(match):
106
  return verbalize_digit(number, alt_one=True)
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # 数字表达式
110
  # 纯小数
111
  RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
@@ -155,7 +178,13 @@ def replace_number(match) -> str:
155
  # match.group(1) and match.group(8) are copy from RE_NUMBER
156
 
157
  RE_RANGE = re.compile(
158
- r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))[-~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
 
 
 
 
 
 
159
 
160
 
161
  def replace_range(match) -> str:
@@ -165,7 +194,7 @@ def replace_range(match) -> str:
165
  Returns:
166
  str
167
  """
168
- first, second = match.group(1), match.group(8)
169
  first = RE_NUMBER.sub(replace_number, first)
170
  second = RE_NUMBER.sub(replace_number, second)
171
  result = f"{first}到{second}"
 
106
  return verbalize_digit(number, alt_one=True)
107
 
108
 
109
+ # 加减乘除
110
+ RE_ASMD = re.compile(
111
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
112
+ asmd_map = {
113
+ '+': '加',
114
+ '-': '减',
115
+ '×': '乘',
116
+ '÷': '除',
117
+ '=': '等于'
118
+ }
119
+
120
+
121
+ def replace_asmd(match) -> str:
122
+ """
123
+ Args:
124
+ match (re.Match)
125
+ Returns:
126
+ str
127
+ """
128
+ result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
129
+ return result
130
+
131
+
132
  # 数字表达式
133
  # 纯小数
134
  RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
 
178
  # match.group(1) and match.group(8) are copy from RE_NUMBER
179
 
180
  RE_RANGE = re.compile(
181
+ r"""
182
+ (?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
183
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
184
+ [-~] # 匹配范围分隔符
185
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
186
+ (?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
187
+ """, re.VERBOSE)
188
 
189
 
190
  def replace_range(match) -> str:
 
194
  Returns:
195
  str
196
  """
197
+ first, second = match.group(1), match.group(6)
198
  first = RE_NUMBER.sub(replace_number, first)
199
  second = RE_NUMBER.sub(replace_number, second)
200
  result = f"{first}到{second}"
GPT_SoVITS/text/zh_normalization/text_normlization.py CHANGED
@@ -34,6 +34,7 @@ from .num import RE_PERCENTAGE
34
  from .num import RE_POSITIVE_QUANTIFIERS
35
  from .num import RE_RANGE
36
  from .num import RE_TO_RANGE
 
37
  from .num import replace_default_num
38
  from .num import replace_frac
39
  from .num import replace_negative_num
@@ -42,6 +43,7 @@ from .num import replace_percentage
42
  from .num import replace_positive_quantifier
43
  from .num import replace_range
44
  from .num import replace_to_range
 
45
  from .phonecode import RE_MOBILE_PHONE
46
  from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
47
  from .phonecode import RE_TELEPHONE
@@ -67,7 +69,7 @@ class TextNormalizer():
67
  if lang == "zh":
68
  text = text.replace(" ", "")
69
  # 过滤掉特殊字符
70
- text = re.sub(r'[——《》【】<=>{}()()#&@“”^_|\\]', '', text)
71
  text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
72
  text = text.strip()
73
  sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
@@ -142,6 +144,11 @@ class TextNormalizer():
142
  sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence)
143
 
144
  sentence = RE_RANGE.sub(replace_range, sentence)
 
 
 
 
 
145
  sentence = RE_INTEGER.sub(replace_negative_num, sentence)
146
  sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
147
  sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
 
34
  from .num import RE_POSITIVE_QUANTIFIERS
35
  from .num import RE_RANGE
36
  from .num import RE_TO_RANGE
37
+ from .num import RE_ASMD
38
  from .num import replace_default_num
39
  from .num import replace_frac
40
  from .num import replace_negative_num
 
43
  from .num import replace_positive_quantifier
44
  from .num import replace_range
45
  from .num import replace_to_range
46
+ from .num import replace_asmd
47
  from .phonecode import RE_MOBILE_PHONE
48
  from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
49
  from .phonecode import RE_TELEPHONE
 
69
  if lang == "zh":
70
  text = text.replace(" ", "")
71
  # 过滤掉特殊字符
72
+ text = re.sub(r'[——《》【】<>{}()()#&@“”^_|\\]', '', text)
73
  text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
74
  text = text.strip()
75
  sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
 
144
  sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence)
145
 
146
  sentence = RE_RANGE.sub(replace_range, sentence)
147
+
148
+ # 处理加减乘除
149
+ while RE_ASMD.search(sentence):
150
+ sentence = RE_ASMD.sub(replace_asmd, sentence)
151
+
152
  sentence = RE_INTEGER.sub(replace_negative_num, sentence)
153
  sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
154
  sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,