| | """PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi.""" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, Dict, List, Union |
| |
|
| | try: |
| | from .configuration_mimi import MimiConfig |
| | from .configuration_text_sync_mimi import TextSyncMimiConfig |
| | from .modeling_mimi_clean import MimiPreTrainedModel, MimiModel |
| | from .modeling_backbone_components import ( |
| | CrossAttentionTransformer, |
| | CausalAttentionTransformer |
| | ) |
| | except ImportError: |
| | from configuration_mimi import MimiConfig |
| | from configuration_text_sync_mimi import TextSyncMimiConfig |
| | from modeling_mimi_clean import MimiPreTrainedModel, MimiModel |
| | from modeling_backbone_components import ( |
| | CrossAttentionTransformer, |
| | CausalAttentionTransformer |
| | ) |
| |
|
| |
|
| | class TextSyncMimi(MimiPreTrainedModel): |
| | """ |
| | TextSyncMimi: Text-Synchronous Neural Audio Codec Model |
| | |
| | A neural audio codec model that combines text and speech representations for |
| | high-quality text-to-speech synthesis. Features: |
| | |
| | - Learnable text embeddings |
| | - Cross-attention transformer for text-speech alignment |
| | - Autoregressive transformer for causal speech generation |
| | - BCE-based end token prediction for dynamic duration control |
| | |
| | Architecture: |
| | - Text Embedding Layer: Maps token IDs to 4,096-dim embeddings |
| | - Mimi Encoder: Pre-trained audio encoder (frozen) |
| | - Text Projection: Linear projection from 4,096 to 512 dimensions |
| | - Cross-Attention Transformer: Aligns text with speech features |
| | - Autoregressive Transformer: Generates speech representations |
| | - End Token Classifier: Predicts when to stop generating |
| | """ |
| | |
| | config_class = TextSyncMimiConfig |
| | |
| | def __init__( |
| | self, |
| | config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None, |
| | model_id: Optional[str] = None, |
| | token: Optional[str] = None, |
| | alpha: Optional[float] = None, |
| | cross_attention_layers: Optional[int] = None, |
| | causal_attention_layers: Optional[int] = None, |
| | bce_threshold: Optional[float] = None, |
| | vocab_size: Optional[int] = None, |
| | ): |
| | """ |
| | Initialize TextSyncMimi model. |
| | |
| | Args: |
| | config: Model configuration (TextSyncMimiConfig or MimiConfig) |
| | model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id |
| | token: Hugging Face authentication token |
| | alpha: Weight for BCE end token loss. If None, uses config.alpha |
| | cross_attention_layers: Number of cross-attention layers. If None, uses config |
| | causal_attention_layers: Number of autoregressive layers. If None, uses config |
| | bce_threshold: BCE loss threshold. If None, uses config.bce_threshold |
| | vocab_size: Text vocabulary size. If None, uses config.vocab_size |
| | """ |
| | |
| | if config is None: |
| | if model_id is None: |
| | raise ValueError("Either config or model_id must be provided") |
| | config = MimiConfig.from_pretrained(model_id, token=token) |
| | |
| | super().__init__(config) |
| | |
| | |
| | if hasattr(config, 'mimi_model_id'): |
| | model_id = model_id or config.mimi_model_id |
| | if model_id is None: |
| | raise ValueError("model_id must be provided either as argument or in config.mimi_model_id") |
| | |
| | alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0) |
| | cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2) |
| | causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2) |
| | bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1) |
| | vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256) |
| |
|
| | |
| | self.config = config |
| | model = MimiModel.from_pretrained(model_id, token=token) |
| |
|
| | |
| | self.alpha = alpha |
| | self.bce_threshold = bce_threshold |
| |
|
| | |
| | self.text_token_embedding = nn.Embedding(vocab_size, 4096) |
| |
|
| | |
| | self.text_proj = nn.Linear(4096, 512) |
| | |
| | |
| | cross_attention_config = MimiConfig(**self.config.__dict__) |
| | cross_attention_config.num_hidden_layers = cross_attention_layers |
| | cross_attention_config.hidden_size = 512 |
| | self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | causal_attention_config = MimiConfig(**self.config.__dict__) |
| | causal_attention_config.num_hidden_layers = causal_attention_layers |
| | causal_attention_config.hidden_size = 512 |
| | self.ar_transformer = CausalAttentionTransformer(causal_attention_config) |
| |
|
| | |
| | self.text_speech_latent_embed = nn.Embedding(1, 512) |
| | self.time_speech_start_embed = nn.Embedding(1, 512) |
| | self.time_speech_end_embed = nn.Embedding(1, 512) |
| |
|
| | |
| | self.end_token_classifier = nn.Linear(512, 1) |
| |
|
| | self.post_init() |
| |
|
| | |
| | self.encoder = model.encoder |
| | self.encoder_transformer = model.encoder_transformer |
| | self.quantizer = model.quantizer |
| | self.downsample = model.downsample |
| | self.upsample = model.upsample |
| |
|
| | |
| | self._print_subnetwork_parameter_counts() |
| |
|
| | def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None: |
| | """ |
| | Initialize text embeddings from a weight matrix. |
| | |
| | Args: |
| | embedding_weight: Weight matrix of shape (vocab_size, 4096) |
| | """ |
| | if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096: |
| | raise ValueError("embedding_weight must have shape (vocab_size, 4096)") |
| | if embedding_weight.size(0) != self.text_token_embedding.num_embeddings: |
| | raise ValueError("Provided vocab_size does not match model's text_token_embedding") |
| | with torch.no_grad(): |
| | self.text_token_embedding.weight.copy_(embedding_weight) |
| | for p in self.text_token_embedding.parameters(): |
| | p.requires_grad = True |
| |
|
| | def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None: |
| | """ |
| | Initialize text embeddings from a LLaMA embedding module. |
| | |
| | Args: |
| | llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096) |
| | """ |
| | if not hasattr(llama_embeddings_module, 'weight'): |
| | raise ValueError("llama_embeddings_module must have a 'weight' attribute") |
| | weight = llama_embeddings_module.weight.data |
| | self.initialize_text_embeddings_from_weights(weight) |
| |
|
| | def _print_subnetwork_parameter_counts(self) -> None: |
| | """Print parameter counts for model subnetworks.""" |
| | print("=" * 70) |
| | print("TextSyncMimi Parameter Counts") |
| | print("=" * 70) |
| | print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M") |
| | print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M") |
| | print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M") |
| | print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M") |
| | print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M") |
| | print("=" * 70) |
| |
|
| | def encode_audio_to_representation( |
| | self, |
| | input_values: torch.Tensor, |
| | audio_attention_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Encode audio to speech representation. |
| | |
| | Args: |
| | input_values: Audio waveform (B, 1, audio_len) |
| | audio_attention_mask: Attention mask (B, audio_len) |
| | |
| | Returns: |
| | Speech embeddings (B, 512, 12.5 * T) |
| | """ |
| | batch_size = input_values.shape[0] |
| | device = input_values.device |
| | |
| | |
| | embeddings = self.encoder(input_values) |
| | encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2)) |
| | embeddings = encoder_outputs[0].transpose(1, 2) |
| | embeddings = self.downsample(embeddings) |
| | |
| | |
| | if audio_attention_mask is not None: |
| | speech_seq_len = embeddings.shape[-1] |
| | speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool) |
| | |
| | for b in range(batch_size): |
| | actual_audio_len = audio_attention_mask[b].sum().item() |
| | actual_speech_len = int(actual_audio_len * 12.5 / 24000) |
| | actual_speech_len = min(actual_speech_len, speech_seq_len) |
| | if actual_speech_len > 0: |
| | speech_attention_mask[b, :actual_speech_len] = True |
| | |
| | speech_mask_expanded = speech_attention_mask.unsqueeze(1) |
| | embeddings = embeddings * speech_mask_expanded.float() |
| | |
| | return embeddings |
| |
|
| | def generate_autoregressive( |
| | self, |
| | text_token_ids: torch.LongTensor, |
| | input_values: Optional[torch.Tensor] = None, |
| | speech_embeddings: Optional[torch.Tensor] = None, |
| | audio_attention_mask: Optional[torch.Tensor] = None, |
| | speech_attention_mask: Optional[torch.Tensor] = None, |
| | text_attention_mask: Optional[torch.Tensor] = None, |
| | max_z_tokens: int = 50, |
| | end_token_threshold: float = 0.5, |
| | device: Optional[torch.device] = None, |
| | ) -> List[List[torch.Tensor]]: |
| | """ |
| | Generate audio autoregressively. |
| | |
| | Args: |
| | text_token_ids: Text token IDs (B, L) |
| | input_values: Audio input (B, 1, 24000 * T) - for normal mode |
| | speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode |
| | audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode |
| | speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode |
| | text_attention_mask: Text mask (B, text_seq_len) |
| | max_z_tokens: Maximum z tokens per text position |
| | end_token_threshold: Probability threshold for stopping |
| | device: Device for computation |
| | |
| | Returns: |
| | List of z_tokens lists (one per batch item) |
| | """ |
| | if device is None: |
| | device = text_token_ids.device |
| | |
| | self.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | if speech_embeddings is not None: |
| | |
| | |
| | pass |
| | else: |
| | |
| | if input_values is None: |
| | raise ValueError("Either input_values or speech_embeddings must be provided") |
| | speech_embeddings = self.encode_audio_to_representation( |
| | input_values, |
| | audio_attention_mask=audio_attention_mask |
| | ) |
| | speech_embeddings = speech_embeddings.transpose(1, 2) |
| | |
| | |
| | text_embeddings_4096 = self.text_token_embedding(text_token_ids) |
| | text_embeddings_proj = self.text_proj(text_embeddings_4096) |
| | |
| | |
| | |
| | formatted_text_attention_mask = None |
| | formatted_speech_attention_mask = None |
| | |
| | batch_size, text_seq_len = text_embeddings_proj.shape[:2] |
| | |
| | if text_attention_mask is not None: |
| | causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype)) |
| | causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
| | padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) |
| | combined_mask = causal_mask * padding_mask |
| | formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf')) |
| | else: |
| | causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype)) |
| | causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
| | formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf')) |
| | |
| | |
| | if speech_attention_mask is not None: |
| | |
| | speech_seq_len = speech_embeddings.shape[1] |
| | speech_mask = speech_attention_mask.bool() |
| | formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
| | formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
| | elif audio_attention_mask is not None: |
| | |
| | speech_seq_len = speech_embeddings.shape[1] |
| | speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device) |
| | for b in range(batch_size): |
| | audio_len = audio_attention_mask[b].sum().item() |
| | speech_len = int(audio_len * 12.5 / 24000) |
| | speech_len = min(speech_len, speech_seq_len) |
| | speech_mask[b, :speech_len] = True |
| | formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
| | formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
| | else: |
| | formatted_speech_attention_mask = None |
| | |
| | |
| | cross_attention_outputs = self.cross_attention_transformer( |
| | hidden_states=text_embeddings_proj, |
| | encoder_hidden_states=speech_embeddings, |
| | attention_mask=formatted_text_attention_mask, |
| | encoder_attention_mask=formatted_speech_attention_mask, |
| | alignment_chunk_sizes=None, |
| | ) |
| | cross_attention_outputs = cross_attention_outputs.last_hidden_state |
| | |
| | |
| | text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| | time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| | time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| | |
| | generated_z_tokens = [] |
| | |
| | |
| | for b in range(batch_size): |
| | |
| | if text_attention_mask is not None: |
| | valid_text_len = text_attention_mask[b].sum().item() |
| | else: |
| | valid_text_len = text_embeddings_proj.shape[1] |
| | |
| | |
| | sequence = [text_speech_latent_emb] |
| | batch_z_tokens = [] |
| | |
| | |
| | for i in range(valid_text_len): |
| | |
| | t_i = text_embeddings_proj[b, i:i+1] |
| | s_i = cross_attention_outputs[b, i:i+1] |
| | sequence.extend([t_i, s_i]) |
| | |
| | |
| | sequence.append(time_speech_start_emb) |
| | |
| | |
| | z_count = 0 |
| | while z_count < max_z_tokens: |
| | |
| | current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) |
| | |
| | |
| | seq_len = current_sequence.shape[1] |
| | ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device) |
| | |
| | |
| | ar_outputs = self.ar_transformer( |
| | hidden_states=current_sequence, |
| | attention_mask=ar_attention_mask, |
| | ) |
| | |
| | |
| | last_prediction = ar_outputs.last_hidden_state[0, -1:, :] |
| | |
| | |
| | end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1) |
| | end_token_prob = torch.sigmoid(end_token_logit).item() |
| | |
| | |
| | if end_token_prob >= end_token_threshold: |
| | |
| | break |
| | else: |
| | |
| | sequence.append(last_prediction) |
| | batch_z_tokens.append(last_prediction.squeeze(0)) |
| | z_count += 1 |
| | |
| | |
| | sequence.append(time_speech_end_emb) |
| | |
| | |
| | generated_z_tokens.append(batch_z_tokens) |
| | |
| | return generated_z_tokens |
| |
|
| | def forward( |
| | self, |
| | text_token_ids: torch.LongTensor, |
| | input_values: Optional[torch.Tensor] = None, |
| | speech_embeddings: Optional[torch.Tensor] = None, |
| | alignment_chunk_sizes: torch.Tensor = None, |
| | audio_attention_mask: Optional[torch.Tensor] = None, |
| | speech_attention_mask: Optional[torch.Tensor] = None, |
| | text_attention_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Forward pass for training. |
| | |
| | Args: |
| | text_token_ids: Text token IDs (B, L) |
| | input_values: Audio input (B, 1, 24000 * T) - for normal mode |
| | speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode |
| | alignment_chunk_sizes: Alignment chunk sizes (B, L) |
| | audio_attention_mask: Audio mask (B, audio_seq_len) |
| | speech_attention_mask: Speech mask (B, speech_seq_len) |
| | text_attention_mask: Text mask (B, text_seq_len) |
| | |
| | Returns: |
| | Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss' |
| | """ |
| | |
| | if speech_embeddings is not None: |
| | pass |
| | elif input_values is not None: |
| | |
| |
|
| | speech_embeddings_raw = self.encode_audio_to_representation( |
| | input_values, |
| | audio_attention_mask |
| | ) |
| | |
| | |
| | speech_embeddings = speech_embeddings_raw.transpose(1, 2) |
| | else: |
| | raise ValueError("Either input_values or speech_embeddings must be provided") |
| | |
| | text_embeddings_4096 = self.text_token_embedding(text_token_ids) |
| | text_embeddings = self.text_proj(text_embeddings_4096) |
| | |
| | |
| | formatted_text_attention_mask = None |
| | formatted_speech_attention_mask = None |
| | |
| | |
| | batch_size, text_seq_len = text_embeddings.shape[:2] |
| | |
| | if text_attention_mask is not None: |
| | |
| | causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype)) |
| | causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
| | |
| | |
| | padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) |
| | combined_mask = causal_mask * padding_mask |
| | |
| | |
| | formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf')) |
| | else: |
| | |
| | causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype)) |
| | causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
| | formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf')) |
| | |
| | |
| | |
| | if speech_attention_mask is not None: |
| | |
| | speech_seq_len = speech_embeddings.shape[1] |
| | speech_mask = speech_attention_mask.bool() |
| | |
| | |
| | formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
| | formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
| | elif audio_attention_mask is not None: |
| | |
| | speech_seq_len = speech_embeddings.shape[1] |
| | |
| | |
| | speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device) |
| | |
| | for b in range(batch_size): |
| | audio_len = audio_attention_mask[b].sum().item() |
| | speech_len = int(audio_len * 12.5 / 24000) |
| | speech_len = min(speech_len, speech_seq_len) |
| | speech_mask[b, :speech_len] = True |
| | |
| | |
| | formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) |
| | formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) |
| | else: |
| | |
| | formatted_speech_attention_mask = None |
| |
|
| | |
| | |
| | cross_attention_outputs = self.cross_attention_transformer( |
| | hidden_states=text_embeddings, |
| | encoder_hidden_states=speech_embeddings, |
| | attention_mask=formatted_text_attention_mask, |
| | encoder_attention_mask=formatted_speech_attention_mask, |
| | alignment_chunk_sizes=None, |
| | ) |
| | cross_attention_outputs = cross_attention_outputs.last_hidden_state |
| |
|
| | |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | embeddings_bct = speech_embeddings.transpose(1, 2) |
| | codes_kbt = self.quantizer.encode(embeddings_bct) |
| | codes_bkt = codes_kbt.transpose(0, 1) |
| | decoder_input_emb = self.quantizer.decode(codes_bkt) |
| | target_representation = decoder_input_emb.transpose(1, 2) |
| |
|
| | |
| | |
| | |
| | device = text_embeddings.device |
| | text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| | time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| | time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| |
|
| | batch_size = text_embeddings.shape[0] |
| | interleaved_sequences = [] |
| | loss_masks = [] |
| | bce_labels_batch = [] |
| | bce_masks = [] |
| | sequence_lengths = [] |
| | all_z_tokens = [] |
| | max_total_length = 0 |
| |
|
| | for b in range(batch_size): |
| | |
| | sequence_parts = [text_speech_latent_emb] |
| | loss_mask_parts = [False] |
| | bce_label_parts = [0] |
| | bce_mask_parts = [False] |
| | |
| | |
| | if text_attention_mask is not None: |
| | valid_text_len = text_attention_mask[b].sum().item() |
| | else: |
| | valid_text_len = text_embeddings.shape[1] |
| | |
| | |
| | speech_position = 0 |
| | |
| | |
| | for i in range(valid_text_len): |
| | |
| | t_i = text_embeddings[b, i:i+1] |
| | sequence_parts.append(t_i) |
| | loss_mask_parts.append(False) |
| | bce_label_parts.append(0) |
| | bce_mask_parts.append(False) |
| | |
| | |
| | s_i = cross_attention_outputs[b, i:i+1] |
| | sequence_parts.append(s_i) |
| | loss_mask_parts.append(False) |
| | bce_label_parts.append(0) |
| | bce_mask_parts.append(False) |
| | |
| | |
| | sequence_parts.append(time_speech_start_emb) |
| | loss_mask_parts.append(False) |
| | bce_label_parts.append(0) |
| | bce_mask_parts.append(False) |
| | |
| | |
| | chunk_size = alignment_chunk_sizes[b, i].item() |
| | if chunk_size > 0: |
| | end_position = speech_position + chunk_size |
| | |
| | end_position = min(end_position, target_representation.shape[1]) |
| | actual_chunk_size = end_position - speech_position |
| | |
| | if actual_chunk_size > 0: |
| | z_tokens = target_representation[b, speech_position:end_position] |
| | sequence_parts.append(z_tokens) |
| | loss_mask_parts.extend([True] * actual_chunk_size) |
| | bce_label_parts.extend([0] * actual_chunk_size) |
| | bce_mask_parts.extend([True] * actual_chunk_size) |
| | |
| | |
| | all_z_tokens.append(z_tokens) |
| | |
| | speech_position = end_position |
| | |
| | |
| | sequence_parts.append(time_speech_end_emb) |
| | loss_mask_parts.append(False) |
| | bce_label_parts.append(1) |
| | bce_mask_parts.append(True) |
| | |
| | |
| | full_sequence = torch.cat(sequence_parts, dim=0) |
| | loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device) |
| | bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device) |
| | bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device) |
| | |
| | interleaved_sequences.append(full_sequence) |
| | loss_masks.append(loss_mask) |
| | bce_labels_batch.append(bce_labels) |
| | bce_masks.append(bce_mask) |
| | sequence_lengths.append(full_sequence.shape[0]) |
| | max_total_length = max(max_total_length, full_sequence.shape[0]) |
| |
|
| | |
| | padded_sequences = [] |
| | padded_loss_masks = [] |
| | padded_bce_labels = [] |
| | padded_bce_masks = [] |
| |
|
| | for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks): |
| | current_length = sequence.shape[0] |
| | if current_length < max_total_length: |
| | padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype) |
| | padded_sequence = torch.cat([sequence, padding], dim=0) |
| | |
| | mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device) |
| | padded_mask = torch.cat([loss_mask, mask_padding], dim=0) |
| | |
| | bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device) |
| | padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0) |
| | |
| | bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device) |
| | padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0) |
| | else: |
| | padded_sequence = sequence |
| | padded_mask = loss_mask |
| | padded_bce_label = bce_labels |
| | padded_bce_mask = bce_mask |
| | |
| | padded_sequences.append(padded_sequence) |
| | padded_loss_masks.append(padded_mask) |
| | padded_bce_labels.append(padded_bce_label) |
| | padded_bce_masks.append(padded_bce_mask) |
| |
|
| | |
| | interleaved_batch = torch.stack(padded_sequences, dim=0) |
| | loss_mask_batch = torch.stack(padded_loss_masks, dim=0) |
| | bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0) |
| | bce_mask_batch = torch.stack(padded_bce_masks, dim=0) |
| |
|
| | |
| | if max_total_length > 1: |
| | ar_input = interleaved_batch[:, :-1, :] |
| | ar_targets = interleaved_batch[:, 1:, :] |
| | ar_loss_mask = loss_mask_batch[:, 1:] |
| | ar_bce_labels = bce_labels_batch_tensor[:, 1:] |
| | ar_bce_mask = bce_mask_batch[:, 1:] |
| | |
| | |
| | |
| | ar_seq_len = ar_input.shape[1] |
| | ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device) |
| | for b in range(batch_size): |
| | valid_len = min(ar_seq_len, sequence_lengths[b] - 1) |
| | if valid_len > 0: |
| | ar_attention_mask[b, :valid_len] = True |
| | |
| | ar_outputs = self.ar_transformer( |
| | hidden_states=ar_input, |
| | attention_mask=ar_attention_mask, |
| | ) |
| | ar_predictions = ar_outputs.last_hidden_state |
| | |
| | |
| | bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1) |
| | |
| | |
| | if ar_loss_mask.any(): |
| | |
| | valid_predictions = ar_predictions[ar_loss_mask] |
| | valid_targets = ar_targets[ar_loss_mask] |
| | |
| | |
| | reconstruction_loss = nn.functional.mse_loss( |
| | valid_predictions, |
| | valid_targets, |
| | reduction='mean' |
| | ) |
| | else: |
| | |
| | reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True) |
| | |
| | |
| | if ar_bce_mask.any(): |
| | |
| | valid_bce_logits = bce_logits[ar_bce_mask] |
| | valid_bce_labels = ar_bce_labels[ar_bce_mask] |
| | |
| | |
| | bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits( |
| | valid_bce_logits, |
| | valid_bce_labels, |
| | reduction='mean' |
| | ) |
| | else: |
| | |
| | bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True) |
| | |
| | if self.bce_threshold > 0.0: |
| | clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0) |
| | total_loss = reconstruction_loss + self.alpha * clamped_bce_loss |
| | else: |
| | total_loss = reconstruction_loss + self.alpha * bce_end_token_loss |
| | else: |
| | reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True) |
| | bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True) |
| | total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True) |
| |
|
| | return { |
| | 'loss': total_loss, |
| | 'reconstruction_loss': reconstruction_loss, |
| | 'bce_end_token_loss': bce_end_token_loss, |
| | } |
| |
|
| |
|
| | __all__ = ["TextSyncMimi"] |
| |
|