Ailyth commited on
Commit
b2a5005
1 Parent(s): f83e4e8

0317-165840

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. AR/__pycache__/__init__.cpython-310.pyc +0 -0
  2. AR/models/__pycache__/__init__.cpython-310.pyc +0 -0
  3. AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc +0 -0
  4. AR/models/__pycache__/t2s_model.cpython-310.pyc +0 -0
  5. AR/models/__pycache__/utils.cpython-310.pyc +0 -0
  6. AR/models/t2s_lightning_module.py +2 -2
  7. AR/models/t2s_model.py +9 -380
  8. AR/models/t2s_model_batch_only.py +0 -483
  9. AR/models/utils.py +7 -7
  10. AR/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  11. AR/modules/__pycache__/activation.cpython-310.pyc +0 -0
  12. AR/modules/__pycache__/embedding.cpython-310.pyc +0 -0
  13. AR/modules/__pycache__/lr_schedulers.cpython-310.pyc +0 -0
  14. AR/modules/__pycache__/optim.cpython-310.pyc +0 -0
  15. AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc +0 -0
  16. AR/modules/__pycache__/scaling.cpython-310.pyc +0 -0
  17. AR/modules/__pycache__/transformer.cpython-310.pyc +0 -0
  18. GPT_SoVITS/configs/tts_infer.yaml +0 -16
  19. TTS_infer_pack/TTS.py +0 -848
  20. TTS_infer_pack/TextPreprocessor.py +0 -209
  21. TTS_infer_pack/__init__.py +0 -1
  22. TTS_infer_pack/__pycache__/TTS.cpython-310.pyc +0 -0
  23. TTS_infer_pack/__pycache__/TextPreprocessor.cpython-310.pyc +0 -0
  24. TTS_infer_pack/__pycache__/__init__.cpython-310.pyc +0 -0
  25. TTS_infer_pack/__pycache__/text_segmentation_method.cpython-310.pyc +0 -0
  26. TTS_infer_pack/text_segmentation_method.py +0 -152
  27. __pycache__/download.cpython-310.pyc +0 -0
  28. __pycache__/info.cpython-310.pyc +0 -0
  29. __pycache__/my_utils.cpython-310.pyc +0 -0
  30. __pycache__/utils.cpython-310.pyc +0 -0
  31. app.py +539 -132
  32. {GPT_SoVITS/configs → configs}/s1.yaml +0 -0
  33. {GPT_SoVITS/configs → configs}/s1big.yaml +0 -0
  34. {GPT_SoVITS/configs → configs}/s1big2.yaml +0 -0
  35. {GPT_SoVITS/configs → configs}/s1longer.yaml +0 -0
  36. {GPT_SoVITS/configs → configs}/s1mq.yaml +0 -0
  37. {GPT_SoVITS/configs → configs}/s2.json +0 -0
  38. {GPT_SoVITS/configs → configs}/train.yaml +0 -0
  39. feature_extractor/__pycache__/__init__.cpython-310.pyc +0 -0
  40. feature_extractor/__pycache__/cnhubert.cpython-310.pyc +0 -0
  41. feature_extractor/__pycache__/whisper_enc.cpython-310.pyc +0 -0
  42. feature_extractor/cnhubert.py +5 -8
  43. gweight.txt +1 -0
  44. module/__pycache__/__init__.cpython-310.pyc +0 -0
  45. module/__pycache__/attentions.cpython-310.pyc +0 -0
  46. module/__pycache__/commons.cpython-310.pyc +0 -0
  47. module/__pycache__/core_vq.cpython-310.pyc +0 -0
  48. module/__pycache__/mel_processing.cpython-310.pyc +0 -0
  49. module/__pycache__/models.cpython-310.pyc +0 -0
  50. module/__pycache__/modules.cpython-310.pyc +0 -0
AR/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/__pycache__/__init__.cpython-310.pyc and b/AR/__pycache__/__init__.cpython-310.pyc differ
 
AR/models/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/__init__.cpython-310.pyc and b/AR/models/__pycache__/__init__.cpython-310.pyc differ
 
AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc and b/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc differ
 
AR/models/__pycache__/t2s_model.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/t2s_model.cpython-310.pyc and b/AR/models/__pycache__/t2s_model.cpython-310.pyc differ
 
AR/models/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/utils.cpython-310.pyc and b/AR/models/__pycache__/utils.cpython-310.pyc differ
 
AR/models/t2s_lightning_module.py CHANGED
@@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
13
  from AR.modules.optim import ScaledAdam
14
 
15
  class Text2SemanticLightningModule(LightningModule):
16
- def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False):
17
  super().__init__()
18
  self.config = config
19
  self.top_k = 3
20
- self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled)
21
  pretrained_s1 = config.get("pretrained_s1")
22
  if pretrained_s1 and is_train:
23
  # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
 
13
  from AR.modules.optim import ScaledAdam
14
 
15
  class Text2SemanticLightningModule(LightningModule):
16
+ def __init__(self, config, output_dir, is_train=True):
17
  super().__init__()
18
  self.config = config
19
  self.top_k = 3
20
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21
  pretrained_s1 = config.get("pretrained_s1")
22
  if pretrained_s1 and is_train:
23
  # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
AR/models/t2s_model.py CHANGED
@@ -1,9 +1,5 @@
1
  # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
  # reference: https://github.com/lifeiteng/vall-e
3
- import os, sys
4
- now_dir = os.getcwd()
5
- sys.path.append(now_dir)
6
- from typing import List
7
  import torch
8
  from tqdm import tqdm
9
 
@@ -39,144 +35,8 @@ default_config = {
39
  }
40
 
41
 
42
- @torch.jit.script
43
- class T2SMLP:
44
- def __init__(self, w1, b1, w2, b2):
45
- self.w1 = w1
46
- self.b1 = b1
47
- self.w2 = w2
48
- self.b2 = b2
49
-
50
- def forward(self, x):
51
- x = F.relu(F.linear(x, self.w1, self.b1))
52
- x = F.linear(x, self.w2, self.b2)
53
- return x
54
-
55
-
56
- @torch.jit.script
57
- class T2SBlock:
58
- def __init__(
59
- self,
60
- num_heads,
61
- hidden_dim: int,
62
- mlp: T2SMLP,
63
- qkv_w,
64
- qkv_b,
65
- out_w,
66
- out_b,
67
- norm_w1,
68
- norm_b1,
69
- norm_eps1,
70
- norm_w2,
71
- norm_b2,
72
- norm_eps2,
73
- ):
74
- self.num_heads = num_heads
75
- self.mlp = mlp
76
- self.hidden_dim: int = hidden_dim
77
- self.qkv_w = qkv_w
78
- self.qkv_b = qkv_b
79
- self.out_w = out_w
80
- self.out_b = out_b
81
- self.norm_w1 = norm_w1
82
- self.norm_b1 = norm_b1
83
- self.norm_eps1 = norm_eps1
84
- self.norm_w2 = norm_w2
85
- self.norm_b2 = norm_b2
86
- self.norm_eps2 = norm_eps2
87
-
88
- def process_prompt(self, x, attn_mask : torch.Tensor):
89
- q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
90
-
91
- batch_size = q.shape[0]
92
- q_len = q.shape[1]
93
- kv_len = k.shape[1]
94
-
95
- k_cache = k
96
- v_cache = v
97
-
98
- q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
99
- k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
100
- v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
101
-
102
- attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
103
-
104
- attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
105
- attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
106
- attn = F.linear(attn, self.out_w, self.out_b)
107
-
108
- x = F.layer_norm(
109
- x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
110
- )
111
- x = F.layer_norm(
112
- x + self.mlp.forward(x),
113
- [self.hidden_dim],
114
- self.norm_w2,
115
- self.norm_b2,
116
- self.norm_eps2,
117
- )
118
- return x, k_cache, v_cache
119
-
120
- def decode_next_token(self, x, k_cache, v_cache):
121
- q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
122
-
123
- k_cache = torch.cat([k_cache, k], dim=1)
124
- v_cache = torch.cat([v_cache, v], dim=1)
125
-
126
- batch_size = q.shape[0]
127
- q_len = q.shape[1]
128
- kv_len = k_cache.shape[1]
129
-
130
- q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
131
- k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
132
- v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
133
-
134
-
135
- attn = F.scaled_dot_product_attention(q, k, v)
136
-
137
- attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
138
- attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
139
- attn = F.linear(attn, self.out_w, self.out_b)
140
-
141
- x = F.layer_norm(
142
- x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
143
- )
144
- x = F.layer_norm(
145
- x + self.mlp.forward(x),
146
- [self.hidden_dim],
147
- self.norm_w2,
148
- self.norm_b2,
149
- self.norm_eps2,
150
- )
151
- return x, k_cache, v_cache
152
-
153
-
154
- @torch.jit.script
155
- class T2STransformer:
156
- def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
157
- self.num_blocks : int = num_blocks
158
- self.blocks = blocks
159
-
160
- def process_prompt(
161
- self, x, attn_mask : torch.Tensor):
162
- k_cache : List[torch.Tensor] = []
163
- v_cache : List[torch.Tensor] = []
164
- for i in range(self.num_blocks):
165
- x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
166
- k_cache.append(k_cache_)
167
- v_cache.append(v_cache_)
168
- return x, k_cache, v_cache
169
-
170
- def decode_next_token(
171
- self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
172
- ):
173
- for i in range(self.num_blocks):
174
- x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
175
- return x, k_cache, v_cache
176
-
177
-
178
  class Text2SemanticDecoder(nn.Module):
179
- def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False):
180
  super(Text2SemanticDecoder, self).__init__()
181
  self.model_dim = config["model"]["hidden_dim"]
182
  self.embedding_dim = config["model"]["embedding_dim"]
@@ -228,47 +88,6 @@ class Text2SemanticDecoder(nn.Module):
228
  multidim_average="global",
229
  ignore_index=self.EOS,
230
  )
231
-
232
- self.enable_flash_attn(flash_attn_enabled)
233
-
234
- def enable_flash_attn(self, enable:bool=True):
235
-
236
- if not enable:
237
- print("Not Using Flash Attention")
238
- self.infer_panel = self.infer_panel_batch_only
239
- else:
240
- self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
241
- print("Using Flash Attention")
242
- blocks = []
243
-
244
- for i in range(self.num_layers):
245
- layer = self.h.layers[i]
246
- t2smlp = T2SMLP(
247
- layer.linear1.weight,
248
- layer.linear1.bias,
249
- layer.linear2.weight,
250
- layer.linear2.bias
251
- )
252
-
253
- block = T2SBlock(
254
- self.num_head,
255
- self.model_dim,
256
- t2smlp,
257
- layer.self_attn.in_proj_weight,
258
- layer.self_attn.in_proj_bias,
259
- layer.self_attn.out_proj.weight,
260
- layer.self_attn.out_proj.bias,
261
- layer.norm1.weight,
262
- layer.norm1.bias,
263
- layer.norm1.eps,
264
- layer.norm2.weight,
265
- layer.norm2.bias,
266
- layer.norm2.eps
267
- )
268
-
269
- blocks.append(block)
270
-
271
- self.t2s_transformer = T2STransformer(self.num_layers, blocks)
272
 
273
  def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
274
  x = self.ar_text_embedding(x)
@@ -502,161 +321,7 @@ class Text2SemanticDecoder(nn.Module):
502
  # 错位
503
  return targets[:, :-1], targets[:, 1:]
504
 
