| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Modeling classes for MossTTSDelay. """ |
|
|
| from dataclasses import dataclass |
| from typing import List, Optional, Tuple, Union |
| from tqdm import tqdm |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import CrossEntropyLoss |
|
|
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| replace_return_docstrings, |
| ) |
| from transformers.cache_utils import Cache |
| from transformers.models.qwen3 import Qwen3Model |
| from transformers import initialization as init |
|
|
| from .configuration_moss_tts import MossTTSDelayConfig |
| from .inference_utils import sample_token, find_last_equal_C |
|
|
| try: |
| from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor |
| except Exception: |
| UserMessage = None |
| AssistantMessage = None |
| MossTTSDelayProcessor = None |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "MossTTSDelayConfig" |
|
|
|
|
| @dataclass |
| class MossTTSDelayOutputWithPast(ModelOutput): |
| """ |
| Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Weighted sum of channel losses. |
| all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*): |
| Sum of losses for each sample and each channel before averaging. |
| all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Number of non-masked tokens per sample. |
| sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): |
| Loss per sample. |
| channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*): |
| Loss per channel (text head + vq heads). |
| logits (`List[torch.FloatTensor]`, *optional*): |
| List of prediction scores from each head. |
| past_key_values (`Cache`, *optional*): |
| Pre-computed hidden-states (key and values in the self-attention blocks). |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): |
| Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer). |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed): |
| Tuple of torch.FloatTensor (one for each layer) of the attention weights. |
| """ |
| loss: Optional[torch.FloatTensor] = None |
| all_sum_losses: Optional[torch.FloatTensor] = None |
| all_token_nums: Optional[torch.LongTensor] = None |
| sample_losses: Optional[torch.FloatTensor] = None |
| channel_losses: Optional[torch.FloatTensor] = None |
| logits: Optional[List[torch.FloatTensor]] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
| class MossTTSDelayPreTrainedModel(PreTrainedModel): |
| config_class = MossTTSDelayConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen3DecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
|
|
| def _init_weights(self, module): |
| """ |
| Transformers 5.0+ safe init: |
| - MUST use transformers.initialization helpers |
| - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params |
| """ |
| |
| super()._init_weights(module) |
|
|
| |
| |
| std = None |
| if hasattr(self.config, "initializer_range"): |
| std = self.config.initializer_range |
| elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"): |
| std = self.config.language_config.initializer_range |
| else: |
| std = 0.02 |
|
|
| |
| if isinstance(module, nn.Embedding): |
| |
| |
| if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1: |
| init.normal_(module.weight, mean=0.0, std=std) |
| |
| |
|
|
| |
| if isinstance(module, nn.Linear): |
| |
| |
| pass |
|
|
|
|
|
|
| MOSSTTS_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`MossTTSDelayConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.", |
| MOSSTTS_START_DOCSTRING, |
| ) |
| class MossTTSDelayModel(MossTTSDelayPreTrainedModel): |
| UserMessage = UserMessage |
| AssistantMessage = AssistantMessage |
| Processor = MossTTSDelayProcessor |
|
|
| def __init__(self, config: MossTTSDelayConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| config.language_config.torch_dtype = config.torch_dtype |
| |
| self.language_model = Qwen3Model(config.language_config) |
|
|
| |
| |
| |
| self.emb_ext = nn.ModuleList() |
| for vq_idx in range(self.config.n_vq): |
| |
| self.emb_ext.append( |
| nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None) |
| ) |
|
|
| |
| |
| |
| self.lm_heads = nn.ModuleList([ |
| nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False) |
| ]) |
| for vq_idx in range(self.config.n_vq): |
| self.lm_heads.append( |
| nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False) |
| ) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| """ |
| Computes the combined embeddings from text and multiple audio VQ channels. |
| |
| Args: |
| input_ids: Shape (Batch, Seq_Len, 1 + n_vq) |
| """ |
| |
| |
| inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0]) |
|
|
| |
| for i, embed_layer in enumerate(self.emb_ext): |
| |
| |
| inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1]) |
| |
| return inputs_embeds |
|
|
| def set_input_embeddings(self, value): |
| self.language_model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| |
| |
| return self.lm_heads |
|
|
| @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING) |
| @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| hidden_out_layers: Optional[List[int]] = None, |
| channelwise_loss_weight: Optional[List[float]] = None, |
| **kwargs, |
| ) -> Union[Tuple, MossTTSDelayOutputWithPast]: |
| r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`): |
| Indices of input sequence tokens in the vocabulary. |
| Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N]. |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*): |
| Labels for computing the masked language modeling loss. |
| channelwise_loss_weight (`List[float]`, *optional*): |
| Manual weights for summing losses across different heads (Text vs Audio channels). |
| |
| Returns: |
| """ |
|
|
| if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1: |
| raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).") |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
| |
| if inputs_embeds is None: |
| inputs_embeds = self.get_input_embeddings(input_ids) |
|
|
| |
| |
| outputs = self.language_model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=True, |
| return_dict=True, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| |
| last_hidden_state = outputs.last_hidden_state |
| if hidden_out_layers is None: |
| |
| |
| |
| hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads)) |
| else: |
| |
| |
| all_hs = outputs.hidden_states |
| hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers] |
|
|
| |
| layer_logits = [] |
| for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)): |
| logits = head(hs) |
| |
| |
| |
| if i > 0: |
| logits[..., -1] = float("-inf") |
| layer_logits.append(logits) |
|
|
| |
| loss = None |
| all_sum_losses = None |
| all_token_nums = None |
| sample_losses = None |
| channel_losses = None |
|
|
| if labels is not None: |
| |
| if labels.dim() != 3: |
| raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}") |
|
|
| batch_size = labels.size(0) |
| n_heads = len(layer_logits) |
| |
| |
| |
| all_sum_losses_list = [] |
| |
| |
| |
| |
| |
| all_token_nums = torch.sum(labels != -100, dim=1) |
|
|
| for i, logits in enumerate(layer_logits): |
| |
| |
| cur_labels = labels[..., i] |
| |
| |
| |
| loss_fct = CrossEntropyLoss(reduction='none') |
| vocab_size = logits.size(-1) |
| |
| reshaped_logits = logits.view(-1, vocab_size) |
| reshaped_labels = cur_labels.contiguous().view(-1) |
| |
| |
| per_token_loss = loss_fct(reshaped_logits, reshaped_labels) |
| |
| |
| per_token_loss = per_token_loss.view(batch_size, -1) |
| per_sample_loss = torch.sum(per_token_loss, dim=-1) |
| |
| all_sum_losses_list.append(per_sample_loss) |
|
|
| |
| all_sum_losses = torch.stack(all_sum_losses_list, dim=1) |
|
|
| |
| if channelwise_loss_weight is not None: |
| if len(channelwise_loss_weight) != n_heads: |
| raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}") |
| |
| w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype) |
| |
| |
| |
| |
| token_counts_safe = all_token_nums.float().clamp(min=1.0) |
| |
| normalized_losses = all_sum_losses / token_counts_safe |
| sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum() |
| |
| |
| total_loss_per_channel = all_sum_losses.sum(dim=0) |
| total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0) |
| channel_losses = total_loss_per_channel / total_tokens_per_channel |
| |
| |
| loss = (channel_losses * w_tensor).sum() / w_tensor.sum() |
| else: |
| |
| total_tokens = all_token_nums.sum().float().clamp(min=1.0) |
| loss = all_sum_losses.sum() / total_tokens |
| channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0) |
|
|
| return MossTTSDelayOutputWithPast( |
| loss=loss, |
| all_sum_losses=all_sum_losses, |
| all_token_nums=all_token_nums, |
| sample_losses=sample_losses, |
| channel_losses=channel_losses, |
| logits=layer_logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| @torch.inference_mode() |
| def generate( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| max_new_tokens: int = 1000, |
| text_temperature: float = 1.5, |
| text_top_p: float = 1.0, |
| text_top_k: int = 50, |
| audio_temperature: float = 1.7, |
| audio_top_p: float = 0.8, |
| audio_top_k: int = 25, |
| audio_repetition_penalty: float = 1.0, |
| ): |
| if text_temperature > 0: |
| text_do_sample = True |
| else: |
| text_temperature = 1 |
| text_do_sample = False |
| if audio_temperature > 0: |
| audio_do_sample = True |
| else: |
| audio_temperature = 1 |
| audio_do_sample = False |
| |
| past_key_values = None |
| device = input_ids.device |
| current_input_ids = input_ids |
| current_attention_mask = attention_mask |
| batch_size, seq_len, n_vq = input_ids.shape |
| n_vq -= 1 |
| |
| generation_ids = input_ids[:] |
| is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
| audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) |
| torch_int64_max = torch.iinfo(torch.int64).max |
| delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) |
| |
| is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id) |
| audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id) |
| audio_start_mask = is_continuation & (audio_start_indices != -1) |
| audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask] |
| |
| is_audio = audio_start_mask.clone() |
| |
| pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device) |
| pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool() |
| pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False |
| |
| for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."): |
| outputs = self( |
| input_ids=current_input_ids, |
| attention_mask=current_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
| past_key_values = outputs.past_key_values |
| |
| next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] |
| next_token_logits[0] = next_token_logits[0].clone() |
| next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device) |
| next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id |
| is_audio_eos = ~is_stopping & (delayed_lengths == n_vq) |
| next_text_token[is_audio_eos] = self.config.audio_end_token_id |
| is_audio[is_audio_eos] = False |
| sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq) |
| next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf')) |
| next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf')) |
| if time_step == 0: |
| next_token_logits[0][..., 151662] = float('-inf') |
| if time_step <= n_vq: |
| next_token_logits[0][..., self.config.im_end_token_id] = float('-inf') |
| |
| next_text_token[sampling_text_mask] = sample_token( |
| logits=next_token_logits[0][sampling_text_mask], |
| top_p=text_top_p, |
| top_k=text_top_k, |
| do_sample=text_do_sample |
| ) |
| is_audio[next_text_token == self.config.audio_start_token_id] = True |
| is_stopping[next_text_token == self.config.im_end_token_id] = True |
|
|
| next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device) |
| |
| pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) |
| post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1 |
| post_audio_mask[delayed_lengths == torch_int64_max] = True |
| sampling_audio_mask = pre_audio_mask & post_audio_mask |
| next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code |
| |
| if sampling_audio_mask.sum() > 0: |
| audio_ch0_logits = next_token_logits[1][sampling_audio_mask[:, 0]] |
| audio_logits = torch.stack(next_token_logits[2:], dim=1)[sampling_audio_mask[:, 1:]] |
| audio_ch0_logits[..., self.config.audio_pad_code] = float('-inf') |
| audio_logits[..., self.config.audio_pad_code] = float('-inf') |
| next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = sample_token( |
| logits=audio_ch0_logits, |
| prev_tokens=generation_ids[:, :, 1], |
| repetition_penalty=audio_repetition_penalty, |
| top_p=audio_top_p, |
| top_k=audio_top_k, |
| do_sample=audio_do_sample |
| ) |
| next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = sample_token( |
| logits=audio_logits, |
| prev_tokens=generation_ids[:, :, 2:], |
| repetition_penalty=audio_repetition_penalty, |
| top_p=audio_top_p, |
| top_k=audio_top_k, |
| do_sample=audio_do_sample |
| ) |
| |
| audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1 |
| audio_lengths[next_text_token == self.config.audio_end_token_id] = 0 |
| delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0 |
| delayed_lengths[delayed_lengths != torch_int64_max] += 1 |
| delayed_lengths[delayed_lengths > n_vq] = torch_int64_max |
| |
| current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) |
| current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1) |
| generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) |
| |
| if is_stopping.sum() == batch_size: |
| break |
| |
| start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3 |
| start_lengths = seq_len - start_indices |
|
|
| output = [] |
| for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids): |
| output.append((start_length, cur_generation_ids[start_idx:])) |
| |
| return output |
|
|