diff --git a/AR/__pycache__/__init__.cpython-310.pyc b/AR/__pycache__/__init__.cpython-310.pyc index fd7187ec8b4b16ca3e6e92dd14ea32958b43c9f7..ebd4219806d43b78e59efb28f5acd0218e5deefc 100644 Binary files a/AR/__pycache__/__init__.cpython-310.pyc and b/AR/__pycache__/__init__.cpython-310.pyc differ diff --git a/AR/models/__pycache__/__init__.cpython-310.pyc b/AR/models/__pycache__/__init__.cpython-310.pyc index 5f1e016e879756497207171d82bff9551a47538e..8017dda15e23f3306d40d658723f11101d2c4185 100644 Binary files a/AR/models/__pycache__/__init__.cpython-310.pyc and b/AR/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc b/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc index 6455717a6fcc27a61598fcdc4abe9888ce6a44ea..ec2fc7ee22063cf8eb86080370cdda3da50cbece 100644 Binary files a/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc and b/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc differ diff --git a/AR/models/__pycache__/t2s_model.cpython-310.pyc b/AR/models/__pycache__/t2s_model.cpython-310.pyc index 70939924652e1de9a439787ce9ecfff1808272d4..14459cf978c6776d389f2eb30d2b8a6a904b3580 100644 Binary files a/AR/models/__pycache__/t2s_model.cpython-310.pyc and b/AR/models/__pycache__/t2s_model.cpython-310.pyc differ diff --git a/AR/models/__pycache__/utils.cpython-310.pyc b/AR/models/__pycache__/utils.cpython-310.pyc index f004f390c4256b6a784813dd4d904ea428254531..e41dc40d47b5b8296479b4481ae4326eacb41268 100644 Binary files a/AR/models/__pycache__/utils.cpython-310.pyc and b/AR/models/__pycache__/utils.cpython-310.pyc differ diff --git a/AR/models/t2s_lightning_module.py b/AR/models/t2s_lightning_module.py index 1b602629a1061cbc05525f74e0944e8bce51a2eb..2dd3f392893f1ea08a6e848f2ff2d9be1a425f15 100644 --- a/AR/models/t2s_lightning_module.py +++ b/AR/models/t2s_lightning_module.py @@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.optim import ScaledAdam class Text2SemanticLightningModule(LightningModule): - def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False): + def __init__(self, config, output_dir, is_train=True): super().__init__() self.config = config self.top_k = 3 - self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled) + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) pretrained_s1 = config.get("pretrained_s1") if pretrained_s1 and is_train: # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) diff --git a/AR/models/t2s_model.py b/AR/models/t2s_model.py index dfd6eb0d6871c7e0a16aecee045fea47e1faa71a..c8ad3d8252f08732ffeadbe244dc9dba5fd26f06 100644 --- a/AR/models/t2s_model.py +++ b/AR/models/t2s_model.py @@ -1,9 +1,5 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e -import os, sys -now_dir = os.getcwd() -sys.path.append(now_dir) -from typing import List import torch from tqdm import tqdm @@ -39,144 +35,8 @@ default_config = { } -@torch.jit.script -class T2SMLP: - def __init__(self, w1, b1, w2, b2): - self.w1 = w1 - self.b1 = b1 - self.w2 = w2 - self.b2 = b2 - - def forward(self, x): - x = F.relu(F.linear(x, self.w1, self.b1)) - x = F.linear(x, self.w2, self.b2) - return x - - -@torch.jit.script -class T2SBlock: - def __init__( - self, - num_heads, - hidden_dim: int, - mlp: T2SMLP, - qkv_w, - qkv_b, - out_w, - out_b, - norm_w1, - norm_b1, - norm_eps1, - norm_w2, - norm_b2, - norm_eps2, - ): - self.num_heads = num_heads - self.mlp = mlp - self.hidden_dim: int = hidden_dim - self.qkv_w = qkv_w - self.qkv_b = qkv_b - self.out_w = out_w - self.out_b = out_b - self.norm_w1 = norm_w1 - self.norm_b1 = norm_b1 - self.norm_eps1 = norm_eps1 - self.norm_w2 = norm_w2 - self.norm_b2 = norm_b2 - self.norm_eps2 = norm_eps2 - - def process_prompt(self, x, attn_mask : torch.Tensor): - q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) - - batch_size = q.shape[0] - q_len = q.shape[1] - kv_len = k.shape[1] - - k_cache = k - v_cache = v - - q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) - k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - - attn = F.scaled_dot_product_attention(q, k, v, attn_mask) - - attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) - attn = F.linear(attn, self.out_w, self.out_b) - - x = F.layer_norm( - x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 - ) - x = F.layer_norm( - x + self.mlp.forward(x), - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) - return x, k_cache, v_cache - - def decode_next_token(self, x, k_cache, v_cache): - q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) - - k_cache = torch.cat([k_cache, k], dim=1) - v_cache = torch.cat([v_cache, v], dim=1) - - batch_size = q.shape[0] - q_len = q.shape[1] - kv_len = k_cache.shape[1] - - q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) - k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - - - attn = F.scaled_dot_product_attention(q, k, v) - - attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) - attn = F.linear(attn, self.out_w, self.out_b) - - x = F.layer_norm( - x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 - ) - x = F.layer_norm( - x + self.mlp.forward(x), - [self.hidden_dim], - self.norm_w2, - self.norm_b2, - self.norm_eps2, - ) - return x, k_cache, v_cache - - -@torch.jit.script -class T2STransformer: - def __init__(self, num_blocks : int, blocks: List[T2SBlock]): - self.num_blocks : int = num_blocks - self.blocks = blocks - - def process_prompt( - self, x, attn_mask : torch.Tensor): - k_cache : List[torch.Tensor] = [] - v_cache : List[torch.Tensor] = [] - for i in range(self.num_blocks): - x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask) - k_cache.append(k_cache_) - v_cache.append(v_cache_) - return x, k_cache, v_cache - - def decode_next_token( - self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] - ): - for i in range(self.num_blocks): - x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) - return x, k_cache, v_cache - - class Text2SemanticDecoder(nn.Module): - def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False): + def __init__(self, config, norm_first=False, top_k=3): super(Text2SemanticDecoder, self).__init__() self.model_dim = config["model"]["hidden_dim"] self.embedding_dim = config["model"]["embedding_dim"] @@ -228,47 +88,6 @@ class Text2SemanticDecoder(nn.Module): multidim_average="global", ignore_index=self.EOS, ) - - self.enable_flash_attn(flash_attn_enabled) - - def enable_flash_attn(self, enable:bool=True): - - if not enable: - print("Not Using Flash Attention") - self.infer_panel = self.infer_panel_batch_only - else: - self.infer_panel = self.infer_panel_batch_infer_with_flash_attn - print("Using Flash Attention") - blocks = [] - - for i in range(self.num_layers): - layer = self.h.layers[i] - t2smlp = T2SMLP( - layer.linear1.weight, - layer.linear1.bias, - layer.linear2.weight, - layer.linear2.bias - ) - - block = T2SBlock( - self.num_head, - self.model_dim, - t2smlp, - layer.self_attn.in_proj_weight, - layer.self_attn.in_proj_bias, - layer.self_attn.out_proj.weight, - layer.self_attn.out_proj.bias, - layer.norm1.weight, - layer.norm1.bias, - layer.norm1.eps, - layer.norm2.weight, - layer.norm2.bias, - layer.norm2.eps - ) - - blocks.append(block) - - self.t2s_transformer = T2STransformer(self.num_layers, blocks) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -502,161 +321,7 @@ class Text2SemanticDecoder(nn.Module): # 错位 return targets[:, :-1], targets[:, 1:] - def infer_panel_batch_infer_with_flash_attn( - self, - x, #####全部文本token - x_lens, - prompts, ####参考音频token - bert_feature, - top_k: int = -100, - top_p: int = 100, - early_stop_num: int = -1, - temperature: float = 1.0, - ): - - bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) - x = self.ar_text_embedding(x) - x = x + bert_feature - x = self.ar_text_position(x) - - # AR Decoder - y = prompts - - x_len = x.shape[1] - x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) - stop = False - # print(1111111,self.num_layers) - - k_cache = None - v_cache = None - ################### first step ########################## - if y is not None: - y_emb = self.ar_audio_embedding(y) - y_len = y_emb.shape[1] - prefix_len = y.shape[1] - y_pos = self.ar_audio_position(y_emb) - xy_pos = torch.concat([x, y_pos], dim=1) - ref_free = False - else: - y_emb = None - y_len = 0 - prefix_len = 0 - y_pos = None - xy_pos = x - y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) - ref_free = True - - - ##### create mask ##### - bsz = x.shape[0] - src_len = x_len + y_len - y_lens = torch.LongTensor([y_len]*bsz).to(x.device) - y_mask = make_pad_mask(y_lens) - x_mask = make_pad_mask(x_lens) - - # (bsz, x_len + y_len) - xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) - - x_mask = F.pad( - x_attn_mask, - (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, - ) - y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), - (x_len, 0), - value=False, - ) - - xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device) - # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) - xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len) - xy_attn_mask = xy_mask.logical_or(xy_padding_mask) - xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) - new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) - xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) - - ###### decode ##### - y_list = [None]*y.shape[0] - batch_idx_map = list(range(y.shape[0])) - idx_list = [None]*y.shape[0] - for idx in tqdm(range(1500)): - if idx == 0: - xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) - else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) - - logits = self.ar_predict_layer( - xy_dec[:, -1] - ) - - if idx == 0: - xy_attn_mask = None - logits = logits[:, :-1] - - samples = sample( - logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature - )[0] - - y = torch.concat([y, samples], dim=1) - - ####### 移除batch中已经生成完毕的序列,进一步优化计算量 - reserved_idx_of_batch_for_y = None - if (self.EOS in samples[:, 0]) or \ - (self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS,则停止 - l = samples[:, 0]==self.EOS - removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist() - reserved_idx_of_batch_for_y = torch.where(l==False)[0] - # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] - for i in removed_idx_of_batch_for_y: - batch_index = batch_idx_map[i] - idx_list[batch_index] = idx - 1 - y_list[batch_index] = y[i, :-1] - - batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()] - - # 只保留batch中未生成完毕的序列 - if reserved_idx_of_batch_for_y is not None: - # index = torch.LongTensor(batch_idx_map).to(y.device) - y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - if k_cache is not None : - for i in range(len(k_cache)): - k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) - v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y) - - - if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499: - print("use early stop num:", early_stop_num) - stop = True - for i, batch_index in enumerate(batch_idx_map): - batch_index = batch_idx_map[i] - idx_list[batch_index] = idx - y_list[batch_index] = y[i, :-1] - - if not (None in idx_list): - stop = True - - if stop: - if y.shape[1]==0: - y = torch.concat([y, torch.zeros_like(samples)], dim=1) - print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") - break - - ####################### update next step ################################### - y_emb = self.ar_audio_embedding(y[:, -1:]) - xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) - - if (None in idx_list): - for i in range(x.shape[0]): - if idx_list[i] is None: - idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 - - if ref_free: - return y_list, [0]*x.shape[0] - return y_list, idx_list - - def infer_panel_batch_only( + def infer_panel( self, x, #####全部文本token x_lens, @@ -721,9 +386,7 @@ class Text2SemanticDecoder(nn.Module): x.device ) - y_list = [None]*y.shape[0] - batch_idx_map = list(range(y.shape[0])) - idx_list = [None]*y.shape[0] + for idx in tqdm(range(1500)): xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) @@ -734,45 +397,17 @@ class Text2SemanticDecoder(nn.Module): if(idx==0):###第一次跑不能EOS否则没有了 logits = logits[:, :-1] ###刨除1024终止符号的概率 samples = sample( - logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature - )[0] + logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + )[0].unsqueeze(0) # 本次生成的 semantic_ids 和之前的 y 构成新的 y # print(samples.shape)#[1,1]#第一个1是bs y = torch.concat([y, samples], dim=1) - # 移除已经生成完毕的序列 - reserved_idx_of_batch_for_y = None - if (self.EOS in torch.argmax(logits, dim=-1)) or \ - (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止 - l = samples[:, 0]==self.EOS - removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist() - reserved_idx_of_batch_for_y = torch.where(l==False)[0] - # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] - for i in removed_idx_of_batch_for_y: - batch_index = batch_idx_map[i] - idx_list[batch_index] = idx - 1 - y_list[batch_index] = y[i, :-1] - - batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()] - - # 只保留未生成完毕的序列 - if reserved_idx_of_batch_for_y is not None: - # index = torch.LongTensor(batch_idx_map).to(y.device) - y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - if cache["y_emb"] is not None: - cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y) - if cache["k"] is not None: - for i in range(self.num_layers): - # 因为kv转置了,所以batch dim是1 - cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y) - cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y) - - if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True - - if not (None in idx_list): + + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) stop = True if stop: @@ -808,12 +443,6 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.zeros( (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device ) - - if (None in idx_list): - for i in range(x.shape[0]): - if idx_list[i] is None: - idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 - if ref_free: - return y_list, [0]*x.shape[0] - return y_list, idx_list \ No newline at end of file + return y[:, :-1], 0 + return y[:, :-1], idx-1 diff --git a/AR/models/t2s_model_batch_only.py b/AR/models/t2s_model_batch_only.py deleted file mode 100644 index 8c31f12abc27f9100c71b092132826bf55cf3b0c..0000000000000000000000000000000000000000 --- a/AR/models/t2s_model_batch_only.py +++ /dev/null @@ -1,483 +0,0 @@ -# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py -import torch -from tqdm import tqdm - -from AR.models.utils import make_pad_mask -from AR.models.utils import ( - topk_sampling, - sample, - logits_to_probs, - multinomial_sample_one_no_sync, - dpo_loss, - make_reject_y, - get_batch_logps -) -from AR.modules.embedding import SinePositionalEmbedding -from AR.modules.embedding import TokenEmbedding -from AR.modules.transformer import LayerNorm -from AR.modules.transformer import TransformerEncoder -from AR.modules.transformer import TransformerEncoderLayer -from torch import nn -from torch.nn import functional as F -from torchmetrics.classification import MulticlassAccuracy - -default_config = { - "embedding_dim": 512, - "hidden_dim": 512, - "num_head": 8, - "num_layers": 12, - "num_codebook": 8, - "p_dropout": 0.0, - "vocab_size": 1024 + 1, - "phoneme_vocab_size": 512, - "EOS": 1024, -} - - -class Text2SemanticDecoder(nn.Module): - def __init__(self, config, norm_first=False, top_k=3): - super(Text2SemanticDecoder, self).__init__() - self.model_dim = config["model"]["hidden_dim"] - self.embedding_dim = config["model"]["embedding_dim"] - self.num_head = config["model"]["head"] - self.num_layers = config["model"]["n_layer"] - self.norm_first = norm_first - self.vocab_size = config["model"]["vocab_size"] - self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] - self.p_dropout = config["model"]["dropout"] - self.EOS = config["model"]["EOS"] - self.norm_first = norm_first - assert self.EOS == self.vocab_size - 1 - # should be same as num of kmeans bin - # assert self.EOS == 1024 - self.bert_proj = nn.Linear(1024, self.embedding_dim) - self.ar_text_embedding = TokenEmbedding( - self.embedding_dim, self.phoneme_vocab_size, self.p_dropout - ) - self.ar_text_position = SinePositionalEmbedding( - self.embedding_dim, dropout=0.1, scale=False, alpha=True - ) - self.ar_audio_embedding = TokenEmbedding( - self.embedding_dim, self.vocab_size, self.p_dropout - ) - self.ar_audio_position = SinePositionalEmbedding( - self.embedding_dim, dropout=0.1, scale=False, alpha=True - ) - - self.h = TransformerEncoder( - TransformerEncoderLayer( - d_model=self.model_dim, - nhead=self.num_head, - dim_feedforward=self.model_dim * 4, - dropout=0.1, - batch_first=True, - norm_first=norm_first, - ), - num_layers=self.num_layers, - norm=LayerNorm(self.model_dim) if norm_first else None, - ) - - self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) - self.loss_fct = nn.CrossEntropyLoss(reduction="sum") - - self.ar_accuracy_metric = MulticlassAccuracy( - self.vocab_size, - top_k=top_k, - average="micro", - multidim_average="global", - ignore_index=self.EOS, - ) - - def make_input_data(self, x, x_lens, y, y_lens, bert_feature): - x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1, 2)) - x = self.ar_text_position(x) - x_mask = make_pad_mask(x_lens) - - y_mask = make_pad_mask(y_lens) - y_mask_int = y_mask.type(torch.int64) - codes = y.type(torch.int64) * (1 - y_mask_int) - - # Training - # AR Decoder - y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS) - x_len = x_lens.max() - y_len = y_lens.max() - y_emb = self.ar_audio_embedding(y) - y_pos = self.ar_audio_position(y_emb) - - xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) - - ar_xy_padding_mask = xy_padding_mask - - x_attn_mask = F.pad( - torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), - (0, y_len), - value=True, - ) - - y_attn_mask = F.pad( - torch.triu( - torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), - diagonal=1, - ), - (x_len, 0), - value=False, - ) - - xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) - bsz, src_len = x.shape[0], x_len + y_len - _xy_padding_mask = ( - ar_xy_padding_mask.view(bsz, 1, 1, src_len) - .expand(-1, self.num_head, -1, -1) - .reshape(bsz * self.num_head, 1, src_len) - ) - xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) - new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) - new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) - xy_attn_mask = new_attn_mask - # x 和完整的 y 一次性输入模型 - xy_pos = torch.concat([x, y_pos], dim=1) - - return xy_pos, xy_attn_mask, targets - - def forward(self, x, x_lens, y, y_lens, bert_feature): - """ - x: phoneme_ids - y: semantic_ids - """ - - reject_y, reject_y_lens = make_reject_y(y, y_lens) - - xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature) - - xy_dec, _ = self.h( - (xy_pos, None), - mask=xy_attn_mask, - ) - x_len = x_lens.max() - logits = self.ar_predict_layer(xy_dec[:, x_len:]) - - ###### DPO ############# - reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature) - - reject_xy_dec, _ = self.h( - (reject_xy_pos, None), - mask=reject_xy_attn_mask, - ) - x_len = x_lens.max() - reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:]) - - # loss - # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum - - loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum") - acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item() - - A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) - loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) - - loss = loss_1 + loss_2 - - return loss, acc - - def forward_old(self, x, x_lens, y, y_lens, bert_feature): - """ - x: phoneme_ids - y: semantic_ids - """ - x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1, 2)) - x = self.ar_text_position(x) - x_mask = make_pad_mask(x_lens) - - y_mask = make_pad_mask(y_lens) - y_mask_int = y_mask.type(torch.int64) - codes = y.type(torch.int64) * (1 - y_mask_int) - - # Training - # AR Decoder - y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS) - x_len = x_lens.max() - y_len = y_lens.max() - y_emb = self.ar_audio_embedding(y) - y_pos = self.ar_audio_position(y_emb) - - xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) - ar_xy_padding_mask = xy_padding_mask - - x_attn_mask = F.pad( - torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), - (0, y_len), - value=True, - ) - y_attn_mask = F.pad( - torch.triu( - torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), - diagonal=1, - ), - (x_len, 0), - value=False, - ) - xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) - bsz, src_len = x.shape[0], x_len + y_len - _xy_padding_mask = ( - ar_xy_padding_mask.view(bsz, 1, 1, src_len) - .expand(-1, self.num_head, -1, -1) - .reshape(bsz * self.num_head, 1, src_len) - ) - xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) - new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) - new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) - xy_attn_mask = new_attn_mask - # x 和完整的 y 一次性输入模型 - xy_pos = torch.concat([x, y_pos], dim=1) - xy_dec, _ = self.h( - (xy_pos, None), - mask=xy_attn_mask, - ) - logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) - # loss - # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum - loss = F.cross_entropy(logits, targets, reduction="sum") - acc = self.ar_accuracy_metric(logits.detach(), targets).item() - return loss, acc - - # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 - def infer( - self, - x, - x_lens, - prompts, - bert_feature, - top_k: int = -100, - early_stop_num: int = -1, - temperature: float = 1.0, - ): - x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1, 2)) - x = self.ar_text_position(x) - - # AR Decoder - y = prompts - prefix_len = y.shape[1] - x_len = x.shape[1] - x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) - stop = False - for _ in tqdm(range(1500)): - y_emb = self.ar_audio_embedding(y) - y_pos = self.ar_audio_position(y_emb) - # x 和逐渐增长的 y 一起输入给模型 - xy_pos = torch.concat([x, y_pos], dim=1) - y_len = y.shape[1] - x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), - value=True, - ) - y_attn_mask = F.pad( - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), - (x_len, 0), - value=False, - ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - y.device - ) - - xy_dec, _ = self.h( - (xy_pos, None), - mask=xy_attn_mask, - ) - logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = topk_sampling( - logits, top_k=top_k, top_p=1.0, temperature=temperature - ) - - if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: - print("use early stop num:", early_stop_num) - stop = True - - if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: - # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) - stop = True - if stop: - if prompts.shape[1] == y.shape[1]: - y = torch.concat([y, torch.zeros_like(samples)], dim=1) - print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") - break - # 本次生成的 semantic_ids 和之前的 y 构成新的 y - # print(samples.shape)#[1,1]#第一个1是bs - # import os - # os._exit(2333) - y = torch.concat([y, samples], dim=1) - return y - - def pad_y_eos(self, y, y_mask_int, eos_id): - targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( - y_mask_int, (0, 1), value=1 - ) - # 错位 - return targets[:, :-1], targets[:, 1:] - - def infer_panel( - self, - x, #####全部文本token - x_lens, - prompts, ####参考音频token - bert_feature, - top_k: int = -100, - top_p: int = 100, - early_stop_num: int = -1, - temperature: float = 1.0, - ): - x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1, 2)) - x = self.ar_text_position(x) - - # AR Decoder - y = prompts - - x_len = x.shape[1] - x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) - stop = False - # print(1111111,self.num_layers) - cache = { - "all_stage": self.num_layers, - "k": [None] * self.num_layers, ###根据配置自己手写 - "v": [None] * self.num_layers, - # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 - "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 - # "logits":None,###原版就已经只对结尾求再拼接了,不用管 - # "xy_dec":None,###不需要,本来只需要最后一个做logits - "first_infer": 1, - "stage": 0, - } - ################### first step ########################## - if y is not None: - y_emb = self.ar_audio_embedding(y) - y_len = y_emb.shape[1] - prefix_len = y.shape[1] - y_pos = self.ar_audio_position(y_emb) - xy_pos = torch.concat([x, y_pos], dim=1) - cache["y_emb"] = y_emb - ref_free = False - else: - y_emb = None - y_len = 0 - prefix_len = 0 - y_pos = None - xy_pos = x - y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) - ref_free = True - - x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, - ) - y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), - (x_len, 0), - value=False, - ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - x.device - ) - - y_list = [None]*y.shape[0] - batch_idx_map = list(range(y.shape[0])) - idx_list = [None]*y.shape[0] - for idx in tqdm(range(1500)): - - xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) - logits = self.ar_predict_layer( - xy_dec[:, -1] - ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 - # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) - if(idx==0):###第一次跑不能EOS否则没有了 - logits = logits[:, :-1] ###刨除1024终止符号的概率 - samples = sample( - logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature - )[0] - # 本次生成的 semantic_ids 和之前的 y 构成新的 y - # print(samples.shape)#[1,1]#第一个1是bs - y = torch.concat([y, samples], dim=1) - - # 移除已经生成完毕的序列 - reserved_idx_of_batch_for_y = None - if (self.EOS in torch.argmax(logits, dim=-1)) or \ - (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止 - l = samples[:, 0]==self.EOS - removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist() - reserved_idx_of_batch_for_y = torch.where(l==False)[0] - # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] - for i in removed_idx_of_batch_for_y: - batch_index = batch_idx_map[i] - idx_list[batch_index] = idx - 1 - y_list[batch_index] = y[i, :-1] - - batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()] - - # 只保留未生成完毕的序列 - if reserved_idx_of_batch_for_y is not None: - # index = torch.LongTensor(batch_idx_map).to(y.device) - y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - if cache["y_emb"] is not None: - cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y) - if cache["k"] is not None: - for i in range(self.num_layers): - # 因为kv转置了,所以batch dim是1 - cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y) - cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y) - - - if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: - print("use early stop num:", early_stop_num) - stop = True - - if not (None in idx_list): - # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) - stop = True - if stop: - # if prompts.shape[1] == y.shape[1]: - # y = torch.concat([y, torch.zeros_like(samples)], dim=1) - # print("bad zero prediction") - if y.shape[1]==0: - y = torch.concat([y, torch.zeros_like(samples)], dim=1) - print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") - break - - ####################### update next step ################################### - cache["first_infer"] = 0 - if cache["y_emb"] is not None: - y_emb = torch.cat( - [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 - ) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos[:, -1:] - else: - y_emb = self.ar_audio_embedding(y[:, -1:]) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos - y_len = y_pos.shape[1] - - ###最右边一列(是错的) - # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) - # xy_attn_mask[:,-1]=False - ###最下面一行(是对的) - xy_attn_mask = torch.zeros( - (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device - ) - - if (None in idx_list): - for i in range(x.shape[0]): - if idx_list[i] is None: - idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 - - if ref_free: - return y_list, [0]*x.shape[0] - return y_list, idx_list diff --git a/AR/models/utils.py b/AR/models/utils.py index ce0a98b77bda70c1cf018f92d8d504c90346f7cb..9678c7e13d81a6e885187c86d9d11e49cf707957 100644 --- a/AR/models/utils.py +++ b/AR/models/utils.py @@ -115,17 +115,17 @@ def logits_to_probs( top_p: Optional[int] = None, repetition_penalty: float = 1.0, ): - # if previous_tokens is not None: - # previous_tokens = previous_tokens.squeeze() + if previous_tokens is not None: + previous_tokens = previous_tokens.squeeze() # print(logits.shape,previous_tokens.shape) # pdb.set_trace() if previous_tokens is not None and repetition_penalty != 1.0: previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.gather(logits, dim=0, index=previous_tokens) score = torch.where( score < 0, score * repetition_penalty, score / repetition_penalty ) - logits.scatter_(dim=1, index=previous_tokens, src=score) + logits.scatter_(dim=0, index=previous_tokens, src=score) if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) @@ -133,9 +133,9 @@ def logits_to_probs( torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 ) sorted_indices_to_remove = cum_probs > top_p - sorted_indices_to_remove[:, 0] = False # keep at least one option + sorted_indices_to_remove[0] = False # keep at least one option indices_to_remove = sorted_indices_to_remove.scatter( - dim=1, index=sorted_indices, src=sorted_indices_to_remove + dim=0, index=sorted_indices, src=sorted_indices_to_remove ) logits = logits.masked_fill(indices_to_remove, -float("Inf")) @@ -143,7 +143,7 @@ def logits_to_probs( if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v[: , -1].unsqueeze(-1) + pivot = v.select(-1, -1).unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) diff --git a/AR/modules/__pycache__/__init__.cpython-310.pyc b/AR/modules/__pycache__/__init__.cpython-310.pyc index d2a84a5de2875a90ca49a0367fa7307bdd4438be..e585b39e850271ef14998dbbd5e2e17a512c40ef 100644 Binary files a/AR/modules/__pycache__/__init__.cpython-310.pyc and b/AR/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/AR/modules/__pycache__/activation.cpython-310.pyc b/AR/modules/__pycache__/activation.cpython-310.pyc index 41aef9ee88e40a568adfa39258043e42d7016642..1c33e316f187162f6409d9ac000617bb3e4e7394 100644 Binary files a/AR/modules/__pycache__/activation.cpython-310.pyc and b/AR/modules/__pycache__/activation.cpython-310.pyc differ diff --git a/AR/modules/__pycache__/embedding.cpython-310.pyc b/AR/modules/__pycache__/embedding.cpython-310.pyc index 968f5aae0425e75677848f9288c35ddb4bee6e56..ef5eb2ed67c8058690d1f383c71980eb1b6f3738 100644 Binary files a/AR/modules/__pycache__/embedding.cpython-310.pyc and b/AR/modules/__pycache__/embedding.cpython-310.pyc differ diff --git a/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc b/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc index b4b7fcfcef8e08513237f6f380b507c475a35c75..e4e1efe76ef159b2ec635f2e3ea478503de5b4ae 100644 Binary files a/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc and b/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc differ diff --git a/AR/modules/__pycache__/optim.cpython-310.pyc b/AR/modules/__pycache__/optim.cpython-310.pyc index 28595d4b85b39cab763e4dab573370f0292358e0..e872d2f54611e688d249bf4f63ae8f34f7d8c929 100644 Binary files a/AR/modules/__pycache__/optim.cpython-310.pyc and b/AR/modules/__pycache__/optim.cpython-310.pyc differ diff --git a/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc b/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc index 2814866038e90deb90412810a0d99e19864c6e2d..a1371ca2b0915268c398b458177524fce6867ba8 100644 Binary files a/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc and b/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc differ diff --git a/AR/modules/__pycache__/scaling.cpython-310.pyc b/AR/modules/__pycache__/scaling.cpython-310.pyc index c0b03815a14b7c1d32d936b6c46faca8f0730e6e..34aeb4de13c9c04cac0c479c7141663a5f091128 100644 Binary files a/AR/modules/__pycache__/scaling.cpython-310.pyc and b/AR/modules/__pycache__/scaling.cpython-310.pyc differ diff --git a/AR/modules/__pycache__/transformer.cpython-310.pyc b/AR/modules/__pycache__/transformer.cpython-310.pyc index 74c8cbbbab08be9b34581bf3e89ed7353c788501..265b3d108b65acc5f8bc579cb5414d503944ead8 100644 Binary files a/AR/modules/__pycache__/transformer.cpython-310.pyc and b/AR/modules/__pycache__/transformer.cpython-310.pyc differ diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml deleted file mode 100644 index d1fd8c2b25d497e0cae36dec5a7fde83782dd4dd..0000000000000000000000000000000000000000 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ /dev/null @@ -1,16 +0,0 @@ -custom: - bert_base_path: pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: pretrained_models/chinese-hubert-base - device: cpu - flash_attn_enabled: true - is_half: false - t2s_weights_path: /content/TTS_OWN/MODELS/22/22.ckpt - vits_weights_path: /content/TTS_OWN/MODELS/22/22.pth -default: - bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cpu - flash_attn_enabled: true - is_half: false - t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt - vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth diff --git a/TTS_infer_pack/TTS.py b/TTS_infer_pack/TTS.py deleted file mode 100644 index 6f46cbd23109ed9917aeabef6ec7af6bf0b0f235..0000000000000000000000000000000000000000 --- a/TTS_infer_pack/TTS.py +++ /dev/null @@ -1,848 +0,0 @@ -from copy import deepcopy -import math -import os, sys -import random -import traceback -now_dir = os.getcwd() -sys.path.append(now_dir) -import ffmpeg -import os -from typing import Generator, List, Union -import numpy as np -import torch -import torch.nn.functional as F -import yaml -from transformers import AutoModelForMaskedLM, AutoTokenizer -from timeit import default_timer as timer - -from AR.models.t2s_lightning_module import Text2SemanticLightningModule -from feature_extractor.cnhubert import CNHubert -from module.models import SynthesizerTrn -import librosa -from time import time as ttime -#from tools.i18n.i18n import I18nAuto -from my_utils import load_audio -from module.mel_processing import spectrogram_torch -from TTS_infer_pack.text_segmentation_method import splits -from TTS_infer_pack.TextPreprocessor import TextPreprocessor -#i18n = I18nAuto() -c1='' - -# configs/tts_infer.yaml -""" -default: - device: cpu - is_half: false - bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt - vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - flash_attn_enabled: true - -custom: - device: cuda - is_half: true - bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt - vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - flash_attn_enabled: true - - -""" - -# def set_seed(seed): -# random.seed(seed) -# os.environ['PYTHONHASHSEED'] = str(seed) -# np.random.seed(seed) -# torch.manual_seed(seed) -# torch.cuda.manual_seed(seed) -# torch.cuda.manual_seed_all(seed) -# torch.backends.cudnn.deterministic = True -# torch.backends.cudnn.benchmark = False -# torch.backends.cudnn.enabled = True -# set_seed(1234) - -class TTS_Config: - default_configs={ - "device": "cpu", - "is_half": False, - "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", - "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", - "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", - "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", - "flash_attn_enabled": True - } - configs:dict = None - def __init__(self, configs: Union[dict, str]=None): - - # 设置默认配置文件路径 - configs_base_path:str = "GPT_SoVITS/configs/" - os.makedirs(configs_base_path, exist_ok=True) - self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") - - if configs in ["", None]: - if not os.path.exists(self.configs_path): - self.save_configs() - print(f"Create default config file at {self.configs_path}") - configs:dict = {"default": deepcopy(self.default_configs)} - - if isinstance(configs, str): - self.configs_path = configs - configs:dict = self._load_configs(self.configs_path) - - assert isinstance(configs, dict) - default_configs:dict = configs.get("default", None) - if default_configs is not None: - self.default_configs = default_configs - - self.configs:dict = configs.get("custom", deepcopy(self.default_configs)) - - - self.device = self.configs.get("device", torch.device("cpu")) - self.is_half = self.configs.get("is_half", False) - self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True) - self.t2s_weights_path = self.configs.get("t2s_weights_path", None) - self.vits_weights_path = self.configs.get("vits_weights_path", None) - self.bert_base_path = self.configs.get("bert_base_path", None) - self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) - - - if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): - self.t2s_weights_path = self.default_configs['t2s_weights_path'] - print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") - if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): - self.vits_weights_path = self.default_configs['vits_weights_path'] - print(f"fall back to default vits_weights_path: {self.vits_weights_path}") - if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): - self.bert_base_path = self.default_configs['bert_base_path'] - print(f"fall back to default bert_base_path: {self.bert_base_path}") - if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): - self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path'] - print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") - self.update_configs() - - - self.max_sec = None - self.hz:int = 50 - self.semantic_frame_rate:str = "25hz" - self.segment_size:int = 20480 - self.filter_length:int = 2048 - self.sampling_rate:int = 32000 - self.hop_length:int = 640 - self.win_length:int = 2048 - self.n_speakers:int = 300 - - self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] - # print(self) - - def _load_configs(self, configs_path: str)->dict: - with open(configs_path, 'r') as f: - configs = yaml.load(f, Loader=yaml.FullLoader) - - return configs - - def save_configs(self, configs_path:str=None)->None: - configs={ - "default":self.default_configs, - } - if self.configs is not None: - configs["custom"] = self.update_configs() - - if configs_path is None: - configs_path = self.configs_path - with open(configs_path, 'w') as f: - yaml.dump(configs, f) - - def update_configs(self): - self.config = { - "device" : str(self.device), - "is_half" : self.is_half, - "t2s_weights_path" : self.t2s_weights_path, - "vits_weights_path" : self.vits_weights_path, - "bert_base_path" : self.bert_base_path, - "cnhuhbert_base_path": self.cnhuhbert_base_path, - "flash_attn_enabled" : self.flash_attn_enabled - } - return self.config - - def __str__(self): - self.configs = self.update_configs() - string = "TTS Config".center(100, '-') + '\n' - for k, v in self.configs.items(): - string += f"{str(k).ljust(20)}: {str(v)}\n" - string += "-" * 100 + '\n' - return string - - def __repr__(self): - return self.__str__() - - -class TTS: - def __init__(self, configs: Union[dict, str, TTS_Config]): - if isinstance(configs, TTS_Config): - self.configs = configs - else: - self.configs:TTS_Config = TTS_Config(configs) - - self.t2s_model:Text2SemanticLightningModule = None - self.vits_model:SynthesizerTrn = None - self.bert_tokenizer:AutoTokenizer = None - self.bert_model:AutoModelForMaskedLM = None - self.cnhuhbert_model:CNHubert = None - - self._init_models() - - self.text_preprocessor:TextPreprocessor = \ - TextPreprocessor(self.bert_model, - self.bert_tokenizer, - self.configs.device) - - - self.prompt_cache:dict = { - "ref_audio_path":None, - "prompt_semantic":None, - "refer_spepc":None, - "prompt_text":None, - "prompt_lang":None, - "phones":None, - "bert_features":None, - "norm_text":None, - } - - - self.stop_flag:bool = False - self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32 - - def _init_models(self,): - self.init_t2s_weights(self.configs.t2s_weights_path) - self.init_vits_weights(self.configs.vits_weights_path) - self.init_bert_weights(self.configs.bert_base_path) - self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) - # self.enable_half_precision(self.configs.is_half) - - - - def init_cnhuhbert_weights(self, base_path: str): - print(f"Loading CNHuBERT weights from {base_path}") - self.cnhuhbert_model = CNHubert(base_path) - self.cnhuhbert_model=self.cnhuhbert_model.eval() - self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) - if self.configs.is_half: - self.cnhuhbert_model = self.cnhuhbert_model.half() - - - - def init_bert_weights(self, base_path: str): - print(f"Loading BERT weights from {base_path}") - self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) - self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) - self.bert_model=self.bert_model.eval() - self.bert_model = self.bert_model.to(self.configs.device) - if self.configs.is_half: - self.bert_model = self.bert_model.half() - - - - def init_vits_weights(self, weights_path: str): - - print(f"Loading VITS weights from {weights_path}") - self.configs.vits_weights_path = weights_path - self.configs.save_configs() - dict_s2 = torch.load(weights_path, map_location=self.configs.device) - hps = dict_s2["config"] - self.configs.filter_length = hps["data"]["filter_length"] - self.configs.segment_size = hps["train"]["segment_size"] - self.configs.sampling_rate = hps["data"]["sampling_rate"] - self.configs.hop_length = hps["data"]["hop_length"] - self.configs.win_length = hps["data"]["win_length"] - self.configs.n_speakers = hps["data"]["n_speakers"] - self.configs.semantic_frame_rate = "25hz" - kwargs = hps["model"] - vits_model = SynthesizerTrn( - self.configs.filter_length // 2 + 1, - self.configs.segment_size // self.configs.hop_length, - n_speakers=self.configs.n_speakers, - **kwargs - ) - # if ("pretrained" not in weights_path): - if hasattr(vits_model, "enc_q"): - del vits_model.enc_q - - vits_model = vits_model.to(self.configs.device) - vits_model = vits_model.eval() - vits_model.load_state_dict(dict_s2["weight"], strict=False) - self.vits_model = vits_model - if self.configs.is_half: - self.vits_model = self.vits_model.half() - - - def init_t2s_weights(self, weights_path: str): - print(f"Loading Text2Semantic weights from {weights_path}") - self.configs.t2s_weights_path = weights_path - self.configs.save_configs() - self.configs.hz = 50 - dict_s1 = torch.load(weights_path, map_location=self.configs.device) - config = dict_s1["config"] - self.configs.max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False, - flash_attn_enabled=self.configs.flash_attn_enabled) - t2s_model.load_state_dict(dict_s1["weight"]) - t2s_model = t2s_model.to(self.configs.device) - t2s_model = t2s_model.eval() - self.t2s_model = t2s_model - if self.configs.is_half: - self.t2s_model = self.t2s_model.half() - - def enable_half_precision(self, enable: bool = True): - ''' - To enable half precision for the TTS model. - Args: - enable: bool, whether to enable half precision. - - ''' - if self.configs.device == "cpu" and enable: - print("Half precision is not supported on CPU.") - return - - self.configs.is_half = enable - self.precison = torch.float16 if enable else torch.float32 - self.configs.save_configs() - if enable: - if self.t2s_model is not None: - self.t2s_model =self.t2s_model.half() - if self.vits_model is not None: - self.vits_model = self.vits_model.half() - if self.bert_model is not None: - self.bert_model =self.bert_model.half() - if self.cnhuhbert_model is not None: - self.cnhuhbert_model = self.cnhuhbert_model.half() - else: - if self.t2s_model is not None: - self.t2s_model = self.t2s_model.float() - if self.vits_model is not None: - self.vits_model = self.vits_model.float() - if self.bert_model is not None: - self.bert_model = self.bert_model.float() - if self.cnhuhbert_model is not None: - self.cnhuhbert_model = self.cnhuhbert_model.float() - - def set_device(self, device: torch.device): - ''' - To set the device for all models. - Args: - device: torch.device, the device to use for all models. - ''' - self.configs.device = device - self.configs.save_configs() - if self.t2s_model is not None: - self.t2s_model = self.t2s_model.to(device) - if self.vits_model is not None: - self.vits_model = self.vits_model.to(device) - if self.bert_model is not None: - self.bert_model = self.bert_model.to(device) - if self.cnhuhbert_model is not None: - self.cnhuhbert_model = self.cnhuhbert_model.to(device) - - def set_ref_audio(self, ref_audio_path:str): - ''' - To set the reference audio for the TTS model, - including the prompt_semantic and refer_spepc. - Args: - ref_audio_path: str, the path of the reference audio. - ''' - self._set_prompt_semantic(ref_audio_path) - self._set_ref_spepc(ref_audio_path) - - def _set_ref_spepc(self, ref_audio_path): - audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) - audio = torch.FloatTensor(audio) - audio_norm = audio - audio_norm = audio_norm.unsqueeze(0) - spec = spectrogram_torch( - audio_norm, - self.configs.filter_length, - self.configs.sampling_rate, - self.configs.hop_length, - self.configs.win_length, - center=False, - ) - spec = spec.to(self.configs.device) - if self.configs.is_half: - spec = spec.half() - # self.refer_spepc = spec - self.prompt_cache["refer_spepc"] = spec - - - def _set_prompt_semantic(self, ref_wav_path:str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): - raise OSError("参考音频在3~10秒范围外,请更换!") - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))[ - "last_hidden_state" - ].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - self.prompt_cache["prompt_semantic"] = prompt_semantic - - def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None): - seq = sequences[0] - ndim = seq.dim() - if axis < 0: - axis += ndim - dtype:torch.dtype = seq.dtype - pad_value = torch.tensor(pad_value, dtype=dtype) - seq_lengths = [seq.shape[axis] for seq in sequences] - if max_length is None: - max_length = max(seq_lengths) - else: - max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length - - padded_sequences = [] - for seq, length in zip(sequences, seq_lengths): - padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1) - padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value) - padded_sequences.append(padded_seq) - batch = torch.stack(padded_sequences) - return batch - - def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True): - - _data:list = [] - index_and_len_list = [] - for idx, item in enumerate(data): - norm_text_len = len(item["norm_text"]) - index_and_len_list.append([idx, norm_text_len]) - - batch_index_list = [] - if split_bucket: - index_and_len_list.sort(key=lambda x: x[1]) - index_and_len_list = np.array(index_and_len_list, dtype=np.int64) - - batch_index_list_len = 0 - pos = 0 - while pos =threshold) or (pos_end-pos==1): - batch_index=index_and_len_list[pos:pos_end, 0].tolist() - batch_index_list_len += len(batch_index) - batch_index_list.append(batch_index) - pos = pos_end - break - pos_end=pos_end-1 - - assert batch_index_list_len == len(data) - - else: - for i in range(len(data)): - if i%batch_size == 0: - batch_index_list.append([]) - batch_index_list[-1].append(i) - - - for batch_idx, index_list in enumerate(batch_index_list): - item_list = [data[idx] for idx in index_list] - phones_list = [] - phones_len_list = [] - # bert_features_list = [] - all_phones_list = [] - all_phones_len_list = [] - all_bert_features_list = [] - norm_text_batch = [] - bert_max_len = 0 - phones_max_len = 0 - for item in item_list: - if prompt_data is not None: - all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ - .to(dtype=self.precison) - all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]) - phones = torch.LongTensor(item["phones"]) - # norm_text = prompt_data["norm_text"]+item["norm_text"] - else: - all_bert_features = item["bert_features"]\ - .to(dtype=self.precison) - phones = torch.LongTensor(item["phones"]) - all_phones = phones - # norm_text = item["norm_text"] - - bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) - phones_max_len = max(phones_max_len, phones.shape[-1]) - - phones_list.append(phones) - phones_len_list.append(phones.shape[-1]) - all_phones_list.append(all_phones) - all_phones_len_list.append(all_phones.shape[-1]) - all_bert_features_list.append(all_bert_features) - norm_text_batch.append(item["norm_text"]) - - phones_batch = phones_list - max_len = max(bert_max_len, phones_max_len) - # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) - all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - # all_bert_features_batch = all_bert_features_list - all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison) - for idx, item in enumerate(all_bert_features_list): - all_bert_features_batch[idx, :, : item.shape[-1]] = item - - batch = { - "phones": phones_batch, - "phones_len": torch.LongTensor(phones_len_list), - "all_phones": all_phones_batch, - "all_phones_len": torch.LongTensor(all_phones_len_list), - "all_bert_features": all_bert_features_batch, - "norm_text": norm_text_batch - } - _data.append(batch) - - return _data, batch_index_list - - def recovery_order(self, data:list, batch_index_list:list)->list: - ''' - Recovery the order of the audio according to the batch_index_list. - - Args: - data (List[list(np.ndarray)]): the out of order audio . - batch_index_list (List[list[int]]): the batch index list. - - Returns: - list (List[np.ndarray]): the data in the original order. - ''' - lenght = len(sum(batch_index_list, [])) - _data = [None]*lenght - for i, index_list in enumerate(batch_index_list): - for j, index in enumerate(index_list): - _data[index] = data[i][j] - return _data - - def stop(self,): - ''' - Stop the inference process. - ''' - self.stop_flag = True - - - def run(self, inputs:dict): - """ - Text to speech inference. - - Args: - inputs (dict): - { - "text": "", # str. text to be synthesized - "text_lang: "", # str. language of the text to be synthesized - "ref_audio_path": "", # str. reference audio path - "prompt_text": "", # str. prompt text for the reference audio - "prompt_lang": "", # str. language of the prompt text for the reference audio - "top_k": 5, # int. top k sampling - "top_p": 1, # float. top p sampling - "temperature": 1, # float. temperature for sampling - "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details. - "batch_size": 1, # int. batch size for inference - "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. - "return_fragment": False, # bool. step by step return the audio fragment. - "speed_factor":1.0, # float. control the speed of the synthesized audio. - } - returns: - tulpe[int, np.ndarray]: sampling rate and audio data. - """ - global c1 - c1=timer() - ########## variables initialization ########### - self.stop_flag:bool = False - text:str = inputs.get("text", "") - text_lang:str = inputs.get("text_lang", "") - ref_audio_path:str = inputs.get("ref_audio_path", "") - prompt_text:str = inputs.get("prompt_text", "") - prompt_lang:str = inputs.get("prompt_lang", "") - top_k:int = inputs.get("top_k", 5) - top_p:float = inputs.get("top_p", 1) - temperature:float = inputs.get("temperature", 1) - text_split_method:str = inputs.get("text_split_method", "") - batch_size = inputs.get("batch_size", 1) - batch_threshold = inputs.get("batch_threshold", 0.75) - speed_factor = inputs.get("speed_factor", 1.0) - split_bucket = inputs.get("split_bucket", True) - volume = inputs.get("volume", 1.0) - return_fragment = inputs.get("return_fragment", False) - - if return_fragment: - split_bucket = False - print("分段返回模式已开启") - if split_bucket: - split_bucket = False - print("分段返回模式不支持分桶处理,已自动关闭分桶处理") - - if split_bucket: - print("分桶处理模式已开启") - - - no_prompt_text = False - if prompt_text in [None, ""]: - no_prompt_text = True - - assert text_lang in self.configs.langauges - if not no_prompt_text: - assert prompt_lang in self.configs.langauges - - if ref_audio_path in [None, ""] and \ - ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)): - raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") - - - ###### setting reference audio and prompt text preprocessing ######## - t0 = ttime() - if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): - self.set_ref_audio(ref_audio_path) - - if not no_prompt_text: - prompt_text = prompt_text.strip("\n") - if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." - print("实际输入的参考文本:", prompt_text) - if self.prompt_cache["prompt_text"] != prompt_text: - self.prompt_cache["prompt_text"] = prompt_text - self.prompt_cache["prompt_lang"] = prompt_lang - phones, bert_features, norm_text = \ - self.text_preprocessor.segment_and_extract_feature_for_text( - prompt_text, - prompt_lang) - self.prompt_cache["phones"] = phones - self.prompt_cache["bert_features"] = bert_features - self.prompt_cache["norm_text"] = norm_text - - - ###### text preprocessing ######## - data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) - if len(data) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), - dtype=np.int16) - return - - t1 = ttime() - data, batch_index_list = self.to_batch(data, - prompt_data=self.prompt_cache if not no_prompt_text else None, - batch_size=batch_size, - threshold=batch_threshold, - split_bucket=split_bucket - ) - t2 = ttime() - try: - print("############ 推理 ############") - ###### inference ###### - t_34 = 0.0 - t_45 = 0.0 - audio = [] - for item in data: - t3 = ttime() - batch_phones = item["phones"] - batch_phones_len = item["phones_len"] - all_phoneme_ids = item["all_phones"] - all_phoneme_lens = item["all_phones_len"] - all_bert_features = item["all_bert_features"] - norm_text = item["norm_text"] - - # batch_phones = batch_phones.to(self.configs.device) - batch_phones_len = batch_phones_len.to(self.configs.device) - all_phoneme_ids = all_phoneme_ids.to(self.configs.device) - all_phoneme_lens = all_phoneme_lens.to(self.configs.device) - all_bert_features = all_bert_features.to(self.configs.device) - if self.configs.is_half: - all_bert_features = all_bert_features.half() - - print("前端处理后的文本(每句):", norm_text) - if no_prompt_text : - prompt = None - else: - prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) - - with torch.no_grad(): - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - ) - t4 = ttime() - t_34 += t4 - t3 - - refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\ - .to(dtype=self.precison, device=self.configs.device) - - batch_audio_fragment = [] - - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc - # )) - - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( - all_pred_semantic, _batch_phones,refer_audio_spepc - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] - - - # ## vits串行推理 - # for i, idx in enumerate(idx_list): - # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - # audio_fragment =(self.vits_model.decode( - # _pred_semantic, phones, refer_audio_spepc - # ).detach()[0, 0, :]) - # batch_audio_fragment.append( - # audio_fragment - # ) ###试试重建不带上prompt部分 - - t5 = ttime() - t_45 += t5 - t4 - if return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess([batch_audio_fragment], - self.configs.sampling_rate, - batch_index_list, - speed_factor, - split_bucket,volume) - else: - audio.append(batch_audio_fragment) - - if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), - dtype=np.int16) - return - - if not return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) - yield self.audio_postprocess(audio, - self.configs.sampling_rate, - batch_index_list, - speed_factor, - split_bucket,volume) - except Exception as e: - traceback.print_exc() - # 必须返回一个空音频, 否则会导致显存不释放。 - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) - # 重置模型, 否则会导致显存释放不完全。 - del self.t2s_model - del self.vits_model - self.t2s_model = None - self.vits_model = None - self.init_t2s_weights(self.configs.t2s_weights_path) - self.init_vits_weights(self.configs.vits_weights_path) - finally: - self.empty_cache() - - def empty_cache(self): - try: - if str(self.configs.device) == "cuda": - torch.cuda.empty_cache() - elif str(self.configs.device) == "mps": - torch.mps.empty_cache() - except: - pass - - def audio_postprocess(self, - audio:List[torch.Tensor], - sr:int, - batch_index_list:list=None, - speed_factor:float=1.0, - split_bucket:bool=True, - volume: float = 1.0)->tuple[int, np.ndarray]: - zero_wav = torch.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=self.precison, - device=self.configs.device - ) - - for i, batch in enumerate(audio): - for j, audio_fragment in enumerate(batch): - max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 - if max_audio>1: audio_fragment/=max_audio - audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) - audio_fragment = audio_fragment * volume - audio[i][j] = audio_fragment.cpu().numpy() - - - if split_bucket: - audio = self.recovery_order(audio, batch_index_list) - else: - # audio = [item for batch in audio for item in batch] - audio = sum(audio, []) - - - audio = np.concatenate(audio, 0) - audio = (audio * 32768).astype(np.int16) - - try: - if speed_factor != 1.0: - audio = speed_change(audio, speed=speed_factor, sr=int(sr)) - except Exception as e: - print(f"Failed to change speed of audio: \n{e}") - c2=timer() - print(f'🆗TTS COMPLETE,{round(c2-c1,4)}s') - return sr, audio - - - - -def speed_change(input_audio:np.ndarray, speed:float, sr:int): - # 将 NumPy 数组转换为原始 PCM 流 - raw_audio = input_audio.astype(np.int16).tobytes() - - # 设置 ffmpeg 输入流 - input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1) - - # 变速处理 - output_stream = input_stream.filter('atempo', speed) - - # 输出流到管道 - out, _ = ( - output_stream.output('pipe:', format='s16le', acodec='pcm_s16le') - .run(input=raw_audio, capture_stdout=True, capture_stderr=True) - ) - - # 将管道输出解码为 NumPy 数组 - processed_audio = np.frombuffer(out, np.int16) - - return processed_audio \ No newline at end of file diff --git a/TTS_infer_pack/TextPreprocessor.py b/TTS_infer_pack/TextPreprocessor.py deleted file mode 100644 index 9586cb0619f779104ec5362589420638bd99c215..0000000000000000000000000000000000000000 --- a/TTS_infer_pack/TextPreprocessor.py +++ /dev/null @@ -1,209 +0,0 @@ - -import os, sys - -from tqdm import tqdm -now_dir = os.getcwd() -sys.path.append(now_dir) - -import re -import torch -import LangSegment -from typing import Dict, List, Tuple -from text.cleaner import clean_text -from text import cleaned_text_to_sequence -from transformers import AutoModelForMaskedLM, AutoTokenizer -from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method - -#from tools.i18n.i18n import I18nAuto -#i18n = I18nAuto() - -def get_first(text:str) -> str: - pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" - text = re.split(pattern, text)[0].strip() - return text - -def merge_short_text_in_array(texts:str, threshold:int) -> list: - if (len(texts)) < 2: - return texts - result = [] - text = "" - for ele in texts: - text += ele - if len(text) >= threshold: - result.append(text) - text = "" - if (len(text) > 0): - if len(result) == 0: - result.append(text) - else: - result[len(result) - 1] += text - return result - - - - - - -class TextPreprocessor: - def __init__(self, bert_model:AutoModelForMaskedLM, - tokenizer:AutoTokenizer, device:torch.device): - self.bert_model = bert_model - self.tokenizer = tokenizer - self.device = device - - def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]: - print("############ 切分文本 ############") - texts = self.pre_seg_text(text, lang, text_split_method) - result = [] - print("############ 提取文本Bert特征 ############") - for text in tqdm(texts): - phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang) - if phones is None: - continue - res={ - "phones": phones, - "bert_features": bert_features, - "norm_text": norm_text, - } - result.append(res) - return result - - def pre_seg_text(self, text:str, lang:str, text_split_method:str): - text = text.strip("\n") - if (text[0] not in splits and len(get_first(text)) < 4): - text = "。" + text if lang != "en" else "." + text - print("实际输入的目标文本:") - print(text) - - seg_method = get_seg_method(text_split_method) - text = seg_method(text) - - while "\n\n" in text: - text = text.replace("\n\n", "\n") - - _texts = text.split("\n") - _texts = merge_short_text_in_array(_texts, 5) - texts = [] - - - for text in _texts: - # 解决输入目标文本的空行导致报错的问题 - if (len(text.strip()) == 0): - continue - if (text[-1] not in splits): text += "。" if lang != "en" else "." - - # 解决句子过长导致Bert报错的问题 - if (len(text) > 510): - texts.extend(split_big_text(text)) - else: - texts.append(text) - - print("实际输入的目标文本(切句后):") - print(texts) - return texts - - def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]: - textlist, langlist = self.seg_text(texts, language) - if len(textlist) == 0: - return None, None, None - - phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist) - return phones, bert_features, norm_text - - - def seg_text(self, text:str, language:str)->Tuple[list, list]: - - textlist=[] - langlist=[] - if language in ["auto", "zh", "ja"]: - LangSegment.setfilters(["zh","ja","en","ko"]) - for tmp in LangSegment.getTexts(text): - if tmp["text"] == "": - continue - if tmp["lang"] == "ko": - langlist.append("zh") - elif tmp["lang"] == "en": - langlist.append("en") - else: - # 因无法区别中日文汉字,以用户输入为准 - langlist.append(language if language!="auto" else tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - LangSegment.setfilters(["en"]) - formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) - while " " in formattext: - formattext = formattext.replace(" ", " ") - if formattext != "": - textlist.append(formattext) - langlist.append("en") - - elif language in ["all_zh","all_ja"]: - - formattext = text - while " " in formattext: - formattext = formattext.replace(" ", " ") - language = language.replace("all_","") - if text == "": - return [],[] - textlist.append(formattext) - langlist.append(language) - - else: - raise ValueError(f"language {language} not supported") - - return textlist, langlist - - - def extract_bert_feature(self, textlist:list, langlist:list): - phones_list = [] - bert_feature_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang) - _bert_feature = self.get_bert_inf(phones, word2ph, norm_text, lang) - # phones_list.append(phones) - phones_list.extend(phones) - norm_text_list.append(norm_text) - bert_feature_list.append(_bert_feature) - bert_feature = torch.cat(bert_feature_list, dim=1) - # phones = sum(phones_list, []) - norm_text = ''.join(norm_text_list) - return phones_list, bert_feature, norm_text - - - def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: - with torch.no_grad(): - inputs = self.tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(self.device) - res = self.bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] - assert len(word2ph) == len(text) - phone_level_feature = [] - for i in range(len(word2ph)): - repeat_feature = res[i].repeat(word2ph[i], 1) - phone_level_feature.append(repeat_feature) - phone_level_feature = torch.cat(phone_level_feature, dim=0) - return phone_level_feature.T - - def clean_text_inf(self, text:str, language:str): - phones, word2ph, norm_text = clean_text(text, language) - phones = cleaned_text_to_sequence(phones) - return phones, word2ph, norm_text - - def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str): - language=language.replace("all_","") - if language == "zh": - feature = self.get_bert_feature(norm_text, word2ph).to(self.device) - else: - feature = torch.zeros( - (1024, len(phones)), - dtype=torch.float32, - ).to(self.device) - - return feature - - - - diff --git a/TTS_infer_pack/__init__.py b/TTS_infer_pack/__init__.py deleted file mode 100644 index 74381982fb662196dcb258e56a965b7ebc846f5e..0000000000000000000000000000000000000000 --- a/TTS_infer_pack/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import TTS, text_segmentation_method \ No newline at end of file diff --git a/TTS_infer_pack/__pycache__/TTS.cpython-310.pyc b/TTS_infer_pack/__pycache__/TTS.cpython-310.pyc deleted file mode 100644 index 5a1e9960ef55f9d48945cdbaa0b404fe6d227fd2..0000000000000000000000000000000000000000 Binary files a/TTS_infer_pack/__pycache__/TTS.cpython-310.pyc and /dev/null differ diff --git a/TTS_infer_pack/__pycache__/TextPreprocessor.cpython-310.pyc b/TTS_infer_pack/__pycache__/TextPreprocessor.cpython-310.pyc deleted file mode 100644 index c136d104e2546fb2910c451ae2e92d523a99ebb8..0000000000000000000000000000000000000000 Binary files a/TTS_infer_pack/__pycache__/TextPreprocessor.cpython-310.pyc and /dev/null differ diff --git a/TTS_infer_pack/__pycache__/__init__.cpython-310.pyc b/TTS_infer_pack/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 3f382470f9391cee96f5ff1e6e44cebdf4579b52..0000000000000000000000000000000000000000 Binary files a/TTS_infer_pack/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/TTS_infer_pack/__pycache__/text_segmentation_method.cpython-310.pyc b/TTS_infer_pack/__pycache__/text_segmentation_method.cpython-310.pyc deleted file mode 100644 index b94b81ed3370034de536d7f0ea4a9f319fb929dd..0000000000000000000000000000000000000000 Binary files a/TTS_infer_pack/__pycache__/text_segmentation_method.cpython-310.pyc and /dev/null differ diff --git a/TTS_infer_pack/text_segmentation_method.py b/TTS_infer_pack/text_segmentation_method.py deleted file mode 100644 index fc69b051692e1e6e7bcbda6ff9a65d3cc68ad7bd..0000000000000000000000000000000000000000 --- a/TTS_infer_pack/text_segmentation_method.py +++ /dev/null @@ -1,152 +0,0 @@ - - - - -import re -from typing import Callable -#from tools.i18n.i18n import I18nAuto - -#i18n = I18nAuto() - -METHODS = dict() - -def get_method(name:str)->Callable: - method = METHODS.get(name, None) - if method is None: - raise ValueError(f"Method {name} not found") - return method - -def register_method(name): - def decorator(func): - METHODS[name] = func - return func - return decorator - -splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } - -def split_big_text(text, max_len=510): - # 定义全角和半角标点符号 - punctuation = "".join(splits) - - # 切割文本 - segments = re.split('([' + punctuation + '])', text) - - # 初始化结果列表和当前片段 - result = [] - current_segment = '' - - for segment in segments: - # 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段 - if len(current_segment + segment) > max_len: - result.append(current_segment) - current_segment = segment - else: - current_segment += segment - - # 将最后一个片段加入结果列表 - if current_segment: - result.append(current_segment) - - return result - - - -def split(todo_text): - todo_text = todo_text.replace("……", "。").replace("——", ",") - if todo_text[-1] not in splits: - todo_text += "。" - i_split_head = i_split_tail = 0 - len_text = len(todo_text) - todo_texts = [] - while 1: - if i_split_head >= len_text: - break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 - if todo_text[i_split_head] in splits: - i_split_head += 1 - todo_texts.append(todo_text[i_split_tail:i_split_head]) - i_split_tail = i_split_head - else: - i_split_head += 1 - return todo_texts - - -# 不切 -@register_method("cut0") -def cut0(inp): - return inp - - -# 凑四句一切 -@register_method("cut1") -def cut1(inp): - inp = inp.strip("\n") - inps = split(inp) - split_idx = list(range(0, len(inps), 4)) - split_idx[-1] = None - if len(split_idx) > 1: - opts = [] - for idx in range(len(split_idx) - 1): - opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) - else: - opts = [inp] - return "\n".join(opts) - -# 凑50字一切 -@register_method("cut2") -def cut2(inp): - inp = inp.strip("\n") - inps = split(inp) - if len(inps) < 2: - return inp - opts = [] - summ = 0 - tmp_str = "" - for i in range(len(inps)): - summ += len(inps[i]) - tmp_str += inps[i] - if summ > 50: - summ = 0 - opts.append(tmp_str) - tmp_str = "" - if tmp_str != "": - opts.append(tmp_str) - # print(opts) - if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 - opts[-2] = opts[-2] + opts[-1] - opts = opts[:-1] - return "\n".join(opts) - -# 按中文句号。切 -@register_method("cut3") -def cut3(inp): - inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) - -#按英文句号.切 -@register_method("cut4") -def cut4(inp): - inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) - -# 按标点符号切 -# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py -@register_method("cut5") -def cut5(inp): - # if not re.search(r'[^\w\s]', inp[-1]): - # inp += '。' - inp = inp.strip("\n") - punds = r'[,.;?!、,。?!;:…]' - items = re.split(f'({punds})', inp) - mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] - # 在句子不存在符号或句尾无符号的时候保证文本完整 - if len(items)%2 == 1: - mergeitems.append(items[-1]) - opt = "\n".join(mergeitems) - return opt - - - -if __name__ == '__main__': - method = get_method("cut5") - print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。")) - \ No newline at end of file diff --git a/__pycache__/download.cpython-310.pyc b/__pycache__/download.cpython-310.pyc index 80ae597b4af8427385111e4a58ccbb111a0a94b9..897bd6479ca717c5f4d0724e262642e12e62029b 100644 Binary files a/__pycache__/download.cpython-310.pyc and b/__pycache__/download.cpython-310.pyc differ diff --git a/__pycache__/info.cpython-310.pyc b/__pycache__/info.cpython-310.pyc index 4eeb522064db56c0e3d592393b205bb2886dcb34..17380d5968419014f518cfbab05606e073ab9c81 100644 Binary files a/__pycache__/info.cpython-310.pyc and b/__pycache__/info.cpython-310.pyc differ diff --git a/__pycache__/my_utils.cpython-310.pyc b/__pycache__/my_utils.cpython-310.pyc index a2a95ee0e108e6add70ef60c0a2d39e0dff7870a..886bebf9ed497a75797464e0e58643a6a27252ea 100644 Binary files a/__pycache__/my_utils.cpython-310.pyc and b/__pycache__/my_utils.cpython-310.pyc differ diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc index cb6f7df07b9f5670a053420fd31db131ee9e36cb..581427490d31b5e9dc7c378969d04f8159881504 100644 Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ diff --git a/app.py b/app.py index bd93a8eee5550f3fb9b798451537c48f5c3985e0..341bd8835426377a0d0cb9d7977245b45b42ddcd 100644 --- a/app.py +++ b/app.py @@ -29,8 +29,6 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) logging.getLogger("multipart").setLevel(logging.WARNING) from download import * download() -from TTS_infer_pack.TTS import TTS, TTS_Config -from TTS_infer_pack.text_segmentation_method import get_method if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] @@ -66,90 +64,533 @@ is_half = eval( os.environ.get("is_half", "True" if torch.cuda.is_available() else "False") ) +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +if is_half == True: + bert_model = bert_model.half().to(device) +else: + bert_model = bert_model.to(device) + + +def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + return phone_level_feature.T + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +ssl_model = cnhubert.get_model() +if is_half == True: + ssl_model = ssl_model.half().to(device) +else: + ssl_model = ssl_model.to(device) + + +def change_sovits_weights(sovits_path): + global vq_model, hps + dict_s2 = torch.load(sovits_path, map_location="cpu") + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model + ) + if ("pretrained" not in sovits_path): + del vq_model.enc_q + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) + with open("./sweight.txt", "w", encoding="utf-8") as f: + f.write(sovits_path) + + +change_sovits_weights(sovits_path) + + +def change_gpt_weights(gpt_path): + global hz, max_sec, t2s_model, config + hz = 50 + dict_s1 = torch.load(gpt_path, map_location="cpu") + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + total = sum([param.nelement() for param in t2s_model.parameters()]) + print("Number of parameter: %.2fM" % (total / 1e6)) + with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) + + +change_gpt_weights(gpt_path) + + +def get_spepc(hps, filename): + audio = load_audio(filename, int(hps.data.sampling_rate)) + audio = torch.FloatTensor(audio) + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + return spec + dict_language = { - "中文1": "all_zh", - "English": "en", - "日文1": "all_ja", - "中文": "zh", - "日本語": "ja", - "混合": "auto", + ("中文1"): "all_zh",#全部按中文识别 + ("English"): "en",#全部按英文识别#######不变 + ("日文1"): "all_ja",#全部按日文识别 + ("中文"): "zh",#按中英混合识别####不变 + ("日本語"): "ja",#按日英混合识别####不变 + ("混合"): "auto",#多语种启动切分识别语种 } -cut_method = { - "Do not split/不切":"cut0", - "Split into groups of 4 sentences/四句一切": "cut1", - "Split every 50 characters/50字一切": "cut2", - "Split at CN/JP periods (。)/按中日文句号切": "cut3", - "Split at English periods (.)/按英文句号切": "cut4", - "Split at punctuation marks/按标点切": "cut5", -} +def splite_en_inf(sentence, language): + pattern = re.compile(r'[a-zA-Z ]+') + textlist = [] + langlist = [] + pos = 0 + for match in pattern.finditer(sentence): + start, end = match.span() + if start > pos: + textlist.append(sentence[pos:start]) + langlist.append(language) + textlist.append(sentence[start:end]) + langlist.append("en") + pos = end + if pos < len(sentence): + textlist.append(sentence[pos:]) + langlist.append(language) + # Merge punctuation into previous word + for i in range(len(textlist)-1, 0, -1): + if re.match(r'^[\W_]+$', textlist[i]): + textlist[i-1] += textlist[i] + del textlist[i] + del langlist[i] + # Merge consecutive words with the same language tag + i = 0 + while i < len(langlist) - 1: + if langlist[i] == langlist[i+1]: + textlist[i] += textlist[i+1] + del textlist[i+1] + del langlist[i+1] + else: + i += 1 + + return textlist, langlist + + +def clean_text_inf(text, language): + formattext = "" + language = language.replace("all_","") + for tmp in LangSegment.getTexts(text): + if language == "ja": + if tmp["lang"] == language or tmp["lang"] == "zh": + formattext += tmp["text"] + " " + continue + if tmp["lang"] == language: + formattext += tmp["text"] + " " + while " " in formattext: + formattext = formattext.replace(" ", " ") + phones, word2ph, norm_text = clean_text(formattext, language) + phones = cleaned_text_to_sequence(phones) + return phones, word2ph, norm_text + +dtype=torch.float16 if is_half == True else torch.float32 +def get_bert_inf(phones, word2ph, norm_text, language): + language=language.replace("all_","") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) -tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") -tts_config.device = device -tts_config.is_half = is_half -if gpt_path is not None: - tts_config.t2s_weights_path = gpt_path -if sovits_path is not None: - tts_config.vits_weights_path = sovits_path -if cnhubert_base_path is not None: - tts_config.cnhuhbert_base_path = cnhubert_base_path -if bert_path is not None: - tts_config.bert_base_path = bert_path + return bert - -tts_pipline = TTS(tts_config) -gpt_path = tts_config.t2s_weights_path -sovits_path = tts_config.vits_weights_path - - -def inference(text, text_lang, - ref_audio_path, prompt_text, - prompt_lang, top_k, - top_p, temperature, - text_split_method, batch_size, - speed_factor, ref_text_free, - split_bucket, - volume - ): - - if not duration(ref_audio_path): + +def nonen_clean_text_inf(text, language): + if(language!="auto"): + textlist, langlist = splite_en_inf(text, language) + else: + textlist=[] + langlist=[] + for tmp in LangSegment.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + phones_list = [] + word2ph_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + phones_list.append(phones) + if lang == "zh": + word2ph_list.append(word2ph) + norm_text_list.append(norm_text) + print(word2ph_list) + phones = sum(phones_list, []) + word2ph = sum(word2ph_list, []) + norm_text = ' '.join(norm_text_list) + + return phones, word2ph, norm_text + + +def nonen_get_bert_inf(text, language): + if(language!="auto"): + textlist, langlist = splite_en_inf(text, language) + else: + textlist=[] + langlist=[] + for tmp in LangSegment.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + bert_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + + return bert + + +splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } + + +def get_first(text): + pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" + text = re.split(pattern, text)[0].strip() + return text + + +def get_cleaned_text_final(text,language): + if language in {"en","all_zh","all_ja"}: + phones, word2ph, norm_text = clean_text_inf(text, language) + elif language in {"zh", "ja","auto"}: + phones, word2ph, norm_text = nonen_clean_text_inf(text, language) + return phones, word2ph, norm_text + +def get_bert_final(phones, word2ph, text,language,device): + if language == "en": + bert = get_bert_inf(phones, word2ph, text, language) + elif language in {"zh", "ja","auto"}: + bert = nonen_get_bert_inf(text, language) + elif language == "all_zh": + bert = get_bert_feature(text, word2ph).to(device) + else: + bert = torch.zeros((1024, len(phones))).to(device) + return bert + +def merge_short_text_in_array(texts, threshold): + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if (len(text) > 0): + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + + +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=("Do not split"), volume_scale=1.0): + if not duration(ref_wav_path): return None if text == '': - wprint("Please input text to generate/请输入生成文字") + wprint("Please enter text to generate/请输入生成文字") return None + t0 = ttime() + startTime=timer() text=trim_text(text,text_language) - tts_pipline.init_vits_weights(sovits_path) - tts_pipline.init_t2s_weights(gpt_path) - + change_sovits_weights(sovits_path) + tprint(f'🏕️LOADED SoVITS Model: {sovits_path}') + change_gpt_weights(gpt_path) + tprint(f'🏕️LOADED GPT Model: {gpt_path}') + + prompt_language = dict_language[prompt_language] try: - lang=dict_language[text_lang] - inputs={ - "text": text, - "text_lang": lang, - "ref_audio_path": ref_audio_path, - "prompt_text": prompt_text if not ref_text_free else "", - "prompt_lang": dict_language[prompt_lang], - "top_k": top_k, - "top_p": top_p, - "temperature": temperature, - "text_split_method": cut_method[text_split_method], - "batch_size":int(batch_size), - "speed_factor":float(speed_factor), - "split_bucket":split_bucket, - "volume":volume, - "return_fragment":False, - } - - yield next(tts_pipline.run(inputs)) + text_language = dict_language[text_language] except KeyError as e: - wprint(f'Unsupported language type:{e}') + wprint(f"Unsupported language type: {e}") return None + + prompt_text = prompt_text.strip("\n") + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." + text = text.strip("\n") + if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text + #print(("实际输入的参考文本:"), prompt_text) + #print(("📝实际输入的目标文本:"), text) + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): + errinfo='参考音频在3~10秒范围外,请更换!' + raise OSError((errinfo)) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ + "last_hidden_state" + ].transpose( + 1, 2 + ) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + t1 = ttime() + + phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language) + + if (how_to_cut == ("Split into groups of 4 sentences")): + text = cut1(text) + elif (how_to_cut == ("Split every 50 characters")): + text = cut2(text) + elif (how_to_cut == ("Split at CN/JP periods (。)")): + text = cut3(text) + elif (how_to_cut == ("Split at English periods (.)")): + text = cut4(text) + elif (how_to_cut == ("Split at punctuation marks")): + text = cut5(text) + while "\n\n" in text: + text = text.replace("\n\n", "\n") + print(f"🧨实际输入的目标文本(切句后):{text}\n") + texts = text.split("\n") + texts = merge_short_text_in_array(texts, 5) + audio_opt = [] + bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype) + + for text in texts: + if (len(text.strip()) == 0): + continue + if (text[-1] not in splits): text += "。" if text_language != "en" else "." + print(("\n🎈实际输入的目标文本(每句):"), text) + phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language) + try: + bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype) + except RuntimeError as e: + wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}") + return None + bert = torch.cat([bert1, bert2], 1) + + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + prompt = prompt_semantic.unsqueeze(0).to(device) + t2 = ttime() + with torch.no_grad(): + # pred_semantic = t2s_model.model.infer( + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=config["inference"]["top_k"], + early_stop_num=hz * max_sec, + ) + t3 = ttime() + # print(pred_semantic.shape,idx) + pred_semantic = pred_semantic[:, -idx:].unsqueeze( + 0 + ) # .unsqueeze(0)#mq要多unsqueeze一次 + refer = get_spepc(hps, ref_wav_path) # .to(device) + if is_half == True: + refer = refer.half().to(device) + else: + refer = refer.to(device) + # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] + try: + audio = ( + vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer + ) + .detach() + .cpu() + .numpy()[0, 0] + ) + except RuntimeError as e: + wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}") + return None + + max_audio=np.abs(audio).max() + if max_audio>1:audio/=max_audio + audio_opt.append(audio) + audio_opt.append(zero_wav) + t4 = ttime() + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + #yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) + audio_data = (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) + + audio_data = (audio_data.astype(np.float32) * volume_scale).astype(np.int16) + output_wav = "output_audio.wav" + sf.write(output_wav, audio_data, hps.data.sampling_rate) + endTime=timer() + tprint(f'🆗TTS COMPLETE,{round(endTime-startTime,4)}s') + return output_wav + +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) + else: + opts = [inp] + return "\n".join(opts) + + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + return "\n".join(opts) + + +def cut3(inp): + inp = inp.strip("\n") + return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) + + +def cut4(inp): + inp = inp.strip("\n") + return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) + + +# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py +def cut5(inp): + # if not re.search(r'[^\w\s]', inp[-1]): + # inp += '。' + inp = inp.strip("\n") + punds = r'[,.;?!、,。?!;:…]' + items = re.split(f'({punds})', inp) + mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] + if len(items)%2 == 1: + mergeitems.append(items[-1]) + opt = "\n".join(mergeitems) + return opt + + + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split('(\d+)', s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts #==========custom functions============ -splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } def tprint(text): now=datetime.now(tz).strftime('%H:%M:%S') print(f'UTC+8 - {now} - {text}') @@ -197,7 +638,7 @@ def trim_text(text,language): return ' '.join(words[:i+1]) return ' '.join(words[:limit_en]) - else: + else:#中文日文 if len(text) <= limit_cj: return text for i in range(limit_cj, -1, -1): @@ -222,11 +663,10 @@ def duration(audio_file_path): return False def update_model(choice): - global gpt_path,sovits_path + global gpt_path, sovits_path model_info = models[choice] gpt_path = abs_path(model_info["gpt_weight"]) sovits_path = abs_path(model_info["sovits_weight"]) - model_name = choice tone_info = model_info["tones"]["tone1"] tone_sample_path = abs_path(tone_info["sample"]) @@ -268,7 +708,7 @@ def transcribe(voice): time2=timer() tprint(f'transcribe COMPLETE,{round(time2-time1,4)}s') - tprint(f' \nTranscribe result:\n 🔣Language:{language} \n 🔣Text:{text}' ) + tprint(f'\nTRANSCRIBE RESULT:\n 🔣Language:{language} \n 🔣Text:{text}' ) return text,language def clone_voice(user_voice,user_text,user_lang): @@ -278,36 +718,29 @@ def clone_voice(user_voice,user_text,user_lang): wprint("Please enter text to generate/请输入生成文字") return None user_text=trim_text(user_text,user_lang) - #global gpt_path, sovits_path + time1=timer() + global gpt_path, sovits_path gpt_path = abs_path("pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt") #tprint(f'Model loaded:{gpt_path}') sovits_path = abs_path("pretrained_models/s2G488k.pth") #tprint(f'Model loaded:{sovits_path}') try: - prompt_text, prompt_lang = transcribe(user_voice) + prompt_text, prompt_language = transcribe(user_voice) except UnboundLocalError as e: wprint(f"The language in the audio cannot be recognized :{str(e)}") return None - tts_pipline.init_vits_weights(sovits_path) - tts_pipline.init_t2s_weights(gpt_path) - inputs={ - "text": user_text, - "text_lang": dict_language[user_lang], - "ref_audio_path": user_voice, - "prompt_text": prompt_text, - "prompt_lang": dict_language[prompt_lang], - "top_k": 5, - "top_p": 1, - "temperature": 1, - "text_split_method": "cut1", - "batch_size":20, - "speed_factor":1.0, - "split_bucket":True, - "volume":1.0, - "return_fragment":False, - } - - yield next(tts_pipline.run(inputs)) + + output_wav = get_tts_wav( + user_voice, + prompt_text, + prompt_language, + user_text, + user_lang, + how_to_cut="Do not split", + volume_scale=1.0) + time2=timer() + tprint(f'🆗CLONE COMPLETE,{round(time2-time1,4)}s') + return output_wav with open('dummy') as f: dummy_txt = f.read().strip().splitlines() @@ -395,26 +828,15 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app: with gr.Accordion(label="Additional generation options/附加生成选项", open=False): - with gr.Row(): - how_to_cut = gr.Dropdown( - label=("How to split input text?/如何对输入文字切片"), - choices=[("Do not split/不切"), ("Split into groups of 4 sentences/四句一切"), ("Split every 50 characters/50字一切"), - ("Split at CN/JP periods (。)/按中日文句号切"), ("Split at English periods (.)/按英文句号切"), ("Split at punctuation marks/按标点切"), ], - value=("Split into groups of 4 sentences/四句一切"), + how_to_cut = gr.Dropdown( + label=("How to split?"), + choices=[("Do not split"), ("Split into groups of 4 sentences"), ("Split every 50 characters"), + ("Split at CN/JP periods (。)"), ("Split at English periods (.)"), ("Split at punctuation marks"), ], + value=("Split into groups of 4 sentences"), interactive=True, - info='A suitable splitting method can achieve better generation results/适合的切片方法会得到更好的效果' + info='A suitable splitting method can achieve better generation results' ) - split_bucket = gr.Checkbox(label="Split bucket/数据分桶", value=True, info='Speed up the inference process/提升推理速度') - with gr.Row(): - volume = gr.Slider(minimum=0.5, maximum=5, value=1, step=0.1, label='Volume/音量',info='audio distortion due to excessive volume/大了要爆音') - speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="Speed factor",value=1.0,info='Playback speed/播放速度') - batch_size = gr.Slider(minimum=1,maximum=100,step=1,label="Batch size",value=20,info='The number of sentences for batch inference./并行推理的句子数量') - with gr.Row(): - top_k = gr.Slider(minimum=1,maximum=100,step=1,label="top_k",value=5) - top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label="top_p",value=1) - temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label="temperature",value=1) - ref_text_free = gr.Checkbox(label="REF_TEXT_FREE", value=False, visible=False) - + volume = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.01, label='Volume/音量') gr.HTML(''' @@ -441,8 +863,7 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app: user_text= gr.Textbox(label="Text for generation/输入想要生成语音的文字", lines=5,placeholder=plsh,info=limit) dddice= gr.Button('🎲', variant='tool',min_width=0,scale=0) - dddice.click(dice, outputs=[user_text, dddice]) - + dddice.click(dice, outputs=[user_text, dddice]) user_text.change( lang_detector, user_text, user_lang) user_button = gr.Button("✨Clone Voice", variant="primary") @@ -456,23 +877,9 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app: tone_select.change(update_tone, inputs=[model_name, tone_select], outputs=[inp_ref, prompt_text, tone_sample]) main_button.click( - inference, - inputs=[text, - text_language, - inp_ref, - prompt_text, - prompt_language, - top_k, - top_p, - temperature, - how_to_cut, - batch_size, - speed_factor, - ref_text_free, - split_bucket, - volume], - outputs=[output] - ) + get_tts_wav, + inputs=[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,volume], + outputs=[output]) user_button.click( clone_voice, diff --git a/GPT_SoVITS/configs/s1.yaml b/configs/s1.yaml similarity index 100% rename from GPT_SoVITS/configs/s1.yaml rename to configs/s1.yaml diff --git a/GPT_SoVITS/configs/s1big.yaml b/configs/s1big.yaml similarity index 100% rename from GPT_SoVITS/configs/s1big.yaml rename to configs/s1big.yaml diff --git a/GPT_SoVITS/configs/s1big2.yaml b/configs/s1big2.yaml similarity index 100% rename from GPT_SoVITS/configs/s1big2.yaml rename to configs/s1big2.yaml diff --git a/GPT_SoVITS/configs/s1longer.yaml b/configs/s1longer.yaml similarity index 100% rename from GPT_SoVITS/configs/s1longer.yaml rename to configs/s1longer.yaml diff --git a/GPT_SoVITS/configs/s1mq.yaml b/configs/s1mq.yaml similarity index 100% rename from GPT_SoVITS/configs/s1mq.yaml rename to configs/s1mq.yaml diff --git a/GPT_SoVITS/configs/s2.json b/configs/s2.json similarity index 100% rename from GPT_SoVITS/configs/s2.json rename to configs/s2.json diff --git a/GPT_SoVITS/configs/train.yaml b/configs/train.yaml similarity index 100% rename from GPT_SoVITS/configs/train.yaml rename to configs/train.yaml diff --git a/feature_extractor/__pycache__/__init__.cpython-310.pyc b/feature_extractor/__pycache__/__init__.cpython-310.pyc index cf440a24d2590c466a33eadf7b4a89ba2834ca6f..f8fb609b79b3fca6909c0794427e19637f51234a 100644 Binary files a/feature_extractor/__pycache__/__init__.cpython-310.pyc and b/feature_extractor/__pycache__/__init__.cpython-310.pyc differ diff --git a/feature_extractor/__pycache__/cnhubert.cpython-310.pyc b/feature_extractor/__pycache__/cnhubert.cpython-310.pyc index 0736dfd47b86259c65debf60d058fd0d9d5d18ba..3f42f3baeeeee5c23c2e2142897e2e42b326b375 100644 Binary files a/feature_extractor/__pycache__/cnhubert.cpython-310.pyc and b/feature_extractor/__pycache__/cnhubert.cpython-310.pyc differ diff --git a/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc b/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc index b2758d6e2eb0810f156469146ef654102ebbb941..9d496437c5cb4437a515aca8bc477464287ab73f 100644 Binary files a/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc and b/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc differ diff --git a/feature_extractor/cnhubert.py b/feature_extractor/cnhubert.py index 7dffbdb22bddd6092a9e4c76c998e92dd3f90166..2e150e5c4a481ddd3137856fd92b3f432e0db75c 100644 --- a/feature_extractor/cnhubert.py +++ b/feature_extractor/cnhubert.py @@ -4,9 +4,9 @@ import librosa import torch import torch.nn.functional as F import soundfile as sf -import logging +#import logging -logging.getLogger("numba").setLevel(logging.WARNING) +#logging.getLogger("numba").setLevel(logging.WARNING) from transformers import ( Wav2Vec2FeatureExtractor, @@ -20,16 +20,13 @@ cnhubert_base_path = None class CNHubert(nn.Module): - def __init__(self, base_path:str=None): + def __init__(self): super().__init__() - if base_path is None: - base_path = cnhubert_base_path - self.model = HubertModel.from_pretrained(base_path) + self.model = HubertModel.from_pretrained(cnhubert_base_path) self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - base_path + cnhubert_base_path ) - def forward(self, x): input_values = self.feature_extractor( x, return_tensors="pt", sampling_rate=16000 diff --git a/gweight.txt b/gweight.txt new file mode 100644 index 0000000000000000000000000000000000000000..1339f1afd0afdf04d63f7dfcd28a413daa591da1 --- /dev/null +++ b/gweight.txt @@ -0,0 +1 @@ +/content/Multi-voice-TTS-GPT-SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt \ No newline at end of file diff --git a/module/__pycache__/__init__.cpython-310.pyc b/module/__pycache__/__init__.cpython-310.pyc index 7e5e75a923388a0753c1dc59135fb96f05c620b2..7ef77fd49527bcf047989fddfc099776013aa8c4 100644 Binary files a/module/__pycache__/__init__.cpython-310.pyc and b/module/__pycache__/__init__.cpython-310.pyc differ diff --git a/module/__pycache__/attentions.cpython-310.pyc b/module/__pycache__/attentions.cpython-310.pyc index ac46216c26f9aed000cfcaeef1afdde5e11b2c0b..fb3d4ac243b70c738040f4f700bec8c9105962fa 100644 Binary files a/module/__pycache__/attentions.cpython-310.pyc and b/module/__pycache__/attentions.cpython-310.pyc differ diff --git a/module/__pycache__/commons.cpython-310.pyc b/module/__pycache__/commons.cpython-310.pyc index 53fd3a938b3b677fd0682c01f0e92a0a67ae1e2f..69c515ae4729d3cf0e92c7cef4f959c1b47fc983 100644 Binary files a/module/__pycache__/commons.cpython-310.pyc and b/module/__pycache__/commons.cpython-310.pyc differ diff --git a/module/__pycache__/core_vq.cpython-310.pyc b/module/__pycache__/core_vq.cpython-310.pyc index afcb9b2b53c8bddd11440e87647c531dfd83124d..562eaf130dd5e9c5ae8c17c54e4cac36e25c7419 100644 Binary files a/module/__pycache__/core_vq.cpython-310.pyc and b/module/__pycache__/core_vq.cpython-310.pyc differ diff --git a/module/__pycache__/mel_processing.cpython-310.pyc b/module/__pycache__/mel_processing.cpython-310.pyc index a997560c0a25264286d85261f9958caaa7599bec..177c71c8efa27019bf1a28d1633978fed492ab3b 100644 Binary files a/module/__pycache__/mel_processing.cpython-310.pyc and b/module/__pycache__/mel_processing.cpython-310.pyc differ diff --git a/module/__pycache__/models.cpython-310.pyc b/module/__pycache__/models.cpython-310.pyc index 16790944e754c215ba6a6f65411b0205bfb5a341..2657b326c2b6cebf8620d1dafa012e952925ba88 100644 Binary files a/module/__pycache__/models.cpython-310.pyc and b/module/__pycache__/models.cpython-310.pyc differ diff --git a/module/__pycache__/modules.cpython-310.pyc b/module/__pycache__/modules.cpython-310.pyc index 1dca7c9e2660c4994ad6832c5e6704de2e833af3..fa5886e56500c4a727dfb1fe219c8e0c8c31e3de 100644 Binary files a/module/__pycache__/modules.cpython-310.pyc and b/module/__pycache__/modules.cpython-310.pyc differ diff --git a/module/__pycache__/mrte_model.cpython-310.pyc b/module/__pycache__/mrte_model.cpython-310.pyc index 30c796fa945d66291486d4c7ac64ee96144a7ceb..8b13a4d56180b457ff66114d11fe8a4e8e635570 100644 Binary files a/module/__pycache__/mrte_model.cpython-310.pyc and b/module/__pycache__/mrte_model.cpython-310.pyc differ diff --git a/module/__pycache__/quantize.cpython-310.pyc b/module/__pycache__/quantize.cpython-310.pyc index d8388af0851351a4d1a7ec95f6734af985f422a9..679444bebbb029657ad1e2b268d241cbb70e0000 100644 Binary files a/module/__pycache__/quantize.cpython-310.pyc and b/module/__pycache__/quantize.cpython-310.pyc differ diff --git a/module/__pycache__/transforms.cpython-310.pyc b/module/__pycache__/transforms.cpython-310.pyc index e61ad4119d04d457122f02a1f249ca9c3bb5e968..99d493d8af1eeb65a82e0b483ae0002209fd6bf0 100644 Binary files a/module/__pycache__/transforms.cpython-310.pyc and b/module/__pycache__/transforms.cpython-310.pyc differ diff --git a/module/models.py b/module/models.py index 75bc61778bc54c8c70a09129c017ce75b558305a..c99485cf352294a90b47aa85ba9989e3eb4da728 100644 --- a/module/models.py +++ b/module/models.py @@ -1,6 +1,5 @@ import copy import math -from typing import List import torch from torch import nn from torch.nn import functional as F @@ -229,7 +228,6 @@ class TextEncoder(nn.Module): ) y = self.ssl_proj(y * y_mask) * y_mask - y = self.encoder_ssl(y * y_mask, y_mask) text_mask = torch.unsqueeze( @@ -960,13 +958,11 @@ class SynthesizerTrn(nn.Module): @torch.no_grad() def decode(self, codes, text, refer, noise_scale=0.5): - ge = None - if refer is not None: - refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) - refer_mask = torch.unsqueeze( - commons.sequence_mask(refer_lengths, refer.size(2)), 1 - ).to(refer.dtype) - ge = self.ref_enc(refer * refer_mask, refer_mask) + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze( + commons.sequence_mask(refer_lengths, refer.size(2)), 1 + ).to(refer.dtype) + ge = self.ref_enc(refer * refer_mask, refer_mask) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) @@ -986,55 +982,6 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o - - - @torch.no_grad() - def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5): - ge = None - if refer is not None: - refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) - refer_mask = torch.unsqueeze( - commons.sequence_mask(refer_lengths, refer.size(2)), 1 - ).to(refer.dtype) - ge = self.ref_enc(refer * refer_mask, refer_mask) - - # y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to( - # codes.dtype - # ) - y_lengths = (y_lengths * 2).long().to(codes.device) - text_lengths = text_lengths.long().to(text.device) - # y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) - # text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) - - # 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题? - quantized = self.quantizer.decode(codes) - if self.semantic_frame_rate == "25hz": - quantized = F.interpolate( - quantized, size=int(quantized.shape[-1] * 2), mode="nearest" - ) - - x, m_p, logs_p, y_mask = self.enc_p( - quantized, y_lengths, text, text_lengths, ge - ) - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - - z = self.flow(z_p, y_mask, g=ge, reverse=True) - z_masked = (z * y_mask)[:, :, :] - - # 串行。把padding部分去掉再decode - o_list:List[torch.Tensor] = [] - for i in range(z_masked.shape[0]): - z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0) - o = self.dec(z_slice, g=ge)[0, 0, :].detach() - o_list.append(o) - - # 并行(会有问题)。先decode,再把padding的部分去掉 - # o = self.dec(z_masked, g=ge) - # upsample_rate = int(math.prod(self.upsample_rates)) - # o_lengths = y_lengths*upsample_rate - # o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)] - - return o_list def extract_latent(self, x): ssl = self.ssl_proj(x) diff --git a/output_audio.wav b/output_audio.wav new file mode 100644 index 0000000000000000000000000000000000000000..6476366edcbb37122c5d3c0052ff2515e0102886 --- /dev/null +++ b/output_audio.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b97115a0643627c7837d43c7a06496cff5cb85f72f2b540d282ded656808700 +size 311084 diff --git a/sweight.txt b/sweight.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3d4392707e186ee4e4321da17ebe3e8e6449293 --- /dev/null +++ b/sweight.txt @@ -0,0 +1 @@ +/content/Multi-voice-TTS-GPT-SoVITS/pretrained_models/s2G488k.pth \ No newline at end of file diff --git a/text/__pycache__/__init__.cpython-310.pyc b/text/__pycache__/__init__.cpython-310.pyc index 4987782533dc2933081e2298978d2403ca7e1435..352df345f285b8784f27fa27d654c481dc9b615e 100644 Binary files a/text/__pycache__/__init__.cpython-310.pyc and b/text/__pycache__/__init__.cpython-310.pyc differ diff --git a/text/__pycache__/chinese.cpython-310.pyc b/text/__pycache__/chinese.cpython-310.pyc index 06c21c86e6880df19aa1b9c3ff2ffe57f37d2069..01a8f7ce975e7cb937960aa6d9b745c8ab9dd528 100644 Binary files a/text/__pycache__/chinese.cpython-310.pyc and b/text/__pycache__/chinese.cpython-310.pyc differ diff --git a/text/__pycache__/cleaner.cpython-310.pyc b/text/__pycache__/cleaner.cpython-310.pyc index d00fc61a10e1ea5aae34210a2748a10a0462f0a4..ed4b30008dd1d8275da56ac31d6519bffca2b18c 100644 Binary files a/text/__pycache__/cleaner.cpython-310.pyc and b/text/__pycache__/cleaner.cpython-310.pyc differ diff --git a/text/__pycache__/english.cpython-310.pyc b/text/__pycache__/english.cpython-310.pyc index cf0665012d1d72ab29977e17637a6eeb00a46922..fe29de50bac6ad56a300c06618491b6bad0bc3d9 100644 Binary files a/text/__pycache__/english.cpython-310.pyc and b/text/__pycache__/english.cpython-310.pyc differ diff --git a/text/__pycache__/japanese.cpython-310.pyc b/text/__pycache__/japanese.cpython-310.pyc index 88e29931718b529a849a7b891511e1955961e2ec..3536efa13c4c736a850617422e41f3bd45f1928a 100644 Binary files a/text/__pycache__/japanese.cpython-310.pyc and b/text/__pycache__/japanese.cpython-310.pyc differ diff --git a/text/__pycache__/symbols.cpython-310.pyc b/text/__pycache__/symbols.cpython-310.pyc index f46b1bf78262e9b3306fbb4c3f4b5f981f5d0218..136a61fff33a386f6bd0e657e3243007742e3e15 100644 Binary files a/text/__pycache__/symbols.cpython-310.pyc and b/text/__pycache__/symbols.cpython-310.pyc differ diff --git a/text/__pycache__/tone_sandhi.cpython-310.pyc b/text/__pycache__/tone_sandhi.cpython-310.pyc index e736b3c4e71caac14e1670152d1f81527926ca7f..64455f18682b6891880e94e1f35cf9ad94d00c6d 100644 Binary files a/text/__pycache__/tone_sandhi.cpython-310.pyc and b/text/__pycache__/tone_sandhi.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/__init__.cpython-310.pyc b/text/zh_normalization/__pycache__/__init__.cpython-310.pyc index 06d9aea88c760d2acc34bcc2e99d1eb567c4e432..724bfda7744dde8694bb6c69eaf0dca7ebb0eff1 100644 Binary files a/text/zh_normalization/__pycache__/__init__.cpython-310.pyc and b/text/zh_normalization/__pycache__/__init__.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/char_convert.cpython-310.pyc b/text/zh_normalization/__pycache__/char_convert.cpython-310.pyc index 8a5d6eb45cf11ac9377b583c8f9c1168a900315a..5d99e7da9c0e315a98b0a9f15c9d5844d5ccef00 100644 Binary files a/text/zh_normalization/__pycache__/char_convert.cpython-310.pyc and b/text/zh_normalization/__pycache__/char_convert.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/chronology.cpython-310.pyc b/text/zh_normalization/__pycache__/chronology.cpython-310.pyc index e246bd88e8f9440ea46826abff5dc47033b0b3d5..d281a6837b283ac39160187bab1e1f683d56092b 100644 Binary files a/text/zh_normalization/__pycache__/chronology.cpython-310.pyc and b/text/zh_normalization/__pycache__/chronology.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/constants.cpython-310.pyc b/text/zh_normalization/__pycache__/constants.cpython-310.pyc index ecd7e5b126da8469acf929a87c15d70c8255c0e2..e6633f14a2531f4f6f527ef3293a3926dab7e82e 100644 Binary files a/text/zh_normalization/__pycache__/constants.cpython-310.pyc and b/text/zh_normalization/__pycache__/constants.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/num.cpython-310.pyc b/text/zh_normalization/__pycache__/num.cpython-310.pyc index ea6e2e9ca69edcade0a8f0a5862c58afa375adcd..e64986cbffb39effe145712dcd766f2556360d81 100644 Binary files a/text/zh_normalization/__pycache__/num.cpython-310.pyc and b/text/zh_normalization/__pycache__/num.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/phonecode.cpython-310.pyc b/text/zh_normalization/__pycache__/phonecode.cpython-310.pyc index eaa92e116ce1e47f459498e3114af846620f5621..3c1aa496fbbe15f5048d35e7830bd693135968ab 100644 Binary files a/text/zh_normalization/__pycache__/phonecode.cpython-310.pyc and b/text/zh_normalization/__pycache__/phonecode.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/quantifier.cpython-310.pyc b/text/zh_normalization/__pycache__/quantifier.cpython-310.pyc index 8f9d379e87e513efac25fa6c8bd01e080543e45f..fefb83fb30142655b38c3667b1bd9874a5e0f352 100644 Binary files a/text/zh_normalization/__pycache__/quantifier.cpython-310.pyc and b/text/zh_normalization/__pycache__/quantifier.cpython-310.pyc differ diff --git a/text/zh_normalization/__pycache__/text_normlization.cpython-310.pyc b/text/zh_normalization/__pycache__/text_normlization.cpython-310.pyc index 327cda021bdfb51a717f57c679e15bcfe566e4ed..c042b1617e026b4c21c55016290f261ba96435ad 100644 Binary files a/text/zh_normalization/__pycache__/text_normlization.cpython-310.pyc and b/text/zh_normalization/__pycache__/text_normlization.cpython-310.pyc differ