505
- def infer_panel_batch_infer_with_flash_attn(
506
- self,
507
- x, #####全部文本token
508
- x_lens,
509
- prompts, ####参考音频token
510
- bert_feature,
511
- top_k: int = -100,
512
- top_p: int = 100,
513
- early_stop_num: int = -1,
514
- temperature: float = 1.0,
515
- ):
516
-
517
- bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
518
- x = self.ar_text_embedding(x)
519
- x = x + bert_feature
520
- x = self.ar_text_position(x)
521
-
522
- # AR Decoder
523
- y = prompts
524
-
525
- x_len = x.shape[1]
526
- x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
527
- stop = False
528
- # print(1111111,self.num_layers)
529
-
530
- k_cache = None
531
- v_cache = None
532
- ################### first step ##########################
533
- if y is not None:
534
- y_emb = self.ar_audio_embedding(y)
535
- y_len = y_emb.shape[1]
536
- prefix_len = y.shape[1]
537
- y_pos = self.ar_audio_position(y_emb)
538
- xy_pos = torch.concat([x, y_pos], dim=1)
539
- ref_free = False
540
- else:
541
- y_emb = None
542
- y_len = 0
543
- prefix_len = 0
544
- y_pos = None
545
- xy_pos = x
546
- y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
547
- ref_free = True
548
-
549
-
550
- ##### create mask #####
551
- bsz = x.shape[0]
552
- src_len = x_len + y_len
553
- y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
554
- y_mask = make_pad_mask(y_lens)
555
- x_mask = make_pad_mask(x_lens)
556
-
557
- # (bsz, x_len + y_len)
558
- xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
559
-
560
- x_mask = F.pad(
561
- x_attn_mask,
562
- (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
563
- value=True,
564
- )
565
- y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
566
- torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
567
- (x_len, 0),
568
- value=False,
569
- )
570
-
571
- xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
572
- # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
573
- xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
574
- xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
575
- xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
576
- new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
577
- xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
578
-
579
- ###### decode #####
580
- y_list = [None]*y.shape[0]
581
- batch_idx_map = list(range(y.shape[0]))
582
- idx_list = [None]*y.shape[0]
583
- for idx in tqdm(range(1500)):
584
- if idx == 0:
585
- xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
586
- else:
587
- xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
588
-
589
- logits = self.ar_predict_layer(
590
- xy_dec[:, -1]
591
- )
592
-
593
- if idx == 0:
594
- xy_attn_mask = None
595
- logits = logits[:, :-1]
596
-
597
- samples = sample(
598
- logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
599
- )[0]
600
-
601
- y = torch.concat([y, samples], dim=1)
602
-
603
- ####### 移除batch中已经生成完毕的序列,进一步优化计算量
604
- reserved_idx_of_batch_for_y = None
605
- if (self.EOS in samples[:, 0]) or \
606
- (self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS,则停止
607
- l = samples[:, 0]==self.EOS
608
- removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
609
- reserved_idx_of_batch_for_y = torch.where(l==False)[0]
610
- # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
611
- for i in removed_idx_of_batch_for_y:
612
- batch_index = batch_idx_map[i]
613
- idx_list[batch_index] = idx - 1
614
- y_list[batch_index] = y[i, :-1]
615
-
616
- batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
617
-
618
- # 只保留batch中未生成完毕的序列
619
- if reserved_idx_of_batch_for_y is not None:
620
- # index = torch.LongTensor(batch_idx_map).to(y.device)
621
- y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
622
- if k_cache is not None :
623
- for i in range(len(k_cache)):
624
- k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
625
- v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
626
-
627
-
628
- if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
629
- print("use early stop num:", early_stop_num)
630
- stop = True
631
- for i, batch_index in enumerate(batch_idx_map):
632
- batch_index = batch_idx_map[i]
633
- idx_list[batch_index] = idx
634
- y_list[batch_index] = y[i, :-1]
635
-
636
- if not (None in idx_list):
637
- stop = True
638
-
639
- if stop:
640
- if y.shape[1]==0:
641
- y = torch.concat([y, torch.zeros_like(samples)], dim=1)
642
- print("bad zero prediction")
643
- print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
644
- break
645
-
646
- ####################### update next step ###################################
647
- y_emb = self.ar_audio_embedding(y[:, -1:])
648
- xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
649
-
650
- if (None in idx_list):
651
- for i in range(x.shape[0]):
652
- if idx_list[i] is None:
653
- idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
654
-
655
- if ref_free:
656
- return y_list, [0]*x.shape[0]
657
- return y_list, idx_list
658
-
659
- def infer_panel_batch_only(
660
  self,
661
  x, #####全部文本token
662
  x_lens,
@@ -721,9 +386,7 @@ class Text2SemanticDecoder(nn.Module):
721
  x.device
722
  )
723
 
724
- y_list = [None]*y.shape[0]
725
- batch_idx_map = list(range(y.shape[0]))
726
- idx_list = [None]*y.shape[0]
727
  for idx in tqdm(range(1500)):
728
 
729
  xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
@@ -734,45 +397,17 @@ class Text2SemanticDecoder(nn.Module):
734
  if(idx==0):###第一次跑不能EOS否则没有了
735
  logits = logits[:, :-1] ###刨除1024终止符号的概率
736
  samples = sample(
737
- logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
738
- )[0]
739
  # 本次生成的 semantic_ids 和之前的 y 构成新的 y
740
  # print(samples.shape)#[1,1]#第一个1是bs
741
  y = torch.concat([y, samples], dim=1)
742
 
743
- # 移除已经生成完毕的序列
744
- reserved_idx_of_batch_for_y = None
745
- if (self.EOS in torch.argmax(logits, dim=-1)) or \
746
- (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
747
- l = samples[:, 0]==self.EOS
748
- removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
749
- reserved_idx_of_batch_for_y = torch.where(l==False)[0]
750
- # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
751
- for i in removed_idx_of_batch_for_y:
752
- batch_index = batch_idx_map[i]
753
- idx_list[batch_index] = idx - 1
754
- y_list[batch_index] = y[i, :-1]
755
-
756
- batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
757
-
758
- # 只保留未生成完毕的序列
759
- if reserved_idx_of_batch_for_y is not None:
760
- # index = torch.LongTensor(batch_idx_map).to(y.device)
761
- y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
762
- if cache["y_emb"] is not None:
763
- cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
764
- if cache["k"] is not None:
765
- for i in range(self.num_layers):
766
- # 因为kv转置了,所以batch dim是1
767
- cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
768
- cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
769
-
770
-
771
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
772
  print("use early stop num:", early_stop_num)
773
  stop = True
774
-
775
- if not (None in idx_list):
776
  # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
777
  stop = True
778
  if stop:
@@ -808,12 +443,6 @@ class Text2SemanticDecoder(nn.Module):
808
  xy_attn_mask = torch.zeros(
809
  (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
810
  )
811
-
812
- if (None in idx_list):
813
- for i in range(x.shape[0]):
814
- if idx_list[i] is None:
815
- idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
816
-
817
  if ref_free:
818
- return y_list, [0]*x.shape[0]
819
- return y_list, idx_list
 
1
  # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
  # reference: https://github.com/lifeiteng/vall-e
 
 
 
 
3
  import torch
4
  from tqdm import tqdm
5
 
 
35
  }
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class Text2SemanticDecoder(nn.Module):
39
+ def __init__(self, config, norm_first=False, top_k=3):
40
  super(Text2SemanticDecoder, self).__init__()
41
  self.model_dim = config["model"]["hidden_dim"]
42
  self.embedding_dim = config["model"]["embedding_dim"]
 
88
  multidim_average="global",
89
  ignore_index=self.EOS,
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
93
  x = self.ar_text_embedding(x)
 
321
  # 错位
322
  return targets[:, :-1], targets[:, 1:]
323
 
324
+ def infer_panel(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  self,
326
  x, #####全部文本token
327
  x_lens,
 
386
  x.device
387
  )
388
 
389
+
 
 
390
  for idx in tqdm(range(1500)):
391
 
392
  xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
 
397
  if(idx==0):###第一次跑不能EOS否则没有了
398
  logits = logits[:, :-1] ###刨除1024终止符号的概率
399
  samples = sample(
400
+ logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
401
+ )[0].unsqueeze(0)
402
  # 本次生成的 semantic_ids 和之前的 y 构成新的 y
403
  # print(samples.shape)#[1,1]#第一个1是bs
404
  y = torch.concat([y, samples], dim=1)
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
407
  print("use early stop num:", early_stop_num)
408
  stop = True
409
+
410
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
411
  # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
412
  stop = True
413
  if stop:
 
443
  xy_attn_mask = torch.zeros(
444
  (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
445
  )
 
 
 
 
 
 
446
  if ref_free:
447
+ return y[:, :-1], 0
448
+ return y[:, :-1], idx-1
AR/models/t2s_model_batch_only.py DELETED
@@ -1,483 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
- import torch
3
- from tqdm import tqdm
4
-
5
- from AR.models.utils import make_pad_mask
6
- from AR.models.utils import (
7
- topk_sampling,
8
- sample,
9
- logits_to_probs,
10
- multinomial_sample_one_no_sync,
11
- dpo_loss,
12
- make_reject_y,
13
- get_batch_logps
14
- )
15
- from AR.modules.embedding import SinePositionalEmbedding
16
- from AR.modules.embedding import TokenEmbedding
17
- from AR.modules.transformer import LayerNorm
18
- from AR.modules.transformer import TransformerEncoder
19
- from AR.modules.transformer import TransformerEncoderLayer
20
- from torch import nn
21
- from torch.nn import functional as F
22
- from torchmetrics.classification import MulticlassAccuracy
23
-
24
- default_config = {
25
- "embedding_dim": 512,
26
- "hidden_dim": 512,
27
- "num_head": 8,
28
- "num_layers": 12,
29
- "num_codebook": 8,
30
- "p_dropout": 0.0,
31
- "vocab_size": 1024 + 1,
32
- "phoneme_vocab_size": 512,
33
- "EOS": 1024,
34
- }
35
-
36
-
37
- class Text2SemanticDecoder(nn.Module):
38
- def __init__(self, config, norm_first=False, top_k=3):
39
- super(Text2SemanticDecoder, self).__init__()
40
- self.model_dim = config["model"]["hidden_dim"]
41
- self.embedding_dim = config["model"]["embedding_dim"]
42
- self.num_head = config["model"]["head"]
43
- self.num_layers = config["model"]["n_layer"]
44
- self.norm_first = norm_first
45
- self.vocab_size = config["model"]["vocab_size"]
46
- self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
47
- self.p_dropout = config["model"]["dropout"]
48
- self.EOS = config["model"]["EOS"]
49
- self.norm_first = norm_first
50
- assert self.EOS == self.vocab_size - 1
51
- # should be same as num of kmeans bin
52
- # assert self.EOS == 1024
53
- self.bert_proj = nn.Linear(1024, self.embedding_dim)
54
- self.ar_text_embedding = TokenEmbedding(
55
- self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
56
- )
57
- self.ar_text_position = SinePositionalEmbedding(
58
- self.embedding_dim, dropout=0.1, scale=False, alpha=True
59
- )
60
- self.ar_audio_embedding = TokenEmbedding(
61
- self.embedding_dim, self.vocab_size, self.p_dropout
62
- )
63
- self.ar_audio_position = SinePositionalEmbedding(
64
- self.embedding_dim, dropout=0.1, scale=False, alpha=True
65
- )
66
-
67
- self.h = TransformerEncoder(
68
- TransformerEncoderLayer(
69
- d_model=self.model_dim,
70
- nhead=self.num_head,
71
- dim_feedforward=self.model_dim * 4,
72
- dropout=0.1,
73
- batch_first=True,
74
- norm_first=norm_first,
75
- ),
76
- num_layers=self.num_layers,
77
- norm=LayerNorm(self.model_dim) if norm_first else None,
78
- )
79
-
80
- self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
81
- self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
82
-
83
- self.ar_accuracy_metric = MulticlassAccuracy(
84
- self.vocab_size,
85
- top_k=top_k,
86
- average="micro",
87
- multidim_average="global",
88
- ignore_index=self.EOS,
89
- )
90
-
91
- def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
92
- x = self.ar_text_embedding(x)
93
- x = x + self.bert_proj(bert_feature.transpose(1, 2))
94
- x = self.ar_text_position(x)
95
- x_mask = make_pad_mask(x_lens)
96
-
97
- y_mask = make_pad_mask(y_lens)
98
- y_mask_int = y_mask.type(torch.int64)
99
- codes = y.type(torch.int64) * (1 - y_mask_int)
100
-
101
- # Training
102
- # AR Decoder
103
- y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
104
- x_len = x_lens.max()
105
- y_len = y_lens.max()
106
- y_emb = self.ar_audio_embedding(y)
107
- y_pos = self.ar_audio_position(y_emb)
108
-
109
- xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
110
-
111
- ar_xy_padding_mask = xy_padding_mask
112
-
113
- x_attn_mask = F.pad(
114
- torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
115
- (0, y_len),
116
- value=True,
117
- )
118
-
119
- y_attn_mask = F.pad(
120
- torch.triu(
121
- torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
122
- diagonal=1,
123
- ),
124
- (x_len, 0),
125
- value=False,
126
- )
127
-
128
- xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
129
- bsz, src_len = x.shape[0], x_len + y_len
130
- _xy_padding_mask = (
131
- ar_xy_padding_mask.view(bsz, 1, 1, src_len)
132
- .expand(-1, self.num_head, -1, -1)
133
- .reshape(bsz * self.num_head, 1, src_len)
134
- )
135
- xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
136
- new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
137
- new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
138
- xy_attn_mask = new_attn_mask
139
- # x 和完整的 y 一次性输入模型
140
- xy_pos = torch.concat([x, y_pos], dim=1)
141
-
142
- return xy_pos, xy_attn_mask, targets
143
-
144
- def forward(self, x, x_lens, y, y_lens, bert_feature):
145
- """
146
- x: phoneme_ids
147
- y: semantic_ids
148
- """
149
-
150
- reject_y, reject_y_lens = make_reject_y(y, y_lens)
151
-
152
- xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
153
-
154
- xy_dec, _ = self.h(
155
- (xy_pos, None),
156
- mask=xy_attn_mask,
157
- )
158
- x_len = x_lens.max()
159
- logits = self.ar_predict_layer(xy_dec[:, x_len:])
160
-
161
- ###### DPO #############
162
- reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
163
-
164
- reject_xy_dec, _ = self.h(
165
- (reject_xy_pos, None),
166
- mask=reject_xy_attn_mask,
167
- )
168
- x_len = x_lens.max()
169
- reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
170
-
171
- # loss
172
- # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
173
-
174
- loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
175
- acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
176
-
177
- A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
178
- loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
179
-
180
- loss = loss_1 + loss_2
181
-
182
- return loss, acc
183
-
184
- def forward_old(self, x, x_lens, y, y_lens, bert_feature):
185
- """
186
- x: phoneme_ids
187
- y: semantic_ids
188
- """
189
- x = self.ar_text_embedding(x)
190
- x = x + self.bert_proj(bert_feature.transpose(1, 2))
191
- x = self.ar_text_position(x)
192
- x_mask = make_pad_mask(x_lens)
193
-
194
- y_mask = make_pad_mask(y_lens)
195
- y_mask_int = y_mask.type(torch.int64)
196
- codes = y.type(torch.int64) * (1 - y_mask_int)
197
-
198
- # Training
199
- # AR Decoder
200
- y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
201
- x_len = x_lens.max()
202
- y_len = y_lens.max()
203
- y_emb = self.ar_audio_embedding(y)
204
- y_pos = self.ar_audio_position(y_emb)
205
-
206
- xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
207
- ar_xy_padding_mask = xy_padding_mask
208
-
209
- x_attn_mask = F.pad(
210
- torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
211
- (0, y_len),
212
- value=True,
213
- )
214
- y_attn_mask = F.pad(
215
- torch.triu(
216
- torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
217
- diagonal=1,
218
- ),
219
- (x_len, 0),
220
- value=False,
221
- )
222
- xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
223
- bsz, src_len = x.shape[0], x_len + y_len
224
- _xy_padding_mask = (
225
- ar_xy_padding_mask.view(bsz, 1, 1, src_len)
226
- .expand(-1, self.num_head, -1, -1)
227
- .reshape(bsz * self.num_head, 1, src_len)
228
- )
229
- xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
230
- new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
231
- new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
232
- xy_attn_mask = new_attn_mask
233
- # x 和完整的 y 一次性输入模型
234
- xy_pos = torch.concat([x, y_pos], dim=1)
235
- xy_dec, _ = self.h(
236
- (xy_pos, None),
237
- mask=xy_attn_mask,
238
- )
239
- logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
240
- # loss
241
- # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
242
- loss = F.cross_entropy(logits, targets, reduction="sum")
243
- acc = self.ar_accuracy_metric(logits.detach(), targets).item()
244
- return loss, acc
245
-
246
- # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
247
- def infer(
248
- self,
249
- x,
250
- x_lens,
251
- prompts,
252
- bert_feature,
253
- top_k: int = -100,
254
- early_stop_num: int = -1,
255
- temperature: float = 1.0,
256
- ):
257
- x = self.ar_text_embedding(x)
258
- x = x + self.bert_proj(bert_feature.transpose(1, 2))
259
- x = self.ar_text_position(x)
260
-
261
- # AR Decoder
262
- y = prompts
263
- prefix_len = y.shape[1]
264
- x_len = x.shape[1]
265
- x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
266
- stop = False
267
- for _ in tqdm(range(1500)):
268
- y_emb = self.ar_audio_embedding(y)
269
- y_pos = self.ar_audio_position(y_emb)
270
- # x 和逐渐增长的 y 一起输入给模型
271
- xy_pos = torch.concat([x, y_pos], dim=1)
272
- y_len = y.shape[1]
273
- x_attn_mask_pad = F.pad(
274
- x_attn_mask,
275
- (0, y_len),
276
- value=True,
277
- )
278
- y_attn_mask = F.pad(
279
- torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
280
- (x_len, 0),
281
- value=False,
282
- )
283
- xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
284
- y.device
285
- )
286
-
287
- xy_dec, _ = self.h(
288
- (xy_pos, None),
289
- mask=xy_attn_mask,
290
- )
291
- logits = self.ar_predict_layer(xy_dec[:, -1])
292
- samples = topk_sampling(
293
- logits, top_k=top_k, top_p=1.0, temperature=temperature
294
- )
295
-
296
- if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
297
- print("use early stop num:", early_stop_num)
298
- stop = True
299
-
300
- if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
301
- # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
302
- stop = True
303
- if stop:
304
- if prompts.shape[1] == y.shape[1]:
305
- y = torch.concat([y, torch.zeros_like(samples)], dim=1)
306
- print("bad zero prediction")
307
- print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
308
- break
309
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
310
- # print(samples.shape)#[1,1]#第一个1是bs
311
- # import os
312
- # os._exit(2333)
313
- y = torch.concat([y, samples], dim=1)
314
- return y
315
-
316
- def pad_y_eos(self, y, y_mask_int, eos_id):
317
- targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
318
- y_mask_int, (0, 1), value=1
319
- )
320
- # 错位
321
- return targets[:, :-1], targets[:, 1:]
322
-
323
- def infer_panel(
324
- self,
325
- x, #####全部文本token
326
- x_lens,
327
- prompts, ####参考音频token
328
- bert_feature,
329
- top_k: int = -100,
330
- top_p: int = 100,
331
- early_stop_num: int = -1,
332
- temperature: float = 1.0,
333
- ):
334
- x = self.ar_text_embedding(x)
335
- x = x + self.bert_proj(bert_feature.transpose(1, 2))
336
- x = self.ar_text_position(x)
337
-
338
- # AR Decoder
339
- y = prompts
340
-
341
- x_len = x.shape[1]
342
- x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
343
- stop = False
344
- # print(1111111,self.num_layers)
345
- cache = {
346
- "all_stage": self.num_layers,
347
- "k": [None] * self.num_layers, ###根据配置自己手写
348
- "v": [None] * self.num_layers,
349
- # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
350
- "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
351
- # "logits":None,###原版就已经只对结尾求再拼接了,不用管
352
- # "xy_dec":None,###不需要,本来只需要最后一个做logits
353
- "first_infer": 1,
354
- "stage": 0,
355
- }
356
- ################### first step ##########################
357
- if y is not None:
358
- y_emb = self.ar_audio_embedding(y)
359
- y_len = y_emb.shape[1]
360
- prefix_len = y.shape[1]
361
- y_pos = self.ar_audio_position(y_emb)
362
- xy_pos = torch.concat([x, y_pos], dim=1)
363
- cache["y_emb"] = y_emb
364
- ref_free = False
365
- else:
366
- y_emb = None
367
- y_len = 0
368
- prefix_len = 0
369
- y_pos = None
370
- xy_pos = x
371
- y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
372
- ref_free = True
373
-
374
- x_attn_mask_pad = F.pad(
375
- x_attn_mask,
376
- (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
377
- value=True,
378
- )
379
- y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
380
- torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
381
- (x_len, 0),
382
- value=False,
383
- )
384
- xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
385
- x.device
386
- )
387
-
388
- y_list = [None]*y.shape[0]
389
- batch_idx_map = list(range(y.shape[0]))
390
- idx_list = [None]*y.shape[0]
391
- for idx in tqdm(range(1500)):
392
-
393
- xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
394
- logits = self.ar_predict_layer(
395
- xy_dec[:, -1]
396
- ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
397
- # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
398
- if(idx==0):###第一次跑不能EOS否则没有了
399
- logits = logits[:, :-1] ###刨除1024终止符号的概率
400
- samples = sample(
401
- logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
402
- )[0]
403
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
404
- # print(samples.shape)#[1,1]#第一个1是bs
405
- y = torch.concat([y, samples], dim=1)
406
-
407
- # 移除已经生成完毕的序列
408
- reserved_idx_of_batch_for_y = None
409
- if (self.EOS in torch.argmax(logits, dim=-1)) or \
410
- (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
411
- l = samples[:, 0]==self.EOS
412
- removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
413
- reserved_idx_of_batch_for_y = torch.where(l==False)[0]
414
- # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
415
- for i in removed_idx_of_batch_for_y:
416
- batch_index = batch_idx_map[i]
417
- idx_list[batch_index] = idx - 1
418
- y_list[batch_index] = y[i, :-1]
419
-
420
- batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
421
-
422
- # 只保留未生成完毕的序列
423
- if reserved_idx_of_batch_for_y is not None:
424
- # index = torch.LongTensor(batch_idx_map).to(y.device)
425
- y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
426
- if cache["y_emb"] is not None:
427
- cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
428
- if cache["k"] is not None:
429
- for i in range(self.num_layers):
430
- # 因为kv转置了,所以batch dim是1
431
- cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
432
- cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
433
-
434
-
435
- if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
436
- print("use early stop num:", early_stop_num)
437
- stop = True
438
-
439
- if not (None in idx_list):
440
- # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
441
- stop = True
442
- if stop:
443
- # if prompts.shape[1] == y.shape[1]:
444
- # y = torch.concat([y, torch.zeros_like(samples)], dim=1)
445
- # print("bad zero prediction")
446
- if y.shape[1]==0:
447
- y = torch.concat([y, torch.zeros_like(samples)], dim=1)
448
- print("bad zero prediction")
449
- print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
450
- break
451
-
452
- ####################### update next step ###################################
453
- cache["first_infer"] = 0
454
- if cache["y_emb"] is not None:
455
- y_emb = torch.cat(
456
- [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
457
- )
458
- cache["y_emb"] = y_emb
459
- y_pos = self.ar_audio_position(y_emb)
460
- xy_pos = y_pos[:, -1:]
461
- else:
462
- y_emb = self.ar_audio_embedding(y[:, -1:])
463
- cache["y_emb"] = y_emb
464
- y_pos = self.ar_audio_position(y_emb)
465
- xy_pos = y_pos
466
- y_len = y_pos.shape[1]
467
-
468
- ###最右边一列(是错的)
469
- # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
470
- # xy_attn_mask[:,-1]=False
471
- ###最下面一行(是对的)
472
- xy_attn_mask = torch.zeros(
473
- (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
474
- )
475
-
476
- if (None in idx_list):
477
- for i in range(x.shape[0]):
478
- if idx_list[i] is None:
479
- idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
480
-
481
- if ref_free:
482
- return y_list, [0]*x.shape[0]
483
- return y_list, idx_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
AR/models/utils.py CHANGED
@@ -115,17 +115,17 @@ def logits_to_probs(
115
  top_p: Optional[int] = None,
116
  repetition_penalty: float = 1.0,
117
  ):
118
- # if previous_tokens is not None:
119
- # previous_tokens = previous_tokens.squeeze()
120
  # print(logits.shape,previous_tokens.shape)
121
  # pdb.set_trace()
122
  if previous_tokens is not None and repetition_penalty != 1.0:
123
  previous_tokens = previous_tokens.long()
124
- score = torch.gather(logits, dim=1, index=previous_tokens)
125
  score = torch.where(
126
  score < 0, score * repetition_penalty, score / repetition_penalty
127
  )
128
- logits.scatter_(dim=1, index=previous_tokens, src=score)
129
 
130
  if top_p is not None and top_p < 1.0:
131
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@@ -133,9 +133,9 @@ def logits_to_probs(
133
  torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
  )
135
  sorted_indices_to_remove = cum_probs > top_p
136
- sorted_indices_to_remove[:, 0] = False # keep at least one option
137
  indices_to_remove = sorted_indices_to_remove.scatter(
138
- dim=1, index=sorted_indices, src=sorted_indices_to_remove
139
  )
140
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
 
@@ -143,7 +143,7 @@ def logits_to_probs(
143
 
144
  if top_k is not None:
145
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
- pivot = v[: , -1].unsqueeze(-1)
147
  logits = torch.where(logits < pivot, -float("Inf"), logits)
148
 
149
  probs = torch.nn.functional.softmax(logits, dim=-1)
 
115
  top_p: Optional[int] = None,
116
  repetition_penalty: float = 1.0,
117
  ):
118
+ if previous_tokens is not None:
119
+ previous_tokens = previous_tokens.squeeze()
120
  # print(logits.shape,previous_tokens.shape)
121
  # pdb.set_trace()
122
  if previous_tokens is not None and repetition_penalty != 1.0:
123
  previous_tokens = previous_tokens.long()
124
+ score = torch.gather(logits, dim=0, index=previous_tokens)
125
  score = torch.where(
126
  score < 0, score * repetition_penalty, score / repetition_penalty
127
  )
128
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
129
 
130
  if top_p is not None and top_p < 1.0:
131
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
 
133
  torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
  )
135
  sorted_indices_to_remove = cum_probs > top_p
136
+ sorted_indices_to_remove[0] = False # keep at least one option
137
  indices_to_remove = sorted_indices_to_remove.scatter(
138
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
139
  )
140
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
 
 
143
 
144
  if top_k is not None:
145
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
+ pivot = v.select(-1, -1).unsqueeze(-1)
147
  logits = torch.where(logits < pivot, -float("Inf"), logits)
148
 
149
  probs = torch.nn.functional.softmax(logits, dim=-1)
AR/modules/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/__init__.cpython-310.pyc and b/AR/modules/__pycache__/__init__.cpython-310.pyc differ
 
AR/modules/__pycache__/activation.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/activation.cpython-310.pyc and b/AR/modules/__pycache__/activation.cpython-310.pyc differ
 
AR/modules/__pycache__/embedding.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/embedding.cpython-310.pyc and b/AR/modules/__pycache__/embedding.cpython-310.pyc differ
 
AR/modules/__pycache__/lr_schedulers.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc and b/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc differ
 
AR/modules/__pycache__/optim.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/optim.cpython-310.pyc and b/AR/modules/__pycache__/optim.cpython-310.pyc differ
 
AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc and b/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc differ
 
AR/modules/__pycache__/scaling.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/scaling.cpython-310.pyc and b/AR/modules/__pycache__/scaling.cpython-310.pyc differ
 
AR/modules/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/transformer.cpython-310.pyc and b/AR/modules/__pycache__/transformer.cpython-310.pyc differ
 
GPT_SoVITS/configs/tts_infer.yaml DELETED
@@ -1,16 +0,0 @@
1
- custom:
2
- bert_base_path: pretrained_models/chinese-roberta-wwm-ext-large
3
- cnhuhbert_base_path: pretrained_models/chinese-hubert-base
4
- device: cpu
5
- flash_attn_enabled: true
6
- is_half: false
7
- t2s_weights_path: /content/TTS_OWN/MODELS/22/22.ckpt
8
- vits_weights_path: /content/TTS_OWN/MODELS/22/22.pth
9
- default:
10
- bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
11
- cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
12
- device: cpu
13
- flash_attn_enabled: true
14
- is_half: false
15
- t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
16
- vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
TTS_infer_pack/TTS.py DELETED
@@ -1,848 +0,0 @@
1
- from copy import deepcopy
2
- import math
3
- import os, sys
4
- import random
5
- import traceback
6
- now_dir = os.getcwd()
7
- sys.path.append(now_dir)
8
- import ffmpeg
9
- import os
10
- from typing import Generator, List, Union
11
- import numpy as np
12
- import torch
13
- import torch.nn.functional as F
14
- import yaml
15
- from transformers import AutoModelForMaskedLM, AutoTokenizer
16
- from timeit import default_timer as timer
17
-
18
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
19
- from feature_extractor.cnhubert import CNHubert
20
- from module.models import SynthesizerTrn
21
- import librosa
22
- from time import time as ttime
23
- #from tools.i18n.i18n import I18nAuto
24
- from my_utils import load_audio
25
- from module.mel_processing import spectrogram_torch
26
- from TTS_infer_pack.text_segmentation_method import splits
27
- from TTS_infer_pack.TextPreprocessor import TextPreprocessor
28
- #i18n = I18nAuto()
29
- c1=''
30
-
31
- # configs/tts_infer.yaml
32
- """
33
- default:
34
- device: cpu
35
- is_half: false
36
- bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
37
- cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
38
- t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
39
- vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
40
- flash_attn_enabled: true
41
-
42
- custom:
43
- device: cuda
44
- is_half: true
45
- bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
46
- cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
47
- t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
48
- vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
49
- flash_attn_enabled: true
50
-
51
-
52
- """
53
-
54
- # def set_seed(seed):
55
- # random.seed(seed)
56
- # os.environ['PYTHONHASHSEED'] = str(seed)
57
- # np.random.seed(seed)
58
- # torch.manual_seed(seed)
59
- # torch.cuda.manual_seed(seed)
60
- # torch.cuda.manual_seed_all(seed)
61
- # torch.backends.cudnn.deterministic = True
62
- # torch.backends.cudnn.benchmark = False
63
- # torch.backends.cudnn.enabled = True
64
- # set_seed(1234)
65
-
66
- class TTS_Config:
67
- default_configs={
68
- "device": "cpu",
69
- "is_half": False,
70
- "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
71
- "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
72
- "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
73
- "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
74
- "flash_attn_enabled": True
75
- }
76
- configs:dict = None
77
- def __init__(self, configs: Union[dict, str]=None):
78
-
79
- # 设置默认配置文件路径
80
- configs_base_path:str = "GPT_SoVITS/configs/"
81
- os.makedirs(configs_base_path, exist_ok=True)
82
- self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
83
-
84
- if configs in ["", None]:
85
- if not os.path.exists(self.configs_path):
86
- self.save_configs()
87
- print(f"Create default config file at {self.configs_path}")
88
- configs:dict = {"default": deepcopy(self.default_configs)}
89
-
90
- if isinstance(configs, str):
91
- self.configs_path = configs
92
- configs:dict = self._load_configs(self.configs_path)
93
-
94
- assert isinstance(configs, dict)
95
- default_configs:dict = configs.get("default", None)
96
- if default_configs is not None:
97
- self.default_configs = default_configs
98
-
99
- self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
100
-
101
-
102
- self.device = self.configs.get("device", torch.device("cpu"))
103
- self.is_half = self.configs.get("is_half", False)
104
- self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
105
- self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
106
- self.vits_weights_path = self.configs.get("vits_weights_path", None)
107
- self.bert_base_path = self.configs.get("bert_base_path", None)
108
- self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
109
-
110
-
111
- if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
112
- self.t2s_weights_path = self.default_configs['t2s_weights_path']
113
- print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
114
- if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
115
- self.vits_weights_path = self.default_configs['vits_weights_path']
116
- print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
117
- if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
118
- self.bert_base_path = self.default_configs['bert_base_path']
119
- print(f"fall back to default bert_base_path: {self.bert_base_path}")
120
- if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
121
- self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
122
- print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
123
- self.update_configs()
124
-
125
-
126
- self.max_sec = None
127
- self.hz:int = 50
128
- self.semantic_frame_rate:str = "25hz"
129
- self.segment_size:int = 20480
130
- self.filter_length:int = 2048
131
- self.sampling_rate:int = 32000
132
- self.hop_length:int = 640
133
- self.win_length:int = 2048
134
- self.n_speakers:int = 300
135
-
136
- self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
137
- # print(self)
138
-
139
- def _load_configs(self, configs_path: str)->dict:
140
- with open(configs_path, 'r') as f:
141
- configs = yaml.load(f, Loader=yaml.FullLoader)
142
-
143
- return configs
144
-
145
- def save_configs(self, configs_path:str=None)->None:
146
- configs={
147
- "default":self.default_configs,
148
- }
149
- if self.configs is not None:
150
- configs["custom"] = self.update_configs()
151
-
152
- if configs_path is None:
153
- configs_path = self.configs_path
154
- with open(configs_path, 'w') as f:
155
- yaml.dump(configs, f)
156
-
157
- def update_configs(self):
158
- self.config = {
159
- "device" : str(self.device),
160
- "is_half" : self.is_half,
161
- "t2s_weights_path" : self.t2s_weights_path,
162
- "vits_weights_path" : self.vits_weights_path,
163
- "bert_base_path" : self.bert_base_path,
164
- "cnhuhbert_base_path": self.cnhuhbert_base_path,
165
- "flash_attn_enabled" : self.flash_attn_enabled
166
- }
167
- return self.config
168
-
169
- def __str__(self):
170
- self.configs = self.update_configs()
171
- string = "TTS Config".center(100, '-') + '\n'
172
- for k, v in self.configs.items():
173
- string += f"{str(k).ljust(20)}: {str(v)}\n"
174
- string += "-" * 100 + '\n'
175
- return string
176
-
177
- def __repr__(self):
178
- return self.__str__()
179
-
180
-
181
- class TTS:
182
- def __init__(self, configs: Union[dict, str, TTS_Config]):
183
- if isinstance(configs, TTS_Config):
184
- self.configs = configs
185
- else:
186
- self.configs:TTS_Config = TTS_Config(configs)
187
-
188
- self.t2s_model:Text2SemanticLightningModule = None
189
- self.vits_model:SynthesizerTrn = None
190
- self.bert_tokenizer:AutoTokenizer = None
191
- self.bert_model:AutoModelForMaskedLM = None
192
- self.cnhuhbert_model:CNHubert = None
193
-
194
- self._init_models()
195
-
196
- self.text_preprocessor:TextPreprocessor = \
197
- TextPreprocessor(self.bert_model,
198
- self.bert_tokenizer,
199
- self.configs.device)
200
-
201
-
202
- self.prompt_cache:dict = {
203
- "ref_audio_path":None,
204
- "prompt_semantic":None,
205
- "refer_spepc":None,
206
- "prompt_text":None,
207
- "prompt_lang":None,
208
- "phones":None,
209
- "bert_features":None,
210
- "norm_text":None,
211
- }
212
-
213
-
214
- self.stop_flag:bool = False
215
- self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
216
-
217
- def _init_models(self,):
218
- self.init_t2s_weights(self.configs.t2s_weights_path)
219
- self.init_vits_weights(self.configs.vits_weights_path)
220
- self.init_bert_weights(self.configs.bert_base_path)
221
- self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
222
- # self.enable_half_precision(self.configs.is_half)
223
-
224
-
225
-
226
- def init_cnhuhbert_weights(self, base_path: str):
227
- print(f"Loading CNHuBERT weights from {base_path}")
228
- self.cnhuhbert_model = CNHubert(base_path)
229
- self.cnhuhbert_model=self.cnhuhbert_model.eval()
230
- self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
231
- if self.configs.is_half:
232
- self.cnhuhbert_model = self.cnhuhbert_model.half()
233
-
234
-
235
-
236
- def init_bert_weights(self, base_path: str):
237
- print(f"Loading BERT weights from {base_path}")
238
- self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
239
- self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
240
- self.bert_model=self.bert_model.eval()
241
- self.bert_model = self.bert_model.to(self.configs.device)
242
- if self.configs.is_half:
243
- self.bert_model = self.bert_model.half()
244
-
245
-
246
-
247
- def init_vits_weights(self, weights_path: str):
248
-
249
- print(f"Loading VITS weights from {weights_path}")
250
- self.configs.vits_weights_path = weights_path
251
- self.configs.save_configs()
252
- dict_s2 = torch.load(weights_path, map_location=self.configs.device)
253
- hps = dict_s2["config"]
254
- self.configs.filter_length = hps["data"]["filter_length"]
255
- self.configs.segment_size = hps["train"]["segment_size"]
256
- self.configs.sampling_rate = hps["data"]["sampling_rate"]
257
- self.configs.hop_length = hps["data"]["hop_length"]
258
- self.configs.win_length = hps["data"]["win_length"]
259
- self.configs.n_speakers = hps["data"]["n_speakers"]
260
- self.configs.semantic_frame_rate = "25hz"
261
- kwargs = hps["model"]
262
- vits_model = SynthesizerTrn(
263
- self.configs.filter_length // 2 + 1,
264
- self.configs.segment_size // self.configs.hop_length,
265
- n_speakers=self.configs.n_speakers,
266
- **kwargs
267
- )
268
- # if ("pretrained" not in weights_path):
269
- if hasattr(vits_model, "enc_q"):
270
- del vits_model.enc_q
271
-
272
- vits_model = vits_model.to(self.configs.device)
273
- vits_model = vits_model.eval()
274
- vits_model.load_state_dict(dict_s2["weight"], strict=False)
275
- self.vits_model = vits_model
276
- if self.configs.is_half:
277
- self.vits_model = self.vits_model.half()
278
-
279
-
280
- def init_t2s_weights(self, weights_path: str):
281
- print(f"Loading Text2Semantic weights from {weights_path}")
282
- self.configs.t2s_weights_path = weights_path
283
- self.configs.save_configs()
284
- self.configs.hz = 50
285
- dict_s1 = torch.load(weights_path, map_location=self.configs.device)
286
- config = dict_s1["config"]
287
- self.configs.max_sec = config["data"]["max_sec"]
288
- t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
289
- flash_attn_enabled=self.configs.flash_attn_enabled)
290
- t2s_model.load_state_dict(dict_s1["weight"])
291
- t2s_model = t2s_model.to(self.configs.device)
292
- t2s_model = t2s_model.eval()
293
- self.t2s_model = t2s_model
294
- if self.configs.is_half:
295
- self.t2s_model = self.t2s_model.half()
296
-
297
- def enable_half_precision(self, enable: bool = True):
298
- '''
299
- To enable half precision for the TTS model.
300
- Args:
301
- enable: bool, whether to enable half precision.
302
-
303
- '''
304
- if self.configs.device == "cpu" and enable:
305
- print("Half precision is not supported on CPU.")
306
- return
307
-
308
- self.configs.is_half = enable
309
- self.precison = torch.float16 if enable else torch.float32
310
- self.configs.save_configs()
311
- if enable:
312
- if self.t2s_model is not None:
313
- self.t2s_model =self.t2s_model.half()
314
- if self.vits_model is not None:
315
- self.vits_model = self.vits_model.half()
316
- if self.bert_model is not None:
317
- self.bert_model =self.bert_model.half()
318
- if self.cnhuhbert_model is not None:
319
- self.cnhuhbert_model = self.cnhuhbert_model.half()
320
- else:
321
- if self.t2s_model is not None:
322
- self.t2s_model = self.t2s_model.float()
323
- if self.vits_model is not None:
324
- self.vits_model = self.vits_model.float()
325
- if self.bert_model is not None:
326
- self.bert_model = self.bert_model.float()
327
- if self.cnhuhbert_model is not None:
328
- self.cnhuhbert_model = self.cnhuhbert_model.float()
329
-
330
- def set_device(self, device: torch.device):
331
- '''
332
- To set the device for all models.
333
- Args:
334
- device: torch.device, the device to use for all models.
335
- '''
336
- self.configs.device = device
337
- self.configs.save_configs()
338
- if self.t2s_model is not None:
339
- self.t2s_model = self.t2s_model.to(device)
340
- if self.vits_model is not None:
341
- self.vits_model = self.vits_model.to(device)
342
- if self.bert_model is not None:
343
- self.bert_model = self.bert_model.to(device)
344
- if self.cnhuhbert_model is not None:
345
- self.cnhuhbert_model = self.cnhuhbert_model.to(device)
346
-
347
- def set_ref_audio(self, ref_audio_path:str):
348
- '''
349
- To set the reference audio for the TTS model,
350
- including the prompt_semantic and refer_spepc.
351
- Args:
352
- ref_audio_path: str, the path of the reference audio.
353
- '''
354
- self._set_prompt_semantic(ref_audio_path)
355
- self._set_ref_spepc(ref_audio_path)
356
-
357
- def _set_ref_spepc(self, ref_audio_path):
358
- audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
359
- audio = torch.FloatTensor(audio)
360
- audio_norm = audio
361
- audio_norm = audio_norm.unsqueeze(0)
362
- spec = spectrogram_torch(
363
- audio_norm,
364
- self.configs.filter_length,
365
- self.configs.sampling_rate,
366
- self.configs.hop_length,
367
- self.configs.win_length,
368
- center=False,
369
- )
370
- spec = spec.to(self.configs.device)
371
- if self.configs.is_half:
372
- spec = spec.half()
373
- # self.refer_spepc = spec
374
- self.prompt_cache["refer_spepc"] = spec
375
-
376
-
377
- def _set_prompt_semantic(self, ref_wav_path:str):
378
- zero_wav = np.zeros(
379
- int(self.configs.sampling_rate * 0.3),
380
- dtype=np.float16 if self.configs.is_half else np.float32,
381
- )
382
- with torch.no_grad():
383
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
384
- if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
385
- raise OSError("参考音频在3~10秒范围外,请更换!")
386
- wav16k = torch.from_numpy(wav16k)
387
- zero_wav_torch = torch.from_numpy(zero_wav)
388
- wav16k = wav16k.to(self.configs.device)
389
- zero_wav_torch = zero_wav_torch.to(self.configs.device)
390
- if self.configs.is_half:
391
- wav16k = wav16k.half()
392
- zero_wav_torch = zero_wav_torch.half()
393
-
394
- wav16k = torch.cat([wav16k, zero_wav_torch])
395
- hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))[
396
- "last_hidden_state"
397
- ].transpose(
398
- 1, 2
399
- ) # .float()
400
- codes = self.vits_model.extract_latent(hubert_feature)
401
-
402
- prompt_semantic = codes[0, 0].to(self.configs.device)
403
- self.prompt_cache["prompt_semantic"] = prompt_semantic
404
-
405
- def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
406
- seq = sequences[0]
407
- ndim = seq.dim()
408
- if axis < 0:
409
- axis += ndim
410
- dtype:torch.dtype = seq.dtype
411
- pad_value = torch.tensor(pad_value, dtype=dtype)
412
- seq_lengths = [seq.shape[axis] for seq in sequences]
413
- if max_length is None:
414
- max_length = max(seq_lengths)
415
- else:
416
- max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
417
-
418
- padded_sequences = []
419
- for seq, length in zip(sequences, seq_lengths):
420
- padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
421
- padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value)
422
- padded_sequences.append(padded_seq)
423
- batch = torch.stack(padded_sequences)
424
- return batch
425
-
426
- def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True):
427
-
428
- _data:list = []
429
- index_and_len_list = []
430
- for idx, item in enumerate(data):
431
- norm_text_len = len(item["norm_text"])
432
- index_and_len_list.append([idx, norm_text_len])
433
-
434
- batch_index_list = []
435
- if split_bucket:
436
- index_and_len_list.sort(key=lambda x: x[1])
437
- index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
438
-
439
- batch_index_list_len = 0
440
- pos = 0
441
- while pos <index_and_len_list.shape[0]:
442
- # batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
443
- pos_end = min(pos+batch_size,index_and_len_list.shape[0])
444
- while pos < pos_end:
445
- batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
446
- score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
447
- if (score>=threshold) or (pos_end-pos==1):
448
- batch_index=index_and_len_list[pos:pos_end, 0].tolist()
449
- batch_index_list_len += len(batch_index)
450
- batch_index_list.append(batch_index)
451
- pos = pos_end
452
- break
453
- pos_end=pos_end-1
454
-
455
- assert batch_index_list_len == len(data)
456
-
457
- else:
458
- for i in range(len(data)):
459
- if i%batch_size == 0:
460
- batch_index_list.append([])
461
- batch_index_list[-1].append(i)
462
-
463
-
464
- for batch_idx, index_list in enumerate(batch_index_list):
465
- item_list = [data[idx] for idx in index_list]
466
- phones_list = []
467
- phones_len_list = []
468
- # bert_features_list = []
469
- all_phones_list = []
470
- all_phones_len_list = []
471
- all_bert_features_list = []
472
- norm_text_batch = []
473
- bert_max_len = 0
474
- phones_max_len = 0
475
- for item in item_list:
476
- if prompt_data is not None:
477
- all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
478
- .to(dtype=self.precison)
479
- all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
480
- phones = torch.LongTensor(item["phones"])
481
- # norm_text = prompt_data["norm_text"]+item["norm_text"]
482
- else:
483
- all_bert_features = item["bert_features"]\
484
- .to(dtype=self.precison)
485
- phones = torch.LongTensor(item["phones"])
486
- all_phones = phones
487
- # norm_text = item["norm_text"]
488
-
489
- bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
490
- phones_max_len = max(phones_max_len, phones.shape[-1])
491
-
492
- phones_list.append(phones)
493
- phones_len_list.append(phones.shape[-1])
494
- all_phones_list.append(all_phones)
495
- all_phones_len_list.append(all_phones.shape[-1])
496
- all_bert_features_list.append(all_bert_features)
497
- norm_text_batch.append(item["norm_text"])
498
-
499
- phones_batch = phones_list
500
- max_len = max(bert_max_len, phones_max_len)
501
- # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
502
- all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
503
- # all_bert_features_batch = all_bert_features_list
504
- all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
505
- for idx, item in enumerate(all_bert_features_list):
506
- all_bert_features_batch[idx, :, : item.shape[-1]] = item
507
-
508
- batch = {
509
- "phones": phones_batch,
510
- "phones_len": torch.LongTensor(phones_len_list),
511
- "all_phones": all_phones_batch,
512
- "all_phones_len": torch.LongTensor(all_phones_len_list),
513
- "all_bert_features": all_bert_features_batch,
514
- "norm_text": norm_text_batch
515
- }
516
- _data.append(batch)
517
-
518
- return _data, batch_index_list
519
-
520
- def recovery_order(self, data:list, batch_index_list:list)->list:
521
- '''
522
- Recovery the order of the audio according to the batch_index_list.
523
-
524
- Args:
525
- data (List[list(np.ndarray)]): the out of order audio .
526
- batch_index_list (List[list[int]]): the batch index list.
527
-
528
- Returns:
529
- list (List[np.ndarray]): the data in the original order.
530
- '''
531
- lenght = len(sum(batch_index_list, []))
532
- _data = [None]*lenght
533
- for i, index_list in enumerate(batch_index_list):
534
- for j, index in enumerate(index_list):
535
- _data[index] = data[i][j]
536
- return _data
537
-
538
- def stop(self,):
539
- '''
540
- Stop the inference process.
541
- '''
542
- self.stop_flag = True
543
-
544
-
545
- def run(self, inputs:dict):
546
- """
547
- Text to speech inference.
548
-
549
- Args:
550
- inputs (dict):
551
- {
552
- "text": "", # str. text to be synthesized
553
- "text_lang: "", # str. language of the text to be synthesized
554
- "ref_audio_path": "", # str. reference audio path
555
- "prompt_text": "", # str. prompt text for the reference audio
556
- "prompt_lang": "", # str. language of the prompt text for the reference audio
557
- "top_k": 5, # int. top k sampling
558
- "top_p": 1, # float. top p sampling
559
- "temperature": 1, # float. temperature for sampling
560
- "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
561
- "batch_size": 1, # int. batch size for inference
562
- "batch_threshold": 0.75, # float. threshold for batch splitting.
563
- "split_bucket: True, # bool. whether to split the batch into multiple buckets.
564
- "return_fragment": False, # bool. step by step return the audio fragment.
565
- "speed_factor":1.0, # float. control the speed of the synthesized audio.
566
- }
567
- returns:
568
- tulpe[int, np.ndarray]: sampling rate and audio data.
569
- """
570
- global c1
571
- c1=timer()
572
- ########## variables initialization ###########
573
- self.stop_flag:bool = False
574
- text:str = inputs.get("text", "")
575
- text_lang:str = inputs.get("text_lang", "")
576
- ref_audio_path:str = inputs.get("ref_audio_path", "")
577
- prompt_text:str = inputs.get("prompt_text", "")
578
- prompt_lang:str = inputs.get("prompt_lang", "")
579
- top_k:int = inputs.get("top_k", 5)
580
- top_p:float = inputs.get("top_p", 1)
581
- temperature:float = inputs.get("temperature", 1)
582
- text_split_method:str = inputs.get("text_split_method", "")
583
- batch_size = inputs.get("batch_size", 1)
584
- batch_threshold = inputs.get("batch_threshold", 0.75)
585
- speed_factor = inputs.get("speed_factor", 1.0)
586
- split_bucket = inputs.get("split_bucket", True)
587
- volume = inputs.get("volume", 1.0)
588
- return_fragment = inputs.get("return_fragment", False)
589
-
590
- if return_fragment:
591
- split_bucket = False
592
- print("分段返回模式已开启")
593
- if split_bucket:
594
- split_bucket = False
595
- print("分段返回模式不支持分桶处理,已自动关闭分桶处理")
596
-
597
- if split_bucket:
598
- print("分桶处理模式已开启")
599
-
600
-
601
- no_prompt_text = False
602
- if prompt_text in [None, ""]:
603
- no_prompt_text = True
604
-
605
- assert text_lang in self.configs.langauges
606
- if not no_prompt_text:
607
- assert prompt_lang in self.configs.langauges
608
-
609
- if ref_audio_path in [None, ""] and \
610
- ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)):
611
- raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
612
-
613
-
614
- ###### setting reference audio and prompt text preprocessing ########
615
- t0 = ttime()
616
- if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
617
- self.set_ref_audio(ref_audio_path)
618
-
619
- if not no_prompt_text:
620
- prompt_text = prompt_text.strip("\n")
621
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "."
622
- print("实际输入的参考文本:", prompt_text)
623
- if self.prompt_cache["prompt_text"] != prompt_text:
624
- self.prompt_cache["prompt_text"] = prompt_text
625
- self.prompt_cache["prompt_lang"] = prompt_lang
626
- phones, bert_features, norm_text = \
627
- self.text_preprocessor.segment_and_extract_feature_for_text(
628
- prompt_text,
629
- prompt_lang)
630
- self.prompt_cache["phones"] = phones
631
- self.prompt_cache["bert_features"] = bert_features
632
- self.prompt_cache["norm_text"] = norm_text
633
-
634
-
635
- ###### text preprocessing ########
636
- data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
637
- if len(data) == 0:
638
- yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
639
- dtype=np.int16)
640
- return
641
-
642
- t1 = ttime()
643
- data, batch_index_list = self.to_batch(data,
644
- prompt_data=self.prompt_cache if not no_prompt_text else None,
645
- batch_size=batch_size,
646
- threshold=batch_threshold,
647
- split_bucket=split_bucket
648
- )
649
- t2 = ttime()
650
- try:
651
- print("############ 推理 ############")
652
- ###### inference ######
653
- t_34 = 0.0
654
- t_45 = 0.0
655
- audio = []
656
- for item in data:
657
- t3 = ttime()
658
- batch_phones = item["phones"]
659
- batch_phones_len = item["phones_len"]
660
- all_phoneme_ids = item["all_phones"]
661
- all_phoneme_lens = item["all_phones_len"]
662
- all_bert_features = item["all_bert_features"]
663
- norm_text = item["norm_text"]
664
-
665
- # batch_phones = batch_phones.to(self.configs.device)
666
- batch_phones_len = batch_phones_len.to(self.configs.device)
667
- all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
668
- all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
669
- all_bert_features = all_bert_features.to(self.configs.device)
670
- if self.configs.is_half:
671
- all_bert_features = all_bert_features.half()
672
-
673
- print("前端处理后的文本(每句):", norm_text)
674
- if no_prompt_text :
675
- prompt = None
676
- else:
677
- prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
678
-
679
- with torch.no_grad():
680
- pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
681
- all_phoneme_ids,
682
- all_phoneme_lens,
683
- prompt,
684
- all_bert_features,
685
- # prompt_phone_len=ph_offset,
686
- top_k=top_k,
687
- top_p=top_p,
688
- temperature=temperature,
689
- early_stop_num=self.configs.hz * self.configs.max_sec,
690
- )
691
- t4 = ttime()
692
- t_34 += t4 - t3
693
-
694
- refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
695
- .to(dtype=self.precison, device=self.configs.device)
696
-
697
- batch_audio_fragment = []
698
-
699
- # ## vits并行推理 method 1
700
- # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
701
- # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
702
- # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
703
- # max_len = 0
704
- # for i in range(0, len(batch_phones)):
705
- # max_len = max(max_len, batch_phones[i].shape[-1])
706
- # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
707
- # batch_phones = batch_phones.to(self.configs.device)
708
- # batch_audio_fragment = (self.vits_model.batched_decode(
709
- # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
710
- # ))
711
-
712
- # ## vits并行推理 method 2
713
- pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
714
- upsample_rate = math.prod(self.vits_model.upsample_rates)
715
- audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
716
- audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
717
- all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
718
- _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
719
- _batch_audio_fragment = (self.vits_model.decode(
720
- all_pred_semantic, _batch_phones,refer_audio_spepc
721
- ).detach()[0, 0, :])
722
- audio_frag_end_idx.insert(0, 0)
723
- batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
724
-
725
-
726
- # ## vits串行推理
727
- # for i, idx in enumerate(idx_list):
728
- # phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
729
- # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
730
- # audio_fragment =(self.vits_model.decode(
731
- # _pred_semantic, phones, refer_audio_spepc
732
- # ).detach()[0, 0, :])
733
- # batch_audio_fragment.append(
734
- # audio_fragment
735
- # ) ###试试重建不带上prompt部分
736
-
737
- t5 = ttime()
738
- t_45 += t5 - t4
739
- if return_fragment:
740
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
741
- yield self.audio_postprocess([batch_audio_fragment],
742
- self.configs.sampling_rate,
743
- batch_index_list,
744
- speed_factor,
745
- split_bucket,volume)
746
- else:
747
- audio.append(batch_audio_fragment)
748
-
749
- if self.stop_flag:
750
- yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
751
- dtype=np.int16)
752
- return
753
-
754
- if not return_fragment:
755
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
756
- yield self.audio_postprocess(audio,
757
- self.configs.sampling_rate,
758
- batch_index_list,
759
- speed_factor,
760
- split_bucket,volume)
761
- except Exception as e:
762
- traceback.print_exc()
763
- # 必须返回一个空音频, 否则会导致显存不释放。
764
- yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
765
- dtype=np.int16)
766
- # 重置模型, 否则会导致显存释放不完全。
767
- del self.t2s_model
768
- del self.vits_model
769
- self.t2s_model = None
770
- self.vits_model = None
771
- self.init_t2s_weights(self.configs.t2s_weights_path)
772
- self.init_vits_weights(self.configs.vits_weights_path)
773
- finally:
774
- self.empty_cache()
775
-
776
- def empty_cache(self):
777
- try:
778
- if str(self.configs.device) == "cuda":
779
- torch.cuda.empty_cache()
780
- elif str(self.configs.device) == "mps":
781
- torch.mps.empty_cache()
782
- except:
783
- pass
784
-
785
- def audio_postprocess(self,
786
- audio:List[torch.Tensor],
787
- sr:int,
788
- batch_index_list:list=None,
789
- speed_factor:float=1.0,
790
- split_bucket:bool=True,
791
- volume: float = 1.0)->tuple[int, np.ndarray]:
792
- zero_wav = torch.zeros(
793
- int(self.configs.sampling_rate * 0.3),
794
- dtype=self.precison,
795
- device=self.configs.device
796
- )
797
-
798
- for i, batch in enumerate(audio):
799
- for j, audio_fragment in enumerate(batch):
800
- max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
801
- if max_audio>1: audio_fragment/=max_audio
802
- audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
803
- audio_fragment = audio_fragment * volume
804
- audio[i][j] = audio_fragment.cpu().numpy()
805
-
806
-
807
- if split_bucket:
808
- audio = self.recovery_order(audio, batch_index_list)
809
- else:
810
- # audio = [item for batch in audio for item in batch]
811
- audio = sum(audio, [])
812
-
813
-
814
- audio = np.concatenate(audio, 0)
815
- audio = (audio * 32768).astype(np.int16)
816
-
817
- try:
818
- if speed_factor != 1.0:
819
- audio = speed_change(audio, speed=speed_factor, sr=int(sr))
820
- except Exception as e:
821
- print(f"Failed to change speed of audio: \n{e}")
822
- c2=timer()
823
- print(f'🆗TTS COMPLETE,{round(c2-c1,4)}s')
824
- return sr, audio
825
-
826
-
827
-
828
-
829
- def speed_change(input_audio:np.ndarray, speed:float, sr:int):
830
- # 将 NumPy 数组转换为原始 PCM 流
831
- raw_audio = input_audio.astype(np.int16).tobytes()
832
-
833
- # 设置 ffmpeg 输入流
834
- input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
835
-
836
- # 变速处理
837
- output_stream = input_stream.filter('atempo', speed)
838
-
839
- # 输出流到管道
840
- out, _ = (
841
- output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
842
- .run(input=raw_audio, capture_stdout=True, capture_stderr=True)
843
- )
844
-
845
- # 将管道输出解码为 NumPy 数组
846
- processed_audio = np.frombuffer(out, np.int16)
847
-
848
- return processed_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
TTS_infer_pack/TextPreprocessor.py DELETED
@@ -1,209 +0,0 @@
1
-
2
- import os, sys
3
-
4
- from tqdm import tqdm
5
- now_dir = os.getcwd()
6
- sys.path.append(now_dir)
7
-
8
- import re
9
- import torch
10
- import LangSegment
11
- from typing import Dict, List, Tuple
12
- from text.cleaner import clean_text
13
- from text import cleaned_text_to_sequence
14
- from transformers import AutoModelForMaskedLM, AutoTokenizer
15
- from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
16
-
17
- #from tools.i18n.i18n import I18nAuto
18
- #i18n = I18nAuto()
19
-
20
- def get_first(text:str) -> str:
21
- pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
22
- text = re.split(pattern, text)[0].strip()
23
- return text
24
-
25
- def merge_short_text_in_array(texts:str, threshold:int) -> list:
26
- if (len(texts)) < 2:
27
- return texts
28
- result = []
29
- text = ""
30
- for ele in texts:
31
- text += ele
32
- if len(text) >= threshold:
33
- result.append(text)
34
- text = ""
35
- if (len(text) > 0):
36
- if len(result) == 0:
37
- result.append(text)
38
- else:
39
- result[len(result) - 1] += text
40
- return result
41
-
42
-
43
-
44
-
45
-
46
-
47
- class TextPreprocessor:
48
- def __init__(self, bert_model:AutoModelForMaskedLM,
49
- tokenizer:AutoTokenizer, device:torch.device):
50
- self.bert_model = bert_model
51
- self.tokenizer = tokenizer
52
- self.device = device
53
-
54
- def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
55
- print("############ 切分文本 ############")
56
- texts = self.pre_seg_text(text, lang, text_split_method)
57
- result = []
58
- print("############ 提取文本Bert特征 ############")
59
- for text in tqdm(texts):
60
- phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
61
- if phones is None:
62
- continue
63
- res={
64
- "phones": phones,
65
- "bert_features": bert_features,
66
- "norm_text": norm_text,
67
- }
68
- result.append(res)
69
- return result
70
-
71
- def pre_seg_text(self, text:str, lang:str, text_split_method:str):
72
- text = text.strip("\n")
73
- if (text[0] not in splits and len(get_first(text)) < 4):
74
- text = "。" + text if lang != "en" else "." + text
75
- print("实际输入的目标文本:")
76
- print(text)
77
-
78
- seg_method = get_seg_method(text_split_method)
79
- text = seg_method(text)
80
-
81
- while "\n\n" in text:
82
- text = text.replace("\n\n", "\n")
83
-
84
- _texts = text.split("\n")
85
- _texts = merge_short_text_in_array(_texts, 5)
86
- texts = []
87
-
88
-
89
- for text in _texts:
90
- # 解决输入目标文本的空行导致报错的问题
91
- if (len(text.strip()) == 0):
92
- continue
93
- if (text[-1] not in splits): text += "。" if lang != "en" else "."
94
-
95
- # 解决句子过长导致Bert报错的问题
96
- if (len(text) > 510):
97
- texts.extend(split_big_text(text))
98
- else:
99
- texts.append(text)
100
-
101
- print("实际输入的目标文本(切句后):")
102
- print(texts)
103
- return texts
104
-
105
- def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
106
- textlist, langlist = self.seg_text(texts, language)
107
- if len(textlist) == 0:
108
- return None, None, None
109
-
110
- phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
111
- return phones, bert_features, norm_text
112
-
113
-
114
- def seg_text(self, text:str, language:str)->Tuple[list, list]:
115
-
116
- textlist=[]
117
- langlist=[]
118
- if language in ["auto", "zh", "ja"]:
119
- LangSegment.setfilters(["zh","ja","en","ko"])
120
- for tmp in LangSegment.getTexts(text):
121
- if tmp["text"] == "":
122
- continue
123
- if tmp["lang"] == "ko":
124
- langlist.append("zh")
125
- elif tmp["lang"] == "en":
126
- langlist.append("en")
127
- else:
128
- # 因无法区别中日文汉字,以用户输入为准
129
- langlist.append(language if language!="auto" else tmp["lang"])
130
- textlist.append(tmp["text"])
131
- elif language == "en":
132
- LangSegment.setfilters(["en"])
133
- formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
134
- while " " in formattext:
135
- formattext = formattext.replace(" ", " ")
136
- if formattext != "":
137
- textlist.append(formattext)
138
- langlist.append("en")
139
-
140
- elif language in ["all_zh","all_ja"]:
141
-
142
- formattext = text
143
- while " " in formattext:
144
- formattext = formattext.replace(" ", " ")
145
- language = language.replace("all_","")
146
- if text == "":
147
- return [],[]
148
- textlist.append(formattext)
149
- langlist.append(language)
150
-
151
- else:
152
- raise ValueError(f"language {language} not supported")
153
-
154
- return textlist, langlist
155
-
156
-
157
- def extract_bert_feature(self, textlist:list, langlist:list):
158
- phones_list = []
159
- bert_feature_list = []
160
- norm_text_list = []
161
- for i in range(len(textlist)):
162
- lang = langlist[i]
163
- phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang)
164
- _bert_feature = self.get_bert_inf(phones, word2ph, norm_text, lang)
165
- # phones_list.append(phones)
166
- phones_list.extend(phones)
167
- norm_text_list.append(norm_text)
168
- bert_feature_list.append(_bert_feature)
169
- bert_feature = torch.cat(bert_feature_list, dim=1)
170
- # phones = sum(phones_list, [])
171
- norm_text = ''.join(norm_text_list)
172
- return phones_list, bert_feature, norm_text
173
-
174
-
175
- def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
176
- with torch.no_grad():
177
- inputs = self.tokenizer(text, return_tensors="pt")
178
- for i in inputs:
179
- inputs[i] = inputs[i].to(self.device)
180
- res = self.bert_model(**inputs, output_hidden_states=True)
181
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
182
- assert len(word2ph) == len(text)
183
- phone_level_feature = []
184
- for i in range(len(word2ph)):
185
- repeat_feature = res[i].repeat(word2ph[i], 1)
186
- phone_level_feature.append(repeat_feature)
187
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
188
- return phone_level_feature.T
189
-
190
- def clean_text_inf(self, text:str, language:str):
191
- phones, word2ph, norm_text = clean_text(text, language)
192
- phones = cleaned_text_to_sequence(phones)
193
- return phones, word2ph, norm_text
194
-
195
- def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
196
- language=language.replace("all_","")
197
- if language == "zh":
198
- feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
199
- else:
200
- feature = torch.zeros(
201
- (1024, len(phones)),
202
- dtype=torch.float32,
203
- ).to(self.device)
204
-
205
- return feature
206
-
207
-
208
-
209
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
TTS_infer_pack/__init__.py DELETED
@@ -1 +0,0 @@
1
- from . import TTS, text_segmentation_method
 
 
TTS_infer_pack/__pycache__/TTS.cpython-310.pyc DELETED
Binary file (21.7 kB)
 
