Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 ByteDance and/or its affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import random | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.nn import Linear | |
| from tqdm import tqdm | |
| from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm | |
| from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb | |
| from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer | |
| from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding | |
| from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder | |
| FS_ENCODERS = { | |
| 'rel_fft': lambda hp, dict_size: RelTransformerEncoder( | |
| dict_size, hp['hidden_size'], hp['hidden_size'], | |
| hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'], | |
| hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']), | |
| } | |
| def fill_with_neg_inf2(t): | |
| """FP16-compatible function that fills a tensor with -inf.""" | |
| return t.float().fill_(-1e8).type_as(t) | |
| def expand_states(h, mel2token): | |
| h = F.pad(h, [0, 0, 1, 0]) | |
| mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) | |
| h = torch.gather(h, 1, mel2token_) # [B, T, H] | |
| return h | |
| class CodePredictor(nn.Module): | |
| def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size): | |
| super().__init__() | |
| self.hparams = deepcopy(hparams) | |
| self.hparams['hidden_size'] = hidden_size | |
| self.hidden_size = hidden_size | |
| char_dict_size = hparams.get('char_dict_size', 4000) | |
| if not hparams.get('lm_use_enc'): | |
| self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0) | |
| if hparams.get('mega_use_char', True): | |
| self.char_encoder = nn.Embedding(char_dict_size, | |
| self.hidden_size, padding_idx=0) | |
| else: | |
| self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size) | |
| if hparams.get('mega_use_char', True): | |
| self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size) | |
| if hparams['use_ph_pos_embed']: | |
| self.ph_pos_embed = PosEmb(self.hidden_size) | |
| self.char_empty_embed = nn.Embedding(1, self.hidden_size) | |
| if hparams.get('use_bert_input'): | |
| self.bert_input_proj = nn.Linear(768, self.hidden_size) | |
| self.ling_label_embed_layers = nn.ModuleDict() | |
| for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']): | |
| self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0) | |
| self.dec_hidden_size = dec_hidden_size | |
| self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size) | |
| self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0) | |
| self.use_pos_embed = hparams.get('use_pos_embed', False) | |
| if self.use_pos_embed: | |
| self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024) | |
| self.use_post_ln = hparams.get('use_post_ln', False) | |
| self.layers = None | |
| if not self.use_post_ln: | |
| self.layer_norm = LayerNorm(dec_hidden_size) | |
| self.code_size = code_size | |
| self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True) | |
| def forward_ling_encoder( | |
| self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre): | |
| ph_tokens = txt_tokens | |
| hparams = self.hparams | |
| ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1] | |
| x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre) | |
| # enc_ph | |
| if not hparams.get('lm_use_enc'): | |
| x_ph = self.encoder(ph_tokens) | |
| x_ph = x_ph + sum( | |
| [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ | |
| if len(hparams['ling_labels']) > 0 else 0 | |
| x_ph = x_ph + x_spk | |
| else: | |
| # enc_ph | |
| ph_enc_oembed = sum( | |
| [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ | |
| if len(hparams['ling_labels']) > 0 else 0 | |
| ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( | |
| torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) | |
| ph_enc_oembed = ph_enc_oembed + x_spk | |
| ph_enc_oembed = ph_enc_oembed * ph_nonpadding | |
| x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed) | |
| # enc_char | |
| if char_tokens is not None and ph2char is not None: | |
| char_nonpadding = (char_tokens > 0).float()[:, :, None] | |
| x_char = self.char_encoder(char_tokens) | |
| empty_char = (ph2char > 100000).long() | |
| ph2char = ph2char * (1 - empty_char) | |
| x_char_phlevel = \ | |
| expand_states(x_char * char_nonpadding, ph2char) \ | |
| * (1 - empty_char)[..., None] + \ | |
| self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None] | |
| else: | |
| x_char_phlevel = 0 | |
| # x_ling | |
| x_ling = x_ph + x_char_phlevel | |
| x_ling = x_ling * ph_nonpadding | |
| x_ling = self.enc_proj(x_ling) | |
| return x_ling | |
| def sample_one_step(self, vq_pred): | |
| hparams = self.hparams | |
| if hparams.get('infer_top_k'): | |
| top_k = hparams.get('infer_top_k') | |
| temperature = hparams.get('infer_temperature', 1) | |
| vq_pred = vq_pred[:, -1] / temperature | |
| # optionally crop the logits to only the top k options | |
| if top_k is not None: | |
| v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1))) | |
| vq_pred[vq_pred < v[:, [-1]]] = -float('Inf') | |
| # apply softmax to convert logits to (normalized) probabilities | |
| probs = F.softmax(vq_pred, dim=-1) | |
| # sample from the distribution | |
| vq_pred = torch.multinomial(probs, num_samples=1) | |
| else: | |
| vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1) | |
| return vq_pred | |
| def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None): | |
| # add spk embed | |
| style_embed = 0 | |
| if self.hparams['use_spk_embed']: | |
| style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :] | |
| if self.hparams['use_spk_id']: | |
| style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :] | |
| if self.hparams['use_spk_enc']: | |
| style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :] | |
| return style_embed | |
| def buffered_future_mask(self, tensor): | |
| dim = tensor.size(0) | |
| if ( | |
| not hasattr(self, '_future_mask') | |
| or self._future_mask is None | |
| or self._future_mask.device != tensor.device | |
| or self._future_mask.size(0) < dim | |
| ): | |
| self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1) | |
| return self._future_mask[:dim, :dim] | |
| class ARDurPredictor(CodePredictor): | |
| def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True, | |
| op_version=1): | |
| super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size) | |
| self.use_rot_embed = use_rot_embed | |
| bias = hparams.get('lm_bias', True) | |
| if self.use_rot_embed: | |
| self.layers = nn.ModuleList([]) | |
| self.layers.extend([ | |
| RotTransformerDecoderLayer( | |
| dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4, | |
| post_ln=self.use_post_ln, op_version=op_version, bias=bias) | |
| for _ in range(lm_num_layers) | |
| ]) | |
| if hparams['dur_model_type'] == 'ar_mse': | |
| self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus()) | |
| else: | |
| self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1) | |
| def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
| prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None, | |
| incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None, | |
| prompt_length=None, cache_size=20, streaming=False): | |
| x = self.code_emb(prev_code) | |
| if x_ling is None: | |
| x_ling = self.forward_ling_encoder( | |
| txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre) | |
| x_ling = x_ling.flatten(0, 1) | |
| txt_tokens = txt_tokens.flatten(0, 1) | |
| x_ling = x_ling[txt_tokens > 0][None] | |
| # run decoder | |
| self_attn_padding_mask = None | |
| if self.use_pos_embed: | |
| positions = self.embed_positions( | |
| prev_code, | |
| incremental_state=incremental_state | |
| ) | |
| if incremental_state is not None: | |
| x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]] | |
| if spk_pos_ids_flat is not None: | |
| spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]] | |
| x = x[:, -1:] | |
| if self.use_pos_embed: | |
| positions = positions[:, -1:] | |
| if streaming: | |
| # Shift Pos: query pos is min(cache_size, idx) | |
| spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device), | |
| spk_pos_ids_flat) | |
| # # B x T x C -> T x B x C | |
| if self.use_pos_embed: | |
| x = x + positions | |
| x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous() | |
| T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1]) | |
| x_ling = x_ling.reshape(-1, T, x_ling.shape[-1]) | |
| x = x + x_ling | |
| x = x.transpose(0, 1) | |
| for idx, layer in enumerate(self.layers): | |
| if incremental_state is None: | |
| self_attn_mask = self.buffered_future_mask(x) | |
| if attn_mask is not None: | |
| self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8 | |
| self_attn_mask = self_attn_mask.clamp_min(-1e8) | |
| else: | |
| self_attn_mask = None | |
| x, attn_weights = layer( | |
| x, | |
| incremental_state=incremental_state, | |
| self_attn_mask=self_attn_mask, | |
| self_attn_padding_mask=self_attn_padding_mask, | |
| spk_pos_ids_flat=spk_pos_ids_flat | |
| ) | |
| if streaming and incremental_state != {}: | |
| for k, v in incremental_state.items(): | |
| if 'attn_state' in k: | |
| prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] | |
| cur_length = prev_key.shape[2] | |
| if cur_length - prompt_length > cache_size: | |
| prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2) | |
| prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]), | |
| dim=2) | |
| incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value | |
| if not self.use_post_ln: | |
| x = self.layer_norm(x) | |
| # T x B x C -> B x T x C | |
| x = x.transpose(0, 1) | |
| x = self.project_out_dim(x) | |
| return x | |
| def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
| spk_id=None, spk_embed=None, mels_timbre=None, | |
| incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, | |
| first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs): | |
| if incremental_state is None: | |
| incremental_state = {} | |
| x_ling = self.forward_ling_encoder( | |
| txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
| spk_id, spk_embed, mels_timbre) | |
| x_ling = x_ling.flatten(0, 1) | |
| txt_tokens_ori = txt_tokens | |
| txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) | |
| x_ling = x_ling[txt_tokens > 0][None] | |
| txt_tokens = txt_tokens[txt_tokens > 0][None] | |
| decoded = torch.zeros_like(txt_tokens) | |
| decoded = F.pad(decoded, [1, 0], value=self.code_size + 1) | |
| if incremental_state != {}: | |
| if first_decoder_inp is None: | |
| assert ctx_vqcodes is not None | |
| decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes | |
| ctx_vqcodes = None | |
| else: | |
| decoded[:, :1] = first_decoder_inp | |
| probs = [] | |
| for step in range(decoded.shape[1] - 1): | |
| vq_pred = self(txt_tokens, None, None, None, None, | |
| decoded[:, :step + 1], None, None, None, | |
| incremental_state=incremental_state, x_ling=x_ling, | |
| spk_pos_ids_flat=spk_pos_ids_flat, **kwargs) | |
| probs.append(vq_pred.cpu()) | |
| if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: | |
| if self.hparams['dur_model_type'] == 'ar_mse': | |
| d = vq_pred[:, -1, 0] | |
| if dur_disturb > 0 and step >= 1: | |
| if random.random() > 0.5: | |
| d = d * (1 + random.random() * dur_disturb) | |
| else: | |
| d = d / (1 + random.random() * dur_disturb) | |
| d = torch.clamp_max(d, self.code_size - 1) | |
| vq_pred = torch.round(d).long() | |
| else: | |
| vq_pred = self.sample_one_step(vq_pred) | |
| decoded[:, step + 1] = torch.clamp_min(vq_pred, 1) | |
| if step == 0: | |
| decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min) | |
| else: | |
| decoded[:, step + 1] = ctx_vqcodes[:, step] | |
| decoded = decoded[:, 1:] | |
| decoded_2d = torch.zeros_like(txt_tokens_ori) | |
| decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded | |
| if return_state: | |
| return decoded_2d, incremental_state | |
| if return_probs: | |
| return decoded_2d, torch.cat(probs, 1) | |
| return decoded_2d | |
| def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
| spk_id=None, spk_embed=None, mels_timbre=None, | |
| incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, | |
| **kwargs): | |
| if incremental_state is None: | |
| incremental_state = {} | |
| x_ling = self.forward_ling_encoder( | |
| txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
| spk_id, spk_embed, mels_timbre) | |
| x_ling = x_ling.flatten(0, 1) | |
| txt_tokens_ori = txt_tokens | |
| txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) | |
| x_ling = x_ling[txt_tokens > 0][None] | |
| txt_tokens = txt_tokens[txt_tokens > 0][None] | |
| vq_decoded = torch.zeros_like(txt_tokens) | |
| vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1) | |
| if incremental_state != {}: | |
| assert ctx_vqcodes is not None | |
| vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes | |
| ctx_vqcodes = None | |
| prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2] | |
| for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'): | |
| vq_pred = self(txt_tokens, None, None, None, None, | |
| vq_decoded[:, :step + 1], None, None, None, | |
| incremental_state=incremental_state, x_ling=x_ling, | |
| spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs) | |
| if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: | |
| if self.hparams['dur_model_type'] == 'ar_mse': | |
| vq_pred = torch.round(vq_pred[:, -1, 0]).long() | |
| else: | |
| vq_pred = self.sample_one_step(vq_pred) | |
| vq_decoded[:, step + 1] = vq_pred | |
| else: | |
| vq_decoded[:, step + 1] = ctx_vqcodes[:, step] | |
| vq_decoded = vq_decoded[:, 1:] | |
| vq_decoded_2d = torch.zeros_like(txt_tokens_ori) | |
| vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded | |
| if return_state: | |
| return vq_decoded_2d, incremental_state | |
| return vq_decoded_2d |