TTS_infer_pack/__pycache__/TextPreprocessor.cpython-310.pyc DELETED
Binary file (6.15 kB)
 
TTS_infer_pack/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (199 Bytes)
 
TTS_infer_pack/__pycache__/text_segmentation_method.cpython-310.pyc DELETED
Binary file (3.67 kB)
 
TTS_infer_pack/text_segmentation_method.py DELETED
@@ -1,152 +0,0 @@
1
-
2
-
3
-
4
-
5
- import re
6
- from typing import Callable
7
- #from tools.i18n.i18n import I18nAuto
8
-
9
- #i18n = I18nAuto()
10
-
11
- METHODS = dict()
12
-
13
- def get_method(name:str)->Callable:
14
- method = METHODS.get(name, None)
15
- if method is None:
16
- raise ValueError(f"Method {name} not found")
17
- return method
18
-
19
- def register_method(name):
20
- def decorator(func):
21
- METHODS[name] = func
22
- return func
23
- return decorator
24
-
25
- splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
26
-
27
- def split_big_text(text, max_len=510):
28
- # 定义全角和半角标点符号
29
- punctuation = "".join(splits)
30
-
31
- # 切割文本
32
- segments = re.split('([' + punctuation + '])', text)
33
-
34
- # 初始化结果列表和当前片段
35
- result = []
36
- current_segment = ''
37
-
38
- for segment in segments:
39
- # 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
40
- if len(current_segment + segment) > max_len:
41
- result.append(current_segment)
42
- current_segment = segment
43
- else:
44
- current_segment += segment
45
-
46
- # 将最后一个片段加入结果列表
47
- if current_segment:
48
- result.append(current_segment)
49
-
50
- return result
51
-
52
-
53
-
54
- def split(todo_text):
55
- todo_text = todo_text.replace("……", "。").replace("——", ",")
56
- if todo_text[-1] not in splits:
57
- todo_text += "。"
58
- i_split_head = i_split_tail = 0
59
- len_text = len(todo_text)
60
- todo_texts = []
61
- while 1:
62
- if i_split_head >= len_text:
63
- break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
64
- if todo_text[i_split_head] in splits:
65
- i_split_head += 1
66
- todo_texts.append(todo_text[i_split_tail:i_split_head])
67
- i_split_tail = i_split_head
68
- else:
69
- i_split_head += 1
70
- return todo_texts
71
-
72
-
73
- # 不切
74
- @register_method("cut0")
75
- def cut0(inp):
76
- return inp
77
-
78
-
79
- # 凑四句一切
80
- @register_method("cut1")
81
- def cut1(inp):
82
- inp = inp.strip("\n")
83
- inps = split(inp)
84
- split_idx = list(range(0, len(inps), 4))
85
- split_idx[-1] = None
86
- if len(split_idx) > 1:
87
- opts = []
88
- for idx in range(len(split_idx) - 1):
89
- opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
90
- else:
91
- opts = [inp]
92
- return "\n".join(opts)
93
-
94
- # 凑50字一切
95
- @register_method("cut2")
96
- def cut2(inp):
97
- inp = inp.strip("\n")
98
- inps = split(inp)
99
- if len(inps) < 2:
100
- return inp
101
- opts = []
102
- summ = 0
103
- tmp_str = ""
104
- for i in range(len(inps)):
105
- summ += len(inps[i])
106
- tmp_str += inps[i]
107
- if summ > 50:
108
- summ = 0
109
- opts.append(tmp_str)
110
- tmp_str = ""
111
- if tmp_str != "":
112
- opts.append(tmp_str)
113
- # print(opts)
114
- if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
115
- opts[-2] = opts[-2] + opts[-1]
116
- opts = opts[:-1]
117
- return "\n".join(opts)
118
-
119
- # 按中文句号。切
120
- @register_method("cut3")
121
- def cut3(inp):
122
- inp = inp.strip("\n")
123
- return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
124
-
125
- #按英文句号.切
126
- @register_method("cut4")
127
- def cut4(inp):
128
- inp = inp.strip("\n")
129
- return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
130
-
131
- # 按标点符号切
132
- # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
133
- @register_method("cut5")
134
- def cut5(inp):
135
- # if not re.search(r'[^\w\s]', inp[-1]):
136
- # inp += '。'
137
- inp = inp.strip("\n")
138
- punds = r'[,.;?!、,。?!;:…]'
139
- items = re.split(f'({punds})', inp)
140
- mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
141
- # 在句子不存在符号或句尾无符号的时候保证文本完整
142
- if len(items)%2 == 1:
143
- mergeitems.append(items[-1])
144
- opt = "\n".join(mergeitems)
145
- return opt
146
-
147
-
148
-
149
- if __name__ == '__main__':
150
- method = get_method("cut5")
151
- print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__pycache__/download.cpython-310.pyc CHANGED
Binary files a/__pycache__/download.cpython-310.pyc and b/__pycache__/download.cpython-310.pyc differ
 
__pycache__/info.cpython-310.pyc CHANGED
Binary files a/__pycache__/info.cpython-310.pyc and b/__pycache__/info.cpython-310.pyc differ
 
__pycache__/my_utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/my_utils.cpython-310.pyc and b/__pycache__/my_utils.cpython-310.pyc differ
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -29,8 +29,6 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
29
  logging.getLogger("multipart").setLevel(logging.WARNING)
30
  from download import *
31
  download()
32
- from TTS_infer_pack.TTS import TTS, TTS_Config
33
- from TTS_infer_pack.text_segmentation_method import get_method
34
 
35
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
36
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
@@ -66,90 +64,533 @@ is_half = eval(
66
  os.environ.get("is_half", "True" if torch.cuda.is_available() else "False")
67
  )
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  dict_language = {
71
- "中文1": "all_zh",
72
- "English": "en",
73
- "日文1": "all_ja",
74
- "中文": "zh",
75
- "日本語": "ja",
76
- "混合": "auto",
77
  }
78
 
79
- cut_method = {
80
- "Do not split/不切":"cut0",
81
- "Split into groups of 4 sentences/四句一切": "cut1",
82
- "Split every 50 characters/50字一切": "cut2",
83
- "Split at CN/JP periods (。)/按中日文句号切": "cut3",
84
- "Split at English periods (.)/按英文句号切": "cut4",
85
- "Split at punctuation marks/按标点切": "cut5",
86
- }
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
90
- tts_config.device = device
91
- tts_config.is_half = is_half
92
- if gpt_path is not None:
93
- tts_config.t2s_weights_path = gpt_path
94
- if sovits_path is not None:
95
- tts_config.vits_weights_path = sovits_path
96
- if cnhubert_base_path is not None:
97
- tts_config.cnhuhbert_base_path = cnhubert_base_path
98
- if bert_path is not None:
99
- tts_config.bert_base_path = bert_path
100
 
101
-
102
- tts_pipline = TTS(tts_config)
103
- gpt_path = tts_config.t2s_weights_path
104
- sovits_path = tts_config.vits_weights_path
105
-
106
-
107
- def inference(text, text_lang,
108
- ref_audio_path, prompt_text,
109
- prompt_lang, top_k,
110
- top_p, temperature,
111
- text_split_method, batch_size,
112
- speed_factor, ref_text_free,
113
- split_bucket,
114
- volume
115
- ):
116
-
117
- if not duration(ref_audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  return None
119
  if text == '':
120
- wprint("Please input text to generate/请输入生成文字")
121
  return None
 
 
122
  text=trim_text(text,text_language)
123
- tts_pipline.init_vits_weights(sovits_path)
124
- tts_pipline.init_t2s_weights(gpt_path)
125
-
 
 
 
126
  try:
127
- lang=dict_language[text_lang]
128
- inputs={
129
- "text": text,
130
- "text_lang": lang,
131
- "ref_audio_path": ref_audio_path,
132
- "prompt_text": prompt_text if not ref_text_free else "",
133
- "prompt_lang": dict_language[prompt_lang],
134
- "top_k": top_k,
135
- "top_p": top_p,
136
- "temperature": temperature,
137
- "text_split_method": cut_method[text_split_method],
138
- "batch_size":int(batch_size),
139
- "speed_factor":float(speed_factor),
140
- "split_bucket":split_bucket,
141
- "volume":volume,
142
- "return_fragment":False,
143
- }
144
-
145
- yield next(tts_pipline.run(inputs))
146
  except KeyError as e:
147
- wprint(f'Unsupported language type:{e}')
148
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  #==========custom functions============
151
 
152
- splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
153
  def tprint(text):
154
  now=datetime.now(tz).strftime('%H:%M:%S')
155
  print(f'UTC+8 - {now} - {text}')
@@ -197,7 +638,7 @@ def trim_text(text,language):
197
  return ' '.join(words[:i+1])
198
  return ' '.join(words[:limit_en])
199
 
200
- else:
201
  if len(text) <= limit_cj:
202
  return text
203
  for i in range(limit_cj, -1, -1):
@@ -222,11 +663,10 @@ def duration(audio_file_path):
222
  return False
223
 
224
  def update_model(choice):
225
- global gpt_path,sovits_path
226
  model_info = models[choice]
227
  gpt_path = abs_path(model_info["gpt_weight"])
228
  sovits_path = abs_path(model_info["sovits_weight"])
229
-
230
  model_name = choice
231
  tone_info = model_info["tones"]["tone1"]
232
  tone_sample_path = abs_path(tone_info["sample"])
@@ -268,7 +708,7 @@ def transcribe(voice):
268
 
269
  time2=timer()
270
  tprint(f'transcribe COMPLETE,{round(time2-time1,4)}s')
271
- tprint(f' \nTranscribe result:\n 🔣Language:{language} \n 🔣Text:{text}' )
272
  return text,language
273
 
274
  def clone_voice(user_voice,user_text,user_lang):
@@ -278,36 +718,29 @@ def clone_voice(user_voice,user_text,user_lang):
278
  wprint("Please enter text to generate/请输入生成文字")
279
  return None
280
  user_text=trim_text(user_text,user_lang)
281
- #global gpt_path, sovits_path
 
282
  gpt_path = abs_path("pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
283
  #tprint(f'Model loaded:{gpt_path}')
284
  sovits_path = abs_path("pretrained_models/s2G488k.pth")
285
  #tprint(f'Model loaded:{sovits_path}')
286
  try:
287
- prompt_text, prompt_lang = transcribe(user_voice)
288
  except UnboundLocalError as e:
289
  wprint(f"The language in the audio cannot be recognized :{str(e)}")
290
  return None
291
- tts_pipline.init_vits_weights(sovits_path)
292
- tts_pipline.init_t2s_weights(gpt_path)
293
- inputs={
294
- "text": user_text,
295
- "text_lang": dict_language[user_lang],
296
- "ref_audio_path": user_voice,
297
- "prompt_text": prompt_text,
298
- "prompt_lang": dict_language[prompt_lang],
299
- "top_k": 5,
300
- "top_p": 1,
301
- "temperature": 1,
302
- "text_split_method": "cut1",
303
- "batch_size":20,
304
- "speed_factor":1.0,
305
- "split_bucket":True,
306
- "volume":1.0,
307
- "return_fragment":False,
308
- }
309
-
310
- yield next(tts_pipline.run(inputs))
311
 
312
  with open('dummy') as f:
313
  dummy_txt = f.read().strip().splitlines()
@@ -395,26 +828,15 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app:
395
 
396
 
397
  with gr.Accordion(label="Additional generation options/附加生成选项", open=False):
398
- with gr.Row():
399
- how_to_cut = gr.Dropdown(
400
- label=("How to split input text?/如何对输入文字切片"),
401
- choices=[("Do not split/不切"), ("Split into groups of 4 sentences/四句一切"), ("Split every 50 characters/50字一切"),
402
- ("Split at CN/JP periods (。)/按中日文句号切"), ("Split at English periods (.)/按英文句号切"), ("Split at punctuation marks/按标点切"), ],
403
- value=("Split into groups of 4 sentences/四句一切"),
404
  interactive=True,
405
- info='A suitable splitting method can achieve better generation results/适合的切片方法会得到更好的效果'
406
  )
407
- split_bucket = gr.Checkbox(label="Split bucket/数据分桶", value=True, info='Speed up the inference process/提升推���速度')
408
- with gr.Row():
409
- volume = gr.Slider(minimum=0.5, maximum=5, value=1, step=0.1, label='Volume/音量',info='audio distortion due to excessive volume/大了要爆音')
410
- speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="Speed factor",value=1.0,info='Playback speed/播放速度')
411
- batch_size = gr.Slider(minimum=1,maximum=100,step=1,label="Batch size",value=20,info='The number of sentences for batch inference./并行推理的句子数量')
412
- with gr.Row():
413
- top_k = gr.Slider(minimum=1,maximum=100,step=1,label="top_k",value=5)
414
- top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label="top_p",value=1)
415
- temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label="temperature",value=1)
416
- ref_text_free = gr.Checkbox(label="REF_TEXT_FREE", value=False, visible=False)
417
-
418
 
419
 
420
  gr.HTML('''
@@ -441,8 +863,7 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app:
441
  user_text= gr.Textbox(label="Text for generation/输入想要生成语音的文字", lines=5,placeholder=plsh,info=limit)
442
  dddice= gr.Button('🎲', variant='tool',min_width=0,scale=0)
443
 
444
- dddice.click(dice, outputs=[user_text, dddice])
445
-
446
  user_text.change( lang_detector, user_text, user_lang)
447
 
448
  user_button = gr.Button("✨Clone Voice", variant="primary")
@@ -456,23 +877,9 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app:
456
  tone_select.change(update_tone, inputs=[model_name, tone_select], outputs=[inp_ref, prompt_text, tone_sample])
457
 
458
  main_button.click(
459
- inference,
460
- inputs=[text,
461
- text_language,
462
- inp_ref,
463
- prompt_text,
464
- prompt_language,
465
- top_k,
466
- top_p,
467
- temperature,
468
- how_to_cut,
469
- batch_size,
470
- speed_factor,
471
- ref_text_free,
472
- split_bucket,
473
- volume],
474
- outputs=[output]
475
- )
476
 
477
  user_button.click(
478
  clone_voice,
 
29
  logging.getLogger("multipart").setLevel(logging.WARNING)
30
  from download import *
31
  download()
 
 
32
 
33
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
34
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
 
64
  os.environ.get("is_half", "True" if torch.cuda.is_available() else "False")
65
  )
66
 
67
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
68
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
69
+ if is_half == True:
70
+ bert_model = bert_model.half().to(device)
71
+ else:
72
+ bert_model = bert_model.to(device)
73
+
74
+
75
+ def get_bert_feature(text, word2ph):
76
+ with torch.no_grad():
77
+ inputs = tokenizer(text, return_tensors="pt")
78
+ for i in inputs:
79
+ inputs[i] = inputs[i].to(device)
80
+ res = bert_model(**inputs, output_hidden_states=True)
81
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
82
+ assert len(word2ph) == len(text)
83
+ phone_level_feature = []
84
+ for i in range(len(word2ph)):
85
+ repeat_feature = res[i].repeat(word2ph[i], 1)
86
+ phone_level_feature.append(repeat_feature)
87
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
88
+ return phone_level_feature.T
89
+
90
+
91
+ class DictToAttrRecursive(dict):
92
+ def __init__(self, input_dict):
93
+ super().__init__(input_dict)
94
+ for key, value in input_dict.items():
95
+ if isinstance(value, dict):
96
+ value = DictToAttrRecursive(value)
97
+ self[key] = value
98
+ setattr(self, key, value)
99
+
100
+ def __getattr__(self, item):
101
+ try:
102
+ return self[item]
103
+ except KeyError:
104
+ raise AttributeError(f"Attribute {item} not found")
105
+
106
+ def __setattr__(self, key, value):
107
+ if isinstance(value, dict):
108
+ value = DictToAttrRecursive(value)
109
+ super(DictToAttrRecursive, self).__setitem__(key, value)
110
+ super().__setattr__(key, value)
111
+
112
+ def __delattr__(self, item):
113
+ try:
114
+ del self[item]
115
+ except KeyError:
116
+ raise AttributeError(f"Attribute {item} not found")
117
+
118
+
119
+ ssl_model = cnhubert.get_model()
120
+ if is_half == True:
121
+ ssl_model = ssl_model.half().to(device)
122
+ else:
123
+ ssl_model = ssl_model.to(device)
124
+
125
+
126
+ def change_sovits_weights(sovits_path):
127
+ global vq_model, hps
128
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
129
+ hps = dict_s2["config"]
130
+ hps = DictToAttrRecursive(hps)
131
+ hps.model.semantic_frame_rate = "25hz"
132
+ vq_model = SynthesizerTrn(
133
+ hps.data.filter_length // 2 + 1,
134
+ hps.train.segment_size // hps.data.hop_length,
135
+ n_speakers=hps.data.n_speakers,
136
+ **hps.model
137
+ )
138
+ if ("pretrained" not in sovits_path):
139
+ del vq_model.enc_q
140
+ if is_half == True:
141
+ vq_model = vq_model.half().to(device)
142
+ else:
143
+ vq_model = vq_model.to(device)
144
+ vq_model.eval()
145
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
146
+ with open("./sweight.txt", "w", encoding="utf-8") as f:
147
+ f.write(sovits_path)
148
+
149
+
150
+ change_sovits_weights(sovits_path)
151
+
152
+
153
+ def change_gpt_weights(gpt_path):
154
+ global hz, max_sec, t2s_model, config
155
+ hz = 50
156
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
157
+ config = dict_s1["config"]
158
+ max_sec = config["data"]["max_sec"]
159
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
160
+ t2s_model.load_state_dict(dict_s1["weight"])
161
+ if is_half == True:
162
+ t2s_model = t2s_model.half()
163
+ t2s_model = t2s_model.to(device)
164
+ t2s_model.eval()
165
+ total = sum([param.nelement() for param in t2s_model.parameters()])
166
+ print("Number of parameter: %.2fM" % (total / 1e6))
167
+ with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
168
+
169
+
170
+ change_gpt_weights(gpt_path)
171
+
172
+
173
+ def get_spepc(hps, filename):
174
+ audio = load_audio(filename, int(hps.data.sampling_rate))
175
+ audio = torch.FloatTensor(audio)
176
+ audio_norm = audio
177
+ audio_norm = audio_norm.unsqueeze(0)
178
+ spec = spectrogram_torch(
179
+ audio_norm,
180
+ hps.data.filter_length,
181
+ hps.data.sampling_rate,
182
+ hps.data.hop_length,
183
+ hps.data.win_length,
184
+ center=False,
185
+ )
186
+ return spec
187
+
188
 
189
  dict_language = {
190
+ ("中文1"): "all_zh",#全部按中文识别
191
+ ("English"): "en",#全部按英文识别#######不变
192
+ ("日文1"): "all_ja",#全部按日文识别
193
+ ("中文"): "zh",#按中英混合识别####不变
194
+ ("日本語"): "ja",#按日英混合识别####不变
195
+ ("混合"): "auto",#多语种启动切分识别语种
196
  }
197
 
 
 
 
 
 
 
 
 
198
 
199
+ def splite_en_inf(sentence, language):
200
+ pattern = re.compile(r'[a-zA-Z ]+')
201
+ textlist = []
202
+ langlist = []
203
+ pos = 0
204
+ for match in pattern.finditer(sentence):
205
+ start, end = match.span()
206
+ if start > pos:
207
+ textlist.append(sentence[pos:start])
208
+ langlist.append(language)
209
+ textlist.append(sentence[start:end])
210
+ langlist.append("en")
211
+ pos = end
212
+ if pos < len(sentence):
213
+ textlist.append(sentence[pos:])
214
+ langlist.append(language)
215
+ # Merge punctuation into previous word
216
+ for i in range(len(textlist)-1, 0, -1):
217
+ if re.match(r'^[\W_]+$', textlist[i]):
218
+ textlist[i-1] += textlist[i]
219
+ del textlist[i]
220
+ del langlist[i]
221
+ # Merge consecutive words with the same language tag
222
+ i = 0
223
+ while i < len(langlist) - 1:
224
+ if langlist[i] == langlist[i+1]:
225
+ textlist[i] += textlist[i+1]
226
+ del textlist[i+1]
227
+ del langlist[i+1]
228
+ else:
229
+ i += 1
230
+
231
+ return textlist, langlist
232
+
233
+
234
+ def clean_text_inf(text, language):
235
+ formattext = ""
236
+ language = language.replace("all_","")
237
+ for tmp in LangSegment.getTexts(text):
238
+ if language == "ja":
239
+ if tmp["lang"] == language or tmp["lang"] == "zh":
240
+ formattext += tmp["text"] + " "
241
+ continue
242
+ if tmp["lang"] == language:
243
+ formattext += tmp["text"] + " "
244
+ while " " in formattext:
245
+ formattext = formattext.replace(" ", " ")
246
+ phones, word2ph, norm_text = clean_text(formattext, language)
247
+ phones = cleaned_text_to_sequence(phones)
248
+ return phones, word2ph, norm_text
249
+
250
+ dtype=torch.float16 if is_half == True else torch.float32
251
+ def get_bert_inf(phones, word2ph, norm_text, language):
252
+ language=language.replace("all_","")
253
+ if language == "zh":
254
+ bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
255
+ else:
256
+ bert = torch.zeros(
257
+ (1024, len(phones)),
258
+ dtype=torch.float16 if is_half == True else torch.float32,
259
+ ).to(device)
260
 
261
+ return bert
 
 
 
 
 
 
 
 
 
 
262
 
263
+
264
+ def nonen_clean_text_inf(text, language):
265
+ if(language!="auto"):
266
+ textlist, langlist = splite_en_inf(text, language)
267
+ else:
268
+ textlist=[]
269
+ langlist=[]
270
+ for tmp in LangSegment.getTexts(text):
271
+ langlist.append(tmp["lang"])
272
+ textlist.append(tmp["text"])
273
+ print(textlist)
274
+ print(langlist)
275
+ phones_list = []
276
+ word2ph_list = []
277
+ norm_text_list = []
278
+ for i in range(len(textlist)):
279
+ lang = langlist[i]
280
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
281
+ phones_list.append(phones)
282
+ if lang == "zh":
283
+ word2ph_list.append(word2ph)
284
+ norm_text_list.append(norm_text)
285
+ print(word2ph_list)
286
+ phones = sum(phones_list, [])
287
+ word2ph = sum(word2ph_list, [])
288
+ norm_text = ' '.join(norm_text_list)
289
+
290
+ return phones, word2ph, norm_text
291
+
292
+
293
+ def nonen_get_bert_inf(text, language):
294
+ if(language!="auto"):
295
+ textlist, langlist = splite_en_inf(text, language)
296
+ else:
297
+ textlist=[]
298
+ langlist=[]
299
+ for tmp in LangSegment.getTexts(text):
300
+ langlist.append(tmp["lang"])
301
+ textlist.append(tmp["text"])
302
+ print(textlist)
303
+ print(langlist)
304
+ bert_list = []
305
+ for i in range(len(textlist)):
306
+ lang = langlist[i]
307
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
308
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
309
+ bert_list.append(bert)
310
+ bert = torch.cat(bert_list, dim=1)
311
+
312
+ return bert
313
+
314
+
315
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
316
+
317
+
318
+ def get_first(text):
319
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
320
+ text = re.split(pattern, text)[0].strip()
321
+ return text
322
+
323
+
324
+ def get_cleaned_text_final(text,language):
325
+ if language in {"en","all_zh","all_ja"}:
326
+ phones, word2ph, norm_text = clean_text_inf(text, language)
327
+ elif language in {"zh", "ja","auto"}:
328
+ phones, word2ph, norm_text = nonen_clean_text_inf(text, language)
329
+ return phones, word2ph, norm_text
330
+
331
+ def get_bert_final(phones, word2ph, text,language,device):
332
+ if language == "en":
333
+ bert = get_bert_inf(phones, word2ph, text, language)
334
+ elif language in {"zh", "ja","auto"}:
335
+ bert = nonen_get_bert_inf(text, language)
336
+ elif language == "all_zh":
337
+ bert = get_bert_feature(text, word2ph).to(device)
338
+ else:
339
+ bert = torch.zeros((1024, len(phones))).to(device)
340
+ return bert
341
+
342
+ def merge_short_text_in_array(texts, threshold):
343
+ if (len(texts)) < 2:
344
+ return texts
345
+ result = []
346
+ text = ""
347
+ for ele in texts:
348
+ text += ele
349
+ if len(text) >= threshold:
350
+ result.append(text)
351
+ text = ""
352
+ if (len(text) > 0):
353
+ if len(result) == 0:
354
+ result.append(text)
355
+ else:
356
+ result[len(result) - 1] += text
357
+ return result
358
+
359
+
360
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=("Do not split"), volume_scale=1.0):
361
+ if not duration(ref_wav_path):
362
  return None
363
  if text == '':
364
+ wprint("Please enter text to generate/请输入生成文字")
365
  return None
366
+ t0 = ttime()
367
+ startTime=timer()
368
  text=trim_text(text,text_language)
369
+ change_sovits_weights(sovits_path)
370
+ tprint(f'🏕️LOADED SoVITS Model: {sovits_path}')
371
+ change_gpt_weights(gpt_path)
372
+ tprint(f'🏕️LOADED GPT Model: {gpt_path}')
373
+
374
+ prompt_language = dict_language[prompt_language]
375
  try:
376
+ text_language = dict_language[text_language]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  except KeyError as e:
378
+ wprint(f"Unsupported language type: {e}")
379
  return None
380
+
381
+ prompt_text = prompt_text.strip("\n")
382
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
383
+ text = text.strip("\n")
384
+ if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
385
+ #print(("实际输入的参考文本:"), prompt_text)
386
+ #print(("📝实际输入的目标文本:"), text)
387
+ zero_wav = np.zeros(
388
+ int(hps.data.sampling_rate * 0.3),
389
+ dtype=np.float16 if is_half == True else np.float32,
390
+ )
391
+ with torch.no_grad():
392
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
393
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
394
+ errinfo='参考音频在3~10秒范围外,请更换!'
395
+ raise OSError((errinfo))
396
+ wav16k = torch.from_numpy(wav16k)
397
+ zero_wav_torch = torch.from_numpy(zero_wav)
398
+ if is_half == True:
399
+ wav16k = wav16k.half().to(device)
400
+ zero_wav_torch = zero_wav_torch.half().to(device)
401
+ else:
402
+ wav16k = wav16k.to(device)
403
+ zero_wav_torch = zero_wav_torch.to(device)
404
+ wav16k = torch.cat([wav16k, zero_wav_torch])
405
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
406
+ "last_hidden_state"
407
+ ].transpose(
408
+ 1, 2
409
+ ) # .float()
410
+ codes = vq_model.extract_latent(ssl_content)
411
+ prompt_semantic = codes[0, 0]
412
+ t1 = ttime()
413
+
414
+ phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
415
+
416
+ if (how_to_cut == ("Split into groups of 4 sentences")):
417
+ text = cut1(text)
418
+ elif (how_to_cut == ("Split every 50 characters")):
419
+ text = cut2(text)
420
+ elif (how_to_cut == ("Split at CN/JP periods (。)")):
421
+ text = cut3(text)
422
+ elif (how_to_cut == ("Split at English periods (.)")):
423
+ text = cut4(text)
424
+ elif (how_to_cut == ("Split at punctuation marks")):
425
+ text = cut5(text)
426
+ while "\n\n" in text:
427
+ text = text.replace("\n\n", "\n")
428
+ print(f"🧨实际输入的目标文本(切句后):{text}\n")
429
+ texts = text.split("\n")
430
+ texts = merge_short_text_in_array(texts, 5)
431
+ audio_opt = []
432
+ bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
433
+
434
+ for text in texts:
435
+ if (len(text.strip()) == 0):
436
+ continue
437
+ if (text[-1] not in splits): text += "。" if text_language != "en" else "."
438
+ print(("\n🎈实际输入的目标文本(每句):"), text)
439
+ phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
440
+ try:
441
+ bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
442
+ except RuntimeError as e:
443
+ wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}")
444
+ return None
445
+ bert = torch.cat([bert1, bert2], 1)
446
+
447
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
448
+ bert = bert.to(device).unsqueeze(0)
449
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
450
+ prompt = prompt_semantic.unsqueeze(0).to(device)
451
+ t2 = ttime()
452
+ with torch.no_grad():
453
+ # pred_semantic = t2s_model.model.infer(
454
+ pred_semantic, idx = t2s_model.model.infer_panel(
455
+ all_phoneme_ids,
456
+ all_phoneme_len,
457
+ prompt,
458
+ bert,
459
+ # prompt_phone_len=ph_offset,
460
+ top_k=config["inference"]["top_k"],
461
+ early_stop_num=hz * max_sec,
462
+ )
463
+ t3 = ttime()
464
+ # print(pred_semantic.shape,idx)
465
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
466
+ 0
467
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
468
+ refer = get_spepc(hps, ref_wav_path) # .to(device)
469
+ if is_half == True:
470
+ refer = refer.half().to(device)
471
+ else:
472
+ refer = refer.to(device)
473
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
474
+ try:
475
+ audio = (
476
+ vq_model.decode(
477
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
478
+ )
479
+ .detach()
480
+ .cpu()
481
+ .numpy()[0, 0]
482
+ )
483
+ except RuntimeError as e:
484
+ wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}")
485
+ return None
486
+
487
+ max_audio=np.abs(audio).max()
488
+ if max_audio>1:audio/=max_audio
489
+ audio_opt.append(audio)
490
+ audio_opt.append(zero_wav)
491
+ t4 = ttime()
492
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
493
+ #yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
494
+ audio_data = (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
495
+
496
+ audio_data = (audio_data.astype(np.float32) * volume_scale).astype(np.int16)
497
+ output_wav = "output_audio.wav"
498
+ sf.write(output_wav, audio_data, hps.data.sampling_rate)
499
+ endTime=timer()
500
+ tprint(f'🆗TTS COMPLETE,{round(endTime-startTime,4)}s')
501
+ return output_wav
502
+
503
+ def split(todo_text):
504
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
505
+ if todo_text[-1] not in splits:
506
+ todo_text += "。"
507
+ i_split_head = i_split_tail = 0
508
+ len_text = len(todo_text)
509
+ todo_texts = []
510
+ while 1:
511
+ if i_split_head >= len_text:
512
+ break
513
+ if todo_text[i_split_head] in splits:
514
+ i_split_head += 1
515
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
516
+ i_split_tail = i_split_head
517
+ else:
518
+ i_split_head += 1
519
+ return todo_texts
520
+
521
+
522
+ def cut1(inp):
523
+ inp = inp.strip("\n")
524
+ inps = split(inp)
525
+ split_idx = list(range(0, len(inps), 4))
526
+ split_idx[-1] = None
527
+ if len(split_idx) > 1:
528
+ opts = []
529
+ for idx in range(len(split_idx) - 1):
530
+ opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
531
+ else:
532
+ opts = [inp]
533
+ return "\n".join(opts)
534
+
535
+
536
+ def cut2(inp):
537
+ inp = inp.strip("\n")
538
+ inps = split(inp)
539
+ if len(inps) < 2:
540
+ return inp
541
+ opts = []
542
+ summ = 0
543
+ tmp_str = ""
544
+ for i in range(len(inps)):
545
+ summ += len(inps[i])
546
+ tmp_str += inps[i]
547
+ if summ > 50:
548
+ summ = 0
549
+ opts.append(tmp_str)
550
+ tmp_str = ""
551
+ if tmp_str != "":
552
+ opts.append(tmp_str)
553
+ # print(opts)
554
+ if len(opts) > 1 and len(opts[-1]) < 50:
555
+ opts[-2] = opts[-2] + opts[-1]
556
+ opts = opts[:-1]
557
+ return "\n".join(opts)
558
+
559
+
560
+ def cut3(inp):
561
+ inp = inp.strip("\n")
562
+ return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
563
+
564
+
565
+ def cut4(inp):
566
+ inp = inp.strip("\n")
567
+ return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
568
+
569
+
570
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
571
+ def cut5(inp):
572
+ # if not re.search(r'[^\w\s]', inp[-1]):
573
+ # inp += '。'
574
+ inp = inp.strip("\n")
575
+ punds = r'[,.;?!、,。?!;:…]'
576
+ items = re.split(f'({punds})', inp)
577
+ mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
578
+ if len(items)%2 == 1:
579
+ mergeitems.append(items[-1])
580
+ opt = "\n".join(mergeitems)
581
+ return opt
582
+
583
+
584
+
585
+ def custom_sort_key(s):
586
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
587
+ parts = re.split('(\d+)', s)
588
+ # 将数字部分转换为整数,非数字部分保持不变
589
+ parts = [int(part) if part.isdigit() else part for part in parts]
590
+ return parts
591
 
592
  #==========custom functions============
593
 
 
594
  def tprint(text):
595
  now=datetime.now(tz).strftime('%H:%M:%S')
596
  print(f'UTC+8 - {now} - {text}')
 
638
  return ' '.join(words[:i+1])
639
  return ' '.join(words[:limit_en])
640
 
641
+ else:#中文日文
642
  if len(text) <= limit_cj:
643
  return text
644
  for i in range(limit_cj, -1, -1):
 
663
  return False
664
 
665
  def update_model(choice):
666
+ global gpt_path, sovits_path
667
  model_info = models[choice]
668
  gpt_path = abs_path(model_info["gpt_weight"])
669
  sovits_path = abs_path(model_info["sovits_weight"])
 
670
  model_name = choice
671
  tone_info = model_info["tones"]["tone1"]
672
  tone_sample_path = abs_path(tone_info["sample"])
 
708
 
709
  time2=timer()
710
  tprint(f'transcribe COMPLETE,{round(time2-time1,4)}s')
711
+ tprint(f'\nTRANSCRIBE RESULT:\n 🔣Language:{language} \n 🔣Text:{text}' )
712
  return text,language
713
 
714
  def clone_voice(user_voice,user_text,user_lang):
 
718
  wprint("Please enter text to generate/请输入生成文字")
719
  return None
720
  user_text=trim_text(user_text,user_lang)
721
+ time1=timer()
722
+ global gpt_path, sovits_path
723
  gpt_path = abs_path("pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
724
  #tprint(f'Model loaded:{gpt_path}')
725
  sovits_path = abs_path("pretrained_models/s2G488k.pth")
726
  #tprint(f'Model loaded:{sovits_path}')
727
  try:
728
+ prompt_text, prompt_language = transcribe(user_voice)
729
  except UnboundLocalError as e:
730
  wprint(f"The language in the audio cannot be recognized :{str(e)}")
731
  return None
732
+
733
+ output_wav = get_tts_wav(
734
+ user_voice,
735
+ prompt_text,
736
+ prompt_language,
737
+ user_text,
738
+ user_lang,
739
+ how_to_cut="Do not split",
740
+ volume_scale=1.0)
741
+ time2=timer()
742
+ tprint(f'🆗CLONE COMPLETE,{round(time2-time1,4)}s')
743
+ return output_wav
 
 
 
 
 
 
 
 
744
 
745
  with open('dummy') as f:
746
  dummy_txt = f.read().strip().splitlines()
 
828
 
829
 
830
  with gr.Accordion(label="Additional generation options/附加生成选项", open=False):
831
+ how_to_cut = gr.Dropdown(
832
+ label=("How to split?"),
833
+ choices=[("Do not split"), ("Split into groups of 4 sentences"), ("Split every 50 characters"),
834
+ ("Split at CN/JP periods (。)"), ("Split at English periods (.)"), ("Split at punctuation marks"), ],
835
+ value=("Split into groups of 4 sentences"),
 
836
  interactive=True,
837
+ info='A suitable splitting method can achieve better generation results'
838
  )
839
+ volume = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.01, label='Volume/音量')
 
 
 
 
 
 
 
 
 
 
840
 
841
 
842
  gr.HTML('''
 
863
  user_text= gr.Textbox(label="Text for generation/输入想要生成语音的文字", lines=5,placeholder=plsh,info=limit)
864
  dddice= gr.Button('🎲', variant='tool',min_width=0,scale=0)
865
 
866
+ dddice.click(dice, outputs=[user_text, dddice])
 
867
  user_text.change( lang_detector, user_text, user_lang)
868
 
869
  user_button = gr.Button("✨Clone Voice", variant="primary")
 
877
  tone_select.change(update_tone, inputs=[model_name, tone_select], outputs=[inp_ref, prompt_text, tone_sample])
878
 
879
  main_button.click(
880
+ get_tts_wav,
881
+ inputs=[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,volume],
882
+ outputs=[output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
 
884
  user_button.click(
885
  clone_voice,
{GPT_SoVITS/configs → configs}/s1.yaml RENAMED
File without changes
{GPT_SoVITS/configs → configs}/s1big.yaml RENAMED
File without changes
{GPT_SoVITS/configs → configs}/s1big2.yaml RENAMED
File without changes
{GPT_SoVITS/configs → configs}/s1longer.yaml RENAMED
File without changes
{GPT_SoVITS/configs → configs}/s1mq.yaml RENAMED
File without changes
{GPT_SoVITS/configs → configs}/s2.json RENAMED
File without changes
{GPT_SoVITS/configs → configs}/train.yaml RENAMED
File without changes
feature_extractor/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/feature_extractor/__pycache__/__init__.cpython-310.pyc and b/feature_extractor/__pycache__/__init__.cpython-310.pyc differ
 
feature_extractor/__pycache__/cnhubert.cpython-310.pyc CHANGED
Binary files a/feature_extractor/__pycache__/cnhubert.cpython-310.pyc and b/feature_extractor/__pycache__/cnhubert.cpython-310.pyc differ
 
feature_extractor/__pycache__/whisper_enc.cpython-310.pyc CHANGED
Binary files a/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc and b/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc differ
 
feature_extractor/cnhubert.py CHANGED
@@ -4,9 +4,9 @@ import librosa
4
  import torch
5
  import torch.nn.functional as F
6
  import soundfile as sf
7
- import logging
8
 
9
- logging.getLogger("numba").setLevel(logging.WARNING)
10
 
11
  from transformers import (
12
  Wav2Vec2FeatureExtractor,
@@ -20,16 +20,13 @@ cnhubert_base_path = None
20
 
21
 
22
  class CNHubert(nn.Module):
23
- def __init__(self, base_path:str=None):
24
  super().__init__()
25
- if base_path is None:
26
- base_path = cnhubert_base_path
27
- self.model = HubertModel.from_pretrained(base_path)
28
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
29
- base_path
30
  )
31
 
32
-
33
  def forward(self, x):
34
  input_values = self.feature_extractor(
35
  x, return_tensors="pt", sampling_rate=16000
 
4
  import torch
5
  import torch.nn.functional as F
6
  import soundfile as sf
7
+ #import logging
8
 
9
+ #logging.getLogger("numba").setLevel(logging.WARNING)
10
 
11
  from transformers import (
12
  Wav2Vec2FeatureExtractor,
 
20
 
21
 
22
  class CNHubert(nn.Module):
23
+ def __init__(self):
24
  super().__init__()
25
+ self.model = HubertModel.from_pretrained(cnhubert_base_path)
 
 
26
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
27
+ cnhubert_base_path
28
  )
29
 
 
30
  def forward(self, x):
31
  input_values = self.feature_extractor(
32
  x, return_tensors="pt", sampling_rate=16000
gweight.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ /content/Multi-voice-TTS-GPT-SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
module/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/__init__.cpython-310.pyc and b/module/__pycache__/__init__.cpython-310.pyc differ
 
module/__pycache__/attentions.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/attentions.cpython-310.pyc and b/module/__pycache__/attentions.cpython-310.pyc differ
 
module/__pycache__/commons.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/commons.cpython-310.pyc and b/module/__pycache__/commons.cpython-310.pyc differ
 
module/__pycache__/core_vq.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/core_vq.cpython-310.pyc and b/module/__pycache__/core_vq.cpython-310.pyc differ
 
module/__pycache__/mel_processing.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/mel_processing.cpython-310.pyc and b/module/__pycache__/mel_processing.cpython-310.pyc differ
 
module/__pycache__/models.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/models.cpython-310.pyc and b/module/__pycache__/models.cpython-310.pyc differ
 
module/__pycache__/modules.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/modules.cpython-310.pyc and b/module/__pycache__/modules.cpython-310.pyc differ