Any-to-Any
Transformers
Safetensors
English
xoron
multimodal
Mixture of Experts
text-to-image
image editing
image to video
text-to-video
video editing
text-to-speech
speech-to-text
speech-to-speech
image-to-text
video-to-text
agentic
tool-use
flow-matching
3d-rope
titok
vidtok
dual-stream-attention
zero-shot-voice-cloning
bigvgan
snake-activation
multi-receptive-field-fusion
custom_code
| """ | |
| Xoron Model for HuggingFace Transformers - Self-Contained Implementation. | |
| AUTO-GENERATED FILE - Do not edit directly! | |
| This module provides a complete, self-contained HuggingFace-compatible model class | |
| for the Xoron multimodal model. All components are embedded directly in this file | |
| to enable loading via AutoModel with trust_remote_code=True WITHOUT requiring | |
| the full Xoron-Dev package to be installed. | |
| Usage: | |
| from transformers import AutoModel, AutoConfig | |
| config = AutoConfig.from_pretrained("your-repo/xoron-model", trust_remote_code=True) | |
| model = AutoModel.from_pretrained("your-repo/xoron-model", trust_remote_code=True) | |
| """ | |
| import os | |
| import math | |
| import json | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Dict, List, Union, Tuple, Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| try: | |
| from safetensors.torch import save_file, load_file | |
| except ImportError: | |
| save_file, load_file = None, None | |
| from transformers import PreTrainedModel, LlamaConfig, LlamaModel, LlamaForCausalLM | |
| from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast | |
| try: | |
| from transformers.models.llama.modeling_llama import ( | |
| LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaMLP, | |
| LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv | |
| ) | |
| except ImportError: | |
| LlamaAttention = LlamaDecoderLayer = LlamaRMSNorm = LlamaMLP = None | |
| LlamaRotaryEmbedding = apply_rotary_pos_emb = repeat_kv = None | |
| try: | |
| from .configuration_xoron import XoronConfig | |
| except ImportError: | |
| from configuration_xoron import XoronConfig | |
| logger = logging.getLogger(__name__) | |
| ============================================================================== | |
| MODELS.COMPONENTS.LORA | |
| ============================================================================== | |
| class LoRALinear (nn .Module ): | |
| """ | |
| SOTA LoRA layer with multiple variants. | |
| Supports: | |
| - Standard LoRA | |
| - DoRA (Weight-Decomposed LoRA) | |
| - rsLoRA (rank-stabilized scaling) | |
| MEMORY OPTIMIZATION: | |
| - Does NOT clone base weights - shares them with original module | |
| - Only LoRA params (A, B, magnitude) consume additional memory | |
| - Base weights are frozen and can be kept in lower precision | |
| """ | |
| def __init__ ( | |
| self , | |
| in_features :int , | |
| out_features :int , | |
| r :int =8 , | |
| lora_alpha :int =16 , | |
| lora_dropout :float =0.05 , | |
| merge_weights :bool =False , | |
| use_dora :bool =False , | |
| use_rslora :bool =True , | |
| base_layer :nn .Linear =None , | |
| ): | |
| super ().__init__ () | |
| self .r =r | |
| self .lora_alpha =lora_alpha | |
| self .merge_weights =merge_weights | |
| self .merged =False | |
| self .use_dora =use_dora | |
| self .use_rslora =use_rslora | |
| self .in_features =in_features | |
| self .out_features =out_features | |
| if base_layer is not None : | |
| self .linear =base_layer | |
| else : | |
| self .linear =nn .Linear (in_features ,out_features ,bias =False ) | |
| if r >0 : | |
| self .lora_A =nn .Parameter (torch .zeros (r ,in_features )) | |
| self .lora_B =nn .Parameter (torch .zeros (out_features ,r )) | |
| if use_rslora : | |
| self .scaling =lora_alpha /math .sqrt (r ) | |
| else : | |
| self .scaling =lora_alpha /r | |
| self .lora_dropout =nn .Dropout (p =lora_dropout )if lora_dropout >0 else nn .Identity () | |
| nn .init .kaiming_uniform_ (self .lora_A ,a =math .sqrt (5 )) | |
| nn .init .zeros_ (self .lora_B ) | |
| if use_dora : | |
| self .magnitude =nn .Parameter (torch .ones (out_features )) | |
| self .linear .weight .requires_grad =False | |
| if hasattr (self .linear ,'bias')and self .linear .bias is not None : | |
| self .linear .bias .requires_grad =False | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| if self .r >0 and not self .merged : | |
| lora_out =self .lora_dropout (x )@self .lora_A .T @self .lora_B .T *self .scaling | |
| if self .use_dora : | |
| weight =self .linear .weight +(self .lora_B @self .lora_A )*self .scaling | |
| weight_norm =weight .norm (dim =1 ,keepdim =True ) | |
| weight_normalized =weight /(weight_norm +1e-6 ) | |
| result =F .linear (x ,weight_normalized *self .magnitude .unsqueeze (1 )) | |
| else : | |
| result =self .linear (x )+lora_out | |
| else : | |
| result =self .linear (x ) | |
| return result | |
| def merge_lora_weights (self ): | |
| """Merge LoRA weights into the main weights for inference.""" | |
| if self .r >0 and not self .merged : | |
| delta =(self .lora_B @self .lora_A )*self .scaling | |
| if self .use_dora : | |
| weight =self .linear .weight +delta | |
| weight_norm =weight .norm (dim =1 ,keepdim =True ) | |
| self .linear .weight .data =(weight /(weight_norm +1e-6 ))*self .magnitude .unsqueeze (1 ) | |
| else : | |
| self .linear .weight .data +=delta | |
| self .merged =True | |
| def unmerge_lora_weights (self ): | |
| """Unmerge LoRA weights for continued training.""" | |
| if self .r >0 and self .merged : | |
| self .linear .weight .data -=(self .lora_B @self .lora_A )*self .scaling | |
| self .merged =False | |
| class LoRAConfig : | |
| """ | |
| Configuration for SOTA LoRA adaptation. | |
| Supports multiple LoRA variants and configurations. | |
| """ | |
| def __init__ ( | |
| self , | |
| r :int =8 , | |
| lora_alpha :int =16 , | |
| lora_dropout :float =0.05 , | |
| target_modules :Optional [List [str ]]=None , | |
| enable_lora :bool =True , | |
| use_dora :bool =False , | |
| use_rslora :bool =True , | |
| lora_plus_lr_ratio :float =16.0 , | |
| ): | |
| self .r =r | |
| self .lora_alpha =lora_alpha | |
| self .lora_dropout =lora_dropout | |
| self .target_modules =target_modules or [ | |
| 'q_proj','k_proj','v_proj','o_proj', | |
| 'gate_proj','up_proj','down_proj', | |
| ] | |
| self .enable_lora =enable_lora | |
| self .use_dora =use_dora | |
| self .use_rslora =use_rslora | |
| self .lora_plus_lr_ratio =lora_plus_lr_ratio | |
| def apply_lora_to_model (model :nn .Module ,lora_config :LoRAConfig )->nn .Module : | |
| """ | |
| Apply LoRA to specified modules in a model. | |
| Returns the model with LoRA layers applied. | |
| MEMORY OPTIMIZATION: | |
| - Passes the original nn.Linear layer directly to LoRALinear | |
| - This SHARES weights instead of cloning them (saves ~50% memory for target modules) | |
| - Only LoRA parameters (A, B, magnitude) are newly allocated | |
| For a 16GB model with 30% of weights in target modules: | |
| - Old behavior: Clone ~5GB = 21GB total | |
| - New behavior: Share weights = 16GB + ~50MB LoRA params | |
| """ | |
| if not lora_config .enable_lora : | |
| return model | |
| lora_layers_added =0 | |
| modules_to_replace =[] | |
| total_base_params =0 | |
| for name ,module in model .named_modules (): | |
| if not isinstance (module ,nn .Linear ): | |
| continue | |
| module_name =name .split ('.')[-1 ] | |
| if module_name in lora_config .target_modules : | |
| modules_to_replace .append ((name ,module )) | |
| total_base_params +=module .weight .numel () | |
| for name ,module in modules_to_replace : | |
| parts =name .split ('.') | |
| attr_name =parts [-1 ] | |
| parent_name ='.'.join (parts [:-1 ]) | |
| if parent_name : | |
| parent =model .get_submodule (parent_name ) | |
| else : | |
| parent =model | |
| lora_layer =LoRALinear ( | |
| in_features =module .in_features , | |
| out_features =module .out_features , | |
| r =lora_config .r , | |
| lora_alpha =lora_config .lora_alpha , | |
| lora_dropout =lora_config .lora_dropout , | |
| use_dora =lora_config .use_dora , | |
| use_rslora =lora_config .use_rslora , | |
| base_layer =module , | |
| ) | |
| setattr (parent ,attr_name ,lora_layer ) | |
| lora_layers_added +=1 | |
| lora_params =lora_layers_added *(lora_config .r *(modules_to_replace [0 ][1 ].in_features +modules_to_replace [0 ][1 ].out_features ))if modules_to_replace else 0 | |
| base_mem_saved_mb =(total_base_params *2 )/(1024 *1024 ) | |
| lora_mem_added_mb =(lora_params *4 )/(1024 *1024 ) | |
| variant ="DoRA"if lora_config .use_dora else ("rsLoRA"if lora_config .use_rslora else "LoRA") | |
| print (f"✅ {variant } applied to {lora_layers_added } layers (r={lora_config .r }, alpha={lora_config .lora_alpha })") | |
| print (f" 💾 Memory optimization: {base_mem_saved_mb :.1f}MB base weights SHARED (not cloned)") | |
| print (f" 📊 New LoRA params: ~{lora_mem_added_mb :.1f}MB (trainable)") | |
| return model | |
| def get_lora_parameters (model :nn .Module )->List [nn .Parameter ]: | |
| """ | |
| Get only the LoRA parameters from a model. | |
| NOTE: This does NOT change requires_grad on any parameters! | |
| It simply returns the LoRA params (lora_A, lora_B, magnitude). | |
| Use this when you want to get LoRA params for separate optimizer groups | |
| or for LoRA-only training mode. | |
| """ | |
| lora_params =[] | |
| for name ,param in model .named_parameters (): | |
| if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name : | |
| lora_params .append (param ) | |
| return lora_params | |
| def enable_lora_training (model :nn .Module )->List [nn .Parameter ]: | |
| """ | |
| Enable training for LoRA parameters (ensure requires_grad=True). | |
| Returns list of LoRA parameters. | |
| """ | |
| lora_params =[] | |
| for name ,param in model .named_parameters (): | |
| if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name : | |
| param .requires_grad =True | |
| lora_params .append (param ) | |
| return lora_params | |
| def freeze_non_lora_params (model :nn .Module )->int : | |
| """ | |
| Freeze all non-LoRA parameters and clear their gradients. | |
| USE THIS ONLY FOR LORA-ONLY TRAINING MODE (train_lora_only=True). | |
| For normal training with parallel fine-tuning (LoRA + full weights on | |
| active components), use the model's freeze_components() method instead, | |
| which respects the training mode flags (--text, --video, --image, --voice). | |
| Returns: | |
| Number of frozen parameters | |
| """ | |
| frozen_params =0 | |
| freed_memory =0 | |
| for name ,param in model .named_parameters (): | |
| is_lora ='lora_A'in name or 'lora_B'in name or 'magnitude'in name | |
| if not is_lora : | |
| param .requires_grad =False | |
| frozen_params +=param .numel () | |
| if param .grad is not None : | |
| freed_memory +=param .grad .numel ()*param .grad .element_size () | |
| param .grad =None | |
| print (f" ❄️ Frozen {frozen_params :,} non-LoRA parameters") | |
| if freed_memory >0 : | |
| print (f" 🧹 Freed {freed_memory /(1024 **2 ):.1f}MB of gradient memory") | |
| return frozen_params | |
| def get_lora_plus_param_groups ( | |
| model :nn .Module , | |
| base_lr :float , | |
| lr_ratio :float =16.0 | |
| )->List [Dict ]: | |
| """ | |
| Get parameter groups for LoRA+ training. | |
| LoRA+ uses different learning rates for A and B matrices: | |
| - B matrix: base_lr * lr_ratio (learns faster) | |
| - A matrix: base_lr | |
| This improves convergence and final performance. | |
| """ | |
| lora_a_params =[] | |
| lora_b_params =[] | |
| magnitude_params =[] | |
| other_params =[] | |
| for name ,param in model .named_parameters (): | |
| if not param .requires_grad : | |
| continue | |
| if 'lora_A'in name : | |
| lora_a_params .append (param ) | |
| elif 'lora_B'in name : | |
| lora_b_params .append (param ) | |
| elif 'magnitude'in name : | |
| magnitude_params .append (param ) | |
| else : | |
| other_params .append (param ) | |
| param_groups =[] | |
| if lora_a_params : | |
| param_groups .append ({'params':lora_a_params ,'lr':base_lr ,'name':'lora_A'}) | |
| if lora_b_params : | |
| param_groups .append ({'params':lora_b_params ,'lr':base_lr *lr_ratio ,'name':'lora_B'}) | |
| if magnitude_params : | |
| param_groups .append ({'params':magnitude_params ,'lr':base_lr ,'name':'magnitude'}) | |
| if other_params : | |
| param_groups .append ({'params':other_params ,'lr':base_lr ,'name':'other'}) | |
| return param_groups | |
| def get_trainable_parameters (model :nn .Module ,train_lora_only :bool =False )->List [nn .Parameter ]: | |
| """Get trainable parameters, optionally only LoRA params.""" | |
| if train_lora_only : | |
| return get_lora_parameters (model ) | |
| else : | |
| return [p for p in model .parameters ()if p .requires_grad ] | |
| def count_lora_parameters (model :nn .Module )->Tuple [int ,int ,float ]: | |
| """ | |
| Count LoRA parameters vs total parameters. | |
| Returns: | |
| (lora_params, total_params, percentage) | |
| """ | |
| lora_params =0 | |
| total_params =0 | |
| for name ,param in model .named_parameters (): | |
| total_params +=param .numel () | |
| if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name : | |
| lora_params +=param .numel () | |
| percentage =100.0 *lora_params /total_params if total_params >0 else 0.0 | |
| return lora_params ,total_params ,percentage | |
| ============================================================================== | |
| MODELS.COMPONENTS.ATTENTION | |
| ============================================================================== | |
| logger =logging .getLogger (__name__ ) | |
| def flash_attention_available ()->bool : | |
| """Check if Flash Attention (via SDPA) is available.""" | |
| try : | |
| from torch .nn .functional import scaled_dot_product_attention | |
| return True | |
| except ImportError : | |
| return False | |
| def compute_qk_scale (head_dim :int )->float : | |
| """Compute the Q/K pre-scaling factor for FP16 stability. | |
| By scaling both Q and K by head_dim^-0.25, the product Q@K^T | |
| is effectively scaled by head_dim^-0.5 (the standard attention scaling). | |
| This prevents overflow in FP16 when Q and K have large values. | |
| """ | |
| return head_dim **-0.25 | |
| class AttentionKVCache : | |
| """Pre-allocated KV Cache — static buffer with index-based filling. | |
| Eliminates VRAM fragmentation from torch.cat during autoregressive generation. | |
| Buffer is allocated once at first use and reused via slice assignment. | |
| """ | |
| __slots__ =('key_cache','value_cache','seen_tokens','_max_len') | |
| def __init__ (self ,max_seq_len :int =131072 ): | |
| self .key_cache :torch .Tensor =None | |
| self .value_cache :torch .Tensor =None | |
| self .seen_tokens :int =0 | |
| self ._max_len =max_seq_len | |
| def _allocate (self ,batch :int ,heads :int ,head_dim :int ,device :torch .device ,dtype :torch .dtype ): | |
| """Allocate static buffer on first use.""" | |
| self .key_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) | |
| self .value_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) | |
| def update ( | |
| self , | |
| key_states :torch .Tensor , | |
| value_states :torch .Tensor , | |
| )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Update cache with new key/value states using index-based filling. | |
| Args: | |
| key_states: New key states [batch, num_heads, seq_len, head_dim] | |
| value_states: New value states [batch, num_heads, seq_len, head_dim] | |
| Returns: | |
| Updated key and value states including cache (views, no copy) | |
| """ | |
| batch ,heads ,new_len ,head_dim =key_states .shape | |
| if self .key_cache is None : | |
| self ._allocate (batch ,heads ,head_dim ,key_states .device ,key_states .dtype ) | |
| self .seen_tokens =0 | |
| if self .seen_tokens +new_len >self .key_cache .shape [2 ]: | |
| new_max =max (self .key_cache .shape [2 ]*2 ,self .seen_tokens +new_len ) | |
| new_key =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) | |
| new_val =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) | |
| new_key [:,:,:self .seen_tokens ]=self .key_cache [:,:,:self .seen_tokens ] | |
| new_val [:,:,:self .seen_tokens ]=self .value_cache [:,:,:self .seen_tokens ] | |
| self .key_cache =new_key | |
| self .value_cache =new_val | |
| self .key_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=key_states | |
| self .value_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=value_states | |
| self .seen_tokens +=new_len | |
| return self .key_cache [:,:,:self .seen_tokens ],self .value_cache [:,:,:self .seen_tokens ] | |
| def get_seq_length (self )->int : | |
| """Get current sequence length in cache.""" | |
| return self .seen_tokens | |
| def reset (self ): | |
| """Reset cache position without deallocating the buffer.""" | |
| self .seen_tokens =0 | |
| class FlashAttention (nn .Module ): | |
| """ | |
| SOTA Flash Attention with KV cache support and FP16-safe Q/K pre-scaling. | |
| Uses PyTorch's scaled_dot_product_attention when available, | |
| with fallback to standard attention. Supports: | |
| - KV caching for efficient generation | |
| - Causal masking | |
| - Attention dropout | |
| - Pre-scaled Q/K for FP16 stability | |
| """ | |
| def __init__ ( | |
| self , | |
| dropout :float =0.0 , | |
| causal :bool =False , | |
| head_dim :int =None , | |
| ): | |
| super ().__init__ () | |
| self .dropout =dropout | |
| self .causal =causal | |
| self ._flash_available =flash_attention_available () | |
| self ._head_dim =head_dim | |
| self ._qk_scale =compute_qk_scale (head_dim )if head_dim else None | |
| def forward ( | |
| self , | |
| query :torch .Tensor , | |
| key :torch .Tensor , | |
| value :torch .Tensor , | |
| attn_mask :torch .Tensor =None , | |
| is_causal :bool =None , | |
| past_key_value :Tuple [torch .Tensor ,torch .Tensor ]=None , | |
| use_cache :bool =False , | |
| output_attentions :bool =False , | |
| )->Tuple [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ],torch .Tensor ]: | |
| """ | |
| Forward pass with KV cache support. | |
| Args: | |
| query: Query tensor [batch, num_heads, seq_len, head_dim] | |
| key: Key tensor [batch, num_heads, seq_len, head_dim] | |
| value: Value tensor [batch, num_heads, seq_len, head_dim] | |
| attn_mask: Optional attention mask | |
| is_causal: Override causal setting | |
| past_key_value: Optional tuple of (past_key, past_value) for KV cache | |
| use_cache: Whether to return updated KV cache | |
| output_attentions: Whether to return attention weights | |
| Returns: | |
| Tuple of (output, present_key_value, attention_weights) | |
| """ | |
| causal =is_causal if is_causal is not None else self .causal | |
| batch_size ,num_heads ,seq_len ,head_dim =query .shape | |
| qk_scale =self ._qk_scale if self ._qk_scale else compute_qk_scale (head_dim ) | |
| if past_key_value is not None : | |
| past_key ,past_value =past_key_value | |
| key =torch .cat ([past_key ,key ],dim =2 ) | |
| value =torch .cat ([past_value ,value ],dim =2 ) | |
| present_key_value =(key ,value )if use_cache else None | |
| kv_seq_len =key .shape [2 ] | |
| attn_weights =None | |
| if self ._flash_available and not output_attentions : | |
| query_scaled =query *qk_scale | |
| key_scaled =key *qk_scale | |
| dropout_p =self .dropout if self .training else 0.0 | |
| use_causal =causal and attn_mask is None and seq_len >1 and seq_len ==kv_seq_len | |
| output =F .scaled_dot_product_attention ( | |
| query_scaled ,key_scaled ,value , | |
| attn_mask =attn_mask , | |
| dropout_p =dropout_p , | |
| is_causal =use_causal , | |
| scale =1.0 , | |
| ) | |
| else : | |
| scale =1.0 /math .sqrt (head_dim ) | |
| attn_weights =torch .matmul (query ,key .transpose (-2 ,-1 ))*scale | |
| if causal and attn_mask is None and seq_len >1 : | |
| causal_mask =torch .triu ( | |
| torch .full ((seq_len ,kv_seq_len ),float ('-inf'),device =query .device ,dtype =query .dtype ), | |
| diagonal =kv_seq_len -seq_len +1 | |
| ) | |
| attn_weights =attn_weights +causal_mask .unsqueeze (0 ).unsqueeze (0 ) | |
| if attn_mask is not None : | |
| attn_weights =attn_weights +attn_mask | |
| attn_weights =F .softmax (attn_weights ,dim =-1 ,dtype =query .dtype ) | |
| if self .training and self .dropout >0 : | |
| attn_weights =F .dropout (attn_weights ,p =self .dropout ) | |
| output =torch .matmul (attn_weights ,value ) | |
| return output ,present_key_value ,attn_weights | |
| class MultimodalCrossAttention (nn .Module ): | |
| """ | |
| SOTA Cross-attention layer for multimodal fusion with KV cache support. | |
| Allows text to attend to image/video/audio features with: | |
| - KV caching for efficient generation | |
| - Gated residual connection for stable training | |
| - Flash Attention support with pre-scaled Q/K for FP16 stability | |
| - Optional attention weight output | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| num_heads :int =8 , | |
| dropout :float =0.1 , | |
| use_flash_attention :bool =True , | |
| gate_init :float =0.0 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .head_dim =hidden_size //num_heads | |
| self .use_flash_attention =use_flash_attention and flash_attention_available () | |
| self .dropout_p =dropout | |
| self .qk_scale =compute_qk_scale (self .head_dim ) | |
| self .q_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .k_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .v_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .o_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .dropout =nn .Dropout (dropout ) | |
| self .layer_norm =nn .LayerNorm (hidden_size ) | |
| self .gate =nn .Parameter (torch .tensor (gate_init )) | |
| def forward ( | |
| self , | |
| text_hidden :torch .Tensor , | |
| modality_hidden :torch .Tensor , | |
| modality_mask :torch .Tensor =None , | |
| past_key_value :Tuple [torch .Tensor ,torch .Tensor ]=None , | |
| use_cache :bool =False , | |
| output_attentions :bool =False , | |
| )->Tuple [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ],torch .Tensor ]: | |
| """ | |
| Cross-attention: text attends to modality features with KV cache support. | |
| Args: | |
| text_hidden: Text hidden states [batch, text_len, hidden_size] | |
| modality_hidden: Modality features [batch, modality_len, hidden_size] | |
| modality_mask: Optional attention mask for modality | |
| past_key_value: Optional cached (key, value) for this layer | |
| use_cache: Whether to return updated KV cache | |
| output_attentions: Whether to return attention weights | |
| Returns: | |
| Tuple of (output, present_key_value, attention_weights) | |
| """ | |
| batch_size ,text_len ,_ =text_hidden .shape | |
| query =self .q_proj (text_hidden ) | |
| query =query .view (batch_size ,text_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) | |
| if past_key_value is not None : | |
| key ,value =past_key_value | |
| else : | |
| modality_len =modality_hidden .shape [1 ] | |
| key =self .k_proj (modality_hidden ) | |
| value =self .v_proj (modality_hidden ) | |
| key =key .view (batch_size ,modality_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) | |
| value =value .view (batch_size ,modality_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) | |
| present_key_value =(key ,value )if use_cache else None | |
| attn_weights =None | |
| if self .use_flash_attention and not output_attentions : | |
| query_scaled =query *self .qk_scale | |
| key_scaled =key *self .qk_scale | |
| dropout_p =self .dropout_p if self .training else 0.0 | |
| attn_output =F .scaled_dot_product_attention ( | |
| query_scaled ,key_scaled ,value , | |
| attn_mask =modality_mask , | |
| dropout_p =dropout_p , | |
| is_causal =False , | |
| scale =1.0 , | |
| ) | |
| else : | |
| scale =1.0 /math .sqrt (self .head_dim ) | |
| attn_weights =torch .matmul (query ,key .transpose (-2 ,-1 ))*scale | |
| if modality_mask is not None : | |
| attn_weights =attn_weights +modality_mask | |
| attn_weights =F .softmax (attn_weights ,dim =-1 ,dtype =text_hidden .dtype ) | |
| if self .training and self .dropout_p >0 : | |
| attn_weights =F .dropout (attn_weights ,p =self .dropout_p ) | |
| attn_output =torch .matmul (attn_weights ,value ) | |
| attn_output =attn_output .transpose (1 ,2 ).contiguous ().view (batch_size ,text_len ,self .hidden_size ) | |
| attn_output =self .o_proj (attn_output ) | |
| gate =torch .sigmoid (self .gate ) | |
| output =text_hidden +gate *self .dropout (attn_output ) | |
| output =self .layer_norm (output ) | |
| return output ,present_key_value ,attn_weights | |
| class MultimodalFusionCache : | |
| """Cache for multimodal fusion layer KV states.""" | |
| image_kv :Tuple [torch .Tensor ,torch .Tensor ]=None | |
| video_kv :Tuple [torch .Tensor ,torch .Tensor ]=None | |
| audio_kv :Tuple [torch .Tensor ,torch .Tensor ]=None | |
| class MultimodalFusionLayer (nn .Module ): | |
| """ | |
| SOTA Multimodal fusion layer with cross-attention for all modalities and KV cache support. | |
| Features: | |
| - Separate cross-attention for each modality (image, video, audio) | |
| - KV caching for efficient generation | |
| - Gated fusion MLP | |
| - Flash Attention support | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| num_heads :int =8 , | |
| dropout :float =0.1 , | |
| use_flash_attention :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .image_cross_attn =MultimodalCrossAttention ( | |
| hidden_size ,num_heads ,dropout ,use_flash_attention | |
| ) | |
| self .video_cross_attn =MultimodalCrossAttention ( | |
| hidden_size ,num_heads ,dropout ,use_flash_attention | |
| ) | |
| self .audio_cross_attn =MultimodalCrossAttention ( | |
| hidden_size ,num_heads ,dropout ,use_flash_attention | |
| ) | |
| self .fusion_mlp =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size *4 ), | |
| nn .GELU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (hidden_size *4 ,hidden_size ), | |
| nn .Dropout (dropout ), | |
| ) | |
| self .fusion_norm =nn .LayerNorm (hidden_size ) | |
| def forward ( | |
| self , | |
| text_hidden :torch .Tensor , | |
| image_hidden :torch .Tensor =None , | |
| video_hidden :torch .Tensor =None , | |
| audio_hidden :torch .Tensor =None , | |
| image_mask :torch .Tensor =None , | |
| video_mask :torch .Tensor =None , | |
| audio_mask :torch .Tensor =None , | |
| past_key_values :MultimodalFusionCache =None , | |
| use_cache :bool =False , | |
| )->Tuple [torch .Tensor ,MultimodalFusionCache ]: | |
| """ | |
| Fuse text with available modalities via cross-attention with KV cache support. | |
| Args: | |
| text_hidden: Text hidden states [batch, text_len, hidden_size] | |
| image_hidden: Image features [batch, image_len, hidden_size] | |
| video_hidden: Video features [batch, video_len, hidden_size] | |
| audio_hidden: Audio features [batch, audio_len, hidden_size] | |
| image_mask: Attention mask for image | |
| video_mask: Attention mask for video | |
| audio_mask: Attention mask for audio | |
| past_key_values: Cached KV states from previous forward pass | |
| use_cache: Whether to return updated KV cache | |
| Returns: | |
| Tuple of (output, present_key_values) | |
| """ | |
| present_key_values =MultimodalFusionCache ()if use_cache else None | |
| past_image_kv =past_key_values .image_kv if past_key_values else None | |
| past_video_kv =past_key_values .video_kv if past_key_values else None | |
| past_audio_kv =past_key_values .audio_kv if past_key_values else None | |
| if self ._has_content (image_hidden )or past_image_kv is not None : | |
| try : | |
| text_hidden ,image_kv ,_ =self .image_cross_attn ( | |
| text_hidden , | |
| image_hidden if image_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ), | |
| image_mask , | |
| past_key_value =past_image_kv , | |
| use_cache =use_cache , | |
| ) | |
| if use_cache : | |
| present_key_values .image_kv =image_kv | |
| except Exception as e : | |
| logger .debug (f"Image cross-attention skipped: {e }") | |
| if self ._has_content (video_hidden )or past_video_kv is not None : | |
| try : | |
| text_hidden ,video_kv ,_ =self .video_cross_attn ( | |
| text_hidden , | |
| video_hidden if video_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ), | |
| video_mask , | |
| past_key_value =past_video_kv , | |
| use_cache =use_cache , | |
| ) | |
| if use_cache : | |
| present_key_values .video_kv =video_kv | |
| except Exception as e : | |
| logger .debug (f"Video cross-attention skipped: {e }") | |
| if self ._has_content (audio_hidden )or past_audio_kv is not None : | |
| try : | |
| text_hidden ,audio_kv ,_ =self .audio_cross_attn ( | |
| text_hidden , | |
| audio_hidden if audio_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ), | |
| audio_mask , | |
| past_key_value =past_audio_kv , | |
| use_cache =use_cache , | |
| ) | |
| if use_cache : | |
| present_key_values .audio_kv =audio_kv | |
| except Exception as e : | |
| logger .debug (f"Audio cross-attention skipped: {e }") | |
| residual =text_hidden | |
| text_hidden =self .fusion_mlp (text_hidden ) | |
| text_hidden =self .fusion_norm (residual +text_hidden ) | |
| return text_hidden ,present_key_values | |
| def _has_content (tensor :torch .Tensor )->bool : | |
| """Check if tensor has meaningful content.""" | |
| if tensor is None : | |
| return False | |
| if not isinstance (tensor ,torch .Tensor ): | |
| return False | |
| try : | |
| if tensor .numel ()==0 : | |
| return False | |
| return bool (tensor .any ()) | |
| except Exception : | |
| return False | |
| ============================================================================== | |
| MODELS.COMPONENTS.PROJECTORS | |
| ============================================================================== | |
| def compute_2d_rope (height :int ,width :int ,dim :int ,device :torch .device ,dtype :torch .dtype ,base :float =10000.0 )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Compute 2D Rotary Position Embeddings for spatial awareness. | |
| Args: | |
| height: Image height in patches | |
| width: Image width in patches | |
| dim: Embedding dimension (must be divisible by 4) | |
| device: Target device | |
| dtype: Target dtype | |
| base: RoPE base frequency | |
| Returns: | |
| cos, sin: [height*width, dim] position embeddings | |
| """ | |
| assert dim %4 ==0 ,"dim must be divisible by 4 for 2D RoPE" | |
| half_dim =dim //2 | |
| quarter_dim =dim //4 | |
| inv_freq =1.0 /(base **(torch .arange (0 ,quarter_dim ,device =device ,dtype =torch .float32 )/quarter_dim )) | |
| y_pos =torch .arange (height ,device =device ,dtype =torch .float32 ) | |
| x_pos =torch .arange (width ,device =device ,dtype =torch .float32 ) | |
| y_emb =torch .outer (y_pos ,inv_freq ) | |
| x_emb =torch .outer (x_pos ,inv_freq ) | |
| y_emb =y_emb .unsqueeze (1 ).expand (-1 ,width ,-1 ) | |
| x_emb =x_emb .unsqueeze (0 ).expand (height ,-1 ,-1 ) | |
| emb =torch .cat ([y_emb ,y_emb ,x_emb ,x_emb ],dim =-1 ) | |
| emb =emb .reshape (height *width ,dim ) | |
| return emb .cos ().to (dtype ),emb .sin ().to (dtype ) | |
| def compute_3d_rope ( | |
| depth :int ,height :int ,width :int ,dim :int , | |
| device :torch .device ,dtype :torch .dtype ,base :float =10000.0 | |
| )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Compute 3D Rotary Position Embeddings for video/temporal awareness. | |
| Args: | |
| depth: Temporal depth (number of frames) | |
| height: Image height in patches | |
| width: Image width in patches | |
| dim: Embedding dimension (must be divisible by 6) | |
| device: Target device | |
| dtype: Target dtype | |
| base: RoPE base frequency | |
| Returns: | |
| cos, sin: [depth*height*width, dim] position embeddings | |
| """ | |
| assert dim %6 ==0 ,"dim must be divisible by 6 for 3D RoPE" | |
| sixth_dim =dim //6 | |
| inv_freq =1.0 /(base **(torch .arange (0 ,sixth_dim ,device =device ,dtype =torch .float32 )/sixth_dim )) | |
| t_pos =torch .arange (depth ,device =device ,dtype =torch .float32 ) | |
| y_pos =torch .arange (height ,device =device ,dtype =torch .float32 ) | |
| x_pos =torch .arange (width ,device =device ,dtype =torch .float32 ) | |
| t_emb =torch .outer (t_pos ,inv_freq ) | |
| y_emb =torch .outer (y_pos ,inv_freq ) | |
| x_emb =torch .outer (x_pos ,inv_freq ) | |
| t_emb =t_emb .unsqueeze (1 ).unsqueeze (2 ).expand (-1 ,height ,width ,-1 ) | |
| y_emb =y_emb .unsqueeze (0 ).unsqueeze (2 ).expand (depth ,-1 ,width ,-1 ) | |
| x_emb =x_emb .unsqueeze (0 ).unsqueeze (1 ).expand (depth ,height ,-1 ,-1 ) | |
| emb =torch .cat ([t_emb ,t_emb ,y_emb ,y_emb ,x_emb ,x_emb ],dim =-1 ) | |
| emb =emb .reshape (depth *height *width ,dim ) | |
| return emb .cos ().to (dtype ),emb .sin ().to (dtype ) | |
| def apply_rope (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : | |
| """Apply rotary position embeddings.""" | |
| x1 =x [...,:x .shape [-1 ]//2 ] | |
| x2 =x [...,x .shape [-1 ]//2 :] | |
| rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) | |
| return x *cos +rotated *sin | |
| class ResidualBottleneckBlock (nn .Module ): | |
| """ | |
| Residual Bottleneck Block for locality-enhanced feature extraction. | |
| Preserves small-scale features (OCR, fine audio events) during compression. | |
| """ | |
| def __init__ (self ,in_channels :int ,out_channels :int ,bottleneck_ratio :float =0.25 ): | |
| super ().__init__ () | |
| bottleneck_channels =int (out_channels *bottleneck_ratio ) | |
| self .conv1 =nn .Conv2d (in_channels ,bottleneck_channels ,1 ,bias =False ) | |
| self .bn1 =nn .BatchNorm2d (bottleneck_channels ) | |
| self .conv2 =nn .Conv2d (bottleneck_channels ,bottleneck_channels ,3 ,padding =1 ,bias =False ) | |
| self .bn2 =nn .BatchNorm2d (bottleneck_channels ) | |
| self .conv3 =nn .Conv2d (bottleneck_channels ,out_channels ,1 ,bias =False ) | |
| self .bn3 =nn .BatchNorm2d (out_channels ) | |
| self .shortcut =nn .Identity ()if in_channels ==out_channels else nn .Sequential ( | |
| nn .Conv2d (in_channels ,out_channels ,1 ,bias =False ), | |
| nn .BatchNorm2d (out_channels ), | |
| ) | |
| self .relu =nn .ReLU (inplace =True ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| identity =self .shortcut (x ) | |
| out =self .relu (self .bn1 (self .conv1 (x ))) | |
| out =self .relu (self .bn2 (self .conv2 (out ))) | |
| out =self .bn3 (self .conv3 (out )) | |
| out =out +identity | |
| out =self .relu (out ) | |
| return out | |
| class LocalityEnhancedResNetAbstractor (nn .Module ): | |
| """ | |
| Locality-Enhanced ResNet Abstractor. | |
| Upgrades the C-Abstractor with residual bottleneck blocks to preserve | |
| small-scale features (OCR/fine audio events) during compression. | |
| """ | |
| def __init__ ( | |
| self , | |
| input_dim :int , | |
| output_dim :int , | |
| num_tokens :int =64 , | |
| num_blocks :int =3 , | |
| use_2d_rope :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .num_tokens =num_tokens | |
| self .use_2d_rope =use_2d_rope | |
| self .input_proj =nn .Linear (input_dim ,output_dim ) | |
| self .blocks =nn .ModuleList ([ | |
| ResidualBottleneckBlock (output_dim ,output_dim ) | |
| for _ in range (num_blocks ) | |
| ]) | |
| self .queries =nn .Parameter (torch .randn (1 ,num_tokens ,output_dim )*0.02 ) | |
| self .cross_attn =nn .MultiheadAttention ( | |
| embed_dim =output_dim , | |
| num_heads =8 , | |
| batch_first =True , | |
| dropout =0.1 , | |
| ) | |
| self .ff =nn .Sequential ( | |
| nn .LayerNorm (output_dim ), | |
| nn .Linear (output_dim ,output_dim *4 ), | |
| nn .GELU (), | |
| nn .Linear (output_dim *4 ,output_dim ), | |
| ) | |
| self .norm =nn .LayerNorm (output_dim ) | |
| print (f" 🏗️ LocalityEnhancedResNetAbstractor: {input_dim } -> {output_dim }, {num_tokens } tokens") | |
| def forward (self ,features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor : | |
| """ | |
| Args: | |
| features: [B, seq_len, input_dim] or [B, H, W, input_dim] | |
| spatial_size: (H, W) if features are flattened | |
| Returns: | |
| abstracted: [B, num_tokens, output_dim] | |
| """ | |
| batch_size =features .shape [0 ] | |
| x =self .input_proj (features ) | |
| if features .dim ()==3 : | |
| seq_len =features .shape [1 ] | |
| if spatial_size is None : | |
| h =w =int (math .sqrt (seq_len )) | |
| else : | |
| h ,w =spatial_size | |
| x =x .view (batch_size ,h ,w ,-1 ) | |
| else : | |
| h ,w =features .shape [1 ],features .shape [2 ] | |
| x =x .permute (0 ,3 ,1 ,2 ) | |
| for block in self .blocks : | |
| x =block (x ) | |
| x =x .permute (0 ,2 ,3 ,1 ) | |
| x =x .reshape (batch_size ,h *w ,-1 ) | |
| if self .use_2d_rope : | |
| cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype ) | |
| x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 )) | |
| queries =self .queries .expand (batch_size ,-1 ,-1 ) | |
| abstracted ,_ =self .cross_attn (queries ,x ,x ) | |
| abstracted =abstracted +self .ff (abstracted ) | |
| return self .norm (abstracted ) | |
| class MultiScaleFeatureFusion (nn .Module ): | |
| """ | |
| Multi-Scale Feature Fusion (MSFF). | |
| Extracts and weights features from multiple encoder depths (early, mid, late) | |
| to capture both low-level textures and high-level semantics. | |
| """ | |
| def __init__ ( | |
| self , | |
| feature_dims :List [int ], | |
| output_dim :int , | |
| num_scales :int =3 , | |
| ): | |
| super ().__init__ () | |
| self .num_scales =num_scales | |
| self .scale_projs =nn .ModuleList ([ | |
| nn .Linear (dim ,output_dim )for dim in feature_dims | |
| ]) | |
| self .scale_weights =nn .Parameter (torch .ones (num_scales )/num_scales ) | |
| self .fusion =nn .Sequential ( | |
| nn .Linear (output_dim ,output_dim *2 ), | |
| nn .GELU (), | |
| nn .Linear (output_dim *2 ,output_dim ), | |
| ) | |
| self .norm =nn .LayerNorm (output_dim ) | |
| print (f" 🔀 MultiScaleFeatureFusion: {feature_dims } -> {output_dim }") | |
| def forward (self ,multi_scale_features :List [torch .Tensor ])->torch .Tensor : | |
| """ | |
| Args: | |
| multi_scale_features: List of [B, seq_len, dim] features from different depths | |
| Returns: | |
| fused: [B, seq_len, output_dim] | |
| """ | |
| assert len (multi_scale_features )==self .num_scales | |
| projected =[] | |
| for i ,(features ,proj )in enumerate (zip (multi_scale_features ,self .scale_projs )): | |
| projected .append (proj (features )) | |
| weights =F .softmax (self .scale_weights ,dim =0 ) | |
| fused =sum (w *p for w ,p in zip (weights ,projected )) | |
| fused =fused +self .fusion (fused ) | |
| return self .norm (fused ) | |
| class MultiScaleDeformableAttention (nn .Module ): | |
| """ | |
| Multi-Scale Deformable Attention. | |
| Replaces fixed-grid cross-attention in Perceiver Resamplers, | |
| allowing the projector to "look" at non-uniform regions of interest. | |
| """ | |
| def __init__ ( | |
| self , | |
| dim :int , | |
| num_heads :int =8 , | |
| num_levels :int =4 , | |
| num_points :int =4 , | |
| dropout :float =0.1 , | |
| ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .num_heads =num_heads | |
| self .num_levels =num_levels | |
| self .num_points =num_points | |
| self .head_dim =dim //num_heads | |
| self .sampling_offsets =nn .Linear (dim ,num_heads *num_levels *num_points *2 ) | |
| self .attention_weights =nn .Linear (dim ,num_heads *num_levels *num_points ) | |
| self .value_proj =nn .Linear (dim ,dim ) | |
| self .output_proj =nn .Linear (dim ,dim ) | |
| self .dropout =nn .Dropout (dropout ) | |
| self ._reset_parameters () | |
| print (f" 🎯 MultiScaleDeformableAttention: {dim }d, {num_heads }H, {num_levels }L, {num_points }P") | |
| def _reset_parameters (self ): | |
| nn .init .constant_ (self .sampling_offsets .weight ,0.0 ) | |
| nn .init .constant_ (self .sampling_offsets .bias ,0.0 ) | |
| nn .init .xavier_uniform_ (self .attention_weights .weight ) | |
| nn .init .constant_ (self .attention_weights .bias ,0.0 ) | |
| nn .init .xavier_uniform_ (self .value_proj .weight ) | |
| nn .init .xavier_uniform_ (self .output_proj .weight ) | |
| def forward ( | |
| self , | |
| query :torch .Tensor , | |
| reference_points :torch .Tensor , | |
| input_flatten :torch .Tensor , | |
| input_spatial_shapes :torch .Tensor , | |
| )->torch .Tensor : | |
| """ | |
| Args: | |
| query: [B, num_queries, dim] | |
| reference_points: [B, num_queries, num_levels, 2] normalized reference points | |
| input_flatten: [B, sum(H*W), dim] flattened multi-scale features | |
| input_spatial_shapes: [num_levels, 2] spatial shapes of each level | |
| Returns: | |
| output: [B, num_queries, dim] | |
| """ | |
| batch_size ,num_queries ,_ =query .shape | |
| offsets =self .sampling_offsets (query ) | |
| offsets =offsets .view (batch_size ,num_queries ,self .num_heads ,self .num_levels ,self .num_points ,2 ) | |
| attn_weights =self .attention_weights (query ) | |
| attn_weights =attn_weights .view (batch_size ,num_queries ,self .num_heads ,self .num_levels *self .num_points ) | |
| attn_weights =F .softmax (attn_weights ,dim =-1 ) | |
| attn_weights =attn_weights .view (batch_size ,num_queries ,self .num_heads ,self .num_levels ,self .num_points ) | |
| sampling_locations =reference_points .unsqueeze (2 ).unsqueeze (4 )+offsets *0.1 | |
| sampling_locations =sampling_locations .clamp (0 ,1 ) | |
| value =self .value_proj (input_flatten ) | |
| value =value .view (batch_size ,-1 ,self .num_heads ,self .head_dim ) | |
| output =torch .zeros (batch_size ,num_queries ,self .num_heads ,self .head_dim ,device =query .device ,dtype =query .dtype ) | |
| start_idx =0 | |
| for level_idx in range (self .num_levels ): | |
| h ,w =input_spatial_shapes [level_idx ] | |
| end_idx =start_idx +h *w | |
| level_value =value [:,start_idx :end_idx ] | |
| level_value =level_value .view (batch_size ,h ,w ,self .num_heads ,self .head_dim ) | |
| level_locs =sampling_locations [:,:,:,level_idx ] | |
| level_weights =attn_weights [:,:,:,level_idx ] | |
| for point_idx in range (self .num_points ): | |
| loc =level_locs [:,:,:,point_idx ] | |
| weight =level_weights [:,:,:,point_idx :point_idx +1 ] | |
| y_idx =(loc [...,0 ]*(h -1 )).long ().clamp (0 ,h -1 ) | |
| x_idx =(loc [...,1 ]*(w -1 )).long ().clamp (0 ,w -1 ) | |
| for b in range (batch_size ): | |
| for q in range (num_queries ): | |
| for head in range (self .num_heads ): | |
| y ,x =y_idx [b ,q ,head ].item (),x_idx [b ,q ,head ].item () | |
| output [b ,q ,head ]+=weight [b ,q ,head ]*level_value [b ,y ,x ,head ] | |
| start_idx =end_idx | |
| output =output .view (batch_size ,num_queries ,self .dim ) | |
| output =self .output_proj (output ) | |
| output =self .dropout (output ) | |
| return output | |
| class DynamicTokenRouter (nn .Module ): | |
| """ | |
| Dynamic Token Router. | |
| Implements a sparse gating mechanism to drop redundant "background" tokens, | |
| drastically reducing KV-cache pressure for Ring Attention. | |
| """ | |
| def __init__ ( | |
| self , | |
| dim :int , | |
| num_tokens :int , | |
| keep_ratio :float =0.5 , | |
| temperature :float =1.0 , | |
| ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .num_tokens =num_tokens | |
| self .keep_ratio =keep_ratio | |
| self .temperature =temperature | |
| self .scorer =nn .Sequential ( | |
| nn .Linear (dim ,dim //2 ), | |
| nn .GELU (), | |
| nn .Linear (dim //2 ,1 ), | |
| ) | |
| self .threshold =nn .Parameter (torch .tensor (0.0 )) | |
| print (f" 🚦 DynamicTokenRouter: keep_ratio={keep_ratio }") | |
| def forward (self ,tokens :torch .Tensor ,return_mask :bool =False )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]: | |
| """ | |
| Args: | |
| tokens: [B, num_tokens, dim] | |
| return_mask: Whether to return the selection mask | |
| Returns: | |
| selected_tokens: [B, num_kept, dim] | |
| mask: [B, num_tokens] selection mask (if return_mask=True) | |
| """ | |
| batch_size ,num_tokens ,_ =tokens .shape | |
| num_keep =max (1 ,int (num_tokens *self .keep_ratio )) | |
| scores =self .scorer (tokens ).squeeze (-1 ) | |
| scores =scores /self .temperature | |
| _ ,indices =torch .topk (scores ,num_keep ,dim =-1 ) | |
| indices =indices .sort (dim =-1 ).values | |
| indices_expanded =indices .unsqueeze (-1 ).expand (-1 ,-1 ,self .dim ) | |
| selected_tokens =torch .gather (tokens ,1 ,indices_expanded ) | |
| if return_mask : | |
| mask =torch .zeros (batch_size ,num_tokens ,device =tokens .device ,dtype =torch .bool ) | |
| mask .scatter_ (1 ,indices ,True ) | |
| return selected_tokens ,mask | |
| return selected_tokens ,None | |
| class PerceiverAttention (nn .Module ): | |
| """ | |
| Perceiver-style cross-attention for resampling with 2D/3D RoPE support. | |
| """ | |
| def __init__ ( | |
| self , | |
| dim :int , | |
| num_heads :int =8 , | |
| dim_head :int =64 , | |
| dropout :float =0.0 , | |
| use_rope :bool =True , | |
| ): | |
| super ().__init__ () | |
| inner_dim =dim_head *num_heads | |
| self .num_heads =num_heads | |
| self .dim_head =dim_head | |
| self .inner_dim =inner_dim | |
| self .scale =dim_head **-0.5 | |
| self .use_rope =use_rope | |
| self .norm_latents =nn .LayerNorm (dim ) | |
| self .norm_context =nn .LayerNorm (dim ) | |
| self .to_q =nn .Linear (dim ,inner_dim ,bias =False ) | |
| self .to_kv =nn .Linear (dim ,inner_dim *2 ,bias =False ) | |
| self .to_out =nn .Sequential ( | |
| nn .Linear (inner_dim ,dim ), | |
| nn .Dropout (dropout ) | |
| ) | |
| def forward ( | |
| self , | |
| latents :torch .Tensor , | |
| context :torch .Tensor , | |
| context_rope :Optional [Tuple [torch .Tensor ,torch .Tensor ]]=None , | |
| )->torch .Tensor : | |
| """ | |
| latents: [B, num_latents, dim] - learnable queries | |
| context: [B, seq_len, dim] - input features to attend to | |
| context_rope: Optional (cos, sin) for context positions | |
| """ | |
| latents =self .norm_latents (latents ) | |
| context =self .norm_context (context ) | |
| b ,n ,_ =latents .shape | |
| ctx_len =context .shape [1 ] | |
| h =self .num_heads | |
| d =self .dim_head | |
| q =self .to_q (latents ) | |
| kv =self .to_kv (context ).chunk (2 ,dim =-1 ) | |
| k ,v =kv | |
| q =q .reshape (b ,n ,h ,d ).transpose (1 ,2 ) | |
| k =k .reshape (b ,ctx_len ,h ,d ).transpose (1 ,2 ) | |
| v =v .reshape (b ,ctx_len ,h ,d ).transpose (1 ,2 ) | |
| if self .use_rope and context_rope is not None : | |
| cos ,sin =context_rope | |
| cos =cos .unsqueeze (0 ).unsqueeze (0 ) | |
| sin =sin .unsqueeze (0 ).unsqueeze (0 ) | |
| k =apply_rope (k ,cos ,sin ) | |
| qk_scale =d **-0.25 | |
| out =F .scaled_dot_product_attention ( | |
| q *qk_scale ,k *qk_scale ,v , | |
| is_causal =False ,scale =1.0 , | |
| ) | |
| out =out .transpose (1 ,2 ).reshape (b ,n ,self .inner_dim ) | |
| return self .to_out (out ) | |
| class PerceiverResampler (nn .Module ): | |
| """ | |
| Perceiver Resampler with 2D/3D RoPE and Dynamic Token Routing. | |
| """ | |
| def __init__ ( | |
| self , | |
| input_dim :int , | |
| output_dim :int , | |
| num_latents :int =64 , | |
| num_heads :int =8 , | |
| num_layers :int =2 , | |
| dropout :float =0.0 , | |
| use_rope :bool =True , | |
| use_dynamic_routing :bool =False , | |
| routing_keep_ratio :float =0.5 , | |
| ): | |
| super ().__init__ () | |
| self .num_latents =num_latents | |
| self .use_rope =use_rope | |
| self .use_dynamic_routing =use_dynamic_routing | |
| self .input_proj =nn .Linear (input_dim ,output_dim )if input_dim !=output_dim else nn .Identity () | |
| self .latents =nn .Parameter (torch .randn (1 ,num_latents ,output_dim )*0.02 ) | |
| self .layers =nn .ModuleList ([ | |
| nn .ModuleList ([ | |
| PerceiverAttention (output_dim ,num_heads ,output_dim //num_heads ,dropout ,use_rope ), | |
| nn .Sequential ( | |
| nn .LayerNorm (output_dim ), | |
| nn .Linear (output_dim ,output_dim *4 ), | |
| nn .GELU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (output_dim *4 ,output_dim ), | |
| nn .Dropout (dropout ), | |
| ) | |
| ]) | |
| for _ in range (num_layers ) | |
| ]) | |
| if use_dynamic_routing : | |
| self .token_router =DynamicTokenRouter (output_dim ,num_latents ,routing_keep_ratio ) | |
| else : | |
| self .token_router =None | |
| self .norm_out =nn .LayerNorm (output_dim ) | |
| def forward ( | |
| self , | |
| x :torch .Tensor , | |
| spatial_size :Optional [Tuple [int ,int ]]=None , | |
| temporal_size :Optional [int ]=None , | |
| )->torch .Tensor : | |
| """ | |
| x: [B, seq_len, input_dim] - input features | |
| spatial_size: (H, W) for 2D RoPE | |
| temporal_size: T for 3D RoPE (video) | |
| returns: [B, num_latents, output_dim] - compressed features | |
| """ | |
| batch_size =x .shape [0 ] | |
| x =self .input_proj (x ) | |
| context_rope =None | |
| if self .use_rope and spatial_size is not None : | |
| h ,w =spatial_size | |
| if temporal_size is not None : | |
| cos ,sin =compute_3d_rope (temporal_size ,h ,w ,x .shape [-1 ],x .device ,x .dtype ) | |
| else : | |
| cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype ) | |
| context_rope =(cos ,sin ) | |
| latents =self .latents .expand (batch_size ,-1 ,-1 ) | |
| for attn ,ff in self .layers : | |
| latents =latents +attn (latents ,x ,context_rope ) | |
| latents =latents +ff (latents ) | |
| latents =self .norm_out (latents ) | |
| if self .token_router is not None : | |
| latents ,_ =self .token_router (latents ) | |
| return latents | |
| class SpatialAwareProjector (nn .Module ): | |
| """ | |
| Spatial-aware projector with 2D RoPE. | |
| """ | |
| def __init__ ( | |
| self , | |
| vision_hidden_size :int , | |
| llm_hidden_size :int , | |
| num_tokens :int =64 , | |
| spatial_pool_size :int =8 , | |
| use_rope :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .num_tokens =num_tokens | |
| self .spatial_pool_size =spatial_pool_size | |
| self .use_rope =use_rope | |
| self .spatial_conv =nn .Sequential ( | |
| nn .Conv2d (vision_hidden_size ,llm_hidden_size ,3 ,padding =1 ), | |
| nn .GELU (), | |
| nn .Conv2d (llm_hidden_size ,llm_hidden_size ,3 ,padding =1 ), | |
| nn .GELU (), | |
| ) | |
| self .adaptive_pool =nn .AdaptiveAvgPool2d ((spatial_pool_size ,spatial_pool_size )) | |
| self .proj =nn .Sequential ( | |
| nn .Linear (llm_hidden_size ,llm_hidden_size ), | |
| nn .GELU (), | |
| nn .Linear (llm_hidden_size ,llm_hidden_size ), | |
| ) | |
| self .norm =nn .LayerNorm (llm_hidden_size ) | |
| def forward (self ,vision_features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor : | |
| batch_size =vision_features .shape [0 ] | |
| if vision_features .dim ()==3 : | |
| seq_len =vision_features .shape [1 ] | |
| if spatial_size is None : | |
| h =w =int (math .sqrt (seq_len )) | |
| else : | |
| h ,w =spatial_size | |
| vision_features =vision_features .view (batch_size ,h ,w ,-1 ) | |
| x =vision_features .permute (0 ,3 ,1 ,2 ) | |
| x =self .spatial_conv (x ) | |
| x =self .adaptive_pool (x ) | |
| x =x .flatten (2 ).transpose (1 ,2 ) | |
| if self .use_rope : | |
| cos ,sin =compute_2d_rope (self .spatial_pool_size ,self .spatial_pool_size ,x .shape [-1 ],x .device ,x .dtype ) | |
| x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 )) | |
| x =self .proj (x ) | |
| x =self .norm (x ) | |
| return x | |
| class CAbstractor (nn .Module ): | |
| """ | |
| C-Abstractor: Compressed Abstraction for efficient multimodal fusion. | |
| Now with 2D RoPE support. | |
| """ | |
| def __init__ ( | |
| self , | |
| vision_hidden_size :int , | |
| llm_hidden_size :int , | |
| num_tokens :int =64 , | |
| num_heads :int =8 , | |
| compression_ratio :int =4 , | |
| use_rope :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .num_tokens =num_tokens | |
| self .use_rope =use_rope | |
| self .input_proj =nn .Linear (vision_hidden_size ,llm_hidden_size ) | |
| self .compress =nn .Sequential ( | |
| nn .Conv1d (llm_hidden_size ,llm_hidden_size ,kernel_size =compression_ratio ,stride =compression_ratio ), | |
| nn .GELU (), | |
| ) | |
| self .queries =nn .Parameter (torch .randn (1 ,num_tokens ,llm_hidden_size )*0.02 ) | |
| self .cross_attn =nn .MultiheadAttention ( | |
| embed_dim =llm_hidden_size , | |
| num_heads =num_heads , | |
| batch_first =True , | |
| dropout =0.1 , | |
| ) | |
| self .ff =nn .Sequential ( | |
| nn .LayerNorm (llm_hidden_size ), | |
| nn .Linear (llm_hidden_size ,llm_hidden_size *4 ), | |
| nn .GELU (), | |
| nn .Linear (llm_hidden_size *4 ,llm_hidden_size ), | |
| ) | |
| self .norm =nn .LayerNorm (llm_hidden_size ) | |
| def forward (self ,vision_features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor : | |
| batch_size =vision_features .shape [0 ] | |
| x =self .input_proj (vision_features ) | |
| if self .use_rope and spatial_size is not None : | |
| h ,w =spatial_size | |
| cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype ) | |
| x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 )) | |
| x =x .transpose (1 ,2 ) | |
| x =self .compress (x ) | |
| x =x .transpose (1 ,2 ) | |
| queries =self .queries .expand (batch_size ,-1 ,-1 ) | |
| abstracted ,_ =self .cross_attn (queries ,x ,x ) | |
| abstracted =abstracted +self .ff (abstracted ) | |
| return self .norm (abstracted ) | |
| class MultimodalProjector (nn .Module ): | |
| """ | |
| SOTA Multimodal Projector with all advanced features. | |
| Combines: | |
| - Locality-Enhanced ResNet Abstractor | |
| - Multi-Scale Feature Fusion | |
| - Multi-Scale Deformable Attention | |
| - Dynamic Token Router | |
| - 2D/3D RoPE | |
| - Perceiver Resampler | |
| """ | |
| def __init__ ( | |
| self , | |
| vision_hidden_size :int , | |
| llm_hidden_size :int , | |
| num_tokens :int =64 , | |
| projector_type :str ="perceiver", | |
| num_heads :int =8 , | |
| num_layers :int =2 , | |
| use_rope :bool =True , | |
| use_dynamic_routing :bool =False , | |
| use_locality_enhanced :bool =False , | |
| use_msff :bool =False , | |
| use_deformable_attn :bool =False , | |
| ): | |
| super ().__init__ () | |
| self .num_tokens =num_tokens | |
| self .projector_type =projector_type | |
| self .use_rope =use_rope | |
| if projector_type =="perceiver": | |
| self .projector =PerceiverResampler ( | |
| input_dim =vision_hidden_size , | |
| output_dim =llm_hidden_size , | |
| num_latents =num_tokens , | |
| num_heads =num_heads , | |
| num_layers =num_layers , | |
| use_rope =use_rope , | |
| use_dynamic_routing =use_dynamic_routing , | |
| ) | |
| elif projector_type =="spatial": | |
| self .projector =SpatialAwareProjector ( | |
| vision_hidden_size =vision_hidden_size , | |
| llm_hidden_size =llm_hidden_size , | |
| num_tokens =num_tokens , | |
| use_rope =use_rope , | |
| ) | |
| elif projector_type =="c_abstractor": | |
| self .projector =CAbstractor ( | |
| vision_hidden_size =vision_hidden_size , | |
| llm_hidden_size =llm_hidden_size , | |
| num_tokens =num_tokens , | |
| num_heads =num_heads , | |
| use_rope =use_rope , | |
| ) | |
| elif projector_type =="locality_enhanced": | |
| self .projector =LocalityEnhancedResNetAbstractor ( | |
| input_dim =vision_hidden_size , | |
| output_dim =llm_hidden_size , | |
| num_tokens =num_tokens , | |
| use_2d_rope =use_rope , | |
| ) | |
| else : | |
| self .projector =nn .Sequential ( | |
| nn .Linear (vision_hidden_size ,llm_hidden_size ), | |
| nn .GELU (), | |
| nn .Linear (llm_hidden_size ,llm_hidden_size ), | |
| ) | |
| self .query_tokens =nn .Parameter (torch .randn (1 ,num_tokens ,llm_hidden_size )*0.02 ) | |
| self .cross_attn =nn .MultiheadAttention ( | |
| embed_dim =llm_hidden_size , | |
| num_heads =num_heads , | |
| batch_first =True | |
| ) | |
| self .norm =nn .LayerNorm (llm_hidden_size ) | |
| if use_msff : | |
| self .msff =MultiScaleFeatureFusion ( | |
| feature_dims =[vision_hidden_size ]*3 , | |
| output_dim =vision_hidden_size , | |
| ) | |
| else : | |
| self .msff =None | |
| if use_deformable_attn : | |
| self .deformable_attn =MultiScaleDeformableAttention ( | |
| dim =llm_hidden_size , | |
| num_heads =num_heads , | |
| ) | |
| else : | |
| self .deformable_attn =None | |
| if use_dynamic_routing and projector_type !="perceiver": | |
| self .token_router =DynamicTokenRouter (llm_hidden_size ,num_tokens ) | |
| else : | |
| self .token_router =None | |
| def forward ( | |
| self , | |
| vision_features :torch .Tensor , | |
| multi_scale_features :Optional [List [torch .Tensor ]]=None , | |
| spatial_size :Optional [Tuple [int ,int ]]=None , | |
| temporal_size :Optional [int ]=None , | |
| )->torch .Tensor : | |
| """Project and resample vision features.""" | |
| if self .msff is not None and multi_scale_features is not None : | |
| vision_features =self .msff (multi_scale_features ) | |
| if self .projector_type in ["perceiver"]: | |
| output =self .projector (vision_features ,spatial_size ,temporal_size ) | |
| elif self .projector_type in ["spatial","c_abstractor","locality_enhanced"]: | |
| output =self .projector (vision_features ,spatial_size ) | |
| else : | |
| batch_size =vision_features .shape [0 ] | |
| projected =self .projector (vision_features ) | |
| queries =self .query_tokens .expand (batch_size ,-1 ,-1 ) | |
| resampled ,_ =self .cross_attn (queries ,projected ,projected ) | |
| output =self .norm (resampled ) | |
| if self .token_router is not None : | |
| output ,_ =self .token_router (output ) | |
| return output | |
| ============================================================================== | |
| MODELS.COMPONENTS.MOE | |
| ============================================================================== | |
| EPS =1e-5 | |
| class ExpertUtilizationTracker : | |
| """ | |
| Tracks expert utilization across MoE layers. | |
| Attach to any MoE layer to log per-expert usage histograms. | |
| Every `report_interval` steps, prints a report showing: | |
| - Frequency of use per expert | |
| - Cold experts (used < 1% of tokens) | |
| - Count of experts offloaded to CPU (if ExpertOffloadManager is available) | |
| Usage: | |
| tracker = ExpertUtilizationTracker(num_experts=8, layer_name="layer.3.moe") | |
| """ | |
| def __init__ ( | |
| self , | |
| num_experts :int , | |
| layer_name :str ="moe", | |
| report_interval :int =100 , | |
| cold_threshold_pct :float =1.0 , | |
| ): | |
| self .num_experts =num_experts | |
| self .layer_name =layer_name | |
| self .report_interval =report_interval | |
| self .cold_threshold_pct =cold_threshold_pct | |
| self ._counts =torch .zeros (num_experts ,dtype =torch .long ) | |
| self ._total_tokens =0 | |
| self ._step =0 | |
| self ._offload_manager =None | |
| def link_offload_manager (self ,manager ): | |
| """Link an ExpertOffloadManager for cold-expert reporting.""" | |
| self ._offload_manager =manager | |
| def record (self ,expert_indices :torch .Tensor ): | |
| """ | |
| Record expert selections from a forward pass. | |
| Args: | |
| expert_indices: [num_tokens, top_k] tensor of selected expert indices | |
| """ | |
| indices_flat =expert_indices .detach ().cpu ().reshape (-1 ) | |
| for idx in range (self .num_experts ): | |
| self ._counts [idx ]+=(indices_flat ==idx ).sum ().item () | |
| self ._total_tokens +=expert_indices .shape [0 ] | |
| def step (self ): | |
| """Advance step counter. Prints report and resets when interval is hit.""" | |
| self ._step +=1 | |
| if self ._step %self .report_interval ==0 : | |
| self ._print_report () | |
| self ._reset () | |
| def _reset (self ): | |
| """Reset accumulators for next interval.""" | |
| self ._counts .zero_ () | |
| self ._total_tokens =0 | |
| def _print_report (self ): | |
| """Print expert utilization histogram.""" | |
| if self ._total_tokens ==0 : | |
| return | |
| freqs =self ._counts .float () | |
| total_assignments =freqs .sum ().item () | |
| if total_assignments ==0 : | |
| return | |
| pcts =(freqs /total_assignments *100 ).tolist () | |
| cold_experts =[i for i ,p in enumerate (pcts )if p <self .cold_threshold_pct ] | |
| max_pct =max (pcts )if pcts else 0 | |
| bar_max =30 | |
| lines =[f"\n{'='*60 }"] | |
| lines .append (f" Expert Utilization — {self .layer_name } (step {self ._step })") | |
| lines .append (f" {self ._total_tokens :,} tokens, {int (total_assignments ):,} assignments") | |
| lines .append (f"{'─'*60 }") | |
| for i ,pct in enumerate (pcts ): | |
| bar_len =int (pct /max_pct *bar_max )if max_pct >0 else 0 | |
| bar ="█"*bar_len | |
| cold_tag =" ❄️"if pct <self .cold_threshold_pct else "" | |
| lines .append (f" Expert {i :2d} │{bar :<{bar_max }}│ {pct :5.1f}% ({int (self ._counts [i ]):>6d}){cold_tag }") | |
| lines .append (f"{'─'*60 }") | |
| if cold_experts : | |
| lines .append (f" ❄️ Cold experts (<{self .cold_threshold_pct }%): {cold_experts }") | |
| else : | |
| lines .append (f" ✅ All experts active (no cold experts)") | |
| if self ._offload_manager is not None : | |
| status =self ._offload_manager .get_status () | |
| lines .append (f" 💾 Offloaded to CPU: {status ['cpu']}/{status ['total']}") | |
| ideal_pct =100.0 /self .num_experts | |
| balance =1.0 -(sum (abs (p -ideal_pct )for p in pcts )/(2 *100 )) | |
| lines .append (f" ⚖️ Load balance score: {balance :.3f} (1.0 = perfect)") | |
| lines .append (f"{'='*60 }") | |
| print ("\n".join (lines )) | |
| def get_stats (self )->dict : | |
| """Return current stats as a dict (for programmatic access).""" | |
| total =self ._counts .sum ().item () | |
| if total ==0 : | |
| pcts =[0.0 ]*self .num_experts | |
| else : | |
| pcts =(self ._counts .float ()/total *100 ).tolist () | |
| cold =[i for i ,p in enumerate (pcts )if p <self .cold_threshold_pct ] | |
| ideal_pct =100.0 /self .num_experts | |
| balance =1.0 -(sum (abs (p -ideal_pct )for p in pcts )/(2 *100 ))if total >0 else 0.0 | |
| return { | |
| "step":self ._step , | |
| "layer_name":self .layer_name , | |
| "total_tokens":self ._total_tokens , | |
| "expert_counts":self ._counts .tolist (), | |
| "expert_pcts":pcts , | |
| "cold_experts":cold , | |
| "balance_score":balance , | |
| } | |
| def attach_utilization_trackers ( | |
| model :torch .nn .Module , | |
| report_interval :int =100 , | |
| )->list : | |
| """ | |
| Find all MoE layers in a model and attach ExpertUtilizationTrackers. | |
| Returns list of trackers for manual step() calls in the training loop. | |
| """ | |
| trackers =[] | |
| for name ,module in model .named_modules (): | |
| if hasattr (module ,'experts')and hasattr (module ,'router'): | |
| num_experts =len (module .experts ) | |
| tracker =ExpertUtilizationTracker ( | |
| num_experts =num_experts , | |
| layer_name =name , | |
| report_interval =report_interval , | |
| ) | |
| if hasattr (module ,'_expert_offload_manager'): | |
| tracker .link_offload_manager (module ._expert_offload_manager ) | |
| module ._utilization_tracker =tracker | |
| trackers .append (tracker ) | |
| if trackers : | |
| print (f" 📊 Attached {len (trackers )} expert utilization trackers (report every {report_interval } steps)") | |
| return trackers | |
| class MoERouter (nn .Module ): | |
| """ | |
| SOTA Router for Mixture of Experts v2.0 - FP16 native. | |
| Supports both traditional aux-loss routing and aux-lossless routing. | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_experts :int ,top_k :int =2 , | |
| noise_std :float =0.01 ,capacity_factor :float =1.25 , | |
| aux_lossless :bool =True ): | |
| super ().__init__ () | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .noise_std =noise_std | |
| self .capacity_factor =capacity_factor | |
| self .hidden_size =hidden_size | |
| self .aux_lossless =aux_lossless | |
| self .input_norm =nn .LayerNorm (hidden_size ,eps =1e-5 ) | |
| self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) | |
| nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) | |
| if aux_lossless : | |
| self .expert_bias =nn .Parameter (torch .zeros (num_experts )) | |
| def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| batch_size ,seq_len ,hidden_dim =hidden_states .shape | |
| hidden_flat =hidden_states .view (-1 ,hidden_dim ) | |
| hidden_norm =self .input_norm (hidden_flat ) | |
| router_logits =self .gate (hidden_norm ) | |
| if self .aux_lossless : | |
| router_logits =router_logits +self .expert_bias | |
| if self .training and self .noise_std >0 : | |
| noise =torch .randn_like (router_logits )*self .noise_std | |
| noisy_logits =router_logits +noise | |
| else : | |
| noisy_logits =router_logits | |
| router_probs =F .softmax (noisy_logits ,dim =-1 ,dtype =hidden_states .dtype ) | |
| top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) | |
| prob_sum =top_k_probs .sum (dim =-1 ,keepdim =True ).clamp (min =EPS ) | |
| top_k_probs =top_k_probs /prob_sum | |
| return top_k_probs ,top_k_indices ,router_logits | |
| class MoEExpert (nn .Module ): | |
| """ | |
| Single expert FFN with SwiGLU activation - FP16 native. | |
| """ | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ,dropout :float =0.0 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .intermediate_size =intermediate_size | |
| self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) | |
| self .act_fn =nn .SiLU () | |
| self .dropout =nn .Dropout (dropout )if dropout >0 else nn .Identity () | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| std =0.02 | |
| nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| gate =self .act_fn (self .gate_proj (x )) | |
| up =self .up_proj (x ) | |
| out =self .down_proj (gate *up ) | |
| return self .dropout (out ) | |
| class SharedExpert (nn .Module ): | |
| """ | |
| Isolated Shared Expert (v2.0) - FP16 native. | |
| Always active, separate from routed experts. | |
| The shared expert processes all tokens independently of routing decisions. | |
| """ | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ,dropout :float =0.0 , | |
| isolated :bool =True ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .intermediate_size =intermediate_size | |
| self .isolated =isolated | |
| self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) | |
| self .act_fn =nn .SiLU () | |
| self .dropout =nn .Dropout (dropout )if dropout >0 else nn .Identity () | |
| self .shared_gate =nn .Parameter (torch .ones (1 )*0.5 ) | |
| if isolated : | |
| self .pre_norm =nn .LayerNorm (hidden_size ,eps =1e-5 ) | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| std =0.02 | |
| nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| if self .isolated : | |
| x =self .pre_norm (x ) | |
| gate =self .act_fn (self .gate_proj (x )) | |
| up =self .up_proj (x ) | |
| out =self .down_proj (gate *up ) | |
| out =self .dropout (out ) | |
| return out *torch .sigmoid (self .shared_gate ) | |
| class MoELayer (nn .Module ): | |
| """ | |
| SOTA Mixture of Experts layer v2.0 - FP16 native. | |
| Supports Aux-Lossless MoE with Isolated Shared Expert. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| intermediate_size :int , | |
| num_experts :int =8 , | |
| num_experts_per_tok :int =2 , | |
| use_shared_expert :bool =True , | |
| shared_expert_intermediate_size :Optional [int ]=None , | |
| capacity_factor :float =1.25 , | |
| expert_dropout :float =0.0 , | |
| aux_lossless :bool =True , | |
| isolated_shared :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_experts =num_experts | |
| self .num_experts_per_tok =num_experts_per_tok | |
| self .use_shared_expert =use_shared_expert | |
| self .capacity_factor =capacity_factor | |
| self .aux_lossless =aux_lossless | |
| self .router =MoERouter ( | |
| hidden_size ,num_experts ,num_experts_per_tok , | |
| capacity_factor =capacity_factor ,aux_lossless =aux_lossless | |
| ) | |
| self .experts =nn .ModuleList ([ | |
| MoEExpert (hidden_size ,intermediate_size ,expert_dropout ) | |
| for _ in range (num_experts ) | |
| ]) | |
| if use_shared_expert : | |
| shared_size =shared_expert_intermediate_size or intermediate_size | |
| self .shared_expert =SharedExpert ( | |
| hidden_size ,shared_size ,expert_dropout ,isolated =isolated_shared | |
| ) | |
| else : | |
| self .shared_expert =None | |
| def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| batch_size ,seq_len ,hidden_size =hidden_states .shape | |
| hidden_flat =hidden_states .view (-1 ,hidden_size ) | |
| num_tokens =hidden_flat .shape [0 ] | |
| top_k_probs ,top_k_indices ,router_logits =self .router (hidden_states ) | |
| if hasattr (self ,'_utilization_tracker'): | |
| self ._utilization_tracker .record (top_k_indices ) | |
| final_output =torch .zeros_like (hidden_flat ) | |
| for expert_idx in range (self .num_experts ): | |
| expert =self .experts [expert_idx ] | |
| for k in range (self .num_experts_per_tok ): | |
| mask =(top_k_indices [:,k ]==expert_idx ) | |
| if mask .any (): | |
| expert_input =hidden_flat [mask ] | |
| expert_output =expert (expert_input ) | |
| weight =top_k_probs [mask ,k :k +1 ] | |
| final_output [mask ]=final_output [mask ]+weight *expert_output | |
| if self .shared_expert is not None : | |
| shared_output =self .shared_expert (hidden_flat ) | |
| final_output =final_output +shared_output | |
| final_output =final_output .view (batch_size ,seq_len ,hidden_size ) | |
| aux_loss =self ._compute_aux_loss (router_logits ,top_k_indices ,num_tokens ) | |
| return final_output ,aux_loss | |
| def _compute_aux_loss (self ,router_logits :torch .Tensor ,top_k_indices :torch .Tensor , | |
| num_tokens :int )->torch .Tensor : | |
| device =router_logits .device | |
| dtype =router_logits .dtype | |
| if self .aux_lossless : | |
| z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.0001 | |
| return z_loss | |
| router_probs =F .softmax (router_logits ,dim =-1 ,dtype =dtype ) | |
| expert_mask =F .one_hot (top_k_indices ,self .num_experts ).to (dtype ) | |
| denominator =max (num_tokens *self .num_experts_per_tok ,1 ) | |
| tokens_per_expert =expert_mask .sum (dim =(0 ,1 ))/denominator | |
| avg_probs =router_probs .mean (dim =0 ) | |
| load_balance_loss =self .num_experts *(tokens_per_expert *avg_probs ).sum () | |
| z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.001 | |
| router_probs_safe =router_probs .clamp (EPS ,1.0 -EPS ) | |
| log_probs =torch .log (router_probs_safe ) | |
| entropy =-(router_probs_safe *log_probs ).sum (dim =-1 ).mean () | |
| max_entropy =torch .log (torch .tensor (float (self .num_experts ),device =device ,dtype =dtype )) | |
| entropy_loss =(max_entropy -entropy ).clamp (min =0.0 )*0.01 | |
| expert_usage =(tokens_per_expert >0.01 ).to (dtype ).mean () | |
| utilization_loss =(1.0 -expert_usage )*0.1 | |
| total_aux_loss =load_balance_loss +z_loss +entropy_loss +utilization_loss | |
| return total_aux_loss | |
| class ExpertChoiceMoELayer (nn .Module ): | |
| """ | |
| Expert Choice MoE - FP16 native. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| intermediate_size :int , | |
| num_experts :int =8 , | |
| capacity_factor :float =1.0 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_experts =num_experts | |
| self .capacity_factor =capacity_factor | |
| self .input_norm =nn .LayerNorm (hidden_size ,eps =1e-5 ) | |
| self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) | |
| nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) | |
| self .experts =nn .ModuleList ([ | |
| MoEExpert (hidden_size ,intermediate_size ) | |
| for _ in range (num_experts ) | |
| ]) | |
| def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| batch_size ,seq_len ,hidden_size =hidden_states .shape | |
| hidden_flat =hidden_states .view (-1 ,hidden_size ) | |
| num_tokens =hidden_flat .shape [0 ] | |
| hidden_norm =self .input_norm (hidden_flat ) | |
| router_logits =self .gate (hidden_norm ) | |
| router_probs =F .softmax (router_logits ,dim =0 ,dtype =hidden_states .dtype ) | |
| capacity =int (num_tokens *self .capacity_factor /self .num_experts ) | |
| capacity =max (capacity ,1 ) | |
| final_output =torch .zeros_like (hidden_flat ) | |
| token_counts =torch .zeros (num_tokens ,device =hidden_flat .device ,dtype =hidden_flat .dtype ) | |
| for expert_idx in range (self .num_experts ): | |
| expert =self .experts [expert_idx ] | |
| expert_probs =router_probs [:,expert_idx ] | |
| top_probs ,top_indices =torch .topk (expert_probs ,min (capacity ,num_tokens )) | |
| expert_input =hidden_flat [top_indices ] | |
| expert_output =expert (expert_input ) | |
| final_output [top_indices ]=final_output [top_indices ]+top_probs .unsqueeze (-1 )*expert_output | |
| token_counts [top_indices ]=token_counts [top_indices ]+top_probs | |
| token_counts =token_counts .clamp (min =EPS ) | |
| final_output =final_output /token_counts .unsqueeze (-1 ) | |
| final_output =final_output .view (batch_size ,seq_len ,hidden_size ) | |
| aux_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.001 | |
| return final_output ,aux_loss | |
| ============================================================================== | |
| MODELS.ENCODERS.VISION | |
| ============================================================================== | |
| EPS =1e-5 | |
| class RoPE2DEncoder (nn .Module ): | |
| """ | |
| 2D Rotary Position Embedding for vision encoder patches. | |
| Matches the 2D-RoPE in image generator for seamless integration. | |
| """ | |
| def __init__ (self ,dim :int ,max_height :int =128 ,max_width :int =128 ,base :float =10000.0 ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .max_height =max_height | |
| self .max_width =max_width | |
| self .base =base | |
| self .dim_x =dim //2 | |
| self .dim_y =dim -self .dim_x | |
| inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) | |
| inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) | |
| self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) | |
| self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| device =x .device | |
| dtype =x .dtype | |
| pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) | |
| pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) | |
| freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) | |
| freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) | |
| freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) | |
| freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) | |
| cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| for y in range (height ): | |
| for w in range (width ): | |
| cos_2d [y ,w ,:self .dim_x ]=freqs_x [w ].cos ().to (dtype ) | |
| sin_2d [y ,w ,:self .dim_x ]=freqs_x [w ].sin ().to (dtype ) | |
| cos_2d [y ,w ,self .dim_x :]=freqs_y [y ].cos ().to (dtype ) | |
| sin_2d [y ,w ,self .dim_x :]=freqs_y [y ].sin ().to (dtype ) | |
| cos_2d =cos_2d .view (height *width ,self .dim ) | |
| sin_2d =sin_2d .view (height *width ,self .dim ) | |
| return cos_2d ,sin_2d | |
| def apply_rope_2d_encoder (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : | |
| """Apply 2D rotary position embedding to tensor.""" | |
| x1 =x [...,:x .shape [-1 ]//2 ] | |
| x2 =x [...,x .shape [-1 ]//2 :] | |
| rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) | |
| return x *cos +rotated *sin | |
| class TiTokTokenizer (nn .Module ): | |
| """ | |
| TiTok-style 1D Tokenizer for efficient visual representation. | |
| Converts 2D patch grid to 1D token sequence with learnable compression. | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_tokens :int =256 ,num_patches :int =576 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_tokens =num_tokens | |
| self .num_patches =num_patches | |
| self .compress =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size ), | |
| nn .GELU (), | |
| nn .Linear (hidden_size ,hidden_size ), | |
| ) | |
| self .token_queries =nn .Parameter (torch .randn (1 ,num_tokens ,hidden_size )*0.02 ) | |
| self .compress_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =8 , | |
| batch_first =True , | |
| dropout =0.1 , | |
| ) | |
| self .compress_norm =nn .LayerNorm (hidden_size ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| """ | |
| Compress patch features to TiTok-style 1D tokens. | |
| Args: | |
| x: [B, num_patches, hidden_size] patch features | |
| Returns: | |
| [B, num_tokens, hidden_size] compressed token features | |
| """ | |
| batch_size =x .shape [0 ] | |
| queries =self .token_queries .expand (batch_size ,-1 ,-1 ) | |
| x_proj =self .compress (x ) | |
| tokens ,_ =self .compress_attn (queries ,x_proj ,x_proj ) | |
| tokens =self .compress_norm (queries +tokens ) | |
| return tokens | |
| class DeepStack (nn .Module ): | |
| """ | |
| DeepStack: Fuses multi-level ViT features to capture fine-grained details and sharpen image-text alignment. | |
| SOTA: Instead of using only the final layer features, DeepStack combines features from | |
| multiple intermediate layers of the vision encoder, enabling: | |
| - Better fine-grained detail capture (early layers have high-resolution features) | |
| - Stronger image-text alignment (different layers capture different semantic levels) | |
| - Improved generation quality for both understanding and generation tasks | |
| Architecture: | |
| - Collects features from selected layers (typically: early, middle, late) | |
| - Projects each level to a common dimension | |
| - Combines via learned weighted sum or attention | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_layers :int =3 ,use_attention :bool =True ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_layers =num_layers | |
| self .use_attention =use_attention | |
| self .level_projs =nn .ModuleList ([ | |
| nn .Linear (hidden_size ,hidden_size ) | |
| for _ in range (num_layers ) | |
| ]) | |
| self .level_norms =nn .ModuleList ([ | |
| nn .LayerNorm (hidden_size ) | |
| for _ in range (num_layers ) | |
| ]) | |
| if use_attention : | |
| self .fusion_query =nn .Parameter (torch .randn (1 ,1 ,hidden_size )*0.02 ) | |
| self .fusion_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =8 , | |
| batch_first =True , | |
| dropout =0.1 , | |
| ) | |
| self .fusion_norm =nn .LayerNorm (hidden_size ) | |
| else : | |
| self .level_weights =nn .Parameter (torch .ones (num_layers )/num_layers ) | |
| self .output_proj =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size ), | |
| nn .GELU (), | |
| nn .Linear (hidden_size ,hidden_size ), | |
| ) | |
| def forward (self ,multi_level_features :list )->torch .Tensor : | |
| """ | |
| Fuse multi-level features. | |
| Args: | |
| multi_level_features: List of [B, seq_len, hidden_size] features from different layers | |
| Returns: | |
| [B, seq_len, hidden_size] fused features | |
| """ | |
| if len (multi_level_features )!=self .num_layers : | |
| multi_level_features =multi_level_features [-self .num_layers :]if len (multi_level_features )>self .num_layers else multi_level_features | |
| batch_size ,seq_len ,_ =multi_level_features [0 ].shape | |
| projected =[] | |
| for i ,(feat ,proj ,norm )in enumerate (zip (multi_level_features ,self .level_projs ,self .level_norms )): | |
| projected .append (norm (proj (feat ))) | |
| if self .use_attention : | |
| stacked =torch .cat (projected ,dim =1 ) | |
| query =self .fusion_query .expand (batch_size ,seq_len ,-1 ) | |
| fused ,_ =self .fusion_attn (query ,stacked ,stacked ) | |
| fused =self .fusion_norm (query +fused ) | |
| else : | |
| weights =F .softmax (self .level_weights ,dim =0 ) | |
| fused =sum (w *feat for w ,feat in zip (weights ,projected )) | |
| return self .output_proj (fused ) | |
| class DualStreamEncoderAttention (nn .Module ): | |
| """ | |
| Symmetric Dual-Stream Self-Attention for vision encoding. | |
| Matches the dual-stream architecture in image generator. | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .head_dim =hidden_size //num_heads | |
| self .scale =self .head_dim **-0.5 | |
| self .to_qkv_a =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) | |
| self .to_qkv_b =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) | |
| self .to_out_a =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .to_out_b =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .norm_a =nn .LayerNorm (hidden_size ) | |
| self .norm_b =nn .LayerNorm (hidden_size ) | |
| self .rope_2d =RoPE2DEncoder (self .head_dim ,max_height ,max_width ) | |
| def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| batch_size ,seq_len ,_ =x_a .shape | |
| x_a =self .norm_a (x_a ) | |
| x_b =self .norm_b (x_b ) | |
| qkv_a =self .to_qkv_a (x_a ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) | |
| qkv_b =self .to_qkv_b (x_b ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) | |
| q_a ,k_a ,v_a =qkv_a .unbind (dim =2 ) | |
| q_b ,k_b ,v_b =qkv_b .unbind (dim =2 ) | |
| cos ,sin =self .rope_2d (x_a ,height ,width ) | |
| cos =cos .unsqueeze (0 ).unsqueeze (0 ) | |
| sin =sin .unsqueeze (0 ).unsqueeze (0 ) | |
| q_a =q_a .transpose (1 ,2 ) | |
| k_a =k_a .transpose (1 ,2 ) | |
| v_a =v_a .transpose (1 ,2 ) | |
| q_b =q_b .transpose (1 ,2 ) | |
| k_b =k_b .transpose (1 ,2 ) | |
| v_b =v_b .transpose (1 ,2 ) | |
| q_a =apply_rope_2d_encoder (q_a ,cos ,sin ) | |
| k_a =apply_rope_2d_encoder (k_a ,cos ,sin ) | |
| q_b =apply_rope_2d_encoder (q_b ,cos ,sin ) | |
| k_b =apply_rope_2d_encoder (k_b ,cos ,sin ) | |
| k_combined =torch .cat ([k_a ,k_b ],dim =2 ) | |
| v_combined =torch .cat ([v_a ,v_b ],dim =2 ) | |
| attn_a =F .scaled_dot_product_attention (q_a ,k_combined ,v_combined ) | |
| attn_b =F .scaled_dot_product_attention (q_b ,k_combined ,v_combined ) | |
| attn_a =attn_a .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) | |
| attn_b =attn_b .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) | |
| out_a =self .to_out_a (attn_a ) | |
| out_b =self .to_out_b (attn_b ) | |
| return out_a ,out_b | |
| class VisionEncoderBlock (nn .Module ): | |
| """Single block with dual-stream attention and FFN.""" | |
| def __init__ (self ,hidden_size :int ,num_heads :int =8 ,ff_mult :int =4 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .dual_attn =DualStreamEncoderAttention (hidden_size ,num_heads ,max_height ,max_width ) | |
| self .ffn_a =nn .Sequential ( | |
| nn .LayerNorm (hidden_size ), | |
| nn .Linear (hidden_size ,hidden_size *ff_mult ), | |
| nn .GELU (), | |
| nn .Linear (hidden_size *ff_mult ,hidden_size ), | |
| ) | |
| self .ffn_b =nn .Sequential ( | |
| nn .LayerNorm (hidden_size ), | |
| nn .Linear (hidden_size ,hidden_size *ff_mult ), | |
| nn .GELU (), | |
| nn .Linear (hidden_size *ff_mult ,hidden_size ), | |
| ) | |
| def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| attn_a ,attn_b =self .dual_attn (x_a ,x_b ,height ,width ) | |
| x_a =x_a +attn_a | |
| x_b =x_b +attn_b | |
| x_a =x_a +self .ffn_a (x_a ) | |
| x_b =x_b +self .ffn_b (x_b ) | |
| return x_a ,x_b | |
| class VisionEncoder (nn .Module ): | |
| """ | |
| SOTA Vision Encoder with 2D-RoPE, TiTok tokenization, and Dual-Stream Attention. | |
| Features: | |
| - SigLIP 2 / CLIP backbone for robust visual features | |
| - 2D-RoPE for flexible aspect ratios | |
| - TiTok-style 1D tokenization for efficient representation | |
| - Dual-stream attention for symmetric processing | |
| - FP16-native numerical stability | |
| """ | |
| def __init__ ( | |
| self , | |
| model_name :str ="google/siglip-so400m-patch14-384", | |
| freeze :bool =False , | |
| use_pooled_output :bool =False , | |
| use_dual_stream :bool =True , | |
| use_titok :bool =True , | |
| num_titok_tokens :int =256 , | |
| num_dual_stream_layers :int =2 , | |
| ): | |
| super ().__init__ () | |
| self .model_name =model_name | |
| self .use_pooled_output =use_pooled_output | |
| self .use_dual_stream =use_dual_stream | |
| self .use_titok =use_titok | |
| self ._is_siglip ="siglip"in model_name .lower () | |
| print (f"\n👁️ Loading Vision Encoder: {model_name }") | |
| if self ._is_siglip : | |
| self ._init_siglip (model_name ,freeze ) | |
| else : | |
| self ._init_clip (model_name ,freeze ) | |
| self .rope_2d =RoPE2DEncoder ( | |
| dim =self .hidden_size , | |
| max_height =64 , | |
| max_width =64 , | |
| ) | |
| print (f" 📐 2D-RoPE: Flexible aspect ratio support") | |
| if use_dual_stream : | |
| patch_size =getattr (self .vision_model .config ,'patch_size',14 ) | |
| image_size =getattr (self .vision_model .config ,'image_size',384 ) | |
| max_patches =(image_size //patch_size ) | |
| self .dual_stream_layers =nn .ModuleList ([ | |
| VisionEncoderBlock ( | |
| hidden_size =self .hidden_size , | |
| num_heads =8 , | |
| ff_mult =4 , | |
| max_height =max_patches , | |
| max_width =max_patches , | |
| ) | |
| for _ in range (num_dual_stream_layers ) | |
| ]) | |
| print (f" 🔄 Dual-Stream: {num_dual_stream_layers } layers") | |
| else : | |
| self .dual_stream_layers =None | |
| if use_titok : | |
| self .titok =TiTokTokenizer ( | |
| hidden_size =self .hidden_size , | |
| num_tokens =num_titok_tokens , | |
| num_patches =self .num_patches , | |
| ) | |
| print (f" 🎫 TiTok: {self .num_patches } patches -> {num_titok_tokens } tokens") | |
| else : | |
| self .titok =None | |
| def _init_siglip (self ,model_name :str ,freeze :bool ): | |
| """Initialize SigLIP 2 vision encoder.""" | |
| try : | |
| from transformers import SiglipVisionModel ,SiglipImageProcessor | |
| self .vision_model =SiglipVisionModel .from_pretrained (model_name ) | |
| self .image_processor =SiglipImageProcessor .from_pretrained (model_name ) | |
| self .hidden_size =self .vision_model .config .hidden_size | |
| print (f" 🎯 Using SigLIP 2 (recommended for MoE)") | |
| print (f" ✅ Hidden size: {self .hidden_size }") | |
| print (f" 📐 Native size: {self .vision_model .config .image_size } (multi-scale: 256-512px)") | |
| print (f" 🔲 Patch size: {self .vision_model .config .patch_size }") | |
| except ImportError : | |
| print (" ⚠️ SigLIP not available, falling back to CLIP") | |
| self ._is_siglip =False | |
| self ._init_clip ("openai/clip-vit-large-patch14",freeze ) | |
| return | |
| if freeze : | |
| for param in self .vision_model .parameters (): | |
| param .requires_grad =False | |
| print (f" ❄️ Vision encoder backbone frozen") | |
| else : | |
| print (f" 🔥 Vision encoder backbone trainable") | |
| def _init_clip (self ,model_name :str ,freeze :bool ): | |
| """Initialize CLIP vision encoder (legacy support).""" | |
| from transformers import CLIPVisionModel ,CLIPImageProcessor | |
| self .vision_model =CLIPVisionModel .from_pretrained (model_name ) | |
| self .image_processor =CLIPImageProcessor .from_pretrained (model_name ) | |
| self .hidden_size =self .vision_model .config .hidden_size | |
| print (f" 📎 Using CLIP") | |
| print (f" ✅ Hidden size: {self .hidden_size }") | |
| if freeze : | |
| for param in self .vision_model .parameters (): | |
| param .requires_grad =False | |
| print (f" ❄️ Vision encoder backbone frozen") | |
| else : | |
| print (f" 🔥 Vision encoder backbone trainable") | |
| def forward (self ,pixel_values :torch .Tensor ,return_titok :bool =None )->torch .Tensor : | |
| """ | |
| Extract vision features from images with SOTA enhancements. | |
| Args: | |
| pixel_values: [B, C, H, W] tensor of images | |
| return_titok: Override for TiTok output (None uses self.use_titok) | |
| Returns: | |
| [B, num_tokens, hidden_size] tensor (TiTok) or | |
| [B, num_patches, hidden_size] tensor (standard) or | |
| [B, hidden_size] if use_pooled_output=True | |
| """ | |
| outputs =self .vision_model (pixel_values =pixel_values ) | |
| features =outputs .last_hidden_state | |
| if self .use_pooled_output : | |
| if hasattr (outputs ,'pooler_output')and outputs .pooler_output is not None : | |
| return outputs .pooler_output | |
| else : | |
| return features .mean (dim =1 ) | |
| batch_size ,num_patches ,hidden_size =features .shape | |
| patch_size =getattr (self .vision_model .config ,'patch_size',14 ) | |
| image_size =getattr (self .vision_model .config ,'image_size',384 ) | |
| if num_patches ==(image_size //patch_size )**2 +1 : | |
| cls_token =features [:,:1 ] | |
| features =features [:,1 :] | |
| num_patches =num_patches -1 | |
| has_cls =True | |
| else : | |
| cls_token =None | |
| has_cls =False | |
| height =width =int (math .sqrt (num_patches )) | |
| if self .dual_stream_layers is not None : | |
| x_a =features | |
| x_b =features .clone () | |
| for layer in self .dual_stream_layers : | |
| x_a ,x_b =layer (x_a ,x_b ,height ,width ) | |
| features =(x_a +x_b )/2 | |
| use_titok_now =return_titok if return_titok is not None else self .use_titok | |
| if use_titok_now and self .titok is not None : | |
| features =self .titok (features ) | |
| return features | |
| def get_image_processor (self ): | |
| """Return the image processor for preprocessing.""" | |
| return self .image_processor | |
| def num_patches (self )->int : | |
| """Get number of patches for the vision model.""" | |
| config =self .vision_model .config | |
| image_size =config .image_size | |
| patch_size =config .patch_size | |
| return (image_size //patch_size )**2 | |
| def image_size (self )->int : | |
| """Get expected image size.""" | |
| return self .vision_model .config .image_size | |
| def output_tokens (self )->int : | |
| """Get number of output tokens (considering TiTok compression).""" | |
| if self .use_titok and self .titok is not None : | |
| return self .titok .num_tokens | |
| return self .num_patches | |
| SIGLIP_MODELS ={ | |
| "siglip-base":"google/siglip-base-patch16-224", | |
| "siglip-base-384":"google/siglip-base-patch16-384", | |
| "siglip-large":"google/siglip-large-patch16-256", | |
| "siglip-large-384":"google/siglip-large-patch16-384", | |
| "siglip-so400m":"google/siglip-so400m-patch14-384", | |
| "siglip-so400m-224":"google/siglip-so400m-patch14-224", | |
| "clip-base":"openai/clip-vit-base-patch16", | |
| "clip-large":"openai/clip-vit-large-patch14", | |
| } | |
| def get_vision_encoder ( | |
| model_key :str ="siglip-so400m", | |
| freeze :bool =False , | |
| use_dual_stream :bool =True , | |
| use_titok :bool =True , | |
| **kwargs | |
| )->VisionEncoder : | |
| """ | |
| Get a vision encoder by key name with SOTA enhancements. | |
| Args: | |
| model_key: Key from SIGLIP_MODELS or full model name | |
| freeze: Whether to freeze encoder backbone weights | |
| use_dual_stream: Enable dual-stream attention | |
| use_titok: Enable TiTok 1D tokenization | |
| **kwargs: Additional arguments for VisionEncoder | |
| Returns: | |
| VisionEncoder instance | |
| """ | |
| model_name =SIGLIP_MODELS .get (model_key ,model_key ) | |
| return VisionEncoder ( | |
| model_name =model_name , | |
| freeze =freeze , | |
| use_dual_stream =use_dual_stream , | |
| use_titok =use_titok , | |
| **kwargs | |
| ) | |
| ============================================================================== | |
| MODELS.ENCODERS.VIDEO | |
| ============================================================================== | |
| EPS =1e-5 | |
| class TextTimestampAlignment (nn .Module ): | |
| """ | |
| Text-Timestamp Alignment: Precise timestamp-grounded event localization for stronger video temporal modeling. | |
| SOTA: Moves beyond T-RoPE by explicitly aligning text descriptions with video timestamps, | |
| enabling: | |
| - Precise temporal localization of events described in text | |
| - Better video captioning with accurate time references | |
| - Improved video question-answering with temporal reasoning | |
| - Enhanced video generation with temporal control | |
| Architecture: | |
| - Cross-attention between text features and frame-level video features | |
| - Learnable timestamp embeddings for each frame | |
| - Temporal alignment loss during training | |
| """ | |
| def __init__ (self ,hidden_size :int ,max_frames :int =64 ,num_heads :int =8 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .max_frames =max_frames | |
| self .num_heads =num_heads | |
| self .timestamp_embedding =nn .Embedding (max_frames ,hidden_size ) | |
| self .video_proj =nn .Linear (hidden_size ,hidden_size ) | |
| self .text_proj =nn .Linear (hidden_size ,hidden_size ) | |
| self .cross_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =num_heads , | |
| batch_first =True , | |
| dropout =0.1 , | |
| ) | |
| self .text_norm =nn .LayerNorm (hidden_size ) | |
| self .video_norm =nn .LayerNorm (hidden_size ) | |
| self .alignment_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //2 ), | |
| nn .GELU (), | |
| nn .Linear (hidden_size //2 ,1 ), | |
| ) | |
| self .output_proj =nn .Linear (hidden_size ,hidden_size ) | |
| def forward ( | |
| self , | |
| video_features :torch .Tensor , | |
| text_features :torch .Tensor , | |
| num_frames :int , | |
| return_alignment_scores :bool =False , | |
| )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]: | |
| """ | |
| Align text with video timestamps. | |
| Args: | |
| video_features: [B, T*H*W, hidden_size] video features | |
| text_features: [B, text_len, hidden_size] text features | |
| num_frames: Number of frames in the video | |
| return_alignment_scores: Whether to return alignment scores for loss | |
| Returns: | |
| aligned_features: [B, T*H*W, hidden_size] timestamp-aligned video features | |
| alignment_scores: Optional [B, text_len, T] alignment scores | |
| """ | |
| batch_size =video_features .shape [0 ] | |
| total_tokens =video_features .shape [1 ] | |
| spatial_tokens =total_tokens //num_frames | |
| timestamp_ids =torch .arange (num_frames ,device =video_features .device ) | |
| timestamp_embeds =self .timestamp_embedding (timestamp_ids ) | |
| timestamp_embeds =timestamp_embeds .unsqueeze (1 ).expand (-1 ,spatial_tokens ,-1 ) | |
| timestamp_embeds =timestamp_embeds .reshape (1 ,total_tokens ,-1 ) | |
| timestamp_embeds =timestamp_embeds .expand (batch_size ,-1 ,-1 ) | |
| video_feat =self .video_norm (self .video_proj (video_features )+timestamp_embeds ) | |
| text_feat =self .text_norm (self .text_proj (text_features )) | |
| aligned ,attn_weights =self .cross_attn (text_feat ,video_feat ,video_feat ) | |
| alignment_scores =None | |
| if return_alignment_scores : | |
| attn_reshaped =attn_weights .view (batch_size ,text_features .shape [1 ],num_frames ,spatial_tokens ) | |
| alignment_scores =attn_reshaped .mean (dim =-1 ) | |
| aligned_text =text_features +self .output_proj (aligned ) | |
| return aligned_text ,alignment_scores | |
| class AlphaBlender (nn .Module ): | |
| """ | |
| AlphaBlender operator from VidTok for temporal blending. | |
| Blends two inputs with a learnable or fixed alpha parameter. | |
| """ | |
| def __init__ (self ,alpha :float =0.55 ): | |
| super ().__init__ () | |
| self .alpha =alpha | |
| def forward (self ,x1 :torch .Tensor ,x2 :torch .Tensor )->torch .Tensor : | |
| return self .alpha *x1 +(1 -self .alpha )*x2 | |
| class VidTokEncoder (nn .Module ): | |
| """ | |
| VidTok-style Video Encoder following Microsoft's VidTok architecture. | |
| SOTA: Implements the VidTok encoder with: | |
| - 3D convolutions for input and bottleneck (information fusion) | |
| - 2D convolutions for spatial downsampling (efficiency) | |
| - AlphaBlender + 1D convolutions for temporal downsampling | |
| - Layer normalization for stability | |
| Compresses video [B, C, T, H, W] -> latent [B, latent_dim, t, h, w] | |
| """ | |
| def __init__ ( | |
| self , | |
| in_channels :int =3 , | |
| latent_channels :int =4 , | |
| base_channels :int =64 , | |
| temporal_downsample :int =4 , | |
| spatial_downsample :int =8 , | |
| causal :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .in_channels =in_channels | |
| self .latent_channels =latent_channels | |
| self .base_channels =base_channels | |
| self .temporal_downsample =temporal_downsample | |
| self .spatial_downsample =spatial_downsample | |
| self .causal =causal | |
| self .num_spatial_downs =int (math .log2 (spatial_downsample )) | |
| self .num_temporal_downs =int (math .log2 (temporal_downsample )) | |
| self .input_block =nn .Sequential ( | |
| nn .Conv3d (in_channels ,base_channels ,kernel_size =3 ,padding =1 ), | |
| nn .GroupNorm (8 ,base_channels ), | |
| nn .SiLU (), | |
| ) | |
| self .spatial_down_blocks =nn .ModuleList () | |
| ch =base_channels | |
| for i in range (self .num_spatial_downs ): | |
| out_ch =min (ch *2 ,512 ) | |
| self .spatial_down_blocks .append ( | |
| self ._make_spatial_down_block (ch ,out_ch ) | |
| ) | |
| ch =out_ch | |
| self .temporal_down_blocks =nn .ModuleList () | |
| for i in range (self .num_temporal_downs ): | |
| self .temporal_down_blocks .append ( | |
| self ._make_temporal_down_block (ch ) | |
| ) | |
| self .bottleneck =nn .Sequential ( | |
| nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), | |
| nn .GroupNorm (8 ,ch ), | |
| nn .SiLU (), | |
| nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), | |
| nn .GroupNorm (8 ,ch ), | |
| nn .SiLU (), | |
| ) | |
| self .to_latent =nn .Conv3d (ch ,latent_channels ,kernel_size =1 ) | |
| print (f" 🎬 VidTokEncoder: {in_channels }ch -> {latent_channels }ch latent") | |
| print (f" Spatial: {spatial_downsample }x down ({self .num_spatial_downs } stages)") | |
| print (f" Temporal: {temporal_downsample }x down ({self .num_temporal_downs } stages)") | |
| def _make_spatial_down_block (self ,in_ch :int ,out_ch :int )->nn .Module : | |
| """Create a spatial downsampling block using 2D convolutions.""" | |
| return nn .Sequential ( | |
| Rearrange3Dto2D (), | |
| nn .Conv2d (in_ch ,out_ch ,kernel_size =3 ,stride =2 ,padding =1 ), | |
| nn .GroupNorm (8 ,out_ch ), | |
| nn .SiLU (), | |
| nn .Conv2d (out_ch ,out_ch ,kernel_size =3 ,padding =1 ), | |
| nn .GroupNorm (8 ,out_ch ), | |
| nn .SiLU (), | |
| Rearrange2Dto3D (), | |
| ) | |
| def _make_temporal_down_block (self ,channels :int )->nn .Module : | |
| """Create a temporal downsampling block using AlphaBlender + 1D conv.""" | |
| return TemporalDownBlock (channels ,causal =self .causal ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| """ | |
| Encode video to latent space. | |
| Args: | |
| x: [B, C, T, H, W] input video | |
| Returns: | |
| [B, latent_channels, t, h, w] latent representation | |
| """ | |
| B ,C ,T ,H ,W =x .shape | |
| x =self .input_block (x ) | |
| for block in self .spatial_down_blocks : | |
| if hasattr (block [0 ],'set_temporal_dim'): | |
| block [0 ].set_temporal_dim (x .shape [2 ]) | |
| if hasattr (block [-1 ],'set_temporal_dim'): | |
| block [-1 ].set_temporal_dim (x .shape [2 ]) | |
| x =block (x ) | |
| for block in self .temporal_down_blocks : | |
| x =block (x ) | |
| x =self .bottleneck (x ) | |
| x =self .to_latent (x ) | |
| return x | |
| class VidTokDecoder (nn .Module ): | |
| """ | |
| VidTok-style Video Decoder following Microsoft's VidTok architecture. | |
| Reconstructs video from latent [B, latent_dim, t, h, w] -> [B, C, T, H, W] | |
| """ | |
| def __init__ ( | |
| self , | |
| out_channels :int =3 , | |
| latent_channels :int =4 , | |
| base_channels :int =64 , | |
| temporal_upsample :int =4 , | |
| spatial_upsample :int =8 , | |
| causal :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .out_channels =out_channels | |
| self .latent_channels =latent_channels | |
| self .base_channels =base_channels | |
| self .temporal_upsample =temporal_upsample | |
| self .spatial_upsample =spatial_upsample | |
| self .causal =causal | |
| self .num_spatial_ups =int (math .log2 (spatial_upsample )) | |
| self .num_temporal_ups =int (math .log2 (temporal_upsample )) | |
| ch =min (base_channels *(2 **self .num_spatial_ups ),512 ) | |
| self .from_latent =nn .Conv3d (latent_channels ,ch ,kernel_size =1 ) | |
| self .bottleneck =nn .Sequential ( | |
| nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), | |
| nn .GroupNorm (8 ,ch ), | |
| nn .SiLU (), | |
| nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), | |
| nn .GroupNorm (8 ,ch ), | |
| nn .SiLU (), | |
| ) | |
| self .temporal_up_blocks =nn .ModuleList () | |
| for i in range (self .num_temporal_ups ): | |
| self .temporal_up_blocks .append ( | |
| TemporalUpBlock (ch ,causal =self .causal ) | |
| ) | |
| self .spatial_up_blocks =nn .ModuleList () | |
| for i in range (self .num_spatial_ups ): | |
| out_ch =max (ch //2 ,base_channels ) | |
| self .spatial_up_blocks .append ( | |
| self ._make_spatial_up_block (ch ,out_ch ) | |
| ) | |
| ch =out_ch | |
| self .output_block =nn .Sequential ( | |
| nn .Conv3d (ch ,out_channels ,kernel_size =3 ,padding =1 ), | |
| nn .Tanh (), | |
| ) | |
| print (f" 🎬 VidTokDecoder: {latent_channels }ch latent -> {out_channels }ch") | |
| def _make_spatial_up_block (self ,in_ch :int ,out_ch :int )->nn .Module : | |
| """Create a spatial upsampling block using 2D convolutions.""" | |
| return nn .Sequential ( | |
| Rearrange3Dto2D (), | |
| nn .ConvTranspose2d (in_ch ,out_ch ,kernel_size =4 ,stride =2 ,padding =1 ), | |
| nn .GroupNorm (8 ,out_ch ), | |
| nn .SiLU (), | |
| nn .Conv2d (out_ch ,out_ch ,kernel_size =3 ,padding =1 ), | |
| nn .GroupNorm (8 ,out_ch ), | |
| nn .SiLU (), | |
| Rearrange2Dto3D (), | |
| ) | |
| def forward (self ,z :torch .Tensor )->torch .Tensor : | |
| """ | |
| Decode latent to video. | |
| Args: | |
| z: [B, latent_channels, t, h, w] latent representation | |
| Returns: | |
| [B, C, T, H, W] reconstructed video | |
| """ | |
| x =self .from_latent (z ) | |
| x =self .bottleneck (x ) | |
| for block in self .temporal_up_blocks : | |
| x =block (x ) | |
| for block in self .spatial_up_blocks : | |
| x =block (x ) | |
| x =self .output_block (x ) | |
| return x | |
| class Rearrange3Dto2D (nn .Module ): | |
| """Reshape [B, C, T, H, W] -> [B*T, C, H, W] for 2D operations.""" | |
| def __init__ (self ): | |
| super ().__init__ () | |
| self .temporal_dim =None | |
| def set_temporal_dim (self ,t :int ): | |
| self .temporal_dim =t | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| B ,C ,T ,H ,W =x .shape | |
| self .temporal_dim =T | |
| return x .permute (0 ,2 ,1 ,3 ,4 ).reshape (B *T ,C ,H ,W ) | |
| class Rearrange2Dto3D (nn .Module ): | |
| """Reshape [B*T, C, H, W] -> [B, C, T, H, W] after 2D operations.""" | |
| def __init__ (self ): | |
| super ().__init__ () | |
| self .temporal_dim =None | |
| def set_temporal_dim (self ,t :int ): | |
| self .temporal_dim =t | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| BT ,C ,H ,W =x .shape | |
| T =self .temporal_dim if self .temporal_dim else 1 | |
| B =BT //T | |
| return x .reshape (B ,T ,C ,H ,W ).permute (0 ,2 ,1 ,3 ,4 ) | |
| class TemporalDownBlock (nn .Module ): | |
| """Temporal downsampling using AlphaBlender + 1D conv (VidTok style).""" | |
| def __init__ (self ,channels :int ,causal :bool =True ): | |
| super ().__init__ () | |
| self .channels =channels | |
| self .causal =causal | |
| self .alpha_blender =AlphaBlender () | |
| padding =(1 ,0 )if causal else 1 | |
| self .temporal_conv =nn .Conv1d (channels ,channels ,kernel_size =2 ,stride =2 ,padding =0 ) | |
| self .norm =nn .GroupNorm (8 ,channels ) | |
| self .act =nn .SiLU () | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| """ | |
| Args: | |
| x: [B, C, T, H, W] | |
| Returns: | |
| [B, C, T//2, H, W] | |
| """ | |
| B ,C ,T ,H ,W =x .shape | |
| x =x .permute (0 ,3 ,4 ,1 ,2 ).reshape (B *H *W ,C ,T ) | |
| x =self .temporal_conv (x ) | |
| x =self .norm (x .unsqueeze (-1 )).squeeze (-1 ) | |
| x =self .act (x ) | |
| T_new =x .shape [2 ] | |
| x =x .reshape (B ,H ,W ,C ,T_new ).permute (0 ,3 ,4 ,1 ,2 ) | |
| return x | |
| class TemporalUpBlock (nn .Module ): | |
| """Temporal upsampling using AlphaBlender + 1D conv (VidTok style).""" | |
| def __init__ (self ,channels :int ,causal :bool =True ): | |
| super ().__init__ () | |
| self .channels =channels | |
| self .causal =causal | |
| self .alpha_blender =AlphaBlender () | |
| self .temporal_conv =nn .ConvTranspose1d (channels ,channels ,kernel_size =2 ,stride =2 ) | |
| self .norm =nn .GroupNorm (8 ,channels ) | |
| self .act =nn .SiLU () | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| """ | |
| Args: | |
| x: [B, C, T, H, W] | |
| Returns: | |
| [B, C, T*2, H, W] | |
| """ | |
| B ,C ,T ,H ,W =x .shape | |
| x =x .permute (0 ,3 ,4 ,1 ,2 ).reshape (B *H *W ,C ,T ) | |
| x =self .temporal_conv (x ) | |
| x =self .norm (x .unsqueeze (-1 )).squeeze (-1 ) | |
| x =self .act (x ) | |
| T_new =x .shape [2 ] | |
| x =x .reshape (B ,H ,W ,C ,T_new ).permute (0 ,3 ,4 ,1 ,2 ) | |
| return x | |
| class VidTokTokenizer (nn .Module ): | |
| """ | |
| VidTok-style Video Tokenizer (3D VAE) following Microsoft's VidTok architecture. | |
| SOTA: Full encoder-decoder architecture for video compression to latent space. | |
| - Efficient 2D+1D architecture (separates spatial and temporal processing) | |
| - AlphaBlender for temporal blending | |
| - Supports both continuous (KL) and discrete (FSQ) tokenization | |
| - Causal mode for streaming/autoregressive applications | |
| Compresses video [B, C, T, H, W] -> latent [B, latent_dim, t, h, w] | |
| """ | |
| def __init__ ( | |
| self , | |
| in_channels :int =3 , | |
| latent_channels :int =4 , | |
| base_channels :int =64 , | |
| temporal_compression :int =4 , | |
| spatial_compression :int =8 , | |
| causal :bool =True , | |
| use_fsq :bool =False , | |
| fsq_levels :int =8 , | |
| ): | |
| super ().__init__ () | |
| self .in_channels =in_channels | |
| self .latent_channels =latent_channels | |
| self .temporal_compression =temporal_compression | |
| self .spatial_compression =spatial_compression | |
| self .causal =causal | |
| self .use_fsq =use_fsq | |
| self .fsq_levels =fsq_levels | |
| self .encoder =VidTokEncoder ( | |
| in_channels =in_channels , | |
| latent_channels =latent_channels *2 if not use_fsq else latent_channels , | |
| base_channels =base_channels , | |
| temporal_downsample =temporal_compression , | |
| spatial_downsample =spatial_compression , | |
| causal =causal , | |
| ) | |
| self .decoder =VidTokDecoder ( | |
| out_channels =in_channels , | |
| latent_channels =latent_channels , | |
| base_channels =base_channels , | |
| temporal_upsample =temporal_compression , | |
| spatial_upsample =spatial_compression , | |
| causal =causal , | |
| ) | |
| print (f" 🎬 VidTokTokenizer: {temporal_compression }x{spatial_compression }x{spatial_compression } compression") | |
| print (f" Mode: {'FSQ (discrete)'if use_fsq else 'KL (continuous)'}, Causal: {causal }") | |
| def encode (self ,x :torch .Tensor )->torch .Tensor : | |
| """Encode video to latent space.""" | |
| h =self .encoder (x ) | |
| if self .use_fsq : | |
| return self ._fsq_quantize (h ) | |
| else : | |
| mean ,logvar =h .chunk (2 ,dim =1 ) | |
| std =torch .exp (0.5 *logvar ) | |
| eps =torch .randn_like (std ) | |
| return mean +eps *std | |
| def decode (self ,z :torch .Tensor )->torch .Tensor : | |
| """Decode latent to video.""" | |
| return self .decoder (z ) | |
| def _fsq_quantize (self ,z :torch .Tensor )->torch .Tensor : | |
| """Finite Scalar Quantization - quantize each channel independently.""" | |
| z =torch .tanh (z ) | |
| z =torch .round ((z +1 )*(self .fsq_levels -1 )/2 )*2 /(self .fsq_levels -1 )-1 | |
| return z | |
| def forward (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Full forward pass: encode then decode. | |
| Args: | |
| x: [B, C, T, H, W] input video | |
| Returns: | |
| Tuple of (reconstructed video, latent representation) | |
| """ | |
| z =self .encode (x ) | |
| x_recon =self .decode (z ) | |
| return x_recon ,z | |
| class RoPE3DEncoder (nn .Module ): | |
| """ | |
| 3D Rotary Position Embedding for (x, y, t) dimensions. | |
| Matches the 3D-RoPE in video generator for seamless integration. | |
| """ | |
| def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,max_frames :int =32 ,base :float =10000.0 ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .max_height =max_height | |
| self .max_width =max_width | |
| self .max_frames =max_frames | |
| self .base =base | |
| dim_per_axis =dim //3 | |
| self .dim_x =dim_per_axis | |
| self .dim_y =dim_per_axis | |
| self .dim_t =dim -2 *dim_per_axis | |
| inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) | |
| inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) | |
| inv_freq_t =1.0 /(base **(torch .arange (0 ,self .dim_t ,2 ,dtype =torch .float32 )/self .dim_t )) | |
| self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) | |
| self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) | |
| self .register_buffer ('inv_freq_t',inv_freq_t ,persistent =False ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| device =x .device | |
| dtype =x .dtype | |
| pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) | |
| pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) | |
| pos_t =torch .arange (frames ,device =device ,dtype =torch .float32 ) | |
| freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) | |
| freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) | |
| freqs_t =torch .outer (pos_t ,self .inv_freq_t .to (device )) | |
| freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) | |
| freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) | |
| freqs_t =torch .cat ([freqs_t ,freqs_t ],dim =-1 ) | |
| cos_x =freqs_x .cos ().to (dtype ) | |
| sin_x =freqs_x .sin ().to (dtype ) | |
| cos_y =freqs_y .cos ().to (dtype ) | |
| sin_y =freqs_y .sin ().to (dtype ) | |
| cos_t =freqs_t .cos ().to (dtype ) | |
| sin_t =freqs_t .sin ().to (dtype ) | |
| cos_3d =torch .zeros (frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| sin_3d =torch .zeros (frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| for t in range (frames ): | |
| for y in range (height ): | |
| for w in range (width ): | |
| cos_3d [t ,y ,w ,:self .dim_x ]=cos_x [w ] | |
| sin_3d [t ,y ,w ,:self .dim_x ]=sin_x [w ] | |
| cos_3d [t ,y ,w ,self .dim_x :self .dim_x +self .dim_y ]=cos_y [y ] | |
| sin_3d [t ,y ,w ,self .dim_x :self .dim_x +self .dim_y ]=sin_y [y ] | |
| cos_3d [t ,y ,w ,self .dim_x +self .dim_y :]=cos_t [t ] | |
| sin_3d [t ,y ,w ,self .dim_x +self .dim_y :]=sin_t [t ] | |
| cos_3d =cos_3d .view (frames *height *width ,self .dim ) | |
| sin_3d =sin_3d .view (frames *height *width ,self .dim ) | |
| return cos_3d ,sin_3d | |
| def apply_rope_3d_encoder (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : | |
| """Apply 3D rotary position embedding to tensor.""" | |
| x1 =x [...,:x .shape [-1 ]//2 ] | |
| x2 =x [...,x .shape [-1 ]//2 :] | |
| rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) | |
| return x *cos +rotated *sin | |
| class TemporalExpertRouterEncoder (nn .Module ): | |
| """ | |
| Temporal-Aware Expert Router for video encoding. | |
| Routes tokens based on temporal context and motion patterns. | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ): | |
| super ().__init__ () | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .temporal_proj =nn .Linear (hidden_size ,hidden_size ) | |
| self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) | |
| nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) | |
| def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| if temporal_context is not None : | |
| x =x +self .temporal_proj (temporal_context ) | |
| router_logits =self .gate (x ) | |
| router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype ) | |
| top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) | |
| top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) | |
| return top_k_probs ,top_k_indices | |
| class VideoExpertEncoder (nn .Module ): | |
| """Single expert for video encoding with SwiGLU.""" | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ): | |
| super ().__init__ () | |
| self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) | |
| self .act_fn =nn .SiLU () | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) | |
| class TemporalMoELayerEncoder (nn .Module ): | |
| """ | |
| Temporal-Aware MoE Layer for video encoding. | |
| Uses motion-aware routing for expert selection. | |
| """ | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .router =TemporalExpertRouterEncoder (hidden_size ,num_experts ,top_k ) | |
| self .experts =nn .ModuleList ([ | |
| VideoExpertEncoder (hidden_size ,intermediate_size ) | |
| for _ in range (num_experts ) | |
| ]) | |
| self .shared_expert =VideoExpertEncoder (hidden_size ,intermediate_size ) | |
| def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor : | |
| batch_size ,seq_len ,hidden_size =x .shape | |
| x_flat =x .view (-1 ,hidden_size ) | |
| top_k_probs ,top_k_indices =self .router (x_flat ,temporal_context .view (-1 ,hidden_size )if temporal_context is not None else None ) | |
| output =torch .zeros_like (x_flat ) | |
| for expert_idx in range (self .num_experts ): | |
| expert =self .experts [expert_idx ] | |
| for k in range (self .top_k ): | |
| mask =(top_k_indices [:,k ]==expert_idx ) | |
| if mask .any (): | |
| expert_input =x_flat [mask ] | |
| expert_output =expert (expert_input ) | |
| weight =top_k_probs [mask ,k :k +1 ] | |
| output [mask ]=output [mask ]+weight *expert_output | |
| shared_output =self .shared_expert (x_flat ) | |
| output =output +shared_output | |
| return output .view (batch_size ,seq_len ,hidden_size ) | |
| class Causal3DAttentionEncoder (nn .Module ): | |
| """ | |
| 3D Causal Self-Attention with 3D-RoPE for video encoding. | |
| Attends to all positions for encoding (non-causal during encoding). | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .head_dim =hidden_size //num_heads | |
| self .scale =self .head_dim **-0.5 | |
| self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) | |
| self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .norm =nn .LayerNorm (hidden_size ) | |
| self .rope_3d =RoPE3DEncoder (self .head_dim ,max_height ,max_width ,max_frames ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =False )->torch .Tensor : | |
| batch_size ,seq_len ,_ =x .shape | |
| x_norm =self .norm (x ) | |
| qkv =self .to_qkv (x_norm ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) | |
| q ,k ,v =qkv .unbind (dim =2 ) | |
| cos ,sin =self .rope_3d (x ,height ,width ,frames ) | |
| cos =cos .unsqueeze (0 ).unsqueeze (2 ) | |
| sin =sin .unsqueeze (0 ).unsqueeze (2 ) | |
| q =q .transpose (1 ,2 ) | |
| k =k .transpose (1 ,2 ) | |
| v =v .transpose (1 ,2 ) | |
| q =apply_rope_3d_encoder (q ,cos ,sin ) | |
| k =apply_rope_3d_encoder (k ,cos ,sin ) | |
| if causal : | |
| attn_output =F .scaled_dot_product_attention (q ,k ,v ,is_causal =True ) | |
| else : | |
| attn_output =F .scaled_dot_product_attention (q ,k ,v ) | |
| attn_output =attn_output .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) | |
| return self .to_out (attn_output ) | |
| class VideoEncoderBlock (nn .Module ): | |
| """Single block with 3D causal attention and temporal MoE FFN.""" | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| num_heads :int =8 , | |
| num_experts :int =4 , | |
| max_frames :int =32 , | |
| max_height :int =64 , | |
| max_width :int =64 , | |
| ): | |
| super ().__init__ () | |
| self .attn =Causal3DAttentionEncoder (hidden_size ,num_heads ,max_frames ,max_height ,max_width ) | |
| self .moe =TemporalMoELayerEncoder (hidden_size ,hidden_size *4 ,num_experts ) | |
| self .norm =nn .LayerNorm (hidden_size ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =False )->torch .Tensor : | |
| x =x +self .attn (x ,height ,width ,frames ,causal ) | |
| x =self .norm (x +self .moe (x )) | |
| return x | |
| class VideoTiTokTokenizer (nn .Module ): | |
| """ | |
| SOTA TiTok-style 1D Tokenizer for video features with temporal awareness. | |
| This compresses encoded video features (from vision encoder) to a smaller | |
| number of tokens, similar to how TiTokTokenizer works for images but with | |
| proper temporal modeling. | |
| SOTA Features: | |
| - Multi-layer transformer with temporal-aware attention | |
| - 3D positional encoding (spatial + temporal) | |
| - Hierarchical compression: spatial first, then temporal | |
| - Causal temporal attention for streaming compatibility | |
| - Gated cross-attention for selective feature extraction | |
| Note: This is different from VidTokTokenizer which is a 3D VAE for raw video compression. | |
| This tokenizer operates on already-encoded features, not raw pixels. | |
| Converts [B, T*H*W, hidden_size] -> [B, num_tokens, hidden_size] | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| num_tokens :int =64 , | |
| num_patches :int =576 , | |
| max_frames :int =32 , | |
| num_layers :int =2 , | |
| num_heads :int =8 , | |
| dropout :float =0.1 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_tokens =num_tokens | |
| self .num_patches =num_patches | |
| self .max_frames =max_frames | |
| self .num_heads =num_heads | |
| self .patches_per_frame =num_patches //max_frames if max_frames >0 else num_patches | |
| self .spatial_size =int (self .patches_per_frame **0.5 ) | |
| self .temporal_pos =nn .Parameter (torch .randn (1 ,max_frames ,1 ,hidden_size )*0.02 ) | |
| self .spatial_pos =nn .Parameter (torch .randn (1 ,1 ,self .patches_per_frame ,hidden_size )*0.02 ) | |
| self .input_norm =nn .LayerNorm (hidden_size ) | |
| self .input_proj =nn .Linear (hidden_size ,hidden_size ) | |
| self .num_temporal_tokens =min (num_tokens //4 ,max_frames ) | |
| self .num_content_tokens =num_tokens -self .num_temporal_tokens | |
| self .temporal_queries =nn .Parameter (torch .randn (1 ,self .num_temporal_tokens ,hidden_size )*0.02 ) | |
| self .content_queries =nn .Parameter (torch .randn (1 ,self .num_content_tokens ,hidden_size )*0.02 ) | |
| self .compress_layers =nn .ModuleList () | |
| for i in range (num_layers ): | |
| self .compress_layers .append (nn .ModuleDict ({ | |
| 'cross_attn':nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =num_heads , | |
| batch_first =True , | |
| dropout =dropout , | |
| ), | |
| 'cross_gate':nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size ), | |
| nn .Sigmoid (), | |
| ), | |
| 'cross_norm':nn .LayerNorm (hidden_size ), | |
| 'self_attn':nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =num_heads , | |
| batch_first =True , | |
| dropout =dropout , | |
| ), | |
| 'self_norm':nn .LayerNorm (hidden_size ), | |
| 'ffn':nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size *4 ), | |
| nn .GELU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (hidden_size *4 ,hidden_size ), | |
| nn .Dropout (dropout ), | |
| ), | |
| 'ffn_norm':nn .LayerNorm (hidden_size ), | |
| })) | |
| self .fusion_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =num_heads , | |
| batch_first =True , | |
| dropout =dropout , | |
| ) | |
| self .fusion_norm =nn .LayerNorm (hidden_size ) | |
| self .output_proj =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size ), | |
| nn .GELU (), | |
| nn .Linear (hidden_size ,hidden_size ), | |
| ) | |
| self .output_norm =nn .LayerNorm (hidden_size ) | |
| print (f" 🎬 VideoTiTokTokenizer: {num_patches } patches -> {num_tokens } tokens") | |
| print (f" Temporal tokens: {self .num_temporal_tokens }, Content tokens: {self .num_content_tokens }") | |
| print (f" Layers: {num_layers }, Heads: {num_heads }") | |
| def _load_from_state_dict (self ,state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ): | |
| """Production-grade hook to handle dynamic frame counts and token counts when loading checkpoints.""" | |
| t_pos_key =prefix +'temporal_pos' | |
| if t_pos_key in state_dict : | |
| ckpt_pos =state_dict [t_pos_key ] | |
| if ckpt_pos .shape !=self .temporal_pos .shape : | |
| print (f" ⚠️ VideoTiTokTokenizer: Interpolating {t_pos_key } from {ckpt_pos .shape [1 ]} to {self .max_frames } frames.") | |
| ckpt_pos =ckpt_pos .squeeze (2 ).transpose (1 ,2 ) | |
| resized =F .interpolate (ckpt_pos ,size =self .max_frames ,mode ='linear',align_corners =False ) | |
| state_dict [t_pos_key ]=resized .transpose (1 ,2 ).unsqueeze (2 ) | |
| t_query_key =prefix +'temporal_queries' | |
| if t_query_key in state_dict : | |
| ckpt_query =state_dict [t_query_key ] | |
| if ckpt_query .shape !=self .temporal_queries .shape : | |
| print (f" ⚠️ VideoTiTokTokenizer: Interpolating {t_query_key } from {ckpt_query .shape [1 ]} to {self .num_temporal_tokens } tokens.") | |
| ckpt_query =ckpt_query .transpose (1 ,2 ) | |
| resized =F .interpolate (ckpt_query ,size =self .num_temporal_tokens ,mode ='linear',align_corners =False ) | |
| state_dict [t_query_key ]=resized .transpose (1 ,2 ) | |
| c_query_key =prefix +'content_queries' | |
| if c_query_key in state_dict : | |
| ckpt_query =state_dict [c_query_key ] | |
| if ckpt_query .shape !=self .content_queries .shape : | |
| print (f" ⚠️ VideoTiTokTokenizer: Interpolating {c_query_key } from {ckpt_query .shape [1 ]} to {self .num_content_tokens } tokens.") | |
| ckpt_query =ckpt_query .transpose (1 ,2 ) | |
| resized =F .interpolate (ckpt_query ,size =self .num_content_tokens ,mode ='linear',align_corners =False ) | |
| state_dict [c_query_key ]=resized .transpose (1 ,2 ) | |
| super ()._load_from_state_dict (state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ) | |
| def _add_3d_pos_encoding (self ,x :torch .Tensor ,num_frames :int ,patches_per_frame :int )->torch .Tensor : | |
| """Add 3D positional encoding (temporal + spatial).""" | |
| B ,seq_len ,D =x .shape | |
| x =x .reshape (B ,num_frames ,patches_per_frame ,D ) | |
| temporal_pos =self .temporal_pos [:,:num_frames ,:,:] | |
| x =x +temporal_pos | |
| spatial_pos =self .spatial_pos [:,:,:patches_per_frame ,:] | |
| x =x +spatial_pos | |
| return x .reshape (B ,seq_len ,D ) | |
| def forward (self ,x :torch .Tensor ,num_frames :int =None )->torch .Tensor : | |
| """ | |
| Compress video patch features to TiTok-style 1D tokens. | |
| Args: | |
| x: [B, T*H*W, hidden_size] video patch features (flattened spatial-temporal) | |
| or [B, T, H*W, hidden_size] video patch features per frame | |
| num_frames: Number of frames (optional, for temporal embedding) | |
| Returns: | |
| [B, num_tokens, hidden_size] compressed token features | |
| """ | |
| batch_size =x .shape [0 ] | |
| if x .dim ()==4 : | |
| B ,T ,HW ,D =x .shape | |
| x =x .reshape (B ,T *HW ,D ) | |
| num_frames =T | |
| patches_per_frame =HW | |
| else : | |
| seq_len =x .shape [1 ] | |
| if num_frames is None : | |
| num_frames =min (self .max_frames ,seq_len //self .patches_per_frame ) | |
| num_frames =max (1 ,num_frames ) | |
| patches_per_frame =seq_len //num_frames if num_frames >0 else seq_len | |
| x =self .input_norm (x ) | |
| x =self .input_proj (x ) | |
| x =self ._add_3d_pos_encoding (x ,num_frames ,patches_per_frame ) | |
| temporal_queries =self .temporal_queries [:,:min (self .num_temporal_tokens ,num_frames ),:].expand (batch_size ,-1 ,-1 ) | |
| content_queries =self .content_queries .expand (batch_size ,-1 ,-1 ) | |
| queries =torch .cat ([temporal_queries ,content_queries ],dim =1 ) | |
| for layer in self .compress_layers : | |
| cross_out ,_ =layer ['cross_attn'](queries ,x ,x ) | |
| gate =layer ['cross_gate'](queries ) | |
| queries =layer ['cross_norm'](queries +gate *cross_out ) | |
| self_out ,_ =layer ['self_attn'](queries ,queries ,queries ) | |
| queries =layer ['self_norm'](queries +self_out ) | |
| ffn_out =layer ['ffn'](queries ) | |
| queries =layer ['ffn_norm'](queries +ffn_out ) | |
| actual_temporal =temporal_queries .shape [1 ] | |
| temporal_tokens =queries [:,:actual_temporal ,:] | |
| content_tokens =queries [:,actual_temporal :,:] | |
| fused ,_ =self .fusion_attn (content_tokens ,temporal_tokens ,temporal_tokens ) | |
| content_tokens =self .fusion_norm (content_tokens +fused ) | |
| tokens =torch .cat ([temporal_tokens ,content_tokens ],dim =1 ) | |
| if tokens .shape [1 ]<self .num_tokens : | |
| pad_size =self .num_tokens -tokens .shape [1 ] | |
| pad_tokens =self .content_queries [:,:pad_size ,:].expand (batch_size ,-1 ,-1 ) | |
| tokens =torch .cat ([tokens ,pad_tokens ],dim =1 ) | |
| elif tokens .shape [1 ]>self .num_tokens : | |
| tokens =tokens [:,:self .num_tokens ,:] | |
| tokens =self .output_proj (tokens ) | |
| tokens =self .output_norm (tokens ) | |
| return tokens | |
| class VideoEncoder (nn .Module ): | |
| """ | |
| SOTA Video Encoder with 3D-RoPE, 3D Causal Attention, Temporal Expert Routing, and VidTokTokenizer. | |
| Features: | |
| - 3D-RoPE for flexible (x, y, t) positional encodings | |
| - 3D Causal Attention for temporal understanding | |
| - Temporal-Aware Expert Routing for motion patterns | |
| - VidTokTokenizer for efficient 1D token compression (mirrors TiTokTokenizer for images) | |
| - Integrated with vision encoder backbone | |
| - FP16-native numerical stability | |
| """ | |
| def __init__ ( | |
| self , | |
| vision_encoder :VisionEncoder , | |
| max_frames :int =32 , | |
| num_encoder_layers :int =4 , | |
| num_experts :int =4 , | |
| use_3d_rope :bool =True , | |
| use_temporal_moe :bool =True , | |
| use_video_tokenizer :bool =True , | |
| num_video_tokens :int =64 , | |
| ): | |
| super ().__init__ () | |
| self .vision_encoder =vision_encoder | |
| self .max_frames =max_frames | |
| self .hidden_size =vision_encoder .hidden_size | |
| self .use_3d_rope =use_3d_rope | |
| self .use_temporal_moe =use_temporal_moe | |
| self .use_video_tokenizer =use_video_tokenizer | |
| self .image_size =getattr (vision_encoder ,'image_size',384 ) | |
| self .patch_size =getattr (vision_encoder .vision_model .config ,'patch_size',14 ) | |
| self .patches_per_side =self .image_size //self .patch_size | |
| self .num_spatial_tokens =self .patches_per_side **2 | |
| if use_3d_rope : | |
| self .rope_3d =RoPE3DEncoder ( | |
| dim =self .hidden_size , | |
| max_height =self .patches_per_side , | |
| max_width =self .patches_per_side , | |
| max_frames =max_frames , | |
| ) | |
| print (f" 📐 3D-RoPE: (x,y,t) position encoding") | |
| else : | |
| self .rope_3d =None | |
| self .encoder_blocks =nn .ModuleList ([ | |
| VideoEncoderBlock ( | |
| hidden_size =self .hidden_size , | |
| num_heads =8 , | |
| num_experts =num_experts if use_temporal_moe else 1 , | |
| max_frames =max_frames , | |
| max_height =self .patches_per_side , | |
| max_width =self .patches_per_side , | |
| ) | |
| for _ in range (num_encoder_layers ) | |
| ]) | |
| print (f" 🎬 3D Causal Transformer: {num_encoder_layers } layers") | |
| if use_temporal_moe : | |
| print (f" 🎯 Temporal MoE: {num_experts } experts per layer") | |
| if use_video_tokenizer : | |
| self .vidtok =VideoTiTokTokenizer ( | |
| hidden_size =self .hidden_size , | |
| num_tokens =num_video_tokens , | |
| num_patches =self .num_spatial_tokens *max_frames , | |
| max_frames =max_frames , | |
| ) | |
| self .video_tokenizer =self .vidtok | |
| else : | |
| self .vidtok =None | |
| self .video_tokenizer =None | |
| self .temporal_pool_query =nn .Parameter (torch .randn (1 ,1 ,self .hidden_size )*0.02 ) | |
| self .temporal_pool_attn =nn .MultiheadAttention ( | |
| embed_dim =self .hidden_size , | |
| num_heads =8 , | |
| batch_first =True , | |
| dropout =0.1 , | |
| ) | |
| self .temporal_pool_norm =nn .LayerNorm (self .hidden_size ) | |
| self .frame_pos_embed =nn .Parameter (torch .randn (1 ,max_frames ,self .hidden_size )*0.02 ) | |
| print (f" 🎬 Video encoder: max {max_frames } frames (multi-scale enabled)") | |
| def _load_from_state_dict (self ,state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ): | |
| """Production-grade hook to handle dynamic frame counts when loading checkpoints. | |
| Interpolates temporal embeddings if the checkpoint frames differ from max_frames. | |
| """ | |
| embed_key =prefix +'frame_pos_embed' | |
| if embed_key in state_dict : | |
| ckpt_embed =state_dict [embed_key ] | |
| if ckpt_embed .shape !=self .frame_pos_embed .shape : | |
| print (f" ⚠️ VideoEncoder: Interpolating {embed_key } from {ckpt_embed .shape [1 ]} to {self .max_frames } frames.") | |
| ckpt_embed =ckpt_embed .transpose (1 ,2 ) | |
| resized =F .interpolate (ckpt_embed ,size =self .max_frames ,mode ='linear',align_corners =False ) | |
| state_dict [embed_key ]=resized .transpose (1 ,2 ) | |
| super ()._load_from_state_dict (state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ) | |
| def _extract_frame_features (self ,frames :torch .Tensor )->torch .Tensor : | |
| """Extract per-frame features using vision encoder.""" | |
| batch_size ,num_frames =frames .shape [:2 ] | |
| frames_flat =frames .view (-1 ,*frames .shape [2 :]) | |
| if frames_flat .shape [-1 ]!=self .image_size or frames_flat .shape [-2 ]!=self .image_size : | |
| frames_flat =F .interpolate ( | |
| frames_flat , | |
| size =(self .image_size ,self .image_size ), | |
| mode ='bilinear', | |
| align_corners =False | |
| ) | |
| if not any (p .requires_grad for p in self .vision_encoder .parameters ()): | |
| with torch .no_grad (): | |
| frame_features =self .vision_encoder (frames_flat ,return_titok =False ) | |
| else : | |
| frame_features =self .vision_encoder (frames_flat ,return_titok =False ) | |
| return frame_features ,batch_size ,num_frames | |
| def forward ( | |
| self , | |
| frames :torch .Tensor , | |
| return_all_frames :bool =False , | |
| causal :bool =False , | |
| return_tokens :bool =False , | |
| )->torch .Tensor : | |
| """ | |
| Process video frames with 3D-RoPE and Causal Attention. | |
| Args: | |
| frames: [B, T, C, H, W] tensor of video frames | |
| return_all_frames: If True, return all frame features; else return pooled | |
| causal: If True, use causal attention (for autoregressive) | |
| return_tokens: If True, return VideoTokenizer compressed tokens | |
| Returns: | |
| If return_tokens: [B, num_tokens, hidden_size] video tokens | |
| If return_all_frames: [B, T, hidden_size] per-frame features | |
| Else: [B, hidden_size] pooled video representation | |
| """ | |
| frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames ) | |
| _ ,num_patches ,hidden_size =frame_features .shape | |
| height =width =int (math .sqrt (num_patches )) | |
| frame_features =frame_features .view (batch_size ,num_frames ,num_patches ,hidden_size ) | |
| frame_features =frame_features +self .frame_pos_embed [:,:num_frames ].unsqueeze (2 ) | |
| x =frame_features .view (batch_size ,num_frames *num_patches ,hidden_size ) | |
| for block in self .encoder_blocks : | |
| x =block (x ,height ,width ,num_frames ,causal =causal ) | |
| if return_tokens and self .vidtok is not None : | |
| return self .vidtok (x ,num_frames ) | |
| if return_all_frames : | |
| x =x .view (batch_size ,num_frames ,num_patches ,hidden_size ) | |
| return x .mean (dim =2 ) | |
| else : | |
| query =self .temporal_pool_query .expand (batch_size ,-1 ,-1 ) | |
| pooled ,_ =self .temporal_pool_attn (query ,x ,x ) | |
| pooled =self .temporal_pool_norm (query +pooled ) | |
| return pooled .squeeze (1 ) | |
| def encode_frames_separately (self ,frames :torch .Tensor )->torch .Tensor : | |
| """ | |
| Encode frames without temporal attention (for generation conditioning). | |
| Args: | |
| frames: [B, T, C, H, W] tensor of video frames | |
| Returns: | |
| [B, T, hidden_size] tensor of frame features | |
| """ | |
| frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames ) | |
| frame_features =frame_features .mean (dim =1 ) | |
| return frame_features .view (batch_size ,num_frames ,-1 ) | |
| def encode_with_spatial (self ,frames :torch .Tensor )->torch .Tensor : | |
| """ | |
| Encode frames preserving spatial structure (for video generation). | |
| Args: | |
| frames: [B, T, C, H, W] tensor of video frames | |
| Returns: | |
| [B, T, H, W, hidden_size] tensor of spatio-temporal features | |
| """ | |
| frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames ) | |
| _ ,num_patches ,hidden_size =frame_features .shape | |
| height =width =int (math .sqrt (num_patches )) | |
| frame_features =frame_features .view (batch_size ,num_frames ,height ,width ,hidden_size ) | |
| return frame_features | |
| ============================================================================== | |
| MODELS.ENCODERS.AUDIO | |
| ============================================================================== | |
| EPS =1e-5 | |
| class RawWaveformTokenizer (nn .Module ): | |
| """ | |
| Raw Waveform Tokenizer - directly tokenizes audio waveforms without mel spectrograms. | |
| Uses multi-scale 1D convolutions to extract features at different temporal resolutions, | |
| then combines them into a unified representation. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| num_codebooks :int =8 , | |
| codebook_size :int =1024 , | |
| sample_rate :int =16000 , | |
| hop_length :int =320 , | |
| num_conv_layers :int =6 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_codebooks =num_codebooks | |
| self .codebook_size =codebook_size | |
| self .sample_rate =sample_rate | |
| self .hop_length =hop_length | |
| self .conv_layers =nn .ModuleList () | |
| in_channels =1 | |
| channels =[32 ,64 ,128 ,256 ,512 ,hidden_size ] | |
| kernel_sizes =[7 ,5 ,5 ,3 ,3 ,3 ] | |
| strides =[2 ,2 ,2 ,2 ,2 ,2 ] | |
| for i in range (num_conv_layers ): | |
| out_channels =channels [i ]if i <len (channels )else hidden_size | |
| kernel_size =kernel_sizes [i ]if i <len (kernel_sizes )else 3 | |
| stride =strides [i ]if i <len (strides )else 2 | |
| self .conv_layers .append (nn .Sequential ( | |
| nn .Conv1d (in_channels ,out_channels ,kernel_size ,stride ,kernel_size //2 ), | |
| nn .GroupNorm (8 if out_channels >=8 else 1 ,out_channels ), | |
| nn .SiLU (), | |
| )) | |
| in_channels =out_channels | |
| self .codebooks =nn .ModuleList ([ | |
| nn .Embedding (codebook_size ,hidden_size ) | |
| for _ in range (num_codebooks ) | |
| ]) | |
| self .commitment_weight =0.25 | |
| self .output_proj =nn .Linear (hidden_size ,hidden_size ) | |
| print (f" 🎵 RawWaveformTokenizer: {num_codebooks } codebooks x {codebook_size } codes") | |
| def encode (self ,waveform :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Encode waveform to continuous features. | |
| Args: | |
| waveform: [B, T] or [B, 1, T] raw audio waveform | |
| Returns: | |
| features: [B, T', hidden_size] encoded features | |
| indices: [B, T', num_codebooks] quantized indices | |
| """ | |
| if waveform .dim ()==2 : | |
| waveform =waveform .unsqueeze (1 ) | |
| x =waveform | |
| for conv in self .conv_layers : | |
| x =conv (x ) | |
| x =x .transpose (1 ,2 ) | |
| return x ,None | |
| def quantize (self ,features :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Residual Vector Quantization. | |
| Args: | |
| features: [B, T, hidden_size] continuous features | |
| Returns: | |
| quantized: [B, T, hidden_size] quantized features | |
| indices: [B, T, num_codebooks] codebook indices | |
| commitment_loss: scalar commitment loss | |
| """ | |
| batch_size ,seq_len ,_ =features .shape | |
| residual =features | |
| quantized =torch .zeros_like (features ) | |
| all_indices =[] | |
| total_commitment_loss =0.0 | |
| for codebook in self .codebooks : | |
| distances =torch .cdist (residual ,codebook .weight ) | |
| indices =distances .argmin (dim =-1 ) | |
| all_indices .append (indices ) | |
| quantized_step =codebook (indices ) | |
| quantized =quantized +residual +(quantized_step -residual ).detach () | |
| commitment_loss =F .mse_loss (residual .detach (),quantized_step ) | |
| total_commitment_loss =total_commitment_loss +commitment_loss | |
| residual =residual -quantized_step .detach () | |
| indices =torch .stack (all_indices ,dim =-1 ) | |
| commitment_loss =total_commitment_loss *self .commitment_weight | |
| return quantized ,indices ,commitment_loss | |
| def forward (self ,waveform :torch .Tensor ,quantize :bool =False )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]: | |
| """ | |
| Forward pass. | |
| Args: | |
| waveform: [B, T] or [B, 1, T] raw audio | |
| quantize: Whether to apply vector quantization | |
| Returns: | |
| features: [B, T', hidden_size] encoded features | |
| commitment_loss: Optional commitment loss if quantize=True | |
| """ | |
| features ,_ =self .encode (waveform ) | |
| if quantize : | |
| features ,indices ,commitment_loss =self .quantize (features ) | |
| features =self .output_proj (features ) | |
| return features ,commitment_loss | |
| features =self .output_proj (features ) | |
| return features ,None | |
| class SnakeActivation (nn .Module ): | |
| """ | |
| Snake activation function from BigVGAN. | |
| x + (1/a) * sin^2(a * x) | |
| Better than ReLU/SiLU for audio generation - preserves periodicity. | |
| """ | |
| def __init__ (self ,channels :int ,alpha :float =1.0 ): | |
| super ().__init__ () | |
| self .alpha =nn .Parameter (torch .ones (1 ,channels ,1 )*alpha ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| return x +(1.0 /(self .alpha +1e-6 ))*torch .sin (self .alpha *x )**2 | |
| class ResidualBlock1D (nn .Module ): | |
| """Residual block with dilated convolutions for multi-receptive field.""" | |
| def __init__ (self ,channels :int ,kernel_size :int =3 ,dilation :int =1 ): | |
| super ().__init__ () | |
| padding =(kernel_size *dilation -dilation )//2 | |
| self .conv1 =nn .utils .parametrizations .weight_norm ( | |
| nn .Conv1d (channels ,channels ,kernel_size ,padding =padding ,dilation =dilation ) | |
| ) | |
| self .conv2 =nn .utils .parametrizations .weight_norm ( | |
| nn .Conv1d (channels ,channels ,kernel_size ,padding =kernel_size //2 ) | |
| ) | |
| self .activation =SnakeActivation (channels ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| residual =x | |
| x =self .activation (self .conv1 (x )) | |
| x =self .activation (self .conv2 (x )) | |
| return x +residual | |
| class MultiReceptiveFieldFusion (nn .Module ): | |
| """ | |
| Multi-Receptive Field Fusion (MRF) from HiFi-GAN. | |
| Processes input through multiple parallel residual stacks with different | |
| kernel sizes and dilations, then sums results. | |
| """ | |
| def __init__ (self ,channels :int ,kernel_sizes :List [int ]=[3 ,7 ,11 ], | |
| dilations :List [List [int ]]=[[1 ,3 ,5 ],[1 ,3 ,5 ],[1 ,3 ,5 ]]): | |
| super ().__init__ () | |
| self .num_kernels =len (kernel_sizes ) | |
| self .resblocks =nn .ModuleList () | |
| for k ,d_list in zip (kernel_sizes ,dilations ): | |
| blocks =nn .ModuleList ([ | |
| ResidualBlock1D (channels ,k ,d )for d in d_list | |
| ]) | |
| self .resblocks .append (blocks ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| out =None | |
| for blocks in self .resblocks : | |
| h =x | |
| for block in blocks : | |
| h =block (h ) | |
| out =h if out is None else out +h | |
| return out /self .num_kernels | |
| class RawWaveformDecoder (nn .Module ): | |
| """ | |
| SOTA Raw Waveform Decoder - BigVGAN/HiFi-GAN style architecture. | |
| Converts features directly to playable audio waveform without external vocoder. | |
| SOTA Features: | |
| - Snake activation (BigVGAN) - preserves audio periodicity | |
| - Multi-Receptive Field Fusion (HiFi-GAN) - captures patterns at multiple scales | |
| - Weight normalization - stable training | |
| - Efficient upsampling with careful kernel/stride ratios | |
| - Anti-aliased resampling | |
| - Streaming-capable architecture | |
| Speed optimizations: | |
| - Fewer layers with smarter architecture | |
| - Fused operations where possible | |
| - Efficient 256x total upsampling (vs 64x before) | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| sample_rate :int =16000 , | |
| upsample_rates :List [int ]=[8 ,8 ,2 ,2 ], | |
| upsample_kernel_sizes :List [int ]=[16 ,16 ,4 ,4 ], | |
| resblock_kernel_sizes :List [int ]=[3 ,7 ,11 ], | |
| resblock_dilations :List [List [int ]]=[[1 ,3 ,5 ],[1 ,3 ,5 ],[1 ,3 ,5 ]], | |
| initial_channels :int =512 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .sample_rate =sample_rate | |
| self .num_upsamples =len (upsample_rates ) | |
| self .input_proj =nn .utils .parametrizations .weight_norm ( | |
| nn .Conv1d (hidden_size ,initial_channels ,kernel_size =7 ,padding =3 ) | |
| ) | |
| self .upsamplers =nn .ModuleList () | |
| self .mrf_blocks =nn .ModuleList () | |
| channels =initial_channels | |
| for i ,(rate ,kernel )in enumerate (zip (upsample_rates ,upsample_kernel_sizes )): | |
| self .upsamplers .append ( | |
| nn .utils .parametrizations .weight_norm ( | |
| nn .ConvTranspose1d ( | |
| channels ,channels //2 , | |
| kernel_size =kernel ,stride =rate , | |
| padding =(kernel -rate )//2 | |
| ) | |
| ) | |
| ) | |
| channels =channels //2 | |
| self .mrf_blocks .append ( | |
| MultiReceptiveFieldFusion (channels ,resblock_kernel_sizes ,resblock_dilations ) | |
| ) | |
| self .final_activation =SnakeActivation (channels ) | |
| self .output_conv =nn .utils .parametrizations .weight_norm ( | |
| nn .Conv1d (channels ,1 ,kernel_size =7 ,padding =3 ) | |
| ) | |
| self .upsample_factor =1 | |
| for rate in upsample_rates : | |
| self .upsample_factor *=rate | |
| print (f" 🔊 RawWaveformDecoder (SOTA BigVGAN-style):") | |
| print (f" - Snake activation for audio periodicity") | |
| print (f" - Multi-Receptive Field Fusion") | |
| print (f" - {self .upsample_factor }x upsampling") | |
| print (f" - Weight normalized layers") | |
| def forward ( | |
| self , | |
| features :torch .Tensor , | |
| target_length :Optional [int ]=None , | |
| )->torch .Tensor : | |
| """ | |
| Decode features to raw waveform. | |
| Args: | |
| features: [B, T, hidden_size] encoded features | |
| target_length: Optional target waveform length (for matching input length) | |
| Returns: | |
| waveform: [B, T_audio] raw audio waveform in [-1, 1] | |
| """ | |
| x =features .transpose (1 ,2 ) | |
| x =self .input_proj (x ) | |
| for upsample ,mrf in zip (self .upsamplers ,self .mrf_blocks ): | |
| x =upsample (x ) | |
| x =mrf (x ) | |
| x =self .final_activation (x ) | |
| waveform =self .output_conv (x ) | |
| waveform =torch .tanh (waveform ) | |
| waveform =waveform .squeeze (1 ) | |
| if target_length is not None and waveform .shape [-1 ]!=target_length : | |
| waveform =F .interpolate ( | |
| waveform .unsqueeze (1 ), | |
| size =target_length , | |
| mode ='linear', | |
| align_corners =False | |
| ).squeeze (1 ) | |
| return waveform | |
| def decode_from_codes ( | |
| self , | |
| codes :torch .Tensor , | |
| codebooks :nn .ModuleList , | |
| target_length :Optional [int ]=None , | |
| )->torch .Tensor : | |
| """ | |
| Decode directly from codebook indices. | |
| Args: | |
| codes: [B, T, num_codebooks] codebook indices | |
| codebooks: List of nn.Embedding codebooks from encoder | |
| target_length: Optional target waveform length | |
| Returns: | |
| waveform: [B, T_audio] raw audio waveform | |
| """ | |
| features =torch .zeros ( | |
| codes .shape [0 ],codes .shape [1 ],codebooks [0 ].embedding_dim , | |
| device =codes .device ,dtype =codebooks [0 ].weight .dtype | |
| ) | |
| for i ,codebook in enumerate (codebooks ): | |
| features =features +codebook (codes [:,:,i ]) | |
| return self .forward (features ,target_length ) | |
| def stream_decode ( | |
| self , | |
| features :torch .Tensor , | |
| chunk_size :int =10 , | |
| )->torch .Tensor : | |
| """ | |
| Streaming decode for real-time speech synthesis. | |
| Processes features in chunks for low-latency output. | |
| Args: | |
| features: [B, T, hidden_size] encoded features | |
| chunk_size: Number of feature frames per chunk | |
| Yields: | |
| waveform_chunk: [B, chunk_audio_len] audio chunk | |
| """ | |
| batch_size ,seq_len ,_ =features .shape | |
| audio_chunks =[] | |
| for start in range (0 ,seq_len ,chunk_size ): | |
| end =min (start +chunk_size ,seq_len ) | |
| chunk =features [:,start :end ,:] | |
| audio_chunk =self .forward (chunk ) | |
| audio_chunks .append (audio_chunk ) | |
| return torch .cat (audio_chunks ,dim =-1 ) | |
| class SpeakerEncoder (nn .Module ): | |
| """ | |
| Zero-Shot Speaker Encoder for speaker cloning. | |
| Extracts speaker embeddings from reference audio that can be used | |
| to clone the speaker's voice characteristics. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =256 , | |
| output_size :int =256 , | |
| num_layers :int =3 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .output_size =output_size | |
| self .frame_encoder =nn .Sequential ( | |
| nn .Conv1d (80 ,hidden_size ,5 ,1 ,2 ), | |
| nn .ReLU (), | |
| nn .GroupNorm (1 ,hidden_size ), | |
| nn .Conv1d (hidden_size ,hidden_size ,5 ,1 ,2 ), | |
| nn .ReLU (), | |
| nn .GroupNorm (1 ,hidden_size ), | |
| nn .Conv1d (hidden_size ,hidden_size ,5 ,1 ,2 ), | |
| nn .ReLU (), | |
| nn .GroupNorm (1 ,hidden_size ), | |
| ) | |
| self .lstm =nn .LSTM ( | |
| hidden_size ,hidden_size , | |
| num_layers =num_layers , | |
| batch_first =True , | |
| bidirectional =True , | |
| ) | |
| self .attention =nn .Sequential ( | |
| nn .Linear (hidden_size *2 ,hidden_size ), | |
| nn .Tanh (), | |
| nn .Linear (hidden_size ,1 ), | |
| ) | |
| self .output_proj =nn .Linear (hidden_size *2 ,output_size ) | |
| print (f" 👤 SpeakerEncoder: {hidden_size }d -> {output_size }d speaker embedding") | |
| def forward (self ,mel_spectrogram :torch .Tensor )->torch .Tensor : | |
| """ | |
| Extract speaker embedding from mel spectrogram. | |
| Args: | |
| mel_spectrogram: [B, n_mels, T] mel spectrogram | |
| Returns: | |
| speaker_embedding: [B, output_size] speaker embedding | |
| """ | |
| x =self .frame_encoder (mel_spectrogram ) | |
| x =x .transpose (1 ,2 ) | |
| x ,_ =self .lstm (x ) | |
| attn_weights =self .attention (x ) | |
| attn_weights =F .softmax (attn_weights ,dim =1 ) | |
| x =(x *attn_weights ).sum (dim =1 ) | |
| speaker_embedding =self .output_proj (x ) | |
| speaker_embedding =F .normalize (speaker_embedding ,p =2 ,dim =-1 ) | |
| return speaker_embedding | |
| class MonotonicAlignmentSearch (nn .Module ): | |
| """ | |
| Monotonic Alignment Search (MAS) for text-to-audio alignment. | |
| Implements both: | |
| 1. Hard MAS for inference (dynamic programming) | |
| 2. Soft/Fluid MAS for training (differentiable) | |
| """ | |
| def __init__ (self ,hidden_size :int =1024 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .alignment_proj =nn .Sequential ( | |
| nn .Linear (hidden_size *2 ,hidden_size ), | |
| nn .ReLU (), | |
| nn .Linear (hidden_size ,1 ), | |
| ) | |
| self .duration_predictor =nn .Sequential ( | |
| nn .Conv1d (hidden_size ,hidden_size ,3 ,padding =1 ), | |
| nn .ReLU (), | |
| nn .GroupNorm (1 ,hidden_size ), | |
| nn .Conv1d (hidden_size ,hidden_size ,3 ,padding =1 ), | |
| nn .ReLU (), | |
| nn .GroupNorm (1 ,hidden_size ), | |
| nn .Conv1d (hidden_size ,1 ,1 ), | |
| ) | |
| def hard_mas (log_probs :torch .Tensor )->torch .Tensor : | |
| """ | |
| Hard Monotonic Alignment Search using dynamic programming. | |
| Args: | |
| log_probs: [B, T_text, T_audio] log alignment probabilities | |
| Returns: | |
| alignment: [B, T_text, T_audio] hard alignment matrix | |
| """ | |
| batch_size ,text_len ,audio_len =log_probs .shape | |
| device =log_probs .device | |
| Q =torch .full ((batch_size ,text_len ,audio_len ),float ('-inf'),device =device ) | |
| Q [:,0 ,0 ]=log_probs [:,0 ,0 ] | |
| for j in range (1 ,audio_len ): | |
| Q [:,0 ,j ]=Q [:,0 ,j -1 ]+log_probs [:,0 ,j ] | |
| for i in range (1 ,text_len ): | |
| Q [:,i ,i ]=Q [:,i -1 ,i -1 ]+log_probs [:,i ,i ] | |
| for j in range (i +1 ,audio_len ): | |
| Q [:,i ,j ]=torch .max (Q [:,i -1 ,j -1 ],Q [:,i ,j -1 ])+log_probs [:,i ,j ] | |
| alignment =torch .zeros_like (log_probs ) | |
| for b in range (batch_size ): | |
| i ,j =text_len -1 ,audio_len -1 | |
| while i >=0 and j >=0 : | |
| alignment [b ,i ,j ]=1 | |
| if i ==0 : | |
| j -=1 | |
| elif j ==0 : | |
| i -=1 | |
| elif Q [b ,i -1 ,j -1 ]>=Q [b ,i ,j -1 ]: | |
| i -=1 | |
| j -=1 | |
| else : | |
| j -=1 | |
| return alignment | |
| def soft_mas ( | |
| self , | |
| text_hidden :torch .Tensor , | |
| audio_hidden :torch .Tensor , | |
| temperature :float =1.0 , | |
| )->torch .Tensor : | |
| """ | |
| Soft/Differentiable Monotonic Alignment Search. | |
| Args: | |
| text_hidden: [B, T_text, hidden_size] text features | |
| audio_hidden: [B, T_audio, hidden_size] audio features | |
| temperature: Softmax temperature | |
| Returns: | |
| soft_alignment: [B, T_text, T_audio] soft alignment matrix | |
| """ | |
| batch_size ,text_len ,_ =text_hidden .shape | |
| audio_len =audio_hidden .shape [1 ] | |
| text_expanded =text_hidden .unsqueeze (2 ).expand (-1 ,-1 ,audio_len ,-1 ) | |
| audio_expanded =audio_hidden .unsqueeze (1 ).expand (-1 ,text_len ,-1 ,-1 ) | |
| combined =torch .cat ([text_expanded ,audio_expanded ],dim =-1 ) | |
| logits =self .alignment_proj (combined ).squeeze (-1 ) | |
| logits =logits /temperature | |
| position_bias =torch .arange (audio_len ,device =logits .device ).float () | |
| position_bias =position_bias .unsqueeze (0 ).unsqueeze (0 ) | |
| text_positions =torch .arange (text_len ,device =logits .device ).float () | |
| text_positions =text_positions .unsqueeze (0 ).unsqueeze (2 ) | |
| expected_pos =text_positions *(audio_len /text_len ) | |
| monotonic_bias =-0.1 *(position_bias -expected_pos ).abs () | |
| logits =logits +monotonic_bias | |
| soft_alignment =F .softmax (logits ,dim =-1 ) | |
| return soft_alignment | |
| def predict_durations (self ,text_hidden :torch .Tensor )->torch .Tensor : | |
| """ | |
| Predict durations for each text token. | |
| Args: | |
| text_hidden: [B, T_text, hidden_size] text features | |
| Returns: | |
| durations: [B, T_text] predicted durations | |
| """ | |
| x =text_hidden .transpose (1 ,2 ) | |
| durations =self .duration_predictor (x ).squeeze (1 ) | |
| durations =F .softplus (durations ) | |
| return durations | |
| def forward ( | |
| self , | |
| text_hidden :torch .Tensor , | |
| audio_hidden :Optional [torch .Tensor ]=None , | |
| use_hard :bool =False , | |
| )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Compute alignment and durations. | |
| Args: | |
| text_hidden: [B, T_text, hidden_size] text features | |
| audio_hidden: [B, T_audio, hidden_size] audio features (optional for inference) | |
| use_hard: Use hard MAS instead of soft | |
| Returns: | |
| alignment: [B, T_text, T_audio] alignment matrix | |
| durations: [B, T_text] predicted durations | |
| """ | |
| durations =self .predict_durations (text_hidden ) | |
| if audio_hidden is None : | |
| return None ,durations | |
| if use_hard : | |
| text_norm =F .normalize (text_hidden ,dim =-1 ) | |
| audio_norm =F .normalize (audio_hidden ,dim =-1 ) | |
| log_probs =torch .bmm (text_norm ,audio_norm .transpose (1 ,2 )) | |
| alignment =self .hard_mas (log_probs ) | |
| else : | |
| alignment =self .soft_mas (text_hidden ,audio_hidden ) | |
| return alignment ,durations | |
| class RotaryMultiHeadLatentAttention (nn .Module ): | |
| """ | |
| Rotary Multi-Head Latent Attention (RMLA). | |
| Combines: | |
| - Multi-Head Latent Attention (MLA) for compressed KV cache | |
| - Rotary Position Embeddings (RoPE) for position awareness | |
| - Efficient attention computation | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| num_heads :int =16 , | |
| num_kv_heads :int =4 , | |
| head_dim :int =64 , | |
| kv_lora_rank :int =256 , | |
| max_position_embeddings :int =8192 , | |
| dropout :float =0.1 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .num_kv_heads =num_kv_heads | |
| self .head_dim =head_dim | |
| self .kv_lora_rank =kv_lora_rank | |
| self .num_key_value_groups =num_heads //num_kv_heads | |
| self .q_proj =nn .Linear (hidden_size ,num_heads *head_dim ,bias =False ) | |
| self .kv_a_proj =nn .Linear (hidden_size ,kv_lora_rank +head_dim ,bias =False ) | |
| self .kv_b_proj =nn .Linear (kv_lora_rank ,num_kv_heads *head_dim *2 ,bias =False ) | |
| self .kv_norm =nn .LayerNorm (kv_lora_rank ) | |
| self .o_proj =nn .Linear (num_heads *head_dim ,hidden_size ,bias =False ) | |
| self .rotary_emb =self ._create_rotary_embedding (head_dim ,max_position_embeddings ) | |
| self .dropout =nn .Dropout (dropout ) | |
| self .scale =head_dim **-0.5 | |
| def _create_rotary_embedding (self ,dim :int ,max_seq_len :int )->nn .Module : | |
| """Create rotary position embeddings.""" | |
| inv_freq =1.0 /(10000 **(torch .arange (0 ,dim ,2 ).float ()/dim )) | |
| self .register_buffer ('inv_freq',inv_freq ) | |
| t =torch .arange (max_seq_len ).float () | |
| freqs =torch .einsum ('i,j->ij',t ,inv_freq ) | |
| emb =torch .cat ([freqs ,freqs ],dim =-1 ) | |
| self .register_buffer ('cos_cached',emb .cos ()) | |
| self .register_buffer ('sin_cached',emb .sin ()) | |
| return None | |
| def _apply_rotary (self ,x :torch .Tensor ,seq_len :int )->torch .Tensor : | |
| """Apply rotary position embeddings.""" | |
| cos =self .cos_cached [:seq_len ].unsqueeze (0 ).unsqueeze (0 ) | |
| sin =self .sin_cached [:seq_len ].unsqueeze (0 ).unsqueeze (0 ) | |
| x1 ,x2 =x [...,:x .shape [-1 ]//2 ],x [...,x .shape [-1 ]//2 :] | |
| rotated =torch .cat ([-x2 ,x1 ],dim =-1 ) | |
| return x *cos .to (x .dtype )+rotated *sin .to (x .dtype ) | |
| def forward ( | |
| self , | |
| hidden_states :torch .Tensor , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| past_key_value :Optional [Tuple [torch .Tensor ,torch .Tensor ]]=None , | |
| use_cache :bool =False , | |
| )->Tuple [torch .Tensor ,Optional [Tuple [torch .Tensor ,torch .Tensor ]]]: | |
| """ | |
| Forward pass with RMLA. | |
| Args: | |
| hidden_states: [B, T, hidden_size] | |
| attention_mask: Optional attention mask | |
| past_key_value: Optional cached KV states | |
| use_cache: Whether to return updated cache | |
| Returns: | |
| output: [B, T, hidden_size] | |
| present_key_value: Optional updated cache | |
| """ | |
| batch_size ,seq_len ,_ =hidden_states .shape | |
| query =self .q_proj (hidden_states ) | |
| query =query .view (batch_size ,seq_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) | |
| kv_compressed =self .kv_a_proj (hidden_states ) | |
| kv_latent ,k_pe =kv_compressed .split ([self .kv_lora_rank ,self .head_dim ],dim =-1 ) | |
| kv_latent =self .kv_norm (kv_latent ) | |
| kv =self .kv_b_proj (kv_latent ) | |
| key ,value =kv .split (self .num_kv_heads *self .head_dim ,dim =-1 ) | |
| key =key .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) | |
| value =value .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) | |
| query =self ._apply_rotary (query ,seq_len ) | |
| key =self ._apply_rotary (key ,seq_len ) | |
| if past_key_value is not None : | |
| past_key ,past_value =past_key_value | |
| key =torch .cat ([past_key ,key ],dim =2 ) | |
| value =torch .cat ([past_value ,value ],dim =2 ) | |
| present_key_value =(key ,value )if use_cache else None | |
| qk_scale =self .head_dim **-0.25 | |
| kv_len =key .shape [2 ] | |
| use_causal =(attention_mask is None and seq_len >1 and seq_len ==kv_len ) | |
| dropout_p =self .dropout .p if self .training else 0.0 | |
| output =F .scaled_dot_product_attention ( | |
| query *qk_scale , | |
| key *qk_scale , | |
| value , | |
| attn_mask =attention_mask , | |
| is_causal =use_causal , | |
| dropout_p =dropout_p , | |
| scale =1.0 , | |
| enable_gqa =(self .num_key_value_groups >1 ), | |
| ) | |
| output =output .transpose (1 ,2 ).contiguous ().view (batch_size ,seq_len ,-1 ) | |
| output =self .o_proj (output ) | |
| return output ,present_key_value | |
| class InContextAudioPrompting (nn .Module ): | |
| """ | |
| In-Context Audio Prompting for conditioning generation on reference audio. | |
| Allows the model to use a reference audio clip to guide the style, | |
| speaker characteristics, and prosody of generated audio. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| num_prompt_tokens :int =32 , | |
| num_heads :int =8 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_prompt_tokens =num_prompt_tokens | |
| self .prompt_tokens =nn .Parameter (torch .randn (1 ,num_prompt_tokens ,hidden_size )*0.02 ) | |
| self .cross_attn =nn .MultiheadAttention ( | |
| hidden_size ,num_heads , | |
| dropout =0.1 , | |
| batch_first =True , | |
| ) | |
| self .prompt_encoder =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size ,hidden_size ), | |
| ) | |
| self .gate =nn .Parameter (torch .zeros (1 )) | |
| self .norm =nn .LayerNorm (hidden_size ) | |
| def encode_prompt (self ,audio_features :torch .Tensor )->torch .Tensor : | |
| """ | |
| Encode reference audio into prompt tokens. | |
| Args: | |
| audio_features: [B, T, hidden_size] reference audio features | |
| Returns: | |
| prompt: [B, num_prompt_tokens, hidden_size] encoded prompt | |
| """ | |
| batch_size =audio_features .shape [0 ] | |
| prompt =self .prompt_tokens .expand (batch_size ,-1 ,-1 ) | |
| prompt ,_ =self .cross_attn (prompt ,audio_features ,audio_features ) | |
| prompt =self .prompt_encoder (prompt ) | |
| return prompt | |
| def forward ( | |
| self , | |
| hidden_states :torch .Tensor , | |
| prompt_features :Optional [torch .Tensor ]=None , | |
| audio_prompt :Optional [torch .Tensor ]=None , | |
| )->torch .Tensor : | |
| """ | |
| Apply in-context audio prompting. | |
| Args: | |
| hidden_states: [B, T, hidden_size] input features | |
| prompt_features: [B, num_prompt_tokens, hidden_size] pre-encoded prompt | |
| audio_prompt: [B, T_prompt, hidden_size] raw audio features to encode | |
| Returns: | |
| output: [B, T, hidden_size] conditioned features | |
| """ | |
| if prompt_features is None and audio_prompt is not None : | |
| prompt_features =self .encode_prompt (audio_prompt ) | |
| if prompt_features is None : | |
| return hidden_states | |
| attended ,_ =self .cross_attn (hidden_states ,prompt_features ,prompt_features ) | |
| gate =torch .sigmoid (self .gate ) | |
| output =hidden_states +gate *attended | |
| output =self .norm (output ) | |
| return output | |
| class ConvolutionModule (nn .Module ): | |
| """Conformer convolution module with gating.""" | |
| def __init__ (self ,channels :int ,kernel_size :int =31 ,dropout :float =0.1 ): | |
| super ().__init__ () | |
| self .layer_norm =nn .LayerNorm (channels ) | |
| self .pointwise_conv1 =nn .Conv1d (channels ,2 *channels ,kernel_size =1 ) | |
| self .depthwise_conv =nn .Conv1d ( | |
| channels ,channels ,kernel_size =kernel_size , | |
| padding =(kernel_size -1 )//2 ,groups =channels | |
| ) | |
| self .batch_norm =nn .GroupNorm (1 ,channels ) | |
| self .pointwise_conv2 =nn .Conv1d (channels ,channels ,kernel_size =1 ) | |
| self .dropout =nn .Dropout (dropout ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| """x: [B, T, C]""" | |
| x =self .layer_norm (x ) | |
| x =x .transpose (1 ,2 ) | |
| x =self .pointwise_conv1 (x ) | |
| x =F .glu (x ,dim =1 ) | |
| x =self .depthwise_conv (x ) | |
| x =self .batch_norm (x ) | |
| x =F .silu (x ) | |
| x =self .pointwise_conv2 (x ) | |
| x =self .dropout (x ) | |
| return x .transpose (1 ,2 ) | |
| class ConformerBlock (nn .Module ): | |
| """Single Conformer block with RMLA, feed-forward, and convolution.""" | |
| def __init__ ( | |
| self , | |
| d_model :int , | |
| num_heads :int =8 , | |
| ff_expansion :int =4 , | |
| conv_kernel_size :int =31 , | |
| dropout :float =0.1 , | |
| use_rmla :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .use_rmla =use_rmla | |
| self .ff1_norm =nn .LayerNorm (d_model ) | |
| self .ff1 =nn .Sequential ( | |
| nn .Linear (d_model ,d_model *ff_expansion ), | |
| nn .SiLU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (d_model *ff_expansion ,d_model ), | |
| nn .Dropout (dropout ) | |
| ) | |
| if use_rmla : | |
| self .attn =RotaryMultiHeadLatentAttention ( | |
| hidden_size =d_model , | |
| num_heads =num_heads , | |
| num_kv_heads =max (1 ,num_heads //4 ), | |
| head_dim =d_model //num_heads , | |
| kv_lora_rank =d_model //4 , | |
| dropout =dropout , | |
| ) | |
| else : | |
| self .attn_norm =nn .LayerNorm (d_model ) | |
| self .attn =nn .MultiheadAttention (d_model ,num_heads ,dropout =dropout ,batch_first =True ) | |
| self .attn_dropout =nn .Dropout (dropout ) | |
| self .conv =ConvolutionModule (d_model ,conv_kernel_size ,dropout ) | |
| self .ff2_norm =nn .LayerNorm (d_model ) | |
| self .ff2 =nn .Sequential ( | |
| nn .Linear (d_model ,d_model *ff_expansion ), | |
| nn .SiLU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (d_model *ff_expansion ,d_model ), | |
| nn .Dropout (dropout ) | |
| ) | |
| self .final_norm =nn .LayerNorm (d_model ) | |
| def forward ( | |
| self , | |
| x :torch .Tensor , | |
| mask :Optional [torch .Tensor ]=None , | |
| past_key_value :Optional [Tuple ]=None , | |
| use_cache :bool =False , | |
| )->Tuple [torch .Tensor ,Optional [Tuple ]]: | |
| x =x +0.5 *self .ff1 (self .ff1_norm (x )) | |
| if self .use_rmla : | |
| attn_mask =None | |
| if mask is not None : | |
| attn_mask =mask .unsqueeze (1 ).unsqueeze (2 ) | |
| attn_mask =attn_mask .to (dtype =x .dtype ) | |
| attn_mask =attn_mask .masked_fill (attn_mask .bool (),float ('-inf')) | |
| attn_out ,present_kv =self .attn (x ,attention_mask =attn_mask ,past_key_value =past_key_value ,use_cache =use_cache ) | |
| else : | |
| attn_out ,_ =self .attn (self .attn_norm (x ),self .attn_norm (x ),self .attn_norm (x ),key_padding_mask =mask ) | |
| present_kv =None | |
| x =x +self .attn_dropout (attn_out ) | |
| x =x +self .conv (x ) | |
| x =x +0.5 *self .ff2 (self .ff2_norm (x )) | |
| return self .final_norm (x ),present_kv | |
| class AudioEncoder (nn .Module ): | |
| """ | |
| SOTA Audio Encoder with Raw Waveform Tokenization, RMLA, and Voice Enhancement. | |
| Features: | |
| - Raw waveform tokenization (no mel spectrogram) | |
| - Conformer blocks with RMLA | |
| - Zero-shot speaker encoding | |
| - In-context audio prompting | |
| - Gradient checkpointing support for memory efficiency | |
| Voice Enhancement Features (SOTA): | |
| - Prosody-aware EoT Prediction (interruption detection) | |
| - AVD Emotion Recognition (arousal/valence/dominance) | |
| - Dynamic Latent Vocalizations (singing/rapping) | |
| - Neural Sound Effects (beatboxing, breathing, expressions) | |
| - Speculative Decoding (mid-stream token rewriting) | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| n_mels :int =80 , | |
| max_audio_length :int =3000 , | |
| num_layers :int =6 , | |
| num_heads :int =8 , | |
| dropout :float =0.1 , | |
| use_raw_waveform :bool =True , | |
| enable_eot :bool =True , | |
| enable_emotion :bool =True , | |
| enable_singing :bool =True , | |
| enable_effects :bool =True , | |
| enable_speculative :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .max_audio_length =max_audio_length | |
| self .use_raw_waveform =use_raw_waveform | |
| self .gradient_checkpointing =False | |
| self .enable_eot =enable_eot | |
| self .enable_emotion =enable_emotion | |
| self .enable_singing =enable_singing | |
| self .enable_effects =enable_effects | |
| self .enable_speculative =enable_speculative | |
| if use_raw_waveform : | |
| self .waveform_tokenizer =RawWaveformTokenizer ( | |
| hidden_size =hidden_size , | |
| num_codebooks =8 , | |
| codebook_size =1024 , | |
| ) | |
| else : | |
| self .waveform_tokenizer =None | |
| self .conv_subsample =nn .Sequential ( | |
| nn .Conv1d (n_mels ,hidden_size //2 ,kernel_size =3 ,stride =2 ,padding =1 ), | |
| nn .GELU (), | |
| nn .Conv1d (hidden_size //2 ,hidden_size ,kernel_size =3 ,stride =2 ,padding =1 ), | |
| nn .GELU (), | |
| ) | |
| self .speaker_encoder =SpeakerEncoder ( | |
| hidden_size =256 , | |
| output_size =hidden_size //4 , | |
| ) | |
| self .audio_prompting =InContextAudioPrompting ( | |
| hidden_size =hidden_size , | |
| num_prompt_tokens =32 , | |
| ) | |
| self .conformer_blocks =nn .ModuleList ([ | |
| ConformerBlock ( | |
| hidden_size ,num_heads , | |
| ff_expansion =4 , | |
| conv_kernel_size =31 , | |
| dropout =dropout , | |
| use_rmla =True , | |
| ) | |
| for _ in range (num_layers ) | |
| ]) | |
| self .output_proj =nn .Linear (hidden_size ,hidden_size ) | |
| if enable_eot : | |
| self .eot_predictor =ProsodyAwareEoTPredictor (hidden_size ,dropout =dropout ) | |
| else : | |
| self .eot_predictor =None | |
| if enable_emotion : | |
| self .emotion_recognizer =AVDEmotionRecognizer (hidden_size ,dropout =dropout ) | |
| else : | |
| self .emotion_recognizer =None | |
| if enable_singing : | |
| self .vocalizer =DynamicLatentVocalizer (hidden_size ) | |
| else : | |
| self .vocalizer =None | |
| if enable_effects : | |
| self .effects_generator =NeuralSoundEffectGenerator (hidden_size ) | |
| else : | |
| self .effects_generator =None | |
| if enable_speculative : | |
| self .speculative_decoder =SpeculativeAudioDecoder (hidden_size ) | |
| else : | |
| self .speculative_decoder =None | |
| print (f" 🎤 AudioEncoder (RMLA Conformer): {hidden_size }d, {num_layers } layers") | |
| if use_raw_waveform : | |
| print (f" - Raw Waveform Tokenizer enabled") | |
| print (f" - Zero-Shot Speaker Encoder enabled") | |
| print (f" - In-Context Audio Prompting enabled") | |
| print (f" - EoT/Interruption Detection: {enable_eot }") | |
| print (f" - Emotion Recognition (AVD): {enable_emotion }") | |
| print (f" - Singing/Rapping (Vocalizer): {enable_singing }") | |
| print (f" - Sound Effects Generator: {enable_effects }") | |
| print (f" - Speculative Decoding: {enable_speculative }") | |
| def gradient_checkpointing_enable (self ): | |
| """Enable gradient checkpointing to save memory during training.""" | |
| self .gradient_checkpointing =True | |
| if hasattr (self ,'waveform_tokenizer')and self .waveform_tokenizer is not None : | |
| if hasattr (self .waveform_tokenizer ,'gradient_checkpointing'): | |
| self .waveform_tokenizer .gradient_checkpointing =True | |
| if hasattr (self ,'speaker_encoder')and self .speaker_encoder is not None : | |
| if hasattr (self .speaker_encoder ,'gradient_checkpointing'): | |
| self .speaker_encoder .gradient_checkpointing =True | |
| def gradient_checkpointing_disable (self ): | |
| """Disable gradient checkpointing.""" | |
| self .gradient_checkpointing =False | |
| def forward ( | |
| self , | |
| audio_input :torch .Tensor , | |
| speaker_ref :Optional [torch .Tensor ]=None , | |
| audio_prompt :Optional [torch .Tensor ]=None , | |
| mask :Optional [torch .Tensor ]=None , | |
| return_eot :bool =False , | |
| return_emotion :bool =False , | |
| )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [dict ]]: | |
| """ | |
| Process audio to features with optional voice enhancement outputs. | |
| Args: | |
| audio_input: [B, T] raw waveform or [B, n_mels, T] mel spectrogram | |
| speaker_ref: [B, n_mels, T_ref] reference audio for speaker cloning | |
| audio_prompt: [B, T_prompt, hidden_size] audio prompt features | |
| mask: Optional attention mask | |
| return_eot: Whether to return EoT/interruption predictions | |
| return_emotion: Whether to return emotion/AVD predictions | |
| Returns: | |
| features: [B, T', hidden_size] audio features | |
| speaker_embedding: [B, hidden_size//4] speaker embedding (if speaker_ref provided) | |
| extras: dict with EoT/emotion predictions (if requested) | |
| """ | |
| commitment_loss =None | |
| if self .use_raw_waveform and self .waveform_tokenizer is not None : | |
| if audio_input .dim ()==3 and audio_input .shape [1 ]==1 : | |
| audio_input =audio_input .squeeze (1 ) | |
| elif audio_input .dim ()==3 : | |
| audio_input =audio_input .mean (dim =1 ) | |
| x ,commitment_loss =self .waveform_tokenizer (audio_input ) | |
| elif hasattr (self ,'conv_subsample')and self .conv_subsample is not None : | |
| if audio_input .dim ()==2 : | |
| audio_input =audio_input .unsqueeze (1 ) | |
| x =self .conv_subsample (audio_input ) | |
| x =x .transpose (1 ,2 ) | |
| else : | |
| raise RuntimeError ( | |
| f"AudioEncoder: Incompatible configuration. " | |
| f"use_raw_waveform={self .use_raw_waveform }, " | |
| f"waveform_tokenizer={self .waveform_tokenizer is not None }, " | |
| f"conv_subsample={hasattr (self ,'conv_subsample')and self .conv_subsample is not None }" | |
| ) | |
| speaker_embedding =None | |
| if speaker_ref is not None : | |
| speaker_embedding =self .speaker_encoder (speaker_ref ) | |
| if audio_prompt is not None : | |
| x =self .audio_prompting (x ,audio_prompt =audio_prompt ) | |
| if self .gradient_checkpointing and self .training : | |
| from torch .utils .checkpoint import checkpoint | |
| for block in self .conformer_blocks : | |
| def create_custom_forward (module ): | |
| def custom_forward (*inputs ): | |
| return module (*inputs ) | |
| return custom_forward | |
| x ,_ =checkpoint (create_custom_forward (block ),x ,mask ,use_reentrant =False ) | |
| else : | |
| for block in self .conformer_blocks : | |
| x ,_ =block (x ,mask ) | |
| x =self .output_proj (x ) | |
| extras ={} | |
| if return_eot and self .eot_predictor is not None : | |
| extras ["eot"]=self .eot_predictor (x ,mask ) | |
| if return_emotion and self .emotion_recognizer is not None : | |
| extras ["emotion"]=self .emotion_recognizer (x ,mask ) | |
| return x ,speaker_embedding ,extras if extras else None | |
| def detect_interruption ( | |
| self , | |
| audio_features :torch .Tensor , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| )->Optional [dict ]: | |
| """ | |
| Detect interruptions, backchannels, and turn-taking events. | |
| Args: | |
| audio_features: [B, T, hidden_size] encoded audio | |
| attention_mask: [B, T] optional mask | |
| Returns: | |
| dict with eot_logits, event_logits, vad_logits, backoff_prob | |
| """ | |
| if self .eot_predictor is None : | |
| return None | |
| return self .eot_predictor (audio_features ,attention_mask ) | |
| def recognize_emotion ( | |
| self , | |
| audio_features :torch .Tensor , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| )->Optional [dict ]: | |
| """ | |
| Recognize emotion with AVD (arousal/valence/dominance) values. | |
| Args: | |
| audio_features: [B, T, hidden_size] encoded audio | |
| attention_mask: [B, T] optional mask | |
| Returns: | |
| dict with emotion_logits, arousal, valence, dominance, response_mode | |
| """ | |
| if self .emotion_recognizer is None : | |
| return None | |
| return self .emotion_recognizer (audio_features ,attention_mask ) | |
| def generate_vocals ( | |
| self , | |
| text_features :torch .Tensor , | |
| style_id :Optional [torch .Tensor ]=None , | |
| mode_id :Optional [torch .Tensor ]=None , | |
| target_pitch :Optional [torch .Tensor ]=None , | |
| tempo_bpm :Optional [torch .Tensor ]=None , | |
| )->Optional [dict ]: | |
| """ | |
| Generate singing/rapping vocals from text/lyrics. | |
| Args: | |
| text_features: [B, T, hidden_size] text embeddings | |
| style_id: [B] style indices (pop, rock, jazz, etc.) | |
| mode_id: [B] mode indices (speak, sing, rap, hum, etc.) | |
| target_pitch: [B, T] pitch targets | |
| tempo_bpm: [B] tempo in BPM | |
| Returns: | |
| dict with vocal_features, pitch_logits, alignment, durations | |
| """ | |
| if self .vocalizer is None : | |
| return None | |
| return self .vocalizer (text_features ,style_id ,mode_id ,target_pitch ,tempo_bpm ) | |
| def generate_effects ( | |
| self , | |
| effect_ids :torch .Tensor , | |
| context :Optional [torch .Tensor ]=None , | |
| intensity :Optional [torch .Tensor ]=None , | |
| )->Optional [dict ]: | |
| """ | |
| Generate sound effects (beatbox, clicks, breathing, etc.). | |
| Args: | |
| effect_ids: [B] or [B, N] effect type indices | |
| context: [B, T, hidden_size] optional context | |
| intensity: [B] intensity values | |
| Returns: | |
| dict with effect_features, waveform, duration, intensity | |
| """ | |
| if self .effects_generator is None : | |
| return None | |
| return self .effects_generator (effect_ids ,context ,intensity ) | |
| def speculative_generate ( | |
| self , | |
| context :torch .Tensor , | |
| generate_draft :bool =True , | |
| verify_with :Optional [torch .Tensor ]=None , | |
| )->Optional [dict ]: | |
| """ | |
| Generate speculative draft tokens for mid-stream rewriting. | |
| Args: | |
| context: [B, T, hidden_size] current context | |
| generate_draft: whether to generate new draft | |
| verify_with: [B, T', hidden_size] new context to verify against | |
| Returns: | |
| dict with checkpoint, draft_tokens, confidence, accept_prob | |
| """ | |
| if self .speculative_decoder is None : | |
| return None | |
| return self .speculative_decoder (context ,generate_draft ,verify_with ) | |
| class VariancePredictor (nn .Module ): | |
| """Variance predictor for duration, pitch, and energy.""" | |
| def __init__ (self ,hidden_size :int ,kernel_size :int =3 ,dropout :float =0.1 ): | |
| super ().__init__ () | |
| self .conv1 =nn .Conv1d (hidden_size ,hidden_size ,kernel_size ,padding =kernel_size //2 ) | |
| self .norm1 =nn .LayerNorm (hidden_size ) | |
| self .conv2 =nn .Conv1d (hidden_size ,hidden_size ,kernel_size ,padding =kernel_size //2 ) | |
| self .norm2 =nn .LayerNorm (hidden_size ) | |
| self .dropout =nn .Dropout (dropout ) | |
| self .linear =nn .Linear (hidden_size ,1 ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| """x: [B, T, C] -> [B, T]""" | |
| out =self .conv1 (x .transpose (1 ,2 )).transpose (1 ,2 ) | |
| out =F .relu (out ) | |
| out =self .norm1 (out ) | |
| out =self .dropout (out ) | |
| out =self .conv2 (out .transpose (1 ,2 )).transpose (1 ,2 ) | |
| out =F .relu (out ) | |
| out =self .norm2 (out ) | |
| out =self .dropout (out ) | |
| return self .linear (out ).squeeze (-1 ) | |
| class FFTBlock (nn .Module ): | |
| """FFT block for mel decoder.""" | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| num_heads :int =4 , | |
| ff_expansion :int =4 , | |
| kernel_size :int =9 , | |
| dropout :float =0.1 , | |
| ): | |
| super ().__init__ () | |
| self .attn =RotaryMultiHeadLatentAttention ( | |
| hidden_size =hidden_size , | |
| num_heads =num_heads , | |
| num_kv_heads =max (1 ,num_heads //2 ), | |
| head_dim =hidden_size //num_heads , | |
| kv_lora_rank =hidden_size //4 , | |
| dropout =dropout , | |
| ) | |
| self .attn_norm =nn .LayerNorm (hidden_size ) | |
| self .attn_dropout =nn .Dropout (dropout ) | |
| self .ff_norm =nn .LayerNorm (hidden_size ) | |
| self .ff =nn .Sequential ( | |
| nn .Conv1d (hidden_size ,hidden_size *ff_expansion ,kernel_size ,padding =kernel_size //2 ), | |
| nn .ReLU (), | |
| nn .Conv1d (hidden_size *ff_expansion ,hidden_size ,kernel_size ,padding =kernel_size //2 ), | |
| nn .Dropout (dropout ) | |
| ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| residual =x | |
| x =self .attn_norm (x ) | |
| x ,_ =self .attn (x ) | |
| x =residual +self .attn_dropout (x ) | |
| residual =x | |
| x =self .ff_norm (x ) | |
| x =self .ff (x .transpose (1 ,2 )).transpose (1 ,2 ) | |
| x =residual +x | |
| return x | |
| class AudioDecoder (nn .Module ): | |
| """ | |
| SOTA Audio Decoder with MAS, Zero-Shot Speaker Cloning, and Voice Enhancement Support. | |
| Features: | |
| - Monotonic Alignment Search for text-to-audio alignment | |
| - Zero-shot speaker cloning via speaker embeddings | |
| - In-context audio prompting | |
| - Variance adaptor with duration, pitch, energy prediction | |
| - RMLA-based FFT blocks | |
| - Gradient checkpointing support for memory efficiency | |
| Voice Enhancement Features (matching AudioEncoder): | |
| - Emotion conditioning for emotional speech synthesis | |
| - Singing/vocal style synthesis support | |
| - Sound effect generation and integration | |
| - Raw waveform output support (optional) | |
| - Speculative decoding integration | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| n_mels :int =80 , | |
| max_audio_length :int =1000 , | |
| num_speakers :int =256 , | |
| num_decoder_layers :int =4 , | |
| dropout :float =0.1 , | |
| enable_emotion :bool =True , | |
| enable_singing :bool =True , | |
| enable_effects :bool =True , | |
| enable_raw_waveform :bool =True , | |
| enable_speculative :bool =True , | |
| num_emotions :int =10 , | |
| num_vocal_styles :int =8 , | |
| num_vocal_modes :int =6 , | |
| num_effect_types :int =20 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .n_mels =n_mels | |
| self .max_audio_length =max_audio_length | |
| self .gradient_checkpointing =False | |
| self .enable_emotion =enable_emotion | |
| self .enable_singing =enable_singing | |
| self .enable_effects =enable_effects | |
| self .enable_raw_waveform =enable_raw_waveform | |
| self .enable_speculative =enable_speculative | |
| self .mas =MonotonicAlignmentSearch (hidden_size ) | |
| self .speaker_embed =nn .Embedding (num_speakers ,hidden_size //4 ) | |
| self .speaker_proj =nn .Linear (hidden_size //4 ,hidden_size //4 ) | |
| self .audio_prompting =InContextAudioPrompting ( | |
| hidden_size =hidden_size , | |
| num_prompt_tokens =32 , | |
| ) | |
| if enable_emotion : | |
| self .emotion_embed =nn .Embedding (num_emotions ,hidden_size //4 ) | |
| self .avd_proj =nn .Sequential ( | |
| nn .Linear (3 ,hidden_size //8 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //8 ,hidden_size //4 ), | |
| ) | |
| self .emotion_cond_size =hidden_size //4 | |
| else : | |
| self .emotion_embed =None | |
| self .avd_proj =None | |
| self .emotion_cond_size =0 | |
| if enable_singing : | |
| self .vocal_style_embed =nn .Embedding (num_vocal_styles ,hidden_size //4 ) | |
| self .vocal_mode_embed =nn .Embedding (num_vocal_modes ,hidden_size //4 ) | |
| self .tempo_proj =nn .Sequential ( | |
| nn .Linear (1 ,hidden_size //8 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //8 ,hidden_size //4 ), | |
| ) | |
| self .singing_cond_size =hidden_size //4 | |
| else : | |
| self .vocal_style_embed =None | |
| self .vocal_mode_embed =None | |
| self .tempo_proj =None | |
| self .singing_cond_size =0 | |
| if enable_effects : | |
| self .effect_embed =nn .Embedding (num_effect_types ,hidden_size //4 ) | |
| self .effect_intensity_proj =nn .Sequential ( | |
| nn .Linear (1 ,hidden_size //8 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //8 ,hidden_size //4 ), | |
| ) | |
| self .effect_cond_size =hidden_size //4 | |
| else : | |
| self .effect_embed =None | |
| self .effect_intensity_proj =None | |
| self .effect_cond_size =0 | |
| total_cond_size =hidden_size //4 | |
| total_cond_size +=self .emotion_cond_size | |
| total_cond_size +=self .singing_cond_size | |
| total_cond_size +=self .effect_cond_size | |
| self .input_proj =nn .Linear (hidden_size +total_cond_size ,hidden_size ) | |
| self .encoder_blocks =nn .ModuleList ([ | |
| FFTBlock (hidden_size ,num_heads =4 ,ff_expansion =4 ,dropout =dropout ) | |
| for _ in range (4 ) | |
| ]) | |
| self .duration_predictor =VariancePredictor (hidden_size ,dropout =dropout ) | |
| self .pitch_predictor =VariancePredictor (hidden_size ,dropout =dropout ) | |
| self .energy_predictor =VariancePredictor (hidden_size ,dropout =dropout ) | |
| self .pitch_embed =nn .Conv1d (1 ,hidden_size ,kernel_size =9 ,padding =4 ) | |
| self .energy_embed =nn .Conv1d (1 ,hidden_size ,kernel_size =9 ,padding =4 ) | |
| self .decoder_blocks =nn .ModuleList ([ | |
| FFTBlock (hidden_size ,num_heads =4 ,ff_expansion =4 ,dropout =dropout ) | |
| for _ in range (num_decoder_layers ) | |
| ]) | |
| self .mel_linear =nn .Linear (hidden_size ,n_mels ) | |
| self .postnet =nn .ModuleList ([ | |
| nn .Sequential ( | |
| nn .Conv1d (n_mels ,256 ,kernel_size =5 ,padding =2 ), | |
| nn .GroupNorm (1 ,256 ), | |
| nn .Tanh (), | |
| ), | |
| nn .Sequential ( | |
| nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ), | |
| nn .GroupNorm (1 ,256 ), | |
| nn .Tanh (), | |
| ), | |
| nn .Sequential ( | |
| nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ), | |
| nn .GroupNorm (1 ,256 ), | |
| nn .Tanh (), | |
| ), | |
| nn .Sequential ( | |
| nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ), | |
| nn .GroupNorm (1 ,256 ), | |
| nn .Tanh (), | |
| ), | |
| nn .Conv1d (256 ,n_mels ,kernel_size =5 ,padding =2 ), | |
| ]) | |
| if enable_raw_waveform : | |
| self .waveform_decoder =RawWaveformDecoder ( | |
| hidden_size =hidden_size , | |
| sample_rate =16000 , | |
| ) | |
| else : | |
| self .waveform_decoder =None | |
| if enable_speculative : | |
| self .speculative_decoder =SpeculativeAudioDecoder ( | |
| hidden_size =hidden_size , | |
| draft_length =10 , | |
| ) | |
| else : | |
| self .speculative_decoder =None | |
| print (f" 🔊 AudioDecoder (MAS + RMLA): {hidden_size }d -> {n_mels } mels") | |
| print (f" - Monotonic Alignment Search enabled") | |
| print (f" - Zero-Shot Speaker Cloning enabled") | |
| print (f" - In-Context Audio Prompting enabled") | |
| print (f" - Emotion Conditioning: {enable_emotion }") | |
| print (f" - Singing/Vocal Styles: {enable_singing }") | |
| print (f" - Sound Effects: {enable_effects }") | |
| print (f" - Raw Waveform Output: {enable_raw_waveform }") | |
| print (f" - Speculative Decoding: {enable_speculative }") | |
| def gradient_checkpointing_enable (self ): | |
| """Enable gradient checkpointing to save memory during training.""" | |
| self .gradient_checkpointing =True | |
| def gradient_checkpointing_disable (self ): | |
| """Disable gradient checkpointing.""" | |
| self .gradient_checkpointing =False | |
| def forward ( | |
| self , | |
| text_embeds :torch .Tensor , | |
| target_length :Optional [int ]=None , | |
| speaker :Optional [torch .Tensor ]=None , | |
| speaker_embedding :Optional [torch .Tensor ]=None , | |
| audio_prompt :Optional [torch .Tensor ]=None , | |
| audio_features :Optional [torch .Tensor ]=None , | |
| duration_target :Optional [torch .Tensor ]=None , | |
| pitch_target :Optional [torch .Tensor ]=None , | |
| energy_target :Optional [torch .Tensor ]=None , | |
| use_mas :bool =True , | |
| emotion_id :Optional [torch .Tensor ]=None , | |
| avd_values :Optional [torch .Tensor ]=None , | |
| vocal_style_id :Optional [torch .Tensor ]=None , | |
| vocal_mode_id :Optional [torch .Tensor ]=None , | |
| tempo_bpm :Optional [torch .Tensor ]=None , | |
| effect_id :Optional [torch .Tensor ]=None , | |
| effect_intensity :Optional [torch .Tensor ]=None , | |
| output_waveform :bool =False , | |
| use_speculative :bool =False , | |
| )->Tuple [torch .Tensor ,torch .Tensor ,Optional [torch .Tensor ],Optional [dict ]]: | |
| """ | |
| Generate mel-spectrogram from text embeddings with voice enhancement support. | |
| Args: | |
| text_embeds: [B, T, hidden_size] text embeddings | |
| target_length: target mel length (for training) | |
| speaker: [B] speaker IDs (for multi-speaker) | |
| speaker_embedding: [B, hidden_size//4] zero-shot speaker embedding | |
| audio_prompt: [B, T_prompt, hidden_size] audio prompt features | |
| audio_features: [B, T_audio, hidden_size] target audio features (for MAS training) | |
| duration_target: [B, T] ground truth durations | |
| pitch_target: [B, T'] ground truth pitch | |
| energy_target: [B, T'] ground truth energy | |
| use_mas: Whether to use MAS for alignment | |
| Voice enhancement args: | |
| emotion_id: [B] discrete emotion category (0-9) | |
| avd_values: [B, 3] continuous arousal/valence/dominance values | |
| vocal_style_id: [B] singing style (0-7: pop, rock, jazz, etc.) | |
| vocal_mode_id: [B] vocal mode (0-5: speak, sing, rap, hum, whistle, chant) | |
| tempo_bpm: [B] tempo in BPM for singing/rapping | |
| effect_id: [B] sound effect type (0-19) | |
| effect_intensity: [B] effect intensity (0-1) | |
| output_waveform: Whether to also output raw waveform | |
| use_speculative: Whether to use speculative decoding | |
| Returns: | |
| mel: [B, n_mels, T'] generated mel spectrogram | |
| durations: [B, T] predicted durations | |
| alignment: [B, T_text, T_audio] alignment matrix (if use_mas and audio_features provided) | |
| extras: dict with optional outputs (waveform, speculative results) | |
| """ | |
| batch_size ,seq_len ,_ =text_embeds .shape | |
| device =text_embeds .device | |
| dtype =text_embeds .dtype | |
| extras ={} | |
| if speaker_embedding is not None : | |
| spk_emb =self .speaker_proj (speaker_embedding ) | |
| elif speaker is not None : | |
| spk_emb =self .speaker_embed (speaker ) | |
| else : | |
| speaker =torch .zeros (batch_size ,dtype =torch .long ,device =device ) | |
| spk_emb =self .speaker_embed (speaker ) | |
| spk_emb =spk_emb .unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| cond_embeds =[spk_emb ] | |
| if self .enable_emotion : | |
| if emotion_id is not None : | |
| emo_emb =self .emotion_embed (emotion_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| elif avd_values is not None : | |
| emo_emb =self .avd_proj (avd_values .to (dtype )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) | |
| else : | |
| neutral =torch .full ((batch_size ,),6 ,dtype =torch .long ,device =device ) | |
| emo_emb =self .emotion_embed (neutral ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| cond_embeds .append (emo_emb ) | |
| if self .enable_singing : | |
| if vocal_style_id is not None : | |
| style_emb =self .vocal_style_embed (vocal_style_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| else : | |
| default_style =torch .zeros (batch_size ,dtype =torch .long ,device =device ) | |
| style_emb =self .vocal_style_embed (default_style ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| if vocal_mode_id is not None : | |
| mode_emb =self .vocal_mode_embed (vocal_mode_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| else : | |
| default_mode =torch .zeros (batch_size ,dtype =torch .long ,device =device ) | |
| mode_emb =self .vocal_mode_embed (default_mode ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| if tempo_bpm is not None : | |
| tempo_norm =(tempo_bpm .float ()-60 )/120 | |
| tempo_emb =self .tempo_proj (tempo_norm .unsqueeze (-1 ).to (dtype )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) | |
| else : | |
| tempo_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ,dtype =dtype ) | |
| singing_emb =style_emb +mode_emb +tempo_emb | |
| cond_embeds .append (singing_emb ) | |
| if self .enable_effects : | |
| if effect_id is not None : | |
| eff_emb =self .effect_embed (effect_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) | |
| if effect_intensity is not None : | |
| intensity_emb =self .effect_intensity_proj (effect_intensity .unsqueeze (-1 ).to (dtype )) | |
| eff_emb =eff_emb *intensity_emb .unsqueeze (1 ) | |
| else : | |
| eff_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ,dtype =dtype ) | |
| cond_embeds .append (eff_emb ) | |
| all_cond =torch .cat (cond_embeds ,dim =-1 ) | |
| x =torch .cat ([text_embeds ,all_cond ],dim =-1 ) | |
| x =self .input_proj (x ) | |
| if audio_prompt is not None : | |
| x =self .audio_prompting (x ,audio_prompt =audio_prompt ) | |
| if self .gradient_checkpointing and self .training : | |
| from torch .utils .checkpoint import checkpoint | |
| for block in self .encoder_blocks : | |
| def create_custom_forward (module ): | |
| def custom_forward (*inputs ): | |
| return module (*inputs ) | |
| return custom_forward | |
| x =checkpoint (create_custom_forward (block ),x ,use_reentrant =False ) | |
| else : | |
| for block in self .encoder_blocks : | |
| x =block (x ) | |
| alignment =None | |
| if use_mas and audio_features is not None : | |
| alignment ,durations =self .mas (x ,audio_features ,use_hard =not self .training ) | |
| else : | |
| _ ,durations =self .mas (x ) | |
| if duration_target is not None : | |
| durations =duration_target | |
| pitch_pred =self .pitch_predictor (x ) | |
| energy_pred =F .softplus (self .energy_predictor (x )) | |
| MIN_MEL_LENGTH =1 | |
| if target_length is not None : | |
| mel_length =max (MIN_MEL_LENGTH ,target_length ) | |
| else : | |
| mel_length =int (durations .sum (dim =1 ).max ().item ()) | |
| mel_length =max (16 ,min (mel_length ,self .max_audio_length )) | |
| x =F .interpolate (x .transpose (1 ,2 ),size =mel_length ,mode ='linear',align_corners =False ).transpose (1 ,2 ) | |
| pitch =pitch_target if pitch_target is not None else pitch_pred | |
| energy =energy_target if energy_target is not None else energy_pred | |
| pitch_up =F .interpolate (pitch .unsqueeze (1 ),size =mel_length ,mode ='linear',align_corners =False ) | |
| energy_up =F .interpolate (energy .unsqueeze (1 ),size =mel_length ,mode ='linear',align_corners =False ) | |
| pitch_emb =self .pitch_embed (pitch_up ).transpose (1 ,2 ) | |
| energy_emb =self .energy_embed (energy_up ).transpose (1 ,2 ) | |
| x =x +pitch_emb +energy_emb | |
| if self .gradient_checkpointing and self .training : | |
| from torch .utils .checkpoint import checkpoint | |
| for block in self .decoder_blocks : | |
| def create_custom_forward (module ): | |
| def custom_forward (*inputs ): | |
| return module (*inputs ) | |
| return custom_forward | |
| x =checkpoint (create_custom_forward (block ),x ,use_reentrant =False ) | |
| else : | |
| for block in self .decoder_blocks : | |
| x =block (x ) | |
| mel =self .mel_linear (x ).transpose (1 ,2 ) | |
| mel_post =mel | |
| for layer in self .postnet : | |
| mel_post =layer (mel_post ) | |
| mel =mel +mel_post | |
| if output_waveform and self .waveform_decoder is not None : | |
| waveform =self .waveform_decoder (x ) | |
| extras ["waveform"]=waveform | |
| if use_speculative and self .speculative_decoder is not None : | |
| spec_results =self .speculative_decoder (x ) | |
| extras ["speculative"]=spec_results | |
| return mel ,durations ,alignment ,extras if extras else None | |
| class ProsodyAwareEoTPredictor (nn .Module ): | |
| """ | |
| Prosody-aware End-of-Turn (EoT) Prediction for real-time interruption detection. | |
| Detects when a speaker is about to finish their turn, allowing the model to: | |
| - Detect user interruptions (coughs, laughs, "uh-huh", etc.) | |
| - Yield the floor when appropriate | |
| - Adjust response mid-stream | |
| Uses prosodic features (pitch, energy, rhythm) combined with semantic features. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| num_eot_classes :int =5 , | |
| prosody_dim :int =128 , | |
| num_heads :int =4 , | |
| dropout :float =0.1 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_eot_classes =num_eot_classes | |
| self .pitch_conv =nn .Sequential ( | |
| nn .Conv1d (1 ,prosody_dim //2 ,kernel_size =5 ,padding =2 ), | |
| nn .SiLU (), | |
| nn .Conv1d (prosody_dim //2 ,prosody_dim ,kernel_size =3 ,padding =1 ), | |
| ) | |
| self .energy_conv =nn .Sequential ( | |
| nn .Conv1d (1 ,prosody_dim //2 ,kernel_size =5 ,padding =2 ), | |
| nn .SiLU (), | |
| nn .Conv1d (prosody_dim //2 ,prosody_dim ,kernel_size =3 ,padding =1 ), | |
| ) | |
| self .vad_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //2 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //2 ,2 ), | |
| ) | |
| self .event_classifier =nn .Sequential ( | |
| nn .Linear (hidden_size +prosody_dim *2 ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (hidden_size ,hidden_size //2 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //2 ,8 ), | |
| ) | |
| self .temporal_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =num_heads , | |
| dropout =dropout , | |
| batch_first =True , | |
| ) | |
| self .eot_head =nn .Sequential ( | |
| nn .Linear (hidden_size +prosody_dim *2 ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (hidden_size ,num_eot_classes ), | |
| ) | |
| self .backoff_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //4 ,1 ), | |
| nn .Sigmoid (), | |
| ) | |
| print (f" 🎙️ ProsodyAwareEoTPredictor: {num_eot_classes } turn states, {prosody_dim }d prosody") | |
| def extract_prosody (self ,audio_features :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """Extract pitch and energy prosodic features.""" | |
| batch_size ,seq_len ,hidden =audio_features .shape | |
| x =audio_features .transpose (1 ,2 ) | |
| pitch_proxy =x [:,:1 ,:] | |
| energy_proxy =x .pow (2 ).mean (dim =1 ,keepdim =True ) | |
| pitch_features =self .pitch_conv (pitch_proxy ).transpose (1 ,2 ) | |
| energy_features =self .energy_conv (energy_proxy ).transpose (1 ,2 ) | |
| return pitch_features ,energy_features | |
| def forward ( | |
| self , | |
| audio_features :torch .Tensor , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| )->dict : | |
| """ | |
| Predict end-of-turn and interruption events. | |
| Args: | |
| audio_features: [B, T, hidden_size] encoded audio | |
| attention_mask: [B, T] optional mask | |
| Returns: | |
| dict with: | |
| - eot_logits: [B, T, num_eot_classes] turn state predictions | |
| - event_logits: [B, T, 8] interruption event predictions | |
| - vad_logits: [B, T, 2] voice activity predictions | |
| - backoff_prob: [B, T, 1] backoff probability | |
| """ | |
| batch_size ,seq_len ,_ =audio_features .shape | |
| pitch_features ,energy_features =self .extract_prosody (audio_features ) | |
| if attention_mask is not None : | |
| key_padding_mask =~attention_mask .bool () | |
| else : | |
| key_padding_mask =None | |
| contextualized ,_ =self .temporal_attn ( | |
| audio_features ,audio_features ,audio_features , | |
| key_padding_mask =key_padding_mask , | |
| ) | |
| combined =torch .cat ([contextualized ,pitch_features ,energy_features ],dim =-1 ) | |
| eot_logits =self .eot_head (combined ) | |
| event_logits =self .event_classifier (combined ) | |
| vad_logits =self .vad_head (contextualized ) | |
| backoff_prob =self .backoff_head (contextualized ) | |
| return { | |
| "eot_logits":eot_logits , | |
| "event_logits":event_logits , | |
| "vad_logits":vad_logits , | |
| "backoff_prob":backoff_prob , | |
| } | |
| class AVDEmotionRecognizer (nn .Module ): | |
| """ | |
| Continuous AVD (Arousal/Valence/Dominance) Emotion Recognition. | |
| Predicts both discrete emotion categories and continuous AVD values | |
| for nuanced emotion understanding and response adaptation. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| num_emotions :int =10 , | |
| num_layers :int =2 , | |
| dropout :float =0.1 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_emotions =num_emotions | |
| self .emotion_query =nn .Parameter (torch .randn (1 ,1 ,hidden_size )) | |
| self .emotion_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =8 , | |
| dropout =dropout , | |
| batch_first =True , | |
| ) | |
| self .temporal_conv =nn .Sequential ( | |
| nn .Conv1d (hidden_size ,hidden_size ,kernel_size =5 ,padding =2 ,groups =8 ), | |
| nn .SiLU (), | |
| nn .Conv1d (hidden_size ,hidden_size ,kernel_size =3 ,padding =1 ), | |
| ) | |
| self .emotion_classifier =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //2 ), | |
| nn .SiLU (), | |
| nn .Dropout (dropout ), | |
| nn .Linear (hidden_size //2 ,num_emotions ), | |
| ) | |
| self .arousal_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //4 ,1 ), | |
| nn .Sigmoid (), | |
| ) | |
| self .valence_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //4 ,1 ), | |
| nn .Tanh (), | |
| ) | |
| self .dominance_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //4 ,1 ), | |
| nn .Sigmoid (), | |
| ) | |
| self .response_adaptation =nn .Sequential ( | |
| nn .Linear (hidden_size +3 ,hidden_size //2 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //2 ,4 ), | |
| ) | |
| print (f" 😊 AVDEmotionRecognizer: {num_emotions } emotions + continuous AVD") | |
| def forward ( | |
| self , | |
| audio_features :torch .Tensor , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| )->dict : | |
| """ | |
| Recognize emotion from audio features. | |
| Args: | |
| audio_features: [B, T, hidden_size] encoded audio | |
| attention_mask: [B, T] optional mask | |
| Returns: | |
| dict with: | |
| - emotion_logits: [B, num_emotions] discrete emotion | |
| - arousal: [B, 1] arousal value (0-1) | |
| - valence: [B, 1] valence value (-1 to 1) | |
| - dominance: [B, 1] dominance value (0-1) | |
| - response_mode: [B, 4] response adaptation logits | |
| """ | |
| batch_size ,seq_len ,_ =audio_features .shape | |
| x_conv =self .temporal_conv (audio_features .transpose (1 ,2 )).transpose (1 ,2 ) | |
| x =audio_features +x_conv | |
| query =self .emotion_query .expand (batch_size ,-1 ,-1 ) | |
| if attention_mask is not None : | |
| key_padding_mask =~attention_mask .bool () | |
| else : | |
| key_padding_mask =None | |
| emotion_context ,_ =self .emotion_attn ( | |
| query ,x ,x , | |
| key_padding_mask =key_padding_mask , | |
| ) | |
| emotion_vec =emotion_context .squeeze (1 ) | |
| emotion_logits =self .emotion_classifier (emotion_vec ) | |
| arousal =self .arousal_head (emotion_vec ) | |
| valence =self .valence_head (emotion_vec ) | |
| dominance =self .dominance_head (emotion_vec ) | |
| avd_concat =torch .cat ([emotion_vec ,arousal ,valence ,dominance ],dim =-1 ) | |
| response_mode =self .response_adaptation (avd_concat ) | |
| return { | |
| "emotion_logits":emotion_logits , | |
| "arousal":arousal , | |
| "valence":valence , | |
| "dominance":dominance , | |
| "response_mode":response_mode , | |
| } | |
| class DynamicLatentVocalizer (nn .Module ): | |
| """ | |
| Dynamic Latent Vocalizations for singing, rapping, humming, etc. | |
| Extends speech synthesis to include: | |
| - Singing with pitch control | |
| - Rapping with rhythm control | |
| - Humming, whistling, chanting | |
| - Musical style transfer | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| num_styles :int =8 , | |
| num_vocal_modes :int =6 , | |
| pitch_bins :int =256 , | |
| tempo_range :Tuple [int ,int ]=(60 ,180 ), | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_styles =num_styles | |
| self .num_vocal_modes =num_vocal_modes | |
| self .pitch_bins =pitch_bins | |
| self .tempo_range =tempo_range | |
| self .style_embed =nn .Embedding (num_styles ,hidden_size //4 ) | |
| self .mode_embed =nn .Embedding (num_vocal_modes ,hidden_size //4 ) | |
| self .pitch_embed =nn .Embedding (pitch_bins ,hidden_size //4 ) | |
| self .pitch_predictor =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //2 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //2 ,pitch_bins ), | |
| ) | |
| self .tempo_encoder =nn .Sequential ( | |
| nn .Linear (1 ,hidden_size //8 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //8 ,hidden_size //4 ), | |
| ) | |
| self .rhythm_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =4 , | |
| dropout =0.1 , | |
| batch_first =True , | |
| ) | |
| self .style_transfer =nn .Sequential ( | |
| nn .Linear (hidden_size +hidden_size //2 ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size ,hidden_size ), | |
| ) | |
| self .lyrics_aligner =MonotonicAlignmentSearch (hidden_size ) | |
| self .output_proj =nn .Linear (hidden_size ,hidden_size ) | |
| print (f" 🎵 DynamicLatentVocalizer: {num_styles } styles, {num_vocal_modes } modes") | |
| def forward ( | |
| self , | |
| text_features :torch .Tensor , | |
| style_id :Optional [torch .Tensor ]=None , | |
| mode_id :Optional [torch .Tensor ]=None , | |
| target_pitch :Optional [torch .Tensor ]=None , | |
| tempo_bpm :Optional [torch .Tensor ]=None , | |
| )->dict : | |
| """ | |
| Generate vocalization features for singing/rapping/etc. | |
| Args: | |
| text_features: [B, T, hidden_size] text/lyrics embeddings | |
| style_id: [B] style indices (0-7) | |
| mode_id: [B] vocal mode indices (0-5) | |
| target_pitch: [B, T] optional pitch targets | |
| tempo_bpm: [B] optional tempo in BPM | |
| Returns: | |
| dict with: | |
| - vocal_features: [B, T', hidden_size] vocalization features | |
| - pitch_logits: [B, T, pitch_bins] predicted pitch | |
| - alignment: [B, T, T'] text-to-audio alignment | |
| """ | |
| batch_size ,seq_len ,_ =text_features .shape | |
| device =text_features .device | |
| if style_id is None : | |
| style_id =torch .zeros (batch_size ,dtype =torch .long ,device =device ) | |
| if mode_id is None : | |
| mode_id =torch .zeros (batch_size ,dtype =torch .long ,device =device ) | |
| style_emb =self .style_embed (style_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) | |
| mode_emb =self .mode_embed (mode_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) | |
| if tempo_bpm is not None : | |
| tempo_norm =(tempo_bpm .float ()-self .tempo_range [0 ])/(self .tempo_range [1 ]-self .tempo_range [0 ]) | |
| tempo_emb =self .tempo_encoder (tempo_norm .unsqueeze (-1 )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) | |
| else : | |
| tempo_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ) | |
| pitch_logits =self .pitch_predictor (text_features ) | |
| if target_pitch is not None : | |
| pitch_emb =self .pitch_embed (target_pitch ) | |
| else : | |
| pitch_idx =pitch_logits .argmax (dim =-1 ) | |
| pitch_emb =self .pitch_embed (pitch_idx ) | |
| conditions =torch .cat ([style_emb ,mode_emb ,tempo_emb ,pitch_emb ],dim =-1 ) | |
| combined =torch .cat ([text_features ,conditions ],dim =-1 ) | |
| vocal_features =self .style_transfer (combined ) | |
| vocal_features ,_ =self .rhythm_attn (vocal_features ,vocal_features ,vocal_features ) | |
| alignment ,durations =self .lyrics_aligner (text_features ) | |
| vocal_features =self .output_proj (vocal_features ) | |
| return { | |
| "vocal_features":vocal_features , | |
| "pitch_logits":pitch_logits , | |
| "alignment":alignment , | |
| "durations":durations , | |
| } | |
| class NeuralSoundEffectGenerator (nn .Module ): | |
| """ | |
| Neural Style Transfer for Sound Effects and Non-verbal Vocalizations. | |
| Generates: | |
| - Beatboxing (kicks, snares, hi-hats) | |
| - Vocal clicks, pops, tongue sounds | |
| - Breathing, sighing, gasping | |
| - Non-verbal expressions (hmm, aha, wow, etc.) | |
| - Polyphonic ad-libs and harmonies | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| num_effect_types :int =20 , | |
| num_layers :int =3 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_effect_types =num_effect_types | |
| self .effect_embed =nn .Embedding (num_effect_types ,hidden_size ) | |
| self .generator =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size *4 ), | |
| nn .SiLU (), | |
| nn .Unflatten (1 ,(hidden_size ,4 )), | |
| nn .ConvTranspose1d (hidden_size ,hidden_size //2 ,4 ,2 ,1 ), | |
| nn .SiLU (), | |
| nn .ConvTranspose1d (hidden_size //2 ,hidden_size //4 ,4 ,2 ,1 ), | |
| nn .SiLU (), | |
| nn .ConvTranspose1d (hidden_size //4 ,hidden_size //8 ,4 ,2 ,1 ), | |
| nn .SiLU (), | |
| nn .ConvTranspose1d (hidden_size //8 ,1 ,4 ,2 ,1 ), | |
| nn .Tanh (), | |
| ) | |
| self .duration_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //4 ,1 ), | |
| nn .Softplus (), | |
| ) | |
| self .intensity_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //4 ,1 ), | |
| nn .Sigmoid (), | |
| ) | |
| self .blend_attn =nn .MultiheadAttention ( | |
| embed_dim =hidden_size , | |
| num_heads =4 , | |
| batch_first =True , | |
| ) | |
| print (f" 🥁 NeuralSoundEffectGenerator: {num_effect_types } effect types") | |
| def forward ( | |
| self , | |
| effect_ids :torch .Tensor , | |
| context :Optional [torch .Tensor ]=None , | |
| intensity :Optional [torch .Tensor ]=None , | |
| )->dict : | |
| """ | |
| Generate sound effect features. | |
| Args: | |
| effect_ids: [B] or [B, N] effect type indices | |
| context: [B, T, hidden_size] optional context features | |
| intensity: [B] or [B, N] optional intensity values | |
| Returns: | |
| dict with: | |
| - effect_features: [B, T', hidden_size] generated features | |
| - waveform: [B, 1, samples] raw waveform (if generating directly) | |
| - duration: [B, 1] predicted duration | |
| """ | |
| if effect_ids .dim ()==1 : | |
| effect_ids =effect_ids .unsqueeze (1 ) | |
| batch_size ,num_effects =effect_ids .shape | |
| device =effect_ids .device | |
| effect_emb =self .effect_embed (effect_ids ) | |
| if num_effects >1 : | |
| effect_emb ,_ =self .blend_attn (effect_emb ,effect_emb ,effect_emb ) | |
| effect_vec =effect_emb .mean (dim =1 ) | |
| if context is not None : | |
| context_vec =context .mean (dim =1 ) | |
| effect_vec =effect_vec +context_vec | |
| duration =self .duration_head (effect_vec ) | |
| pred_intensity =self .intensity_head (effect_vec ) | |
| if intensity is not None : | |
| pred_intensity =intensity .unsqueeze (-1 )if intensity .dim ()==1 else intensity | |
| effect_vec =effect_vec *pred_intensity | |
| waveform =self .generator (effect_vec ) | |
| return { | |
| "effect_features":effect_emb , | |
| "waveform":waveform , | |
| "duration":duration , | |
| "intensity":pred_intensity , | |
| } | |
| class SpeculativeAudioDecoder (nn .Module ): | |
| """ | |
| Mid-stream Token Rewriting support for Speculative Decoding in audio. | |
| Allows the model to: | |
| - Generate draft audio tokens speculatively | |
| - Accept/reject based on user feedback or context change | |
| - Rollback and regenerate from checkpoints | |
| - Smooth transitions during rewrites | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int =1024 , | |
| draft_length :int =10 , | |
| num_heads :int =8 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .draft_length =draft_length | |
| self .draft_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size ,hidden_size ), | |
| ) | |
| self .verify_head =nn .Sequential ( | |
| nn .Linear (hidden_size *2 ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size ,1 ), | |
| nn .Sigmoid (), | |
| ) | |
| self .checkpoint_encoder =nn .GRU ( | |
| input_size =hidden_size , | |
| hidden_size =hidden_size , | |
| num_layers =1 , | |
| batch_first =True , | |
| ) | |
| self .smoother =nn .Sequential ( | |
| nn .Linear (hidden_size *2 ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size ,hidden_size ), | |
| ) | |
| self .confidence_head =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size //4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size //4 ,1 ), | |
| nn .Sigmoid (), | |
| ) | |
| print (f" ⚡ SpeculativeAudioDecoder: draft_length={draft_length }") | |
| def generate_draft ( | |
| self , | |
| context :torch .Tensor , | |
| num_tokens :int =None , | |
| )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Generate draft tokens speculatively. | |
| Args: | |
| context: [B, T, hidden_size] context features | |
| num_tokens: number of draft tokens (default: self.draft_length) | |
| Returns: | |
| draft_tokens: [B, N, hidden_size] draft features | |
| confidence: [B, N, 1] confidence per token | |
| """ | |
| if num_tokens is None : | |
| num_tokens =self .draft_length | |
| batch_size =context .shape [0 ] | |
| device =context .device | |
| seed =context [:,-1 :,:] | |
| draft_tokens =[] | |
| confidences =[] | |
| current =seed | |
| for _ in range (num_tokens ): | |
| draft =self .draft_head (current ) | |
| conf =self .confidence_head (draft ) | |
| draft_tokens .append (draft ) | |
| confidences .append (conf ) | |
| current =draft | |
| draft_tokens =torch .cat (draft_tokens ,dim =1 ) | |
| confidences =torch .cat (confidences ,dim =1 ) | |
| return draft_tokens ,confidences | |
| def verify_draft ( | |
| self , | |
| draft_tokens :torch .Tensor , | |
| new_context :torch .Tensor , | |
| )->torch .Tensor : | |
| """ | |
| Verify if draft tokens should be accepted given new context. | |
| Args: | |
| draft_tokens: [B, N, hidden_size] draft features | |
| new_context: [B, T, hidden_size] updated context | |
| Returns: | |
| accept_prob: [B, N, 1] probability to accept each token | |
| """ | |
| context_summary =new_context .mean (dim =1 ,keepdim =True ).expand (-1 ,draft_tokens .shape [1 ],-1 ) | |
| combined =torch .cat ([draft_tokens ,context_summary ],dim =-1 ) | |
| accept_prob =self .verify_head (combined ) | |
| return accept_prob | |
| def create_checkpoint (self ,hidden_state :torch .Tensor )->torch .Tensor : | |
| """Save hidden state for potential rollback.""" | |
| _ ,checkpoint =self .checkpoint_encoder (hidden_state ) | |
| return checkpoint .squeeze (0 ) | |
| def smooth_transition ( | |
| self , | |
| old_features :torch .Tensor , | |
| new_features :torch .Tensor , | |
| )->torch .Tensor : | |
| """Create smooth transition between old and new features.""" | |
| combined =torch .cat ([old_features ,new_features ],dim =-1 ) | |
| return self .smoother (combined ) | |
| def forward ( | |
| self , | |
| context :torch .Tensor , | |
| generate_draft :bool =True , | |
| verify_with :Optional [torch .Tensor ]=None , | |
| )->dict : | |
| """ | |
| Full speculative decoding step. | |
| Args: | |
| context: [B, T, hidden_size] current context | |
| generate_draft: whether to generate new draft | |
| verify_with: [B, T', hidden_size] new context to verify against | |
| Returns: | |
| dict with draft tokens, confidence, verification results | |
| """ | |
| results ={} | |
| results ["checkpoint"]=self .create_checkpoint (context ) | |
| if generate_draft : | |
| draft ,confidence =self .generate_draft (context ) | |
| results ["draft_tokens"]=draft | |
| results ["confidence"]=confidence | |
| if verify_with is not None and "draft_tokens"in results : | |
| accept_prob =self .verify_draft (results ["draft_tokens"],verify_with ) | |
| results ["accept_prob"]=accept_prob | |
| return results | |
| ============================================================================== | |
| MODELS.GENERATORS.IMAGE | |
| ============================================================================== | |
| EPS =1e-5 | |
| class RoPE2D (nn .Module ): | |
| """ | |
| 2D Rotary Position Embedding for flexible aspect ratios. | |
| Encodes (x, y) spatial positions for patch-based DiT. | |
| """ | |
| def __init__ (self ,dim :int ,max_height :int =128 ,max_width :int =128 ,base :float =10000.0 ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .max_height =max_height | |
| self .max_width =max_width | |
| self .base =base | |
| self .dim_x =dim //2 | |
| self .dim_y =dim -self .dim_x | |
| inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) | |
| inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) | |
| self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) | |
| self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| device =x .device | |
| dtype =x .dtype | |
| pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) | |
| pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) | |
| freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) | |
| freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) | |
| freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) | |
| freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) | |
| cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| for y in range (height ): | |
| for w in range (width ): | |
| cos_2d [y ,w ,:self .dim_x ]=freqs_x [w ].cos ().to (dtype ) | |
| sin_2d [y ,w ,:self .dim_x ]=freqs_x [w ].sin ().to (dtype ) | |
| cos_2d [y ,w ,self .dim_x :]=freqs_y [y ].cos ().to (dtype ) | |
| sin_2d [y ,w ,self .dim_x :]=freqs_y [y ].sin ().to (dtype ) | |
| cos_2d =cos_2d .view (height *width ,self .dim ) | |
| sin_2d =sin_2d .view (height *width ,self .dim ) | |
| return cos_2d ,sin_2d | |
| def apply_rope_2d (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : | |
| x1 =x [...,:x .shape [-1 ]//2 ] | |
| x2 =x [...,x .shape [-1 ]//2 :] | |
| rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) | |
| return x *cos +rotated *sin | |
| class ImageExpert (nn .Module ): | |
| """Single expert for DiT with SwiGLU activation.""" | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ): | |
| super ().__init__ () | |
| self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) | |
| self .act_fn =nn .SiLU () | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) | |
| class ImageMoERouter (nn .Module ): | |
| """Router for Image MoE with spatial awareness.""" | |
| def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ): | |
| super ().__init__ () | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .norm =nn .LayerNorm (hidden_size ) | |
| self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) | |
| nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) | |
| def forward (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| x_norm =self .norm (x ) | |
| router_logits =self .gate (x_norm ) | |
| router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype ) | |
| top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) | |
| top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) | |
| return top_k_probs ,top_k_indices | |
| class ImageMoELayer (nn .Module ): | |
| """MoE Layer for DiT with shared expert.""" | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .router =ImageMoERouter (hidden_size ,num_experts ,top_k ) | |
| self .experts =nn .ModuleList ([ | |
| ImageExpert (hidden_size ,intermediate_size ) | |
| for _ in range (num_experts ) | |
| ]) | |
| self .shared_expert =ImageExpert (hidden_size ,intermediate_size ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| batch_size ,seq_len ,hidden_size =x .shape | |
| x_flat =x .view (-1 ,hidden_size ) | |
| top_k_probs ,top_k_indices =self .router (x_flat ) | |
| output =torch .zeros_like (x_flat ) | |
| for expert_idx in range (self .num_experts ): | |
| expert =self .experts [expert_idx ] | |
| for k in range (self .top_k ): | |
| mask =(top_k_indices [:,k ]==expert_idx ) | |
| if mask .any (): | |
| expert_input =x_flat [mask ] | |
| expert_output =expert (expert_input ) | |
| weight =top_k_probs [mask ,k :k +1 ] | |
| output [mask ]=output [mask ]+weight *expert_output | |
| shared_output =self .shared_expert (x_flat ) | |
| output =output +shared_output | |
| return output .view (batch_size ,seq_len ,hidden_size ) | |
| class DualStreamSelfAttention (nn .Module ): | |
| """ | |
| Symmetric Dual-Stream Self-Attention (SD3/Flux-style). | |
| Two parallel streams with cross-stream information exchange. | |
| Uses Flash Attention 2.0 via SDPA for O(N) memory. | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .head_dim =hidden_size //num_heads | |
| self .scale =self .head_dim **-0.5 | |
| self ._qk_scale =self .head_dim **-0.25 | |
| self .to_qkv_a =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) | |
| self .to_qkv_b =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) | |
| self .to_out_a =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .to_out_b =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .norm_a =nn .LayerNorm (hidden_size ) | |
| self .norm_b =nn .LayerNorm (hidden_size ) | |
| self .rope_2d =RoPE2D (self .head_dim ,max_height ,max_width ) | |
| def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| batch_size ,seq_len ,_ =x_a .shape | |
| x_a =self .norm_a (x_a ) | |
| x_b =self .norm_b (x_b ) | |
| qkv_a =self .to_qkv_a (x_a ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) | |
| qkv_b =self .to_qkv_b (x_b ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) | |
| q_a ,k_a ,v_a =qkv_a .unbind (dim =2 ) | |
| q_b ,k_b ,v_b =qkv_b .unbind (dim =2 ) | |
| cos ,sin =self .rope_2d (x_a ,height ,width ) | |
| cos =cos .unsqueeze (0 ).unsqueeze (1 ) | |
| sin =sin .unsqueeze (0 ).unsqueeze (1 ) | |
| q_a =q_a .transpose (1 ,2 ) | |
| k_a =k_a .transpose (1 ,2 ) | |
| v_a =v_a .transpose (1 ,2 ) | |
| q_b =q_b .transpose (1 ,2 ) | |
| k_b =k_b .transpose (1 ,2 ) | |
| v_b =v_b .transpose (1 ,2 ) | |
| q_a =apply_rope_2d (q_a ,cos ,sin ) | |
| k_a =apply_rope_2d (k_a ,cos ,sin ) | |
| q_b =apply_rope_2d (q_b ,cos ,sin ) | |
| k_b =apply_rope_2d (k_b ,cos ,sin ) | |
| k_combined =torch .cat ([k_a ,k_b ],dim =2 ) | |
| v_combined =torch .cat ([v_a ,v_b ],dim =2 ) | |
| out_a =F .scaled_dot_product_attention ( | |
| q_a *self ._qk_scale ,k_combined *self ._qk_scale ,v_combined , | |
| is_causal =False ,scale =1.0 , | |
| ) | |
| out_b =F .scaled_dot_product_attention ( | |
| q_b *self ._qk_scale ,k_combined *self ._qk_scale ,v_combined , | |
| is_causal =False ,scale =1.0 , | |
| ) | |
| out_a =out_a .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) | |
| out_b =out_b .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) | |
| out_a =self .to_out_a (out_a ) | |
| out_b =self .to_out_b (out_b ) | |
| return out_a ,out_b | |
| class CrossAttention (nn .Module ): | |
| """Cross-attention for text conditioning.""" | |
| def __init__ (self ,query_dim :int ,context_dim :int =None ,heads :int =8 ): | |
| super ().__init__ () | |
| self .heads =heads | |
| context_dim =context_dim or query_dim | |
| self .head_dim =query_dim //heads | |
| self .scale =self .head_dim **-0.5 | |
| self .norm =nn .LayerNorm (query_dim ) | |
| self .to_q =nn .Linear (query_dim ,query_dim ,bias =False ) | |
| self .to_k =nn .Linear (context_dim ,query_dim ,bias =False ) | |
| self .to_v =nn .Linear (context_dim ,query_dim ,bias =False ) | |
| self .to_out =nn .Linear (query_dim ,query_dim ,bias =False ) | |
| def forward (self ,x :torch .Tensor ,context :torch .Tensor )->torch .Tensor : | |
| batch_size ,seq_len ,_ =x .shape | |
| ctx_len =context .shape [1 ] | |
| x =self .norm (x ) | |
| q =self .to_q (x ).reshape (batch_size ,seq_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) | |
| k =self .to_k (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) | |
| v =self .to_v (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) | |
| qk_scale =self .head_dim **-0.25 | |
| out =F .scaled_dot_product_attention ( | |
| q *qk_scale ,k *qk_scale ,v , | |
| is_causal =False ,scale =1.0 , | |
| ) | |
| out =out .transpose (1 ,2 ).reshape (batch_size ,seq_len ,-1 ) | |
| out =self .to_out (out ) | |
| return out | |
| class DiTBlock (nn .Module ): | |
| """ | |
| DiT Block with Dual-Stream Attention and MoE FFN. | |
| """ | |
| def __init__ (self ,hidden_size :int ,context_dim :int ,num_heads :int =8 ,num_experts :int =4 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .dual_attn =DualStreamSelfAttention (hidden_size ,num_heads ,max_height ,max_width ) | |
| self .cross_attn_a =CrossAttention (hidden_size ,context_dim ,num_heads ) | |
| self .cross_attn_b =CrossAttention (hidden_size ,context_dim ,num_heads ) | |
| self .moe_a =ImageMoELayer (hidden_size ,hidden_size *4 ,num_experts ) | |
| self .moe_b =ImageMoELayer (hidden_size ,hidden_size *4 ,num_experts ) | |
| self .adaLN_a =nn .Sequential ( | |
| nn .SiLU (), | |
| nn .Linear (hidden_size ,hidden_size *6 ), | |
| ) | |
| self .adaLN_b =nn .Sequential ( | |
| nn .SiLU (), | |
| nn .Linear (hidden_size ,hidden_size *6 ), | |
| ) | |
| self .norm1_a =nn .LayerNorm (hidden_size ,elementwise_affine =False ) | |
| self .norm1_b =nn .LayerNorm (hidden_size ,elementwise_affine =False ) | |
| self .norm2_a =nn .LayerNorm (hidden_size ,elementwise_affine =False ) | |
| self .norm2_b =nn .LayerNorm (hidden_size ,elementwise_affine =False ) | |
| def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,context :torch .Tensor ,t_emb :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| shift_a ,scale_a ,gate_a ,shift2_a ,scale2_a ,gate2_a =self .adaLN_a (t_emb ).chunk (6 ,dim =-1 ) | |
| shift_b ,scale_b ,gate_b ,shift2_b ,scale2_b ,gate2_b =self .adaLN_b (t_emb ).chunk (6 ,dim =-1 ) | |
| shift_a =shift_a .unsqueeze (1 ) | |
| scale_a =scale_a .unsqueeze (1 ) | |
| gate_a =gate_a .unsqueeze (1 ) | |
| shift2_a =shift2_a .unsqueeze (1 ) | |
| scale2_a =scale2_a .unsqueeze (1 ) | |
| gate2_a =gate2_a .unsqueeze (1 ) | |
| shift_b =shift_b .unsqueeze (1 ) | |
| scale_b =scale_b .unsqueeze (1 ) | |
| gate_b =gate_b .unsqueeze (1 ) | |
| shift2_b =shift2_b .unsqueeze (1 ) | |
| scale2_b =scale2_b .unsqueeze (1 ) | |
| gate2_b =gate2_b .unsqueeze (1 ) | |
| x_a_norm =self .norm1_a (x_a )*(1 +scale_a )+shift_a | |
| x_b_norm =self .norm1_b (x_b )*(1 +scale_b )+shift_b | |
| attn_out_a ,attn_out_b =self .dual_attn (x_a_norm ,x_b_norm ,height ,width ) | |
| x_a =x_a +gate_a *attn_out_a | |
| x_b =x_b +gate_b *attn_out_b | |
| x_a =x_a +self .cross_attn_a (x_a ,context ) | |
| x_b =x_b +self .cross_attn_b (x_b ,context ) | |
| x_a_norm =self .norm2_a (x_a )*(1 +scale2_a )+shift2_a | |
| x_b_norm =self .norm2_b (x_b )*(1 +scale2_b )+shift2_b | |
| x_a =x_a +gate2_a *self .moe_a (x_a_norm ) | |
| x_b =x_b +gate2_b *self .moe_b (x_b_norm ) | |
| return x_a ,x_b | |
| class FlowMatchingScheduler : | |
| """Flow Matching scheduler for image generation.""" | |
| def __init__ (self ,num_steps :int =50 ,sigma_min :float =0.002 ): | |
| self .num_steps =num_steps | |
| self .sigma_min =sigma_min | |
| self .timesteps =torch .linspace (1 ,0 ,num_steps +1 ) | |
| def get_velocity (self ,x_t :torch .Tensor ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : | |
| return x_0 -x_t | |
| def step (self ,model_output :torch .Tensor ,t :torch .Tensor ,t_prev :torch .Tensor ,x_t :torch .Tensor )->torch .Tensor : | |
| dt =t -t_prev | |
| x_prev =x_t +model_output *dt .view (-1 ,1 ,1 ,1 ) | |
| return x_prev | |
| def add_noise (self ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : | |
| noise =torch .randn_like (x_0 ) | |
| t =t .to (x_0 .dtype ).view (-1 ,1 ,1 ,1 ) | |
| x_t =t *noise +(1 -t )*x_0 | |
| return x_t | |
| class PatchEmbed (nn .Module ): | |
| """Patch embedding for DiT.""" | |
| def __init__ (self ,patch_size :int =2 ,in_channels :int =4 ,hidden_size :int =512 ): | |
| super ().__init__ () | |
| self .patch_size =patch_size | |
| self .proj =nn .Conv2d (in_channels ,hidden_size ,kernel_size =patch_size ,stride =patch_size ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| x =self .proj (x ) | |
| x =x .flatten (2 ).transpose (1 ,2 ) | |
| return x | |
| class UnpatchEmbed (nn .Module ): | |
| """Unpatch embedding to reconstruct image from patches.""" | |
| def __init__ (self ,patch_size :int =2 ,out_channels :int =4 ,hidden_size :int =512 ): | |
| super ().__init__ () | |
| self .patch_size =patch_size | |
| self .out_channels =out_channels | |
| self .proj =nn .Linear (hidden_size ,patch_size *patch_size *out_channels ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int )->torch .Tensor : | |
| x =self .proj (x ) | |
| batch_size =x .shape [0 ] | |
| x =x .reshape (batch_size ,height ,width ,self .patch_size ,self .patch_size ,self .out_channels ) | |
| x =x .permute (0 ,5 ,1 ,3 ,2 ,4 ).reshape (batch_size ,self .out_channels ,height *self .patch_size ,width *self .patch_size ) | |
| return x | |
| class MoEDiT (nn .Module ): | |
| """ | |
| MoE Diffusion Transformer with Dual-Stream Attention. | |
| """ | |
| def __init__ ( | |
| self , | |
| in_channels :int =4 , | |
| out_channels :int =4 , | |
| hidden_size :int =512 , | |
| context_dim :int =1024 , | |
| num_layers :int =8 , | |
| num_heads :int =8 , | |
| num_experts :int =4 , | |
| patch_size :int =2 , | |
| max_image_size :int =64 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .patch_size =patch_size | |
| max_patches =max_image_size //patch_size | |
| self .time_embed =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size *4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size *4 ,hidden_size ), | |
| ) | |
| self .patch_embed =PatchEmbed (patch_size ,in_channels ,hidden_size ) | |
| self .context_proj =nn .Linear (context_dim ,hidden_size ) | |
| self .blocks =nn .ModuleList ([ | |
| DiTBlock (hidden_size ,hidden_size ,num_heads ,num_experts ,max_patches ,max_patches ) | |
| for _ in range (num_layers ) | |
| ]) | |
| self .final_norm =nn .LayerNorm (hidden_size ) | |
| self .unpatch_embed =UnpatchEmbed (patch_size ,out_channels ,hidden_size ) | |
| self .gradient_checkpointing =False | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| nn .init .zeros_ (self .unpatch_embed .proj .weight ) | |
| nn .init .zeros_ (self .unpatch_embed .proj .bias ) | |
| def enable_gradient_checkpointing (self ): | |
| """Enable gradient checkpointing for memory efficiency.""" | |
| self .gradient_checkpointing =True | |
| def forward (self ,x :torch .Tensor ,timesteps :torch .Tensor ,context :torch .Tensor ,mask :Optional [torch .Tensor ]=None )->torch .Tensor : | |
| batch_size ,channels ,height ,width =x .shape | |
| patch_height =height //self .patch_size | |
| patch_width =width //self .patch_size | |
| half_dim =self .hidden_size //2 | |
| t_emb =math .log (10000 )/(half_dim -1 ) | |
| t_emb =torch .exp (torch .arange (half_dim ,device =x .device ,dtype =x .dtype )*-t_emb ) | |
| t_emb =timesteps [:,None ].to (x .dtype )*t_emb [None ,:] | |
| t_emb =torch .cat ([torch .sin (t_emb ),torch .cos (t_emb )],dim =-1 ) | |
| t_emb =self .time_embed (t_emb ) | |
| x_patches =self .patch_embed (x ) | |
| context_proj =self .context_proj (context ) | |
| x_a =x_patches | |
| x_b =x_patches .clone () | |
| for block in self .blocks : | |
| if self .gradient_checkpointing and self .training : | |
| x_a ,x_b =torch .utils .checkpoint .checkpoint ( | |
| block ,x_a ,x_b ,context_proj ,t_emb ,patch_height ,patch_width , | |
| use_reentrant =False | |
| ) | |
| else : | |
| x_a ,x_b =block (x_a ,x_b ,context_proj ,t_emb ,patch_height ,patch_width ) | |
| x_combined =(x_a +x_b )/2 | |
| x_combined =self .final_norm (x_combined ) | |
| velocity =self .unpatch_embed (x_combined ,patch_height ,patch_width ) | |
| return velocity | |
| class ImageVAE (nn .Module ): | |
| """Lightweight VAE for image encoding/decoding.""" | |
| def __init__ (self ,in_channels :int =3 ,latent_channels :int =4 ,base_channels :int =64 ): | |
| super ().__init__ () | |
| self .encoder =nn .Sequential ( | |
| nn .Conv2d (in_channels ,base_channels ,3 ,padding =1 ), | |
| nn .SiLU (), | |
| nn .Conv2d (base_channels ,base_channels *2 ,3 ,stride =2 ,padding =1 ), | |
| nn .SiLU (), | |
| nn .Conv2d (base_channels *2 ,base_channels *4 ,3 ,stride =2 ,padding =1 ), | |
| nn .SiLU (), | |
| nn .Conv2d (base_channels *4 ,latent_channels *2 ,3 ,padding =1 ), | |
| ) | |
| self .decoder =nn .Sequential ( | |
| nn .Conv2d (latent_channels ,base_channels *4 ,3 ,padding =1 ), | |
| nn .SiLU (), | |
| nn .Upsample (scale_factor =2 ,mode ='bilinear',align_corners =False ), | |
| nn .Conv2d (base_channels *4 ,base_channels *2 ,3 ,padding =1 ), | |
| nn .SiLU (), | |
| nn .Upsample (scale_factor =2 ,mode ='bilinear',align_corners =False ), | |
| nn .Conv2d (base_channels *2 ,base_channels ,3 ,padding =1 ), | |
| nn .SiLU (), | |
| nn .Conv2d (base_channels ,in_channels ,3 ,padding =1 ), | |
| ) | |
| def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| h =self .encoder (x ) | |
| mean ,logvar =h .chunk (2 ,dim =1 ) | |
| logvar =torch .clamp (logvar ,-30 ,20 ) | |
| std =torch .exp (0.5 *logvar ) | |
| z =mean +std *torch .randn_like (std ) | |
| return z ,mean ,logvar | |
| def decode (self ,z :torch .Tensor )->torch .Tensor : | |
| return self .decoder (z ) | |
| class MobileDiffusionGenerator (nn .Module ): | |
| """ | |
| SOTA Image Diffusion with MoE-DiT, Flow Matching, 2D-RoPE, Dual-Stream. | |
| Optimized for 2x T4 GPUs with FP16. | |
| """ | |
| def __init__ ( | |
| self , | |
| latent_channels :int =4 , | |
| base_channels :int =128 , | |
| context_dim :int =1024 , | |
| num_inference_steps :int =50 , | |
| image_size :int =256 , | |
| cfg_scale :float =7.5 , | |
| ): | |
| super ().__init__ () | |
| self .latent_channels =latent_channels | |
| self .context_dim =context_dim | |
| self .image_size =image_size | |
| self .latent_size =image_size //4 | |
| self .num_inference_steps =num_inference_steps | |
| self .cfg_scale =cfg_scale | |
| self .vae_encoder =ImageVAE (3 ,latent_channels ,base_channels //2 ) | |
| self .vae_decoder =self .vae_encoder | |
| self .unet =MoEDiT ( | |
| in_channels =latent_channels , | |
| out_channels =latent_channels , | |
| hidden_size =base_channels *4 , | |
| context_dim =context_dim , | |
| num_layers =8 , | |
| num_heads =8 , | |
| num_experts =4 , | |
| patch_size =2 , | |
| max_image_size =self .latent_size , | |
| ) | |
| self .scheduler =FlowMatchingScheduler (num_inference_steps ) | |
| def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| return self .vae_encoder .encode (x ) | |
| def decode (self ,z :torch .Tensor )->torch .Tensor : | |
| return self .vae_decoder .decode (z ) | |
| def training_step (self ,images :torch .Tensor ,context :torch .Tensor ,mask :Optional [torch .Tensor ]=None )->dict : | |
| device =images .device | |
| dtype =images .dtype | |
| batch_size =images .shape [0 ] | |
| z ,mean ,logvar =self .encode (images *2 -1 ) | |
| del images | |
| t =torch .rand (batch_size ,device =device ,dtype =dtype ) | |
| x_t =self .scheduler .add_noise (z ,t ) | |
| target_velocity =self .scheduler .get_velocity (x_t ,z ,t ) | |
| if self .training : | |
| drop_mask =torch .rand (batch_size ,device =device )<0.1 | |
| drop_mask_expanded =drop_mask .view (batch_size ,1 ,1 ).expand_as (context ) | |
| null_ctx =torch .zeros_like (context ) | |
| context =torch .where (drop_mask_expanded ,null_ctx ,context ) | |
| del drop_mask ,drop_mask_expanded ,null_ctx | |
| pred_velocity =self .unet (x_t ,(t *1000 ).to (dtype ),context ,mask ) | |
| del x_t ,context | |
| flow_loss =F .mse_loss (pred_velocity ,target_velocity ) | |
| del pred_velocity ,target_velocity | |
| kl_loss =-0.5 *torch .mean (1 +logvar -mean .pow (2 )-logvar .exp ()) | |
| del z ,mean ,logvar | |
| total_loss =flow_loss +0.0001 *kl_loss | |
| return { | |
| 'flow_loss':flow_loss , | |
| 'kl_loss':kl_loss , | |
| 'total_loss':total_loss , | |
| } | |
| def generate (self ,context :torch .Tensor ,guidance_scale :float =None ,num_steps :int =None ,init_latents :Optional [torch .Tensor ]=None ,mask :Optional [torch .Tensor ]=None ,masked_image_latents :Optional [torch .Tensor ]=None )->torch .Tensor : | |
| device =context .device | |
| batch_size =context .shape [0 ] | |
| seq_len =context .shape [1 ] | |
| guidance_scale =guidance_scale or self .cfg_scale | |
| num_steps =num_steps or self .num_inference_steps | |
| if init_latents is not None : | |
| latents =init_latents | |
| else : | |
| latents =torch .randn (batch_size ,self .latent_channels ,self .latent_size ,self .latent_size ,device =device ) | |
| timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device ) | |
| if guidance_scale >1.0 : | |
| null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype ) | |
| context =torch .cat ([null_ctx ,context ]) | |
| for i in range (num_steps ): | |
| t =timesteps [i ] | |
| t_prev =timesteps [i +1 ] | |
| t_batch =t .expand (batch_size )*1000 | |
| if guidance_scale >1.0 : | |
| latent_input =torch .cat ([latents ,latents ]) | |
| t_input =torch .cat ([t_batch ,t_batch ]) | |
| velocity_pred =self .unet (latent_input ,t_input ,context ,mask ) | |
| velocity_uncond ,velocity_cond =velocity_pred .chunk (2 ) | |
| velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond ) | |
| else : | |
| velocity_pred =self .unet (latents ,t_batch ,context ,mask ) | |
| latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents ) | |
| if mask is not None and masked_image_latents is not None : | |
| latents =masked_image_latents *mask +latents *(1 -mask ) | |
| images =self .decode (latents ) | |
| images =(images +1 )/2 | |
| return torch .clamp (images ,0 ,1 ) | |
| def edit_image (self ,image :torch .Tensor ,context :torch .Tensor ,mask :torch .Tensor ,strength :float =0.8 ,guidance_scale :float =None )->torch .Tensor : | |
| device =image .device | |
| image_norm =image *2 -1 | |
| z ,_ ,_ =self .encode (image_norm ) | |
| mask_latent =F .interpolate (mask ,size =(self .latent_size ,self .latent_size ),mode ='nearest') | |
| num_steps =int (self .num_inference_steps *strength ) | |
| t =torch .tensor ([strength ],device =device ) | |
| noisy_z =self .scheduler .add_noise (z ,t .expand (z .shape [0 ])) | |
| return self .generate ( | |
| context , | |
| guidance_scale =guidance_scale , | |
| num_steps =num_steps , | |
| init_latents =noisy_z , | |
| mask =mask_latent , | |
| masked_image_latents =z , | |
| ) | |
| ============================================================================== | |
| MODELS.GENERATORS.VIDEO | |
| ============================================================================== | |
| EPS =1e-5 | |
| class InterleavedMRoPE (nn .Module ): | |
| """ | |
| Interleaved Multi-dimensional Rotary Position Embedding (MRoPE). | |
| SOTA: Full-frequency allocation over time, width, and height via robust positional embeddings. | |
| Unlike separate spatial and temporal RoPE, Interleaved-MRoPE allocates frequencies across | |
| all three dimensions jointly, enhancing long-horizon video reasoning. | |
| Key advantages: | |
| - Better temporal-spatial correlation modeling | |
| - More robust for variable aspect ratios and frame counts | |
| - Improved long-range video understanding | |
| """ | |
| def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,max_frames :int =64 ,base :float =10000.0 ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .max_height =max_height | |
| self .max_width =max_width | |
| self .max_frames =max_frames | |
| self .base =base | |
| self .dim_t =dim //3 | |
| self .dim_y =dim //3 | |
| self .dim_x =dim -self .dim_t -self .dim_y | |
| inv_freq_t =1.0 /(base **(torch .arange (0 ,self .dim_t ,2 ,dtype =torch .float32 )/self .dim_t )) | |
| inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) | |
| inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) | |
| self .register_buffer ('inv_freq_t',inv_freq_t ,persistent =False ) | |
| self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) | |
| self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int ,num_frames :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Compute interleaved 3D positional embeddings. | |
| Args: | |
| x: Input tensor for device/dtype reference | |
| height: Spatial height | |
| width: Spatial width | |
| num_frames: Temporal frames | |
| Returns: | |
| cos, sin: [T * H * W, dim] positional embeddings | |
| """ | |
| device =x .device | |
| dtype =x .dtype | |
| pos_t =torch .arange (num_frames ,device =device ,dtype =torch .float32 ) | |
| pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) | |
| pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) | |
| freqs_t =torch .outer (pos_t ,self .inv_freq_t .to (device )) | |
| freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) | |
| freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) | |
| freqs_t =torch .cat ([freqs_t ,freqs_t ],dim =-1 ) | |
| freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) | |
| freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) | |
| seq_len =num_frames *height *width | |
| cos_3d =torch .zeros (num_frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| sin_3d =torch .zeros (num_frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| for t in range (num_frames ): | |
| for h in range (height ): | |
| for w in range (width ): | |
| cos_3d [t ,h ,w ,:self .dim_t ]=freqs_t [t ].cos ().to (dtype ) | |
| sin_3d [t ,h ,w ,:self .dim_t ]=freqs_t [t ].sin ().to (dtype ) | |
| cos_3d [t ,h ,w ,self .dim_t :self .dim_t +self .dim_y ]=freqs_y [h ].cos ().to (dtype ) | |
| sin_3d [t ,h ,w ,self .dim_t :self .dim_t +self .dim_y ]=freqs_y [h ].sin ().to (dtype ) | |
| cos_3d [t ,h ,w ,self .dim_t +self .dim_y :]=freqs_x [w ].cos ().to (dtype ) | |
| sin_3d [t ,h ,w ,self .dim_t +self .dim_y :]=freqs_x [w ].sin ().to (dtype ) | |
| cos_3d =cos_3d .view (seq_len ,self .dim ) | |
| sin_3d =sin_3d .view (seq_len ,self .dim ) | |
| return cos_3d ,sin_3d | |
| class RoPE2D (nn .Module ): | |
| """ | |
| 2D Rotary Position Embedding for spatial dimensions (memory efficient). | |
| Used for spatial attention in factorized video attention. | |
| """ | |
| def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,base :float =10000.0 ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .dim_x =dim //2 | |
| self .dim_y =dim -self .dim_x | |
| inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) | |
| inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) | |
| self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) | |
| self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| device =x .device | |
| dtype =x .dtype | |
| pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) | |
| pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) | |
| freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) | |
| freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) | |
| cos_x =torch .cat ([freqs_x .cos (),freqs_x .cos ()],dim =-1 ) | |
| sin_x =torch .cat ([freqs_x .sin (),freqs_x .sin ()],dim =-1 ) | |
| cos_y =torch .cat ([freqs_y .cos (),freqs_y .cos ()],dim =-1 ) | |
| sin_y =torch .cat ([freqs_y .sin (),freqs_y .sin ()],dim =-1 ) | |
| cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) | |
| cos_2d [:,:,:self .dim_x ]=cos_x .unsqueeze (0 ).expand (height ,-1 ,-1 ) | |
| sin_2d [:,:,:self .dim_x ]=sin_x .unsqueeze (0 ).expand (height ,-1 ,-1 ) | |
| cos_2d [:,:,self .dim_x :]=cos_y .unsqueeze (1 ).expand (-1 ,width ,-1 ) | |
| sin_2d [:,:,self .dim_x :]=sin_y .unsqueeze (1 ).expand (-1 ,width ,-1 ) | |
| return cos_2d .view (height *width ,self .dim ).to (dtype ),sin_2d .view (height *width ,self .dim ).to (dtype ) | |
| class RoPE1D (nn .Module ): | |
| """ | |
| 1D Rotary Position Embedding for temporal dimension. | |
| Used for temporal attention in factorized video attention. | |
| """ | |
| def __init__ (self ,dim :int ,max_len :int =64 ,base :float =10000.0 ): | |
| super ().__init__ () | |
| self .dim =dim | |
| inv_freq =1.0 /(base **(torch .arange (0 ,dim ,2 ,dtype =torch .float32 )/dim )) | |
| self .register_buffer ('inv_freq',inv_freq ,persistent =False ) | |
| def forward (self ,x :torch .Tensor ,seq_len :int )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| device =x .device | |
| dtype =x .dtype | |
| pos =torch .arange (seq_len ,device =device ,dtype =torch .float32 ) | |
| freqs =torch .outer (pos ,self .inv_freq .to (device )) | |
| freqs =torch .cat ([freqs ,freqs ],dim =-1 ) | |
| return freqs .cos ().to (dtype ),freqs .sin ().to (dtype ) | |
| def apply_rope (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : | |
| """Apply rotary position embedding.""" | |
| x1 =x [...,:x .shape [-1 ]//2 ] | |
| x2 =x [...,x .shape [-1 ]//2 :] | |
| rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) | |
| return x *cos +rotated *sin | |
| class TemporalExpertRouter (nn .Module ): | |
| """ | |
| Temporal-Aware Expert Router for video generation. | |
| Routes tokens based on temporal context and motion patterns. | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ): | |
| super ().__init__ () | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .temporal_proj =nn .Linear (hidden_size ,hidden_size ) | |
| self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) | |
| nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) | |
| def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| if temporal_context is not None : | |
| x =x +self .temporal_proj (temporal_context ) | |
| router_logits =self .gate (x ) | |
| router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype ) | |
| top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) | |
| top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) | |
| return top_k_probs ,top_k_indices | |
| class VideoExpert (nn .Module ): | |
| """Single expert for video processing with SwiGLU.""" | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ): | |
| super ().__init__ () | |
| self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) | |
| self .act_fn =nn .SiLU () | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) | |
| class TemporalMoELayer (nn .Module ): | |
| """ | |
| Temporal-Aware MoE Layer for video generation. | |
| Uses motion-aware routing for expert selection. | |
| """ | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .router =TemporalExpertRouter (hidden_size ,num_experts ,top_k ) | |
| self .experts =nn .ModuleList ([ | |
| VideoExpert (hidden_size ,intermediate_size ) | |
| for _ in range (num_experts ) | |
| ]) | |
| self .shared_expert =VideoExpert (hidden_size ,intermediate_size ) | |
| def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor : | |
| batch_size ,seq_len ,hidden_size =x .shape | |
| x_flat =x .view (-1 ,hidden_size ) | |
| top_k_probs ,top_k_indices =self .router (x_flat ,temporal_context .view (-1 ,hidden_size )if temporal_context is not None else None ) | |
| output =torch .zeros_like (x_flat ) | |
| for expert_idx in range (self .num_experts ): | |
| expert =self .experts [expert_idx ] | |
| for k in range (self .top_k ): | |
| mask =(top_k_indices [:,k ]==expert_idx ) | |
| if mask .any (): | |
| expert_input =x_flat [mask ] | |
| expert_output =expert (expert_input ) | |
| weight =top_k_probs [mask ,k :k +1 ] | |
| output [mask ]=output [mask ]+weight *expert_output | |
| shared_output =self .shared_expert (x_flat ) | |
| output =output +shared_output | |
| return output .view (batch_size ,seq_len ,hidden_size ) | |
| class SpatialAttention (nn .Module ): | |
| """ | |
| Spatial self-attention: each frame attends only within itself. | |
| Memory: O(T * (H*W)^2) instead of O((T*H*W)^2) | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .head_dim =hidden_size //num_heads | |
| self .scale =self .head_dim **-0.5 | |
| self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) | |
| self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .rope_2d =RoPE2D (self .head_dim ,max_height ,max_width ) | |
| self .norm =nn .LayerNorm (hidden_size ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int )->torch .Tensor : | |
| batch_size ,seq_len ,_ =x .shape | |
| spatial_len =height *width | |
| x =self .norm (x ) | |
| x =x .view (batch_size *frames ,spatial_len ,self .hidden_size ) | |
| qkv =self .to_qkv (x ).reshape (batch_size *frames ,spatial_len ,3 ,self .num_heads ,self .head_dim ) | |
| q ,k ,v =qkv .unbind (dim =2 ) | |
| cos ,sin =self .rope_2d (x ,height ,width ) | |
| cos =cos .unsqueeze (0 ).unsqueeze (1 ) | |
| sin =sin .unsqueeze (0 ).unsqueeze (1 ) | |
| q =q .transpose (1 ,2 ) | |
| k =k .transpose (1 ,2 ) | |
| v =v .transpose (1 ,2 ) | |
| q =apply_rope (q ,cos ,sin ) | |
| k =apply_rope (k ,cos ,sin ) | |
| qk_scale =self .head_dim **-0.25 | |
| out =F .scaled_dot_product_attention ( | |
| q *qk_scale ,k *qk_scale ,v , | |
| is_causal =False ,scale =1.0 , | |
| ) | |
| out =out .transpose (1 ,2 ).reshape (batch_size *frames ,spatial_len ,self .hidden_size ) | |
| out =self .to_out (out ) | |
| return out .view (batch_size ,seq_len ,self .hidden_size ) | |
| class TemporalAttention (nn .Module ): | |
| """ | |
| Temporal self-attention: each spatial position attends across time. | |
| Memory: O(H*W * T^2) instead of O((T*H*W)^2) | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .head_dim =hidden_size //num_heads | |
| self .scale =self .head_dim **-0.5 | |
| self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) | |
| self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False ) | |
| self .rope_1d =RoPE1D (self .head_dim ,max_frames ) | |
| self .norm =nn .LayerNorm (hidden_size ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =True )->torch .Tensor : | |
| batch_size ,seq_len ,_ =x .shape | |
| spatial_len =height *width | |
| x =self .norm (x ) | |
| x =x .view (batch_size ,frames ,spatial_len ,self .hidden_size ) | |
| x =x .permute (0 ,2 ,1 ,3 ).reshape (batch_size *spatial_len ,frames ,self .hidden_size ) | |
| qkv =self .to_qkv (x ).reshape (batch_size *spatial_len ,frames ,3 ,self .num_heads ,self .head_dim ) | |
| q ,k ,v =qkv .unbind (dim =2 ) | |
| cos ,sin =self .rope_1d (x ,frames ) | |
| cos =cos .unsqueeze (0 ).unsqueeze (1 ) | |
| sin =sin .unsqueeze (0 ).unsqueeze (1 ) | |
| q =q .transpose (1 ,2 ) | |
| k =k .transpose (1 ,2 ) | |
| v =v .transpose (1 ,2 ) | |
| q =apply_rope (q ,cos ,sin ) | |
| k =apply_rope (k ,cos ,sin ) | |
| qk_scale =self .head_dim **-0.25 | |
| out =F .scaled_dot_product_attention ( | |
| q *qk_scale ,k *qk_scale ,v , | |
| is_causal =causal ,scale =1.0 , | |
| ) | |
| out =out .transpose (1 ,2 ).reshape (batch_size *spatial_len ,frames ,self .hidden_size ) | |
| out =out .view (batch_size ,spatial_len ,frames ,self .hidden_size ) | |
| out =out .permute (0 ,2 ,1 ,3 ).reshape (batch_size ,seq_len ,self .hidden_size ) | |
| out =self .to_out (out ) | |
| return out | |
| class FactorizedSpatioTemporalAttention (nn .Module ): | |
| """ | |
| Factorized Spatial-Temporal Attention (like CogVideo, Open-Sora, SVD). | |
| Instead of full 3D attention O((T*H*W)^2), uses: | |
| 1. Spatial attention per frame: O(T * (H*W)^2) | |
| 2. Temporal attention per position: O(H*W * T^2) | |
| Total: O(T*(H*W)^2 + H*W*T^2) << O((T*H*W)^2) | |
| For T=8, H=W=64: | |
| - Full 3D: 32768^2 = 1B attention scores | |
| - Factorized: 8*4096^2 + 4096*64 = 134M attention scores (7.5x less!) | |
| """ | |
| def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .spatial_attn =SpatialAttention (hidden_size ,num_heads ,max_height ,max_width ) | |
| self .temporal_attn =TemporalAttention (hidden_size ,num_heads ,max_frames ) | |
| def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =True )->torch .Tensor : | |
| x =x +self .spatial_attn (x ,height ,width ,frames ) | |
| x =x +self .temporal_attn (x ,height ,width ,frames ,causal ) | |
| return x | |
| class CrossAttention3D (nn .Module ): | |
| """Cross-attention for text-to-video conditioning.""" | |
| def __init__ (self ,query_dim :int ,context_dim :int =None ,heads :int =8 ): | |
| super ().__init__ () | |
| self .heads =heads | |
| context_dim =context_dim or query_dim | |
| self .head_dim =query_dim //heads | |
| self .scale =self .head_dim **-0.5 | |
| self .norm =nn .LayerNorm (query_dim ) | |
| self .to_q =nn .Linear (query_dim ,query_dim ,bias =False ) | |
| self .to_k =nn .Linear (context_dim ,query_dim ,bias =False ) | |
| self .to_v =nn .Linear (context_dim ,query_dim ,bias =False ) | |
| self .to_out =nn .Linear (query_dim ,query_dim ,bias =False ) | |
| def forward (self ,x :torch .Tensor ,context :torch .Tensor )->torch .Tensor : | |
| batch_size ,seq_len ,_ =x .shape | |
| ctx_len =context .shape [1 ] | |
| x =self .norm (x ) | |
| q =self .to_q (x ).reshape (batch_size ,seq_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) | |
| k =self .to_k (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) | |
| v =self .to_v (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) | |
| qk_scale =self .head_dim **-0.25 | |
| out =F .scaled_dot_product_attention ( | |
| q *qk_scale ,k *qk_scale ,v , | |
| is_causal =False ,scale =1.0 , | |
| ) | |
| out =out .transpose (1 ,2 ).reshape (batch_size ,seq_len ,-1 ) | |
| out =self .to_out (out ) | |
| return out | |
| class Causal3DTransformerBlock (nn .Module ): | |
| """ | |
| 3D Causal Transformer Block with Factorized Spatial-Temporal Attention. | |
| Uses memory-efficient factorized attention instead of full 3D attention: | |
| - Spatial: Each frame attends within itself O(T * (H*W)^2) | |
| - Temporal: Each position attends across frames O(H*W * T^2) | |
| This reduces memory from O((T*H*W)^2) to O(T*(H*W)^2 + H*W*T^2) | |
| """ | |
| def __init__ (self ,hidden_size :int ,context_dim :int ,num_heads :int =8 ,num_experts :int =4 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ): | |
| super ().__init__ () | |
| self .self_attn =FactorizedSpatioTemporalAttention (hidden_size ,num_heads ,max_frames ,max_height ,max_width ) | |
| self .cross_attn =CrossAttention3D (hidden_size ,context_dim ,num_heads ) | |
| self .moe =TemporalMoELayer (hidden_size ,hidden_size *4 ,num_experts ) | |
| self .norm1 =nn .LayerNorm (hidden_size ) | |
| self .norm2 =nn .LayerNorm (hidden_size ) | |
| self .norm3 =nn .LayerNorm (hidden_size ) | |
| def forward (self ,x :torch .Tensor ,context :torch .Tensor ,height :int ,width :int ,frames :int ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor : | |
| x =self .self_attn (self .norm1 (x ),height ,width ,frames ,causal =True ) | |
| x =x +self .cross_attn (self .norm2 (x ),context ) | |
| x =x +self .moe (self .norm3 (x ),temporal_context ) | |
| return x | |
| class FlowMatchingScheduler : | |
| """ | |
| Flow Matching scheduler for video generation. | |
| Uses optimal transport paths for superior generation quality. | |
| """ | |
| def __init__ (self ,num_steps :int =50 ,sigma_min :float =0.002 ): | |
| self .num_steps =num_steps | |
| self .sigma_min =sigma_min | |
| self .timesteps =torch .linspace (1 ,0 ,num_steps +1 ) | |
| def get_velocity (self ,x_t :torch .Tensor ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : | |
| """Compute target velocity for flow matching.""" | |
| return x_0 -x_t | |
| def step (self ,model_output :torch .Tensor ,t :torch .Tensor ,t_prev :torch .Tensor ,x_t :torch .Tensor )->torch .Tensor : | |
| """Single step of flow matching ODE.""" | |
| dt =t -t_prev | |
| x_prev =x_t +model_output *dt .view (-1 ,1 ,1 ,1 ,1 ) | |
| return x_prev | |
| def add_noise (self ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : | |
| """Add noise for training (linear interpolation).""" | |
| noise =torch .randn_like (x_0 ) | |
| t =t .to (x_0 .dtype ).view (-1 ,1 ,1 ,1 ,1 ) | |
| x_t =t *noise +(1 -t )*x_0 | |
| return x_t | |
| class VideoUNet3D (nn .Module ): | |
| """ | |
| 3D U-Net for video generation with Factorized Spatial-Temporal Attention. | |
| Uses memory-efficient factorized attention that processes spatial and temporal | |
| dimensions separately, reducing memory from O((T*H*W)^2) to O(T*(H*W)^2 + H*W*T^2). | |
| """ | |
| def __init__ ( | |
| self , | |
| in_channels :int =4 , | |
| out_channels :int =4 , | |
| hidden_size :int =512 , | |
| context_dim :int =1024 , | |
| num_layers :int =4 , | |
| num_heads :int =8 , | |
| num_experts :int =4 , | |
| num_frames :int =16 , | |
| max_height :int =64 , | |
| max_width :int =64 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_frames =num_frames | |
| self .time_embed =nn .Sequential ( | |
| nn .Linear (hidden_size ,hidden_size *4 ), | |
| nn .SiLU (), | |
| nn .Linear (hidden_size *4 ,hidden_size ), | |
| ) | |
| self .input_proj =nn .Conv3d (in_channels ,hidden_size ,kernel_size =3 ,padding =1 ) | |
| self .transformer_blocks =nn .ModuleList ([ | |
| Causal3DTransformerBlock (hidden_size ,context_dim ,num_heads ,num_experts ,num_frames ,max_height ,max_width ) | |
| for _ in range (num_layers ) | |
| ]) | |
| self .output_proj =nn .Sequential ( | |
| nn .GroupNorm (32 ,hidden_size ), | |
| nn .SiLU (), | |
| nn .Conv3d (hidden_size ,out_channels ,kernel_size =3 ,padding =1 ), | |
| ) | |
| nn .init .zeros_ (self .output_proj [-1 ].weight ) | |
| nn .init .zeros_ (self .output_proj [-1 ].bias ) | |
| self .gradient_checkpointing =False | |
| def enable_gradient_checkpointing (self ): | |
| """Enable gradient checkpointing for memory efficiency.""" | |
| self .gradient_checkpointing =True | |
| def forward (self ,x :torch .Tensor ,timesteps :torch .Tensor ,context :torch .Tensor ,first_frame_latent :Optional [torch .Tensor ]=None )->torch .Tensor : | |
| batch_size ,channels ,frames ,height ,width =x .shape | |
| half_dim =self .hidden_size //2 | |
| t_emb =math .log (10000 )/(half_dim -1 ) | |
| t_emb =torch .exp (torch .arange (half_dim ,device =x .device ,dtype =x .dtype )*-t_emb ) | |
| t_emb =timesteps [:,None ].to (x .dtype )*t_emb [None ,:] | |
| t_emb =torch .cat ([torch .sin (t_emb ),torch .cos (t_emb )],dim =-1 ) | |
| t_emb =self .time_embed (t_emb ) | |
| h =self .input_proj (x ) | |
| h =h .permute (0 ,2 ,3 ,4 ,1 ).reshape (batch_size ,frames *height *width ,self .hidden_size ) | |
| temporal_context =t_emb .unsqueeze (1 ).expand (-1 ,frames *height *width ,-1 ) | |
| for block in self .transformer_blocks : | |
| if self .gradient_checkpointing and self .training : | |
| h =torch .utils .checkpoint .checkpoint ( | |
| block ,h ,context ,height ,width ,frames ,temporal_context , | |
| use_reentrant =False | |
| ) | |
| else : | |
| h =block (h ,context ,height ,width ,frames ,temporal_context ) | |
| h =h .reshape (batch_size ,frames ,height ,width ,self .hidden_size ).permute (0 ,4 ,1 ,2 ,3 ) | |
| velocity =self .output_proj (h ) | |
| return velocity | |
| class VideoVAE3D (nn .Module ): | |
| """ | |
| 3D VAE for video encoding/decoding using VidTok architecture. | |
| This replaces the simple placeholder with proper temporal+spatial compression | |
| following Microsoft's VidTok architecture for high-quality video tokenization. | |
| Features: | |
| - Proper temporal compression (4x default) | |
| - Proper spatial compression (8x default, same as image VAE) | |
| - AlphaBlender for temporal blending | |
| - Causal mode support for streaming | |
| - Both KL (continuous) and FSQ (discrete) tokenization | |
| Compression: [B, C, T, H, W] -> [B, latent_ch, T/4, H/8, W/8] | |
| """ | |
| def __init__ ( | |
| self , | |
| in_channels :int =3 , | |
| latent_channels :int =4 , | |
| base_channels :int =64 , | |
| temporal_compression :int =4 , | |
| spatial_compression :int =8 , | |
| causal :bool =True , | |
| use_fsq :bool =False , | |
| ): | |
| super ().__init__ () | |
| self .in_channels =in_channels | |
| self .latent_channels =latent_channels | |
| self .temporal_compression =temporal_compression | |
| self .spatial_compression =spatial_compression | |
| self .causal =causal | |
| self .use_fsq =use_fsq | |
| self .temporal_stages =int (math .log2 (temporal_compression )) | |
| self .spatial_stages =int (math .log2 (spatial_compression )) | |
| encoder_layers =[] | |
| ch_in =in_channels | |
| ch_out =base_channels | |
| encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) | |
| encoder_layers .append (nn .SiLU ()) | |
| for i in range (self .spatial_stages -self .temporal_stages ): | |
| ch_in =ch_out | |
| ch_out =min (ch_out *2 ,base_channels *8 ) | |
| encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,stride =(1 ,2 ,2 ),padding =1 )) | |
| encoder_layers .append (nn .SiLU ()) | |
| for i in range (self .temporal_stages ): | |
| ch_in =ch_out | |
| ch_out =min (ch_out *2 ,base_channels *8 ) | |
| encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,stride =(2 ,2 ,2 ),padding =1 )) | |
| encoder_layers .append (nn .SiLU ()) | |
| out_ch =latent_channels *2 if not use_fsq else latent_channels | |
| encoder_layers .append (nn .Conv3d (ch_out ,out_ch ,3 ,padding =1 )) | |
| self .encoder =nn .Sequential (*encoder_layers ) | |
| decoder_layers =[] | |
| ch_in =latent_channels | |
| ch_out =base_channels *(2 **min (self .spatial_stages ,3 )) | |
| decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) | |
| decoder_layers .append (nn .SiLU ()) | |
| for i in range (self .temporal_stages ): | |
| ch_in =ch_out | |
| ch_out =max (ch_out //2 ,base_channels ) | |
| decoder_layers .append (nn .Upsample (scale_factor =(2 ,2 ,2 ),mode ='trilinear',align_corners =False )) | |
| decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) | |
| decoder_layers .append (nn .SiLU ()) | |
| for i in range (self .spatial_stages -self .temporal_stages ): | |
| ch_in =ch_out | |
| ch_out =max (ch_out //2 ,base_channels ) | |
| decoder_layers .append (nn .Upsample (scale_factor =(1 ,2 ,2 ),mode ='trilinear',align_corners =False )) | |
| decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) | |
| decoder_layers .append (nn .SiLU ()) | |
| decoder_layers .append (nn .Conv3d (ch_out ,in_channels ,3 ,padding =1 )) | |
| self .decoder =nn .Sequential (*decoder_layers ) | |
| print (f" 🎬 VideoVAE3D (VidTok): {temporal_compression }x{spatial_compression }x{spatial_compression } compression") | |
| print (f" Temporal stages: {self .temporal_stages }, Spatial stages: {self .spatial_stages }") | |
| print (f" Mode: {'FSQ (discrete)'if use_fsq else 'KL (continuous)'}, Causal: {causal }") | |
| def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| """ | |
| Encode video to latent space. | |
| Args: | |
| x: [B, C, T, H, W] video tensor, values in [0, 1] or [-1, 1] | |
| Returns: | |
| Tuple of (z, mean, logvar) where z is the sampled latent | |
| """ | |
| h =self .encoder (x ) | |
| if self .use_fsq : | |
| z =self ._fsq_quantize (h ) | |
| return z ,z ,torch .zeros_like (z ) | |
| else : | |
| mean ,logvar =h .chunk (2 ,dim =1 ) | |
| logvar =torch .clamp (logvar ,-30 ,20 ) | |
| std =torch .exp (0.5 *logvar ) | |
| z =mean +std *torch .randn_like (std ) | |
| return z ,mean ,logvar | |
| def _fsq_quantize (self ,z :torch .Tensor ,levels :int =8 )->torch .Tensor : | |
| """Finite Scalar Quantization.""" | |
| z =torch .tanh (z ) | |
| z =torch .round ((z +1 )*(levels -1 )/2 )*2 /(levels -1 )-1 | |
| return z | |
| def decode (self ,z :torch .Tensor )->torch .Tensor : | |
| """ | |
| Decode latent to video. | |
| Args: | |
| z: [B, latent_ch, t, h, w] latent tensor | |
| Returns: | |
| [B, C, T, H, W] reconstructed video | |
| """ | |
| return self .decoder (z ) | |
| class MobileVideoDiffusion (nn .Module ): | |
| """ | |
| SOTA Video Diffusion with Flow Matching, Factorized Attention, Temporal MoE. | |
| Uses memory-efficient factorized spatial-temporal attention: | |
| - Full 3D attention: O((T*H*W)^2) = 1B+ attention scores (OOM!) | |
| - Factorized: O(T*(H*W)^2 + H*W*T^2) = ~134M scores (7.5x less memory) | |
| Optimized for 2x T4 GPUs (15GB each) with FP16. | |
| """ | |
| def __init__ ( | |
| self , | |
| latent_channels :int =4 , | |
| base_channels :int =64 , | |
| context_dim :int =1024 , | |
| num_frames :int =16 , | |
| image_size :int =256 , | |
| num_inference_steps :int =50 , | |
| cfg_scale :float =7.5 , | |
| temporal_compression :int =4 , | |
| spatial_compression :int =8 , | |
| causal :bool =True , | |
| use_fsq :bool =False , | |
| ): | |
| super ().__init__ () | |
| self .latent_channels =latent_channels | |
| self .context_dim =context_dim | |
| self .num_frames =num_frames | |
| self .image_size =image_size | |
| self .temporal_compression =temporal_compression | |
| self .spatial_compression =spatial_compression | |
| self .latent_size =image_size //spatial_compression | |
| self .latent_frames =num_frames //temporal_compression | |
| self .num_inference_steps =num_inference_steps | |
| self .cfg_scale =cfg_scale | |
| self .vae =VideoVAE3D ( | |
| in_channels =3 , | |
| latent_channels =latent_channels , | |
| base_channels =base_channels , | |
| temporal_compression =temporal_compression , | |
| spatial_compression =spatial_compression , | |
| causal =causal , | |
| use_fsq =use_fsq , | |
| ) | |
| self .unet =VideoUNet3D ( | |
| in_channels =latent_channels , | |
| out_channels =latent_channels , | |
| hidden_size =base_channels *4 , | |
| context_dim =context_dim , | |
| num_layers =4 , | |
| num_heads =8 , | |
| num_experts =4 , | |
| num_frames =num_frames , | |
| max_height =self .latent_size , | |
| max_width =self .latent_size , | |
| ) | |
| self .scheduler =FlowMatchingScheduler (num_inference_steps ) | |
| def encode_video (self ,video :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| return self .vae .encode (video *2 -1 ) | |
| def decode_video (self ,z :torch .Tensor )->torch .Tensor : | |
| return self .vae .decode (z ) | |
| def encode_image (self ,image :torch .Tensor )->torch .Tensor : | |
| image_expanded =image .unsqueeze (2 ) | |
| z ,_ ,_ =self .vae .encode (image_expanded ) | |
| return z .squeeze (2 ) | |
| def training_step (self ,video :torch .Tensor ,context :torch .Tensor ,first_frame :Optional [torch .Tensor ]=None )->dict : | |
| device =video .device | |
| dtype =video .dtype | |
| batch_size =video .shape [0 ] | |
| z ,mean ,logvar =self .encode_video (video ) | |
| del video | |
| t =torch .rand (batch_size ,device =device ,dtype =dtype ) | |
| x_t =self .scheduler .add_noise (z ,t ) | |
| target_velocity =self .scheduler .get_velocity (x_t ,z ,t ) | |
| if self .training : | |
| drop_mask =torch .rand (batch_size ,device =device )<0.1 | |
| drop_mask_expanded =drop_mask .view (batch_size ,1 ,1 ).expand_as (context ) | |
| null_ctx =torch .zeros_like (context ) | |
| context =torch .where (drop_mask_expanded ,null_ctx ,context ) | |
| del drop_mask ,drop_mask_expanded ,null_ctx | |
| pred_velocity =self .unet (x_t ,(t *1000 ).to (dtype ),context ,None ) | |
| del x_t ,context | |
| flow_loss =F .mse_loss (pred_velocity ,target_velocity ) | |
| del pred_velocity ,target_velocity | |
| kl_loss =-0.5 *torch .mean (1 +logvar -mean .pow (2 )-logvar .exp ()) | |
| temporal_loss =torch .tensor (0.0 ,device =device ,dtype =dtype ) | |
| if z .shape [2 ]>1 : | |
| z_diff =z [:,:,1 :]-z [:,:,:-1 ] | |
| temporal_loss =torch .mean (z_diff **2 ) | |
| del z_diff | |
| del z ,mean ,logvar | |
| total_loss =flow_loss +0.0001 *kl_loss +0.01 *temporal_loss | |
| return { | |
| 'flow_loss':flow_loss , | |
| 'kl_loss':kl_loss , | |
| 'temporal_loss':temporal_loss , | |
| 'total_loss':total_loss , | |
| } | |
| def generate_t2v (self ,context :torch .Tensor ,num_frames :int =None ,guidance_scale :float =None ,num_steps :int =None )->torch .Tensor : | |
| device =context .device | |
| batch_size =context .shape [0 ] | |
| seq_len =context .shape [1 ] | |
| num_frames =num_frames or self .num_frames | |
| guidance_scale =guidance_scale or self .cfg_scale | |
| num_steps =num_steps or self .num_inference_steps | |
| latents =torch .randn ( | |
| batch_size ,self .latent_channels ,num_frames , | |
| self .latent_size ,self .latent_size ,device =device | |
| ) | |
| timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device ) | |
| if guidance_scale >1.0 : | |
| null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype ) | |
| context =torch .cat ([null_ctx ,context ]) | |
| for i in range (num_steps ): | |
| t =timesteps [i ] | |
| t_prev =timesteps [i +1 ] | |
| t_batch =t .expand (batch_size )*1000 | |
| if guidance_scale >1.0 : | |
| latent_input =torch .cat ([latents ,latents ]) | |
| t_input =torch .cat ([t_batch ,t_batch ]) | |
| velocity_pred =self .unet (latent_input ,t_input ,context ,None ) | |
| velocity_uncond ,velocity_cond =velocity_pred .chunk (2 ) | |
| velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond ) | |
| else : | |
| velocity_pred =self .unet (latents ,t_batch ,context ,None ) | |
| latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents ) | |
| video =self .decode_video (latents ) | |
| return torch .clamp ((video +1 )/2 ,0 ,1 ) | |
| def generate_i2v (self ,first_frame :torch .Tensor ,context :Optional [torch .Tensor ]=None ,num_frames :int =None ,guidance_scale :float =None ,num_steps :int =None )->torch .Tensor : | |
| device =first_frame .device | |
| batch_size =first_frame .shape [0 ] | |
| num_frames =num_frames or self .num_frames | |
| guidance_scale =guidance_scale or self .cfg_scale | |
| num_steps =num_steps or self .num_inference_steps | |
| first_frame_latent =self .encode_image (first_frame *2 -1 ) | |
| latents =torch .randn ( | |
| batch_size ,self .latent_channels ,num_frames , | |
| self .latent_size ,self .latent_size ,device =device | |
| ) | |
| latents [:,:,0 ]=first_frame_latent | |
| if context is None : | |
| context =torch .zeros (batch_size ,77 ,self .context_dim ,device =device ) | |
| seq_len =context .shape [1 ] | |
| timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device ) | |
| if guidance_scale >1.0 : | |
| null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype ) | |
| context =torch .cat ([null_ctx ,context ]) | |
| for i in range (num_steps ): | |
| t =timesteps [i ] | |
| t_prev =timesteps [i +1 ] | |
| t_batch =t .expand (batch_size )*1000 | |
| if guidance_scale >1.0 : | |
| latent_input =torch .cat ([latents ,latents ]) | |
| t_input =torch .cat ([t_batch ,t_batch ]) | |
| velocity_pred =self .unet (latent_input ,t_input ,context ,None ) | |
| velocity_uncond ,velocity_cond =velocity_pred .chunk (2 ) | |
| velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond ) | |
| else : | |
| velocity_pred =self .unet (latents ,t_batch ,context ,None ) | |
| latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents ) | |
| latents [:,:,0 ]=first_frame_latent | |
| video =self .decode_video (latents ) | |
| return torch .clamp ((video +1 )/2 ,0 ,1 ) | |
| ============================================================================== | |
| MODELS.LLM.MOE_LLAMA | |
| ============================================================================== | |
| EPS =1e-5 | |
| class YaRNRotaryEmbedding (nn .Module ): | |
| """ | |
| YaRN (Yet another RoPE extensioN) with LongRoPE-style improvements. | |
| Supports up to 128K+ context with proper frequency scaling. | |
| """ | |
| def __init__ ( | |
| self , | |
| dim :int , | |
| max_position_embeddings :int =131072 , | |
| base :float =500000.0 , | |
| original_max_position_embeddings :int =8192 , | |
| beta_fast :float =32.0 , | |
| beta_slow :float =1.0 , | |
| mscale :float =1.0 , | |
| ): | |
| super ().__init__ () | |
| self .dim =dim | |
| self .max_position_embeddings =max_position_embeddings | |
| self .base =base | |
| self .original_max_position =original_max_position_embeddings | |
| self .beta_fast =beta_fast | |
| self .beta_slow =beta_slow | |
| self .mscale =mscale | |
| self .scaling_factor =max_position_embeddings /original_max_position_embeddings | |
| inv_freq =self ._compute_yarn_inv_freq () | |
| self .register_buffer ('inv_freq',inv_freq ,persistent =False ) | |
| def _compute_yarn_inv_freq (self )->torch .Tensor : | |
| """Compute YaRN-scaled inverse frequencies.""" | |
| pos_freqs =self .base **(torch .arange (0 ,self .dim ,2 ,dtype =torch .float32 )/self .dim ) | |
| inv_freq_extrapolation =1.0 /pos_freqs | |
| inv_freq_interpolation =1.0 /(self .scaling_factor *pos_freqs ) | |
| low =max (math .floor (self .dim *math .log (self .original_max_position /(self .beta_fast *2 *math .pi ))/ | |
| (2 *math .log (self .base ))),0 ) | |
| high =min (math .ceil (self .dim *math .log (self .original_max_position /(self .beta_slow *2 *math .pi ))/ | |
| (2 *math .log (self .base ))),self .dim -1 ) | |
| inv_freq =torch .zeros (self .dim //2 ,dtype =torch .float32 ) | |
| for i in range (self .dim //2 ): | |
| if i <low : | |
| inv_freq [i ]=inv_freq_interpolation [i ] | |
| elif i >high : | |
| inv_freq [i ]=inv_freq_extrapolation [i ] | |
| else : | |
| smooth =(i -low )/max (high -low ,1 ) | |
| inv_freq [i ]=(1 -smooth )*inv_freq_interpolation [i ]+smooth *inv_freq_extrapolation [i ] | |
| return inv_freq | |
| def _get_mscale (self ,scale :float )->float : | |
| """Get attention scaling factor for YaRN.""" | |
| if scale <=1 : | |
| return 1.0 | |
| return 0.1 *math .log (scale )+1.0 | |
| def forward (self ,x :torch .Tensor ,position_ids :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| device =x .device | |
| inv_freq =self .inv_freq .to (device ) | |
| inv_freq_expanded =inv_freq [None ,:,None ].float ().expand (position_ids .shape [0 ],-1 ,1 ) | |
| position_ids_expanded =position_ids [:,None ,:].float () | |
| freqs =(inv_freq_expanded @position_ids_expanded ).transpose (1 ,2 ) | |
| emb =torch .cat ((freqs ,freqs ),dim =-1 ) | |
| mscale =self ._get_mscale (self .scaling_factor )*self .mscale | |
| cos =emb .cos ().to (dtype =x .dtype )*mscale | |
| sin =emb .sin ().to (dtype =x .dtype )*mscale | |
| return cos ,sin | |
| LlamaRotaryEmbedding =YaRNRotaryEmbedding | |
| def rotate_half (x :torch .Tensor )->torch .Tensor : | |
| x1 =x [...,:x .shape [-1 ]//2 ] | |
| x2 =x [...,x .shape [-1 ]//2 :] | |
| return torch .cat ((-x2 ,x1 ),dim =-1 ) | |
| def apply_rotary_pos_emb ( | |
| q :torch .Tensor , | |
| k :torch .Tensor , | |
| cos :torch .Tensor , | |
| sin :torch .Tensor , | |
| position_ids :Optional [torch .Tensor ]=None , | |
| unsqueeze_dim :int =1 , | |
| )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| cos =cos .unsqueeze (unsqueeze_dim ) | |
| sin =sin .unsqueeze (unsqueeze_dim ) | |
| q_embed =(q *cos )+(rotate_half (q )*sin ) | |
| k_embed =(k *cos )+(rotate_half (k )*sin ) | |
| return q_embed ,k_embed | |
| class KVCache : | |
| """Pre-allocated KV Cache — static buffer with index-based filling. | |
| Eliminates VRAM fragmentation from torch.cat during autoregressive generation. | |
| Buffer is allocated once at first use and reused via slice assignment. | |
| """ | |
| __slots__ =('key_cache','value_cache','seen_tokens','_max_len') | |
| def __init__ ( | |
| self , | |
| key_cache :torch .Tensor =None , | |
| value_cache :torch .Tensor =None , | |
| seen_tokens :int =0 , | |
| max_seq_len :int =131072 , | |
| ): | |
| self .key_cache =key_cache | |
| self .value_cache =value_cache | |
| self .seen_tokens =seen_tokens | |
| self ._max_len =max_seq_len | |
| def _allocate (self ,batch :int ,heads :int ,head_dim :int ,device :torch .device ,dtype :torch .dtype ): | |
| """Allocate static buffer on first use.""" | |
| self .key_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) | |
| self .value_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) | |
| def update ( | |
| self , | |
| key_states :torch .Tensor , | |
| value_states :torch .Tensor , | |
| chunk_size :Optional [int ]=None , | |
| )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| batch ,heads ,new_len ,head_dim =key_states .shape | |
| if self .key_cache is None : | |
| self ._allocate (batch ,heads ,head_dim ,key_states .device ,key_states .dtype ) | |
| self .seen_tokens =0 | |
| if chunk_size is not None and self .seen_tokens +new_len >chunk_size *2 : | |
| keep =chunk_size | |
| if self .seen_tokens >keep : | |
| self .key_cache [:,:,:keep ]=self .key_cache [:,:,self .seen_tokens -keep :self .seen_tokens ].clone () | |
| self .value_cache [:,:,:keep ]=self .value_cache [:,:,self .seen_tokens -keep :self .seen_tokens ].clone () | |
| self .seen_tokens =keep | |
| if self .seen_tokens +new_len >self .key_cache .shape [2 ]: | |
| new_max =max (self .key_cache .shape [2 ]*2 ,self .seen_tokens +new_len ) | |
| new_key =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) | |
| new_val =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) | |
| new_key [:,:,:self .seen_tokens ]=self .key_cache [:,:,:self .seen_tokens ] | |
| new_val [:,:,:self .seen_tokens ]=self .value_cache [:,:,:self .seen_tokens ] | |
| self .key_cache =new_key | |
| self .value_cache =new_val | |
| self .key_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=key_states | |
| self .value_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=value_states | |
| self .seen_tokens +=new_len | |
| return self .key_cache [:,:,:self .seen_tokens ],self .value_cache [:,:,:self .seen_tokens ] | |
| def reset (self ): | |
| """Reset cache position without deallocating the buffer.""" | |
| self .seen_tokens =0 | |
| def ring_attention ( | |
| query :torch .Tensor , | |
| key :torch .Tensor , | |
| value :torch .Tensor , | |
| chunk_size :int =4096 , | |
| causal :bool =True , | |
| )->torch .Tensor : | |
| """ | |
| Ring Attention for distributed long-context processing. | |
| Processes sequence in chunks with online softmax accumulation. | |
| Args: | |
| query: [batch, heads, seq_len, head_dim] | |
| key: [batch, heads, kv_len, head_dim] | |
| value: [batch, heads, kv_len, head_dim] | |
| chunk_size: Size of each attention chunk | |
| causal: Whether to apply causal masking | |
| Returns: | |
| Output tensor [batch, heads, seq_len, head_dim] | |
| """ | |
| batch_size ,num_heads ,seq_len ,head_dim =query .shape | |
| kv_len =key .shape [2 ] | |
| if seq_len <=chunk_size and kv_len <=chunk_size : | |
| qk_scale =head_dim **-0.25 | |
| use_causal =causal and seq_len ==kv_len and seq_len >1 | |
| if use_causal : | |
| return F .scaled_dot_product_attention ( | |
| query *qk_scale ,key *qk_scale ,value , | |
| is_causal =True ,scale =1.0 , | |
| ) | |
| elif causal and kv_len >seq_len : | |
| causal_mask =torch .zeros (seq_len ,kv_len ,device =query .device ,dtype =query .dtype ) | |
| q_pos =torch .arange (seq_len ,device =query .device )+(kv_len -seq_len ) | |
| k_pos =torch .arange (kv_len ,device =query .device ) | |
| causal_mask =torch .where (k_pos .unsqueeze (0 )>q_pos .unsqueeze (1 ),float ('-inf'),0.0 ) | |
| return F .scaled_dot_product_attention ( | |
| query *qk_scale ,key *qk_scale ,value , | |
| attn_mask =causal_mask ,scale =1.0 , | |
| ) | |
| else : | |
| return F .scaled_dot_product_attention ( | |
| query *qk_scale ,key *qk_scale ,value , | |
| is_causal =False ,scale =1.0 , | |
| ) | |
| scale =head_dim **-0.5 | |
| output =torch .zeros_like (query ) | |
| max_logits =torch .full ((batch_size ,num_heads ,seq_len ,1 ),float ('-inf'),device =query .device ,dtype =query .dtype ) | |
| sum_exp =torch .zeros ((batch_size ,num_heads ,seq_len ,1 ),device =query .device ,dtype =query .dtype ) | |
| if causal : | |
| q_positions =torch .arange (seq_len ,device =query .device ) | |
| if kv_len >seq_len : | |
| q_positions =q_positions +(kv_len -seq_len ) | |
| num_kv_chunks =(kv_len +chunk_size -1 )//chunk_size | |
| for kv_idx in range (num_kv_chunks ): | |
| kv_start =kv_idx *chunk_size | |
| kv_end =min ((kv_idx +1 )*chunk_size ,kv_len ) | |
| key_chunk =key [:,:,kv_start :kv_end ,:] | |
| value_chunk =value [:,:,kv_start :kv_end ,:] | |
| attn_chunk =torch .matmul (query ,key_chunk .transpose (-1 ,-2 ))*scale | |
| if causal : | |
| k_positions =torch .arange (kv_start ,kv_end ,device =query .device ) | |
| causal_mask =k_positions .unsqueeze (0 )>q_positions .unsqueeze (1 ) | |
| attn_chunk =attn_chunk .masked_fill (causal_mask .unsqueeze (0 ).unsqueeze (0 ),float ('-inf')) | |
| chunk_max =attn_chunk .max (dim =-1 ,keepdim =True )[0 ] | |
| new_max =torch .maximum (max_logits ,chunk_max ) | |
| exp_weights =torch .exp (attn_chunk -new_max ) | |
| exp_sum_chunk =exp_weights .sum (dim =-1 ,keepdim =True ) | |
| correction =torch .exp (max_logits -new_max ) | |
| output =output *correction +torch .matmul (exp_weights ,value_chunk ) | |
| sum_exp =sum_exp *correction +exp_sum_chunk | |
| max_logits =new_max | |
| output =output /(sum_exp +EPS ) | |
| return output | |
| class MultiHeadLatentAttention (nn .Module ): | |
| """ | |
| Multi-Head Latent Attention (MLA) from DeepSeek-V2. | |
| Compresses KV cache using low-rank projections for memory efficiency. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| num_heads :int , | |
| num_kv_heads :int =None , | |
| head_dim :int =None , | |
| kv_lora_rank :int =512 , | |
| q_lora_rank :int =0 , | |
| rope_theta :float =500000.0 , | |
| max_position_embeddings :int =131072 , | |
| use_ring_attention :bool =True , | |
| ring_chunk_size :int =4096 , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_heads =num_heads | |
| self .num_kv_heads =num_kv_heads or num_heads | |
| self .head_dim =head_dim or hidden_size //num_heads | |
| self .kv_lora_rank =kv_lora_rank | |
| self .q_lora_rank =q_lora_rank | |
| self .use_ring_attention =use_ring_attention | |
| self .ring_chunk_size =ring_chunk_size | |
| self .num_key_value_groups =self .num_heads //self .num_kv_heads | |
| self .scale =self .head_dim **-0.5 | |
| if q_lora_rank >0 : | |
| self .q_a_proj =nn .Linear (hidden_size ,q_lora_rank ,bias =False ) | |
| self .q_b_proj =nn .Linear (q_lora_rank ,num_heads *self .head_dim ,bias =False ) | |
| self .q_a_layernorm =LlamaRMSNorm (q_lora_rank ) | |
| else : | |
| self .q_proj =nn .Linear (hidden_size ,num_heads *self .head_dim ,bias =False ) | |
| self .kv_a_proj =nn .Linear (hidden_size ,kv_lora_rank +self .head_dim ,bias =False ) | |
| self .kv_b_proj =nn .Linear (kv_lora_rank ,self .num_kv_heads *self .head_dim *2 ,bias =False ) | |
| self .kv_a_layernorm =LlamaRMSNorm (kv_lora_rank ) | |
| self .o_proj =nn .Linear (num_heads *self .head_dim ,hidden_size ,bias =False ) | |
| self .rotary_emb =YaRNRotaryEmbedding ( | |
| dim =self .head_dim , | |
| max_position_embeddings =max_position_embeddings , | |
| base =rope_theta , | |
| ) | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| std =0.02 | |
| for name ,module in self .named_modules (): | |
| if isinstance (module ,nn .Linear ): | |
| nn .init .normal_ (module .weight ,mean =0.0 ,std =std ) | |
| def forward ( | |
| self , | |
| hidden_states :torch .Tensor , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| position_ids :Optional [torch .Tensor ]=None , | |
| past_key_value :Optional [KVCache ]=None , | |
| output_attentions :bool =False , | |
| use_cache :bool =False , | |
| )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [KVCache ]]: | |
| batch_size ,seq_len ,_ =hidden_states .shape | |
| if self .q_lora_rank >0 : | |
| q_compressed =self .q_a_layernorm (self .q_a_proj (hidden_states )) | |
| query_states =self .q_b_proj (q_compressed ) | |
| else : | |
| query_states =self .q_proj (hidden_states ) | |
| kv_compressed =self .kv_a_proj (hidden_states ) | |
| kv_latent ,k_pe =kv_compressed .split ([self .kv_lora_rank ,self .head_dim ],dim =-1 ) | |
| kv_latent =self .kv_a_layernorm (kv_latent ) | |
| kv_states =self .kv_b_proj (kv_latent ) | |
| query_states =query_states .view (batch_size ,seq_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) | |
| key_states ,value_states =kv_states .split (self .num_kv_heads *self .head_dim ,dim =-1 ) | |
| key_states =key_states .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) | |
| value_states =value_states .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) | |
| if position_ids is None : | |
| position_ids =torch .arange (seq_len ,device =hidden_states .device ).unsqueeze (0 ).expand (batch_size ,-1 ) | |
| if past_key_value is not None and past_key_value .seen_tokens >0 : | |
| position_ids =position_ids +past_key_value .seen_tokens | |
| cos ,sin =self .rotary_emb (hidden_states ,position_ids ) | |
| query_states ,key_states =apply_rotary_pos_emb (query_states ,key_states ,cos ,sin ) | |
| if past_key_value is not None : | |
| key_states ,value_states =past_key_value .update ( | |
| key_states ,value_states , | |
| self .ring_chunk_size if self .use_ring_attention else None | |
| ) | |
| if self .use_ring_attention : | |
| if self .num_key_value_groups >1 : | |
| key_expanded =key_states .repeat_interleave (self .num_key_value_groups ,dim =1 ) | |
| value_expanded =value_states .repeat_interleave (self .num_key_value_groups ,dim =1 ) | |
| else : | |
| key_expanded =key_states | |
| value_expanded =value_states | |
| attn_output =ring_attention ( | |
| query_states ,key_expanded ,value_expanded , | |
| chunk_size =self .ring_chunk_size , | |
| causal =True , | |
| ) | |
| else : | |
| qk_scale =self .head_dim **-0.25 | |
| kv_len =key_states .shape [2 ] | |
| use_causal =(attention_mask is None and seq_len >1 and seq_len ==kv_len ) | |
| attn_output =F .scaled_dot_product_attention ( | |
| query_states *qk_scale , | |
| key_states *qk_scale , | |
| value_states , | |
| attn_mask =attention_mask , | |
| is_causal =use_causal , | |
| scale =1.0 , | |
| enable_gqa =(self .num_key_value_groups >1 ), | |
| ) | |
| attn_output =attn_output .transpose (1 ,2 ).contiguous ().view (batch_size ,seq_len ,-1 ) | |
| attn_output =self .o_proj (attn_output ) | |
| return attn_output ,None ,past_key_value if use_cache else None | |
| class AuxLosslessMoERouter (nn .Module ): | |
| """ | |
| Aux-Lossless MoE Router with Shared Expert Isolation. | |
| Eliminates auxiliary loss while maintaining load balance through architecture. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| num_experts :int , | |
| top_k :int =2 , | |
| norm_topk_prob :bool =True , | |
| ): | |
| super ().__init__ () | |
| self .num_experts =num_experts | |
| self .top_k =top_k | |
| self .norm_topk_prob =norm_topk_prob | |
| self .input_norm =LlamaRMSNorm (hidden_size ) | |
| self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) | |
| nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) | |
| self .expert_bias =nn .Parameter (torch .zeros (num_experts )) | |
| # Deep experts gate (4 deep experts) | |
| self .num_deep_experts = 4 | |
| self .deep_gate = nn .Linear (hidden_size , self .num_deep_experts , bias =False ) | |
| nn .init .normal_ (self .deep_gate .weight , mean =0.0 , std =0.01 ) | |
| self .deep_expert_bias = nn .Parameter (torch .zeros (self .num_deep_experts )) | |
| def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| batch_size ,seq_len ,hidden_dim =hidden_states .shape | |
| hidden_flat =hidden_states .view (-1 ,hidden_dim ) | |
| hidden_norm =self .input_norm (hidden_flat ) | |
| # Standard experts | |
| router_logits_std =self .gate (hidden_norm ) | |
| biased_logits_std =router_logits_std +self .expert_bias | |
| # Deep experts | |
| router_logits_deep = self .deep_gate (hidden_norm ) | |
| biased_logits_deep = router_logits_deep + self .deep_expert_bias | |
| # Concatenate: [batch*seq, num_experts + num_deep_experts] | |
| router_logits = torch .cat ([biased_logits_std , biased_logits_deep ], dim =-1 ) | |
| router_probs =F .softmax (router_logits ,dim =-1 ,dtype =hidden_states .dtype ) | |
| top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) | |
| if self .norm_topk_prob : | |
| top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) | |
| return top_k_probs ,top_k_indices ,router_logits | |
| class MoEExpert (nn .Module ): | |
| """Single MoE Expert with SwiGLU activation.""" | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ): | |
| super ().__init__ () | |
| self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) | |
| self .act_fn =nn .SiLU () | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| std =0.02 | |
| nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) | |
| class DeepMoEExpert (nn .Module ): | |
| """Deep MoE Expert with multiple sequential SwiGLU transformations.""" | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ,depth :int =2 ): | |
| super ().__init__ () | |
| self .depth = depth | |
| self .gate_projs = nn .ModuleList ([nn .Linear (hidden_size if i == 0 else intermediate_size , intermediate_size , bias =False ) for i in range (depth )]) | |
| self .up_projs = nn .ModuleList ([nn .Linear (hidden_size if i == 0 else intermediate_size , intermediate_size , bias =False ) for i in range (depth )]) | |
| self .down_projs = nn .ModuleList ([nn .Linear (intermediate_size , intermediate_size if i < depth - 1 else hidden_size , bias =False ) for i in range (depth )]) | |
| self .act_fn = nn .SiLU () | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| std =0.02 | |
| for g , u , d in zip (self .gate_projs , self .up_projs , self .down_projs ): | |
| nn .init .normal_ (g .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (u .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (d .weight ,mean =0.0 ,std =std *0.5 ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| for i in range (self .depth ): | |
| # Optional residual connection if intermediate sizes match, but standard SwiGLU doesn't usually use them internally unless specified. | |
| # We'll stick to sequential application as defined: Input -> SwiGLU -> SwiGLU ... -> DownProj | |
| gate = self .act_fn (self .gate_projs [i ](x )) | |
| up = self .up_projs [i ](x ) | |
| x = self .down_projs [i ](gate * up ) | |
| return x | |
| class IsolatedSharedExpert (nn .Module ): | |
| """ | |
| Isolated Shared Expert that always processes all tokens. | |
| Separate from routed experts to prevent competition. | |
| """ | |
| def __init__ (self ,hidden_size :int ,intermediate_size :int ): | |
| super ().__init__ () | |
| self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) | |
| self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) | |
| self .act_fn =nn .SiLU () | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| std =0.02 | |
| nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) | |
| nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) | |
| def forward (self ,x :torch .Tensor )->torch .Tensor : | |
| return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) | |
| class AuxLosslessMoELayer (nn .Module ): | |
| """ | |
| Aux-Lossless MoE Layer with Isolated Shared Expert. | |
| No auxiliary loss needed - load balance maintained through isolation. | |
| """ | |
| def __init__ ( | |
| self , | |
| hidden_size :int , | |
| intermediate_size :int , | |
| num_experts :int =8 , | |
| num_experts_per_tok :int =2 , | |
| shared_expert_intermediate_size :int =None , | |
| ): | |
| super ().__init__ () | |
| self .hidden_size =hidden_size | |
| self .num_experts =num_experts | |
| self .num_experts_per_tok =num_experts_per_tok | |
| self .router =AuxLosslessMoERouter (hidden_size ,num_experts ,num_experts_per_tok ) | |
| self .experts =nn .ModuleList ([ | |
| MoEExpert (hidden_size ,intermediate_size ) | |
| for _ in range (num_experts ) | |
| ]) | |
| # Deep Experts: Depths 2, 3, 4, 5 | |
| self .num_deep_experts = 4 | |
| self .deep_experts = nn .ModuleList ([ | |
| DeepMoEExpert (hidden_size , intermediate_size , depth =d ) | |
| for d in range (2 , 6 ) | |
| ]) | |
| shared_size =shared_expert_intermediate_size or intermediate_size | |
| self .shared_expert =IsolatedSharedExpert (hidden_size ,shared_size ) | |
| def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: | |
| batch_size ,seq_len ,hidden_size =hidden_states .shape | |
| original_dtype =hidden_states .dtype | |
| hidden_flat =hidden_states .view (-1 ,hidden_size ) | |
| num_tokens =hidden_flat .shape [0 ] | |
| top_k_probs ,top_k_indices ,router_logits =self .router (hidden_states ) | |
| if hasattr (self ,'_utilization_tracker'): | |
| self ._utilization_tracker .record (top_k_indices ) | |
| final_output =torch .zeros_like (hidden_flat ) | |
| total_experts = self .num_experts + self .num_deep_experts | |
| for expert_idx in range (total_experts ): | |
| # Determine which expert list to use | |
| if expert_idx < self .num_experts : | |
| expert =self .experts [expert_idx ] | |
| else : | |
| expert =self .deep_experts [expert_idx - self .num_experts ] | |
| for k in range (self .num_experts_per_tok ): | |
| mask =(top_k_indices [:,k ]==expert_idx ) | |
| if mask .any (): | |
| expert_input =hidden_flat [mask ] | |
| expert_output =expert (expert_input ) | |
| weight =top_k_probs [mask ,k :k +1 ] | |
| weighted_output =(weight *expert_output ).to (original_dtype ) | |
| final_output [mask ]=final_output [mask ]+weighted_output | |
| shared_output =self .shared_expert (hidden_flat ) | |
| final_output =final_output +shared_output .to (original_dtype ) | |
| final_output =final_output .view (batch_size ,seq_len ,hidden_size ) | |
| aux_loss =self ._compute_aux_loss (router_logits ,top_k_indices ,num_tokens ) | |
| return final_output ,aux_loss | |
| def _compute_aux_loss ( | |
| self , | |
| router_logits :torch .Tensor , | |
| top_k_indices :torch .Tensor , | |
| num_tokens :int , | |
| )->torch .Tensor : | |
| """ | |
| Aux-lossless auxiliary loss. | |
| Uses z-loss to keep router logits from growing unboundedly (FP16 stability), | |
| plus a soft utilization penalty that activates only when experts go completely | |
| cold. The expert_bias parameter handles routine load balancing. | |
| """ | |
| z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.0001 | |
| # Add penalty for choosing deep experts | |
| # Depths are 2, 3, 4, 5 for indices (num_experts) to (num_experts + 3) | |
| # Cost is roughly proportional to depth | |
| deep_penalty = torch .tensor (0.0 , device =router_logits .device , dtype =router_logits .dtype ) | |
| # Calculate how often each deep expert was selected | |
| # top_k_indices shape: [batch*seq, top_k] | |
| for i in range (self .num_deep_experts ): | |
| expert_idx = self .num_experts + i | |
| depth = i + 2 # depths 2, 3, 4, 5 | |
| # Count how many times this deep expert was chosen in top-k | |
| selection_count = (top_k_indices == expert_idx ).sum () | |
| # Simple penalty: deeper experts cost more | |
| # Multiplied by a small scalar to act as a soft deterrent | |
| # The model must truly need the depth to offset this loss increase | |
| deep_penalty += selection_count .float () * depth * 0.00005 | |
| return z_loss + deep_penalty | |
| expert_mask =F .one_hot (top_k_indices ,self .num_experts ).float () | |
| tokens_per_expert =expert_mask .sum (dim =(0 ,1 )) | |
| fraction_used =(tokens_per_expert >0 ).float ().mean () | |
| utilization_loss =(1.0 -fraction_used )*0.01 | |
| return z_loss +utilization_loss | |
| MoELayer =AuxLosslessMoELayer | |
| class MoELlamaDecoderLayer (nn .Module ): | |
| """Decoder layer with MLA and Aux-Lossless MoE.""" | |
| def __init__ (self ,config ,layer_idx :int ,moe_config :dict =None ): | |
| super ().__init__ () | |
| self .hidden_size =config .hidden_size | |
| self .layer_idx =layer_idx | |
| use_ring =getattr (config ,'use_ring_attention',True ) | |
| ring_chunk =getattr (config ,'ring_attention_chunk_size',4096 ) | |
| num_kv_heads =getattr (config ,'num_key_value_heads',config .num_attention_heads //4 ) | |
| self .self_attn =MultiHeadLatentAttention ( | |
| hidden_size =config .hidden_size , | |
| num_heads =config .num_attention_heads , | |
| num_kv_heads =num_kv_heads , | |
| rope_theta =getattr (config ,'rope_theta',500000.0 ), | |
| max_position_embeddings =config .max_position_embeddings , | |
| use_ring_attention =use_ring , | |
| ring_chunk_size =ring_chunk , | |
| ) | |
| self .input_layernorm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps ) | |
| self .post_attention_layernorm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps ) | |
| self .use_moe =moe_config and moe_config .get ('use_moe',False ) | |
| moe_freq =moe_config .get ('moe_layer_freq',2 )if moe_config else 2 | |
| if self .use_moe and layer_idx %moe_freq ==(moe_freq -1 ): | |
| self .mlp =AuxLosslessMoELayer ( | |
| hidden_size =config .hidden_size , | |
| intermediate_size =moe_config .get ('intermediate_size',config .intermediate_size ), | |
| num_experts =moe_config .get ('num_experts',8 ), | |
| num_experts_per_tok =moe_config .get ('num_experts_per_tok',2 ), | |
| ) | |
| self .is_moe_layer =True | |
| else : | |
| self .mlp =MoEExpert (config .hidden_size ,config .intermediate_size ) | |
| self .is_moe_layer =False | |
| def forward ( | |
| self , | |
| hidden_states :torch .Tensor , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| position_ids :Optional [torch .Tensor ]=None , | |
| past_key_value :Optional [KVCache ]=None , | |
| output_attentions :bool =False , | |
| use_cache :bool =False , | |
| )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [KVCache ],Optional [torch .Tensor ]]: | |
| residual =hidden_states | |
| hidden_states =self .input_layernorm (hidden_states ) | |
| hidden_states ,_ ,present_key_value =self .self_attn ( | |
| hidden_states =hidden_states , | |
| attention_mask =attention_mask , | |
| position_ids =position_ids , | |
| past_key_value =past_key_value , | |
| output_attentions =output_attentions , | |
| use_cache =use_cache , | |
| ) | |
| hidden_states =residual +hidden_states | |
| residual =hidden_states | |
| hidden_states =self .post_attention_layernorm (hidden_states ) | |
| aux_loss =None | |
| if self .is_moe_layer : | |
| hidden_states ,aux_loss =self .mlp (hidden_states ) | |
| else : | |
| hidden_states =self .mlp (hidden_states ) | |
| hidden_states =residual +hidden_states | |
| return hidden_states ,None ,present_key_value ,aux_loss | |
| class MoELlamaModelOutput : | |
| last_hidden_state :torch .Tensor | |
| past_key_values :Optional [List [KVCache ]]=None | |
| hidden_states :Optional [Tuple [torch .Tensor ]]=None | |
| attentions :Optional [Tuple [torch .Tensor ]]=None | |
| aux_loss :Optional [torch .Tensor ]=None | |
| class MoELlamaModel (nn .Module ): | |
| """MoE LLaMA Model with MLA and Ring Attention.""" | |
| def __init__ (self ,config ,moe_config :dict =None ): | |
| super ().__init__ () | |
| self .config =config | |
| self .moe_config =moe_config | |
| self .gradient_checkpointing =False | |
| self .embed_tokens =nn .Embedding (config .vocab_size ,config .hidden_size ) | |
| self .layers =nn .ModuleList ([ | |
| MoELlamaDecoderLayer (config ,layer_idx ,moe_config ) | |
| for layer_idx in range (config .num_hidden_layers ) | |
| ]) | |
| self .norm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps ) | |
| self .num_moe_layers =sum (1 for layer in self .layers if layer .is_moe_layer ) | |
| # ── Coconut: Continuous Thought components ── | |
| # Learned gate controls how much recurrent thought vs original input | |
| # to retain at each thinking step. Sigmoid output in [0,1]. | |
| self .thought_gate = nn .Linear (config .hidden_size , 1 , bias =True ) | |
| nn .init .constant_ (self .thought_gate .bias , -2.0 ) # Initialize gate biased toward original (sigmoid(-2)≈0.12) | |
| self .thought_layernorm = LlamaRMSNorm (config .hidden_size , eps =config .rms_norm_eps ) | |
| # Halt head: dynamically decides when to stop thinking | |
| self .thought_halt_head = nn .Linear (config .hidden_size , 1 , bias =True ) | |
| nn .init .constant_ (self .thought_halt_head .bias , -2.0 ) # Biased toward continuing to think initially | |
| # Fast Ponder Block for hyper-efficient 10x faster latent reasoning | |
| # Bypasses O(N^2) attention, uses pure deep SwiGLU logic | |
| self .fast_ponder_block = DeepMoEExpert (config .hidden_size , config .intermediate_size , depth =3 ) | |
| self ._init_weights () | |
| def _init_weights (self ): | |
| nn .init .normal_ (self .embed_tokens .weight ,mean =0.0 ,std =0.02 ) | |
| def gradient_checkpointing_enable (self ): | |
| """Enable gradient checkpointing for memory efficiency.""" | |
| self .gradient_checkpointing =True | |
| def gradient_checkpointing_disable (self ): | |
| """Disable gradient checkpointing.""" | |
| self .gradient_checkpointing =False | |
| def forward ( | |
| self , | |
| input_ids :Optional [torch .Tensor ]=None , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| position_ids :Optional [torch .Tensor ]=None , | |
| inputs_embeds :Optional [torch .Tensor ]=None , | |
| past_key_values :Optional [List [KVCache ]]=None , | |
| use_cache :bool =False , | |
| output_attentions :bool =False , | |
| output_hidden_states :bool =False , | |
| return_dict :bool =True , | |
| cache_position :Optional [torch .Tensor ]=None , | |
| thinking_depth :int =0 , | |
| )->Union [Tuple ,MoELlamaModelOutput ]: | |
| if inputs_embeds is None : | |
| inputs_embeds =self .embed_tokens (input_ids ) | |
| hidden_states =inputs_embeds | |
| batch_size ,seq_len =hidden_states .shape [:2 ] | |
| if position_ids is None : | |
| position_ids =torch .arange (seq_len ,device =hidden_states .device ).unsqueeze (0 ).expand (batch_size ,-1 ) | |
| if past_key_values is None : | |
| past_key_values =[None ]*len (self .layers ) | |
| all_hidden_states =()if output_hidden_states else None | |
| all_attentions =()if output_attentions else None | |
| next_cache =[]if use_cache else None | |
| total_aux_loss =torch .tensor (0.0 ,device =hidden_states .device ,dtype =hidden_states .dtype ) | |
| for idx ,layer in enumerate (self .layers ): | |
| if output_hidden_states : | |
| all_hidden_states =all_hidden_states +(hidden_states ,) | |
| if self .gradient_checkpointing and self .training and not use_cache : | |
| def create_custom_forward (module ): | |
| def custom_forward (*inputs ): | |
| return module (*inputs ) | |
| return custom_forward | |
| layer_outputs =torch .utils .checkpoint .checkpoint ( | |
| create_custom_forward (layer ), | |
| hidden_states , | |
| attention_mask , | |
| position_ids , | |
| past_key_values [idx ], | |
| output_attentions , | |
| use_cache , | |
| use_reentrant =False , | |
| ) | |
| hidden_states ,attn_weights ,present_key_value ,aux_loss =layer_outputs | |
| else : | |
| hidden_states ,attn_weights ,present_key_value ,aux_loss =layer ( | |
| hidden_states =hidden_states , | |
| attention_mask =attention_mask , | |
| position_ids =position_ids , | |
| past_key_value =past_key_values [idx ], | |
| output_attentions =output_attentions , | |
| use_cache =use_cache , | |
| ) | |
| if use_cache : | |
| next_cache .append (present_key_value ) | |
| if aux_loss is not None : | |
| total_aux_loss =total_aux_loss +aux_loss | |
| if output_attentions and attn_weights is not None : | |
| all_attentions =all_attentions +(attn_weights ,) | |
| # ── Coconut: Continuous Thought Loop ── | |
| # After the normal pass, loop hidden states back through the | |
| # transformer layers for extra computation in latent space. | |
| # No tokens are decoded — pure continuous reasoning. | |
| if thinking_depth > 0 : | |
| original_hidden = hidden_states .clone () | |
| thought_position_ids = torch .arange ( | |
| seq_len , device =hidden_states .device | |
| ).unsqueeze (0 ).expand (batch_size , -1 ) | |
| for thought_step in range (thinking_depth ): | |
| # Check if we should halt thinking (only during inference or if forced) | |
| # We evaluate the halt head on the *current* hidden state of the last token | |
| halt_logits = self .thought_halt_head (hidden_states [:, -1:, :]) | |
| halt_prob = torch .sigmoid (halt_logits ) | |
| # If during generation we decide to stop, break early | |
| if not self .training and (halt_prob > 0.5 ).all (): | |
| break | |
| # Normalize before processing | |
| hidden_states = self .thought_layernorm (hidden_states ) | |
| # Run purely through the attention-free fast ponder block | |
| # This achieves ~10x speedup by completely bypassing the O(N^2) self-attention stack | |
| hidden_states = self .fast_ponder_block (hidden_states ) | |
| # Gated residual: blend thought with original | |
| # gate ∈ [0,1], initialized small so early training | |
| # stays close to original behavior | |
| gate = torch .sigmoid (self .thought_gate (hidden_states )) | |
| hidden_states = gate * hidden_states + (1.0 - gate ) * original_hidden | |
| hidden_states =self .norm (hidden_states ) | |
| if output_hidden_states : | |
| all_hidden_states =all_hidden_states +(hidden_states ,) | |
| return MoELlamaModelOutput ( | |
| last_hidden_state =hidden_states , | |
| past_key_values =next_cache if use_cache else None , | |
| hidden_states =all_hidden_states , | |
| attentions =all_attentions , | |
| aux_loss =total_aux_loss , | |
| ) | |
| class CausalLMOutput : | |
| loss :Optional [torch .Tensor ]=None | |
| logits :torch .Tensor =None | |
| past_key_values :Optional [List [KVCache ]]=None | |
| hidden_states :Optional [Tuple [torch .Tensor ]]=None | |
| attentions :Optional [Tuple [torch .Tensor ]]=None | |
| aux_loss :Optional [torch .Tensor ]=None | |
| class MoELlamaForCausalLM (nn .Module ): | |
| """MoE LLaMA for Causal Language Modeling with MLA and Ring Attention.""" | |
| def __init__ (self ,config ,moe_config :dict =None ): | |
| super ().__init__ () | |
| self .config =config | |
| self .moe_config =moe_config | |
| self .model =MoELlamaModel (config ,moe_config ) | |
| self .lm_head =nn .Linear (config .hidden_size ,config .vocab_size ,bias =False ) | |
| if getattr (config ,'tie_word_embeddings',True ): | |
| self .lm_head .weight =self .model .embed_tokens .weight | |
| self .apply (self ._init_weights ) | |
| def _init_weights (self ,module ): | |
| std =0.02 | |
| if isinstance (module ,nn .Linear ): | |
| nn .init .normal_ (module .weight ,mean =0.0 ,std =std ) | |
| if module .bias is not None : | |
| nn .init .zeros_ (module .bias ) | |
| elif isinstance (module ,nn .Embedding ): | |
| nn .init .normal_ (module .weight ,mean =0.0 ,std =std ) | |
| def get_input_embeddings (self )->nn .Embedding : | |
| return self .model .embed_tokens | |
| def set_input_embeddings (self ,value :nn .Embedding ): | |
| self .model .embed_tokens =value | |
| def get_output_embeddings (self )->nn .Linear : | |
| return self .lm_head | |
| def set_output_embeddings (self ,new_embeddings :nn .Linear ): | |
| self .lm_head =new_embeddings | |
| def gradient_checkpointing_enable (self ): | |
| """Enable gradient checkpointing for memory efficiency.""" | |
| self .model .gradient_checkpointing_enable () | |
| def gradient_checkpointing_disable (self ): | |
| """Disable gradient checkpointing.""" | |
| self .model .gradient_checkpointing_disable () | |
| def prepare_inputs_for_generation ( | |
| self , | |
| input_ids :torch .Tensor , | |
| past_key_values :Optional [List [KVCache ]]=None , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| inputs_embeds :Optional [torch .Tensor ]=None , | |
| **kwargs , | |
| )->dict : | |
| if past_key_values is not None : | |
| input_ids =input_ids [:,-1 :] | |
| position_ids =kwargs .get ("position_ids",None ) | |
| if attention_mask is not None and position_ids is None : | |
| position_ids =attention_mask .long ().cumsum (-1 )-1 | |
| position_ids .masked_fill_ (attention_mask ==0 ,1 ) | |
| if past_key_values is not None : | |
| position_ids =position_ids [:,-1 :] | |
| return { | |
| "input_ids":input_ids , | |
| "past_key_values":past_key_values , | |
| "use_cache":kwargs .get ("use_cache",True ), | |
| "position_ids":position_ids , | |
| "attention_mask":attention_mask , | |
| } | |
| def forward ( | |
| self , | |
| input_ids :Optional [torch .Tensor ]=None , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| position_ids :Optional [torch .Tensor ]=None , | |
| inputs_embeds :Optional [torch .Tensor ]=None , | |
| labels :Optional [torch .Tensor ]=None , | |
| past_key_values :Optional [List [KVCache ]]=None , | |
| use_cache :bool =False , | |
| output_attentions :bool =False , | |
| output_hidden_states :bool =False , | |
| return_dict :bool =True , | |
| cache_position :Optional [torch .Tensor ]=None , | |
| thinking_depth :int =0 , | |
| **kwargs , | |
| )->Union [Tuple ,CausalLMOutput ]: | |
| outputs =self .model ( | |
| input_ids =input_ids , | |
| attention_mask =attention_mask , | |
| position_ids =position_ids , | |
| inputs_embeds =inputs_embeds , | |
| past_key_values =past_key_values , | |
| use_cache =use_cache , | |
| output_attentions =output_attentions , | |
| output_hidden_states =output_hidden_states , | |
| return_dict =True , | |
| cache_position =cache_position , | |
| thinking_depth =thinking_depth , | |
| ) | |
| hidden_states =outputs .last_hidden_state | |
| aux_loss =outputs .aux_loss | |
| logits =self .lm_head (hidden_states ) | |
| loss =None | |
| if labels is not None : | |
| shift_logits =logits [...,:-1 ,:].contiguous () | |
| shift_labels =labels [...,1 :].contiguous () | |
| if shift_labels .dtype !=torch .long : | |
| shift_labels =shift_labels .long () | |
| valid_mask =(shift_labels !=-100 ) | |
| num_valid =valid_mask .sum ().item () | |
| if num_valid >0 : | |
| loss_fct =nn .CrossEntropyLoss (ignore_index =-100 ) | |
| loss =loss_fct ( | |
| shift_logits .view (-1 ,shift_logits .size (-1 )), | |
| shift_labels .view (-1 ) | |
| ) | |
| loss =torch .clamp (loss ,min =0.0 ,max =100.0 ) | |
| else : | |
| loss =torch .tensor (0.0 ,device =logits .device ,dtype =logits .dtype ,requires_grad =True ) | |
| return CausalLMOutput ( | |
| loss =loss , | |
| logits =logits , | |
| past_key_values =outputs .past_key_values , | |
| hidden_states =outputs .hidden_states , | |
| attentions =outputs .attentions , | |
| aux_loss =aux_loss , | |
| ) | |
| def generate ( | |
| self , | |
| input_ids :torch .Tensor , | |
| max_new_tokens :int =100 , | |
| temperature :float =1.0 , | |
| top_k :int =50 , | |
| top_p :float =0.9 , | |
| do_sample :bool =True , | |
| pad_token_id :Optional [int ]=None , | |
| eos_token_id :Optional [int ]=None , | |
| attention_mask :Optional [torch .Tensor ]=None , | |
| thinking_depth :int =0 , | |
| **kwargs , | |
| )->torch .Tensor : | |
| batch_size =input_ids .shape [0 ] | |
| device =input_ids .device | |
| past_key_values =None | |
| is_prefill =True # Deep thinking only on first pass (full context) | |
| if attention_mask is None : | |
| attention_mask =torch .ones_like (input_ids ) | |
| for _ in range (max_new_tokens ): | |
| model_inputs =self .prepare_inputs_for_generation ( | |
| input_ids , | |
| past_key_values =past_key_values , | |
| attention_mask =attention_mask , | |
| ) | |
| # Apply thinking depth only on prefill, not per-token steps | |
| current_depth = thinking_depth if is_prefill else 0 | |
| outputs =self .forward (**model_inputs ,use_cache =True ,return_dict =True ,thinking_depth =current_depth ) | |
| is_prefill =False | |
| next_token_logits =outputs .logits [:,-1 ,:] | |
| if temperature !=1.0 : | |
| next_token_logits =next_token_logits /temperature | |
| if do_sample : | |
| if top_k >0 : | |
| indices_to_remove =next_token_logits <torch .topk (next_token_logits ,top_k )[0 ][...,-1 ,None ] | |
| next_token_logits [indices_to_remove ]=float ('-inf') | |
| if top_p <1.0 : | |
| sorted_logits ,sorted_indices =torch .sort (next_token_logits ,descending =True ) | |
| cumulative_probs =torch .cumsum (F .softmax (sorted_logits ,dim =-1 ),dim =-1 ) | |
| sorted_indices_to_remove =cumulative_probs >top_p | |
| sorted_indices_to_remove [...,1 :]=sorted_indices_to_remove [...,:-1 ].clone () | |
| sorted_indices_to_remove [...,0 ]=0 | |
| indices_to_remove =sorted_indices_to_remove .scatter (1 ,sorted_indices ,sorted_indices_to_remove ) | |
| next_token_logits [indices_to_remove ]=float ('-inf') | |
| probs =F .softmax (next_token_logits ,dim =-1 ) | |
| next_tokens =torch .multinomial (probs ,num_samples =1 ).squeeze (-1 ) | |
| else : | |
| next_tokens =torch .argmax (next_token_logits ,dim =-1 ) | |
| input_ids =torch .cat ([input_ids ,next_tokens .unsqueeze (-1 )],dim =-1 ) | |
| attention_mask =torch .cat ([attention_mask ,torch .ones ((batch_size ,1 ),device =device )],dim =-1 ) | |
| past_key_values =outputs .past_key_values | |
| if eos_token_id is not None and (next_tokens ==eos_token_id ).all (): | |
| break | |
| return input_ids | |
| ============================================================================== | |
| MODELS.XORON | |
| ============================================================================== | |
| logger =logging .getLogger (__name__ ) | |
| MAX_HIDDEN =10000.0 | |
| def safe_clamp_tensor (x :torch .Tensor ,max_val :float =MAX_HIDDEN )->torch .Tensor : | |
| """Clamp tensor values for FP16 safety, handling NaN/Inf properly. | |
| WARNING: Only use for linear/hidden states, NOT for attention scores before softmax! | |
| For attention scores, use a max of ~11.0 to prevent exp() overflow. | |
| CRITICAL: torch.clamp does NOT fix NaN! clamp(nan, -10, 10) = nan | |
| Must use nan_to_num first. | |
| """ | |
| if x is None or x .numel ()==0 : | |
| return x | |
| x =torch .nan_to_num (x ,nan =0.0 ,posinf =max_val ,neginf =-max_val ) | |
| return x .clamp (-max_val ,max_val ) | |
| COMPONENT_GROUPS ={ | |
| 'vision':['vision_encoder','projector'], | |
| 'video':['video_encoder'], | |
| 'audio':['audio_encoder','audio_decoder','audio_projector','waveform_decoder'], | |
| 'speech':['waveform_decoder'], | |
| 'llm':['llm'], | |
| 'cross_attention':['cross_attention_layers'], | |
| 'image_generation':['generator'], | |
| 'video_generation':['video_generator'], | |
| 'modality_markers':['image_start','image_end','video_start','video_end','audio_start','audio_end'], | |
| } | |
| class MultimodalModelOutput (dict ): | |
| """Output class for multimodal model.""" | |
| def __getattr__ (self ,name ): | |
| try : | |
| return self [name ] | |
| except KeyError : | |
| raise AttributeError (f"'{type (self ).__name__ }' has no attribute '{name }'") | |
| def __setattr__ (self ,name ,value ): | |
| self [name ]=value | |
| class XoronMultimodalModel (nn .Module ): | |
| """ | |
| Xoron-Dev: Complete multimodal model with: | |
| - Image/video understanding (CLIP) | |
| - Text generation (MoE LLM) | |
| - Image/video generation (MobileDiffusion) | |
| - Voice understanding and generation (ASR/TTS) | |
| - Cross-attention for multimodal fusion | |
| - LoRA support for efficient fine-tuning | |
| - Flash Attention for faster training | |
| - Model Parallelism support for multi-GPU training | |
| """ | |
| def __init__ (self ,config :XoronConfig ,device_map :Dict [str ,str ]=None ): | |
| super ().__init__ () | |
| self .config =config | |
| self .device_map =device_map | |
| if device_map is not None : | |
| device_values =[v for v in device_map .values ()if isinstance (v ,str )] | |
| self ._model_parallel =len (set (device_values ))>1 | |
| else : | |
| self ._model_parallel =False | |
| logger .info ("Initializing Xoron-Dev Multimodal Model Build") | |
| if self ._model_parallel : | |
| logger .info (" ⚡ Model Parallelism: ENABLED") | |
| self .vision_encoder =VisionEncoder (config .vision_model_name ,freeze =config .freeze_vision ) | |
| self .video_encoder =VideoEncoder (self .vision_encoder ,max_frames =config .video_max_frames ) | |
| logger .info ("Building SOTA Audio Encoder...") | |
| self .audio_encoder =AudioEncoder ( | |
| hidden_size =config .hidden_size , | |
| n_mels =80 , | |
| max_audio_length =3000 , | |
| use_raw_waveform =getattr (config ,'use_raw_waveform',True ), | |
| ) | |
| logger .info ("Building SOTA Audio Decoder...") | |
| self .audio_decoder =AudioDecoder ( | |
| hidden_size =config .hidden_size , | |
| n_mels =80 , | |
| max_audio_length =1000 , | |
| ) | |
| logger .info ("Building Raw Waveform Decoder (Speech-to-Speech)...") | |
| self .waveform_decoder =RawWaveformDecoder ( | |
| hidden_size =config .hidden_size , | |
| sample_rate =getattr (config ,'audio_sample_rate',16000 ), | |
| ) | |
| llm_config =LlamaConfig ( | |
| vocab_size =config .vocab_size , | |
| hidden_size =config .hidden_size , | |
| intermediate_size =config .intermediate_size , | |
| num_hidden_layers =config .num_layers , | |
| num_attention_heads =config .num_heads , | |
| max_position_embeddings =config .max_position_embeddings , | |
| rms_norm_eps =1e-6 , | |
| tie_word_embeddings =getattr (config ,'tie_word_embeddings',True ), | |
| pad_token_id =0 , | |
| ) | |
| llm_config .use_flash_attention =config .use_flash_attention | |
| llm_config .use_ring_attention =getattr (config ,'use_ring_attention',True ) | |
| llm_config .ring_attention_chunk_size =getattr (config ,'ring_attention_chunk_size',4096 ) | |
| moe_config ={ | |
| 'use_moe':config .use_moe , | |
| 'num_experts':config .num_experts , | |
| 'num_experts_per_tok':config .num_experts_per_tok , | |
| 'moe_layer_freq':config .moe_layer_freq , | |
| 'intermediate_size':config .intermediate_size , | |
| } | |
| logger .info (f"Building LLM Core: {config .hidden_size }d, {config .num_layers }L") | |
| logger .info (f" 📏 Context: {config .max_position_embeddings //1024 }K positions") | |
| if config .use_ring_attention : | |
| logger .info (f" 🔄 Ring Attention Enabled (chunk size: {config .ring_attention_chunk_size })") | |
| logger .info (f" 🎯 MoE: {config .num_experts } experts, top-{config .num_experts_per_tok }") | |
| self .llm =MoELlamaForCausalLM (llm_config ,moe_config ) | |
| logger .info (f" ✅ MoE layers initialized: {self .llm .model .num_moe_layers }/{config .num_layers }") | |
| self .projector =MultimodalProjector ( | |
| self .vision_encoder .hidden_size , | |
| config .hidden_size , | |
| config .num_vision_tokens | |
| ) | |
| logger .info (f" 🔗 Projector initialized: {self .vision_encoder .hidden_size } -> {config .hidden_size }") | |
| self .audio_projector =nn .Linear (config .hidden_size ,config .hidden_size ) | |
| self .image_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) | |
| self .image_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) | |
| self .video_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) | |
| self .video_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) | |
| self .audio_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) | |
| self .audio_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) | |
| self .cross_attention_layers =None | |
| if config .use_cross_attention : | |
| logger .info (f"Building Cross-Attention Fusion ({config .cross_attention_layers } layers)...") | |
| self .cross_attention_layers =nn .ModuleList ([ | |
| MultimodalFusionLayer ( | |
| hidden_size =config .hidden_size , | |
| num_heads =config .cross_attention_heads , | |
| dropout =config .cross_attention_dropout , | |
| use_flash_attention =config .use_flash_attention , | |
| ) | |
| for _ in range (config .cross_attention_layers ) | |
| ]) | |
| logger .info (f" ✅ Cross-attention: {config .cross_attention_layers } layers, {config .cross_attention_heads } heads") | |
| self .generator =None | |
| if config .enable_generation : | |
| logger .info ("Building MobileDiffusion Generators (Image & Video)...") | |
| self .generator =MobileDiffusionGenerator ( | |
| latent_channels =config .generation_latent_channels , | |
| base_channels =config .generation_base_channels , | |
| context_dim =config .hidden_size , | |
| num_inference_steps =config .generation_inference_steps , | |
| image_size =config .image_max_size , | |
| ) | |
| self .video_generator =None | |
| if config .enable_generation : | |
| self .video_generator =MobileVideoDiffusion ( | |
| latent_channels =config .generation_latent_channels , | |
| base_channels =config .generation_base_channels //2 , | |
| context_dim =config .hidden_size , | |
| num_frames =config .video_max_frames , | |
| image_size =config .video_max_size , | |
| num_inference_steps =config .generation_inference_steps , | |
| ) | |
| self .num_vision_tokens =config .num_vision_tokens | |
| self .video_max_frames =config .video_max_frames | |
| self .lora_applied =False | |
| self ._print_stats () | |
| logger .info ("Xoron-Dev Multimodal Model Build Complete") | |
| def apply_model_parallel (self ,device_map :Dict [str ,str ]): | |
| """Apply Model Parallelism by sharding components across devices. | |
| Trained components get their layers split across all training GPUs. | |
| Frozen components go to CPU. Small components (projectors, markers) | |
| go to the primary GPU. | |
| """ | |
| self .device_map =device_map | |
| training_gpus = device_map .get ('training_gpus', ['cuda:0']) | |
| primary = device_map .get ('primary', 'cuda:0') | |
| if len (training_gpus ) <= 1 and not any (v == 'cpu' for v in device_map .values () if isinstance (v, str)): | |
| logger .info (" ℹ️ Single device - no model parallelism needed") | |
| return self | |
| self ._model_parallel = True | |
| logger .info ("Applying Model Parallelism (layer sharding)...") | |
| def _shard_module (module, name, gpus): | |
| """Shard a module's sub-layers across GPUs.""" | |
| # Find shardable sub-layers (nn.ModuleList children) | |
| layer_lists = [] | |
| for attr_name in dir (module): | |
| attr = getattr (module, attr_name, None) | |
| if isinstance (attr, nn .ModuleList) and len (attr) > 0: | |
| layer_lists .append ((attr_name, attr)) | |
| if layer_lists: | |
| # Shard the largest ModuleList across GPUs | |
| layer_lists .sort (key=lambda x: len (x[1]), reverse=True) | |
| list_name, layers = layer_lists [0] | |
| for i, layer in enumerate (layers): | |
| target_gpu = gpus [i % len (gpus)] | |
| layer .to (target_gpu) | |
| # Put remaining params on primary GPU | |
| for param_name, param in module .named_parameters (): | |
| if not any (f'{list_name}.' in param_name for _ in [1]): | |
| param .data = param .data .to (gpus [0]) | |
| logger .info (f" ✅ {name}: {len(layers)} layers sharded across {gpus}") | |
| else: | |
| # No layers to shard — put whole module on first GPU | |
| module .to (gpus [0]) | |
| logger .info (f" ✅ {name} -> {gpus[0]}") | |
| # Map component names to actual attributes | |
| component_attrs = { | |
| 'vision_encoder': 'vision_encoder', | |
| 'video_encoder': 'video_encoder', | |
| 'audio_encoder': 'audio_encoder', | |
| 'audio_decoder': 'audio_decoder', | |
| 'waveform_decoder': 'waveform_decoder', | |
| 'projector': 'projector', | |
| 'audio_projector': 'audio_projector', | |
| 'llm': 'llm', | |
| 'cross_attention': 'cross_attention_layers', | |
| 'generator': 'generator', | |
| 'video_generator': 'video_generator', | |
| } | |
| for comp_name, attr_name in component_attrs .items (): | |
| comp = getattr (self, attr_name, None) | |
| if comp is None: | |
| continue | |
| target = device_map .get (comp_name, 'cpu') | |
| if target == 'cpu': | |
| comp .to ('cpu') | |
| logger .info (f" ❄️ {comp_name} -> cpu (frozen)") | |
| else: | |
| # Shard across all training GPUs | |
| _shard_module (comp, comp_name, training_gpus) | |
| # Modality markers → primary GPU | |
| marker_device = device_map .get ('modality_markers', primary) | |
| if marker_device != 'cpu': | |
| marker_device = primary | |
| for marker_name in ['image_start', 'image_end', 'video_start', 'video_end', 'audio_start', 'audio_end']: | |
| marker = getattr (self, marker_name, None) | |
| if marker is not None: | |
| setattr (self, marker_name, nn .Parameter (marker .data .to (marker_device))) | |
| logger .info (f" ✅ Modality markers -> {marker_device}") | |
| logger .info ("Model Parallelism applied successfully!") | |
| return self | |
| def get_llm_device (self ): | |
| """Get the device where LLM is located.""" | |
| if self .device_map is not None : | |
| return torch .device (self .device_map ['llm']) | |
| return next (self .llm .parameters ()).device | |
| def generate (self ,*args ,**kwargs ): | |
| """ | |
| Delegates generation to the internal LLM. | |
| This allows the model to be treated as a causal LM in many pipelines. | |
| """ | |
| return self .llm .generate (*args ,**kwargs ) | |
| def get_encoder_device (self ): | |
| """Get the device where encoders are located.""" | |
| if self .device_map is not None : | |
| return torch .device (self .device_map ['vision_encoder']) | |
| return next (self .vision_encoder .parameters ()).device | |
| def apply_lora (self ): | |
| """ | |
| Apply LoRA to the LLM and optionally cross-attention layers. | |
| MEMORY OPTIMIZATION: | |
| - LoRA layers share base weights (no cloning) | |
| - Base weights in LoRA layers are frozen (requires_grad=False) | |
| - LoRA params (A, B, magnitude) are always trainable | |
| NOTE: This does NOT freeze other components! | |
| Component freezing is handled separately by freeze_components() based on | |
| training mode (--text, --video, --image, --voice flags). | |
| This allows PARALLEL FINE-TUNING: | |
| - LoRA adapters on LLM for efficient adaptation | |
| - Full weight training on active components (vision, audio, etc.) | |
| """ | |
| if self .lora_applied : | |
| logger .warning ("LoRA already applied") | |
| return | |
| if not self .config .use_lora : | |
| logger .info ("LoRA disabled in config") | |
| return | |
| lora_config =LoRAConfig ( | |
| r =self .config .lora_r , | |
| lora_alpha =self .config .lora_alpha , | |
| lora_dropout =self .config .lora_dropout , | |
| target_modules =list (self .config .lora_target_modules ), | |
| enable_lora =True , | |
| ) | |
| logger .info ("Applying LoRA to LLM Core...") | |
| self .llm =apply_lora_to_model (self .llm ,lora_config ) | |
| if self .cross_attention_layers is not None : | |
| logger .info ("Applying LoRA to cross-attention layers...") | |
| cross_attn_lora_config =LoRAConfig ( | |
| r =lora_config .r , | |
| lora_alpha =lora_config .lora_alpha , | |
| lora_dropout =lora_config .lora_dropout , | |
| target_modules =['q_proj','k_proj','v_proj','o_proj'], | |
| enable_lora =True , | |
| ) | |
| for i ,layer in enumerate (self .cross_attention_layers ): | |
| self .cross_attention_layers [i ]=apply_lora_to_model (layer ,cross_attn_lora_config ) | |
| self .lora_applied =True | |
| self ._print_stats () | |
| def get_trainable_params (self ): | |
| """ | |
| Get trainable parameters, respecting LoRA settings and component freezing. | |
| If train_lora_only=True and LoRA is applied: | |
| - Freezes all non-LoRA params | |
| - Returns only LoRA params | |
| Otherwise: | |
| - Returns all params with requires_grad=True | |
| - This includes both LoRA params AND unfrozen component weights | |
| - Allows parallel fine-tuning: LoRA + full weights on active components | |
| """ | |
| if self .config .train_lora_only and self .lora_applied : | |
| freeze_non_lora_params (self ) | |
| return get_lora_parameters (self ) | |
| return [p for p in self .parameters ()if p .requires_grad ] | |
| def _print_stats (self ): | |
| total =sum (p .numel ()for p in self .parameters ()) | |
| trainable =sum (p .numel ()for p in self .parameters ()if p .requires_grad ) | |
| logger .info ("Model Statistics:") | |
| logger .info (f" Total parameters: {total /1e6 :.1f}M") | |
| logger .info (f" Trainable parameters: {trainable /1e6 :.1f}M") | |
| if self .lora_applied : | |
| lora_params =sum (p .numel ()for n ,p in self .named_parameters ()if 'lora_'in n ) | |
| logger .info (f" LoRA parameters: {lora_params /1e6 :.2f}M") | |
| def encode_image (self ,pixel_values :torch .Tensor )->torch .Tensor : | |
| encoder_device =self .get_encoder_device () | |
| pixel_values =pixel_values .to (encoder_device ) | |
| vision_features =self .vision_encoder (pixel_values ) | |
| projected =self .projector (vision_features ) | |
| llm_device =self .get_llm_device () | |
| return projected .to (llm_device ) | |
| def encode_video (self ,video_frames :torch .Tensor )->torch .Tensor : | |
| encoder_device =self .get_encoder_device () | |
| video_frames =video_frames .to (encoder_device ) | |
| video_features =self .video_encoder (video_frames ) | |
| projected =self .projector (video_features ) | |
| llm_device =self .get_llm_device () | |
| return projected .to (llm_device ) | |
| def encode_audio (self ,audio_features :torch .Tensor )->torch .Tensor : | |
| encoder_device =self .get_encoder_device () | |
| audio_features =audio_features .to (encoder_device ) | |
| audio_embeds =self .audio_encoder (audio_features ) | |
| projected =self .audio_projector (audio_embeds ) | |
| llm_device =self .get_llm_device () | |
| return projected .to (llm_device ) | |
| def get_text_embeddings (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None )->torch .Tensor : | |
| llm_device =self .get_llm_device () | |
| input_ids =input_ids .to (llm_device ) | |
| embeddings =self .llm .model .embed_tokens (input_ids ) | |
| return embeddings | |
| def _apply_cross_attention ( | |
| self , | |
| text_embeds :torch .Tensor , | |
| image_embeds :torch .Tensor =None , | |
| video_embeds :torch .Tensor =None , | |
| audio_embeds :torch .Tensor =None , | |
| )->torch .Tensor : | |
| if self .cross_attention_layers is None : | |
| return text_embeds | |
| for fusion_layer in self .cross_attention_layers : | |
| text_embeds ,_ =fusion_layer ( | |
| text_hidden =text_embeds , | |
| image_hidden =image_embeds , | |
| video_hidden =video_embeds , | |
| audio_hidden =audio_embeds , | |
| use_cache =False , | |
| ) | |
| return text_embeds | |
| def forward ( | |
| self , | |
| input_ids :torch .Tensor , | |
| attention_mask :torch .Tensor =None , | |
| pixel_values :torch .Tensor =None , | |
| video_frames :torch .Tensor =None , | |
| audio_features :torch .Tensor =None , | |
| labels :torch .Tensor =None , | |
| ): | |
| """Forward pass - FP16 native.""" | |
| batch_size =input_ids .shape [0 ] | |
| llm_device =self .get_llm_device () | |
| input_ids_llm =input_ids .to (llm_device ) | |
| text_embeds =self .llm .model .embed_tokens (input_ids_llm ) | |
| text_embeds =safe_clamp_tensor (text_embeds ) | |
| device =text_embeds .device | |
| if attention_mask is not None : | |
| attention_mask =attention_mask .to (device ) | |
| if labels is not None : | |
| labels =labels .to (device ) | |
| image_embeds_for_cross =None | |
| video_embeds_for_cross =None | |
| audio_embeds_for_cross =None | |
| def has_content (tensor ): | |
| if tensor is None : | |
| return False | |
| if not isinstance (tensor ,torch .Tensor ): | |
| return False | |
| try : | |
| if tensor .numel ()==0 : | |
| return False | |
| return bool (tensor .any ()) | |
| except Exception : | |
| return False | |
| if has_content (pixel_values ): | |
| try : | |
| image_embeds =self .encode_image (pixel_values ) | |
| image_embeds =safe_clamp_tensor (image_embeds ) | |
| image_embeds_for_cross =image_embeds | |
| image_start =self .image_start .expand (batch_size ,-1 ,-1 ) | |
| image_end =self .image_end .expand (batch_size ,-1 ,-1 ) | |
| image_embeds =torch .cat ([image_start ,image_embeds ,image_end ],dim =1 ) | |
| text_embeds =torch .cat ([image_embeds ,text_embeds ],dim =1 ) | |
| text_embeds =safe_clamp_tensor (text_embeds ) | |
| if attention_mask is not None : | |
| image_mask =torch .ones (batch_size ,image_embeds .shape [1 ],device =device ) | |
| attention_mask =torch .cat ([image_mask ,attention_mask ],dim =1 ) | |
| if labels is not None : | |
| image_labels =torch .full ((batch_size ,image_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype ) | |
| labels =torch .cat ([image_labels ,labels ],dim =1 ) | |
| except Exception as e : | |
| logger .debug (f"Image encoding skipped: {e }") | |
| if has_content (video_frames ): | |
| try : | |
| video_embeds =self .encode_video (video_frames ) | |
| video_embeds =safe_clamp_tensor (video_embeds ) | |
| video_embeds_for_cross =video_embeds | |
| video_start =self .video_start .expand (batch_size ,-1 ,-1 ) | |
| video_end =self .video_end .expand (batch_size ,-1 ,-1 ) | |
| video_embeds =torch .cat ([video_start ,video_embeds ,video_end ],dim =1 ) | |
| text_embeds =torch .cat ([video_embeds ,text_embeds ],dim =1 ) | |
| text_embeds =safe_clamp_tensor (text_embeds ) | |
| if attention_mask is not None : | |
| video_mask =torch .ones (batch_size ,video_embeds .shape [1 ],device =device ) | |
| attention_mask =torch .cat ([video_mask ,attention_mask ],dim =1 ) | |
| if labels is not None : | |
| video_labels =torch .full ((batch_size ,video_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype ) | |
| labels =torch .cat ([video_labels ,labels ],dim =1 ) | |
| except Exception as e : | |
| logger .debug (f"Video encoding skipped: {e }") | |
| if has_content (audio_features ): | |
| try : | |
| audio_embeds =self .encode_audio (audio_features ) | |
| audio_embeds =safe_clamp_tensor (audio_embeds ) | |
| audio_embeds_for_cross =audio_embeds | |
| audio_start =self .audio_start .expand (batch_size ,-1 ,-1 ) | |
| audio_end =self .audio_end .expand (batch_size ,-1 ,-1 ) | |
| audio_embeds =torch .cat ([audio_start ,audio_embeds ,audio_end ],dim =1 ) | |
| text_embeds =torch .cat ([audio_embeds ,text_embeds ],dim =1 ) | |
| text_embeds =safe_clamp_tensor (text_embeds ) | |
| if attention_mask is not None : | |
| audio_mask =torch .ones (batch_size ,audio_embeds .shape [1 ],device =device ) | |
| attention_mask =torch .cat ([audio_mask ,attention_mask ],dim =1 ) | |
| if labels is not None : | |
| audio_labels =torch .full ((batch_size ,audio_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype ) | |
| labels =torch .cat ([audio_labels ,labels ],dim =1 ) | |
| except Exception as e : | |
| logger .debug (f"Audio encoding skipped: {e }") | |
| if self .cross_attention_layers is not None : | |
| try : | |
| text_embeds =self ._apply_cross_attention ( | |
| text_embeds , | |
| image_embeds =image_embeds_for_cross , | |
| video_embeds =video_embeds_for_cross , | |
| audio_embeds =audio_embeds_for_cross , | |
| ) | |
| text_embeds =safe_clamp_tensor (text_embeds ) | |
| except Exception as e : | |
| logger .debug (f"Cross-attention skipped: {e }") | |
| text_embeds =safe_clamp_tensor (text_embeds ) | |
| outputs =self .llm (inputs_embeds =text_embeds ,attention_mask =attention_mask ,labels =labels ) | |
| return MultimodalModelOutput ( | |
| loss =outputs .loss if hasattr (outputs ,'loss')else None , | |
| logits =outputs .logits if hasattr (outputs ,'logits')else None , | |
| aux_loss =outputs .aux_loss if hasattr (outputs ,'aux_loss')else None , | |
| ) | |
| def generate_image (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None ): | |
| """Generate image from text.""" | |
| if self .generator is None : | |
| raise ValueError ("Image generator not enabled") | |
| context =self .get_text_embeddings (input_ids ,attention_mask ) | |
| images =self .generator .generate (context ) | |
| return images | |
| def generate_video (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None , | |
| first_frame :torch .Tensor =None ,num_frames :int =None ): | |
| """Generate video from text (T2V) or from image (I2V).""" | |
| if self .video_generator is None : | |
| raise ValueError ("Video generator not enabled") | |
| context =self .get_text_embeddings (input_ids ,attention_mask ) | |
| context =context .mean (dim =1 ) | |
| if first_frame is not None : | |
| video =self .video_generator .generate_i2v (first_frame ,context ,num_frames ) | |
| else : | |
| video =self .video_generator .generate_t2v (context ,num_frames ) | |
| return video | |
| def generate_speech (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None ): | |
| """Generate speech (mel-spectrogram) from text (TTS).""" | |
| text_embeds =self .get_text_embeddings (input_ids ,attention_mask ) | |
| mel ,durations ,_ ,_ =self .audio_decoder (text_embeds ) | |
| return mel ,durations | |
| def speak ( | |
| self , | |
| input_ids :torch .Tensor , | |
| attention_mask :torch .Tensor =None , | |
| speaker_embedding :torch .Tensor =None , | |
| return_mel :bool =False , | |
| )->torch .Tensor : | |
| """ | |
| Generate playable audio waveform from text (Speech-to-Speech TTS). | |
| This is the main method for making the model talk. It converts text | |
| directly to audio waveform without needing an external vocoder. | |
| Args: | |
| input_ids: [B, T] tokenized text input | |
| attention_mask: [B, T] attention mask | |
| speaker_embedding: [B, D] optional speaker embedding for voice cloning | |
| return_mel: If True, also return intermediate mel spectrogram | |
| Returns: | |
| waveform: [B, T_audio] raw audio waveform in [-1, 1] range at 16kHz | |
| Can be played directly or saved as WAV file | |
| mel (optional): [B, 80, T_mel] mel spectrogram if return_mel=True | |
| """ | |
| text_embeds =self .get_text_embeddings (input_ids ,attention_mask ) | |
| mel ,durations ,_ ,_ =self .audio_decoder ( | |
| text_embeds , | |
| speaker_embedding =speaker_embedding , | |
| ) | |
| mel_features =mel .transpose (1 ,2 ) | |
| if not hasattr (self ,'_mel_to_hidden'): | |
| self ._mel_to_hidden =nn .Linear (80 ,self .config .hidden_size ).to (mel .device ) | |
| audio_features =self ._mel_to_hidden (mel_features ) | |
| waveform =self .waveform_decoder (audio_features ) | |
| if return_mel : | |
| return waveform ,mel | |
| return waveform | |
| def listen (self ,audio_waveform :torch .Tensor )->torch .Tensor : | |
| """ | |
| Transcribe audio to text embeddings (Speech-to-Speech ASR). | |
| This is the listening component - converts speech to embeddings | |
| that can be fed to the LLM for understanding. | |
| Args: | |
| audio_waveform: [B, T_audio] raw audio waveform | |
| Returns: | |
| audio_embeds: [B, T, hidden_size] encoded audio features | |
| """ | |
| return self .encode_audio (audio_waveform ) | |
| def listen_and_respond ( | |
| self , | |
| audio_waveform :torch .Tensor , | |
| tokenizer =None , | |
| max_new_tokens :int =512 , | |
| speaker_embedding :torch .Tensor =None , | |
| temperature :float =0.7 , | |
| top_p :float =0.9 , | |
| tool_executor =None , | |
| available_tools :list =None , | |
| system_prompt :str =None , | |
| max_tool_calls :int =5 , | |
| ) -> Dict [str ,Any ]: | |
| """ | |
| Agentic Speech-to-Speech: Listen, think, use tools, speak back. | |
| This is the full agentic pipeline for live voice conversations. | |
| The model can detect when the user is asking for actions (e.g. | |
| "write me a Python script") and execute tools mid-generation. | |
| Pipeline: | |
| 1. Encode input audio → audio embeddings (ASR) | |
| 2. Build context (system prompt with tools + audio embeddings) | |
| 3. Generate tokens, watching for <|tool_call|> sequences | |
| 4. When tool call detected: parse, execute, inject result, resume | |
| 5. Synthesize final spoken response from non-tool text | |
| Args: | |
| audio_waveform: [B, T_audio] input audio waveform | |
| tokenizer: Tokenizer for decoding tokens to text (required for tools) | |
| max_new_tokens: Maximum total tokens to generate | |
| speaker_embedding: [B, D] optional speaker embedding for voice cloning | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling probability | |
| tool_executor: Callable(tool_name, args_dict) -> str result. | |
| If None, tool calls are detected but not executed. | |
| available_tools: List of tool definition dicts for system prompt. | |
| system_prompt: Optional system prompt override. | |
| max_tool_calls: Maximum number of tool calls per response (safety limit). | |
| Returns: | |
| Dict with: | |
| 'waveform': [B, T_response] audio waveform tensor (in-memory, no file I/O) | |
| 'text': str full response text (excluding tool call markup) | |
| 'token_ids': [B, T_tokens] all generated token IDs | |
| 'mel': [B, 80, T_mel] intermediate mel spectrogram | |
| 'tool_calls': List[Dict] executed tool calls and their results | |
| 'speaking_text': str clean text that was spoken (no tool markup) | |
| """ | |
| import re | |
| import json as _json | |
| device = audio_waveform .device | |
| batch_size = audio_waveform .shape [0 ] | |
| llm_device = self .get_llm_device () | |
| # ── 1. Listen: encode input audio ── | |
| audio_embeds = self .encode_audio (audio_waveform ) | |
| # Wrap with start/end markers | |
| audio_start = self .audio_start .expand (batch_size , -1 , -1 ).to (llm_device ) | |
| audio_end = self .audio_end .expand (batch_size , -1 , -1 ).to (llm_device ) | |
| audio_embeds = audio_embeds .to (llm_device ) | |
| # ── 2. Build context with system prompt + tools ── | |
| context_parts = [] | |
| if tokenizer is not None and (system_prompt or tool_executor): | |
| sys_text = system_prompt or "You are Xoron, an intelligent voice assistant. You can use tools to help the user." | |
| if tool_executor and hasattr (tool_executor , 'get_tool_prompt' ): | |
| sys_text = sys_text + "\n\n" + tool_executor .get_tool_prompt () | |
| elif available_tools : | |
| from utils .tool_executor import format_tools_for_prompt | |
| sys_text = sys_text + "\n\n" + format_tools_for_prompt (available_tools ) | |
| # Encode system prompt and prepend | |
| sys_str = "<|system|>" + sys_text + "<|/system|>" | |
| sys_token_ids = tokenizer .encode (sys_str , return_tensors ="pt" ).to (llm_device ) | |
| sys_embeds = self .llm .model .embed_tokens (sys_token_ids ) | |
| context_parts .append (sys_embeds .squeeze (0 ) if sys_embeds .dim () == 3 else sys_embeds ) | |
| # Audio context | |
| context_parts .extend ([audio_start , audio_embeds , audio_end ]) | |
| # Assistant generation prompt | |
| if tokenizer is not None : | |
| asst_str = "<|assistant|>" | |
| asst_ids = tokenizer .encode (asst_str , return_tensors ="pt" ).to (llm_device ) | |
| asst_embeds = self .llm .model .embed_tokens (asst_ids ) | |
| context_parts .append (asst_embeds .squeeze (0 ) if asst_embeds .dim () == 3 else asst_embeds ) | |
| input_embeds = torch .cat (context_parts , dim =1 ) | |
| # ── 3. Agentic generation loop with tool call detection ── | |
| tool_call_start_token = "<|tool_call|>" | |
| tool_call_end_token = "<|/tool_call|>" | |
| fn_name_start = "<|function_name|>" | |
| fn_name_end = "<|/function_name|>" | |
| fn_args_start = "<|function_args|>" | |
| fn_args_end = "<|/function_args|>" | |
| tool_result_start = "<|tool_result|>" | |
| tool_result_end = "<|/tool_result|>" | |
| eos_token = "<|eos|>" | |
| all_generated_ids = [] | |
| tool_calls_made = [] | |
| num_tool_calls = 0 | |
| generated_text = "" | |
| total_tokens = 0 | |
| # Use standard generation if no tool executor | |
| if tool_executor is None or tokenizer is None : | |
| gen_kwargs = { | |
| 'inputs_embeds': input_embeds , | |
| 'max_new_tokens': max_new_tokens , | |
| 'do_sample': True , | |
| 'temperature': temperature , | |
| 'top_p': top_p , | |
| 'use_cache': True , | |
| } | |
| generated_ids = self .llm .generate (**gen_kwargs ) | |
| all_generated_ids = [generated_ids ] | |
| if tokenizer is not None : | |
| generated_text = tokenizer .batch_decode (generated_ids , skip_special_tokens =True )[0 ] | |
| else : | |
| # Token-by-token generation with tool call detection | |
| current_embeds = input_embeds | |
| past_key_values = None | |
| in_tool_call = False | |
| tool_call_buffer = "" | |
| while total_tokens < max_new_tokens : | |
| outputs = self .llm ( | |
| inputs_embeds =current_embeds , | |
| past_key_values =past_key_values , | |
| use_cache =True , | |
| ) | |
| past_key_values = outputs .past_key_values | |
| logits = outputs .logits [:, -1 :, :] | |
| # Sample next token | |
| if temperature > 0 : | |
| logits = logits / temperature | |
| if top_p < 1.0 : | |
| sorted_logits , sorted_indices = torch .sort (logits , descending =True , dim =-1 ) | |
| cumulative_probs = torch .cumsum (F .softmax (sorted_logits , dim =-1 ), dim =-1 ) | |
| sorted_mask = cumulative_probs - F .softmax (sorted_logits , dim =-1 ) >= top_p | |
| sorted_logits [sorted_mask ] = float ('-inf' ) | |
| logits .scatter_ (-1 , sorted_indices , sorted_logits ) | |
| probs = F .softmax (logits , dim =-1 ) | |
| next_token = torch .multinomial (probs .squeeze (1 ), num_samples =1 ) | |
| else : | |
| next_token = logits .argmax (dim =-1 ) | |
| total_tokens += 1 | |
| all_generated_ids .append (next_token ) | |
| # Decode the token | |
| token_text = tokenizer .decode (next_token [0 ], skip_special_tokens =False ) | |
| generated_text = generated_text + token_text | |
| # Check for EOS | |
| if eos_token in token_text or next_token .item () == tokenizer .eos_token_id : | |
| break | |
| # ── Tool call detection ── | |
| if tool_call_start_token in generated_text and not in_tool_call : | |
| in_tool_call = True | |
| # Extract everything after the tool_call_start | |
| tc_start_idx = generated_text .rfind (tool_call_start_token ) | |
| tool_call_buffer = generated_text [tc_start_idx :] | |
| if in_tool_call : | |
| tool_call_buffer = tool_call_buffer + token_text if tool_call_buffer else generated_text | |
| # Check if we have a complete tool call | |
| if tool_call_end_token in tool_call_buffer : | |
| in_tool_call = False | |
| num_tool_calls += 1 | |
| # Parse the tool call | |
| tool_name = "" | |
| tool_args = {} | |
| try : | |
| # Extract function name | |
| name_start = tool_call_buffer .find (fn_name_start ) + len (fn_name_start ) | |
| name_end = tool_call_buffer .find (fn_name_end ) | |
| if name_start > 0 and name_end > 0 : | |
| tool_name = tool_call_buffer [name_start :name_end ].strip () | |
| # Extract arguments | |
| args_start = tool_call_buffer .find (fn_args_start ) + len (fn_args_start ) | |
| args_end = tool_call_buffer .find (fn_args_end ) | |
| if args_start > 0 and args_end > 0 : | |
| args_str = tool_call_buffer [args_start :args_end ].strip () | |
| try : | |
| import json as _json | |
| tool_args = _json .loads (args_str ) | |
| except Exception : | |
| tool_args = {"raw": args_str } | |
| except Exception : | |
| pass | |
| # Execute the tool | |
| tool_result = "[error]: Failed to parse tool call" | |
| if tool_name : | |
| tool_result = tool_executor (tool_name , tool_args ) | |
| tool_calls_made .append ({ | |
| "name": tool_name , | |
| "arguments": tool_args , | |
| "result": tool_result , | |
| }) | |
| # Inject tool result back into generation context | |
| result_str = tool_result_start + tool_result + tool_result_end | |
| result_ids = tokenizer .encode (result_str , return_tensors ="pt" ).to (llm_device ) | |
| result_embeds = self .llm .model .embed_tokens (result_ids ) | |
| current_embeds = result_embeds | |
| past_key_values = None # Reset KV cache to include result | |
| all_generated_ids .append (result_ids .squeeze (0 )) | |
| generated_text = generated_text + result_str | |
| tool_call_buffer = "" | |
| if num_tool_calls >= max_tool_calls : | |
| break | |
| continue | |
| # Prepare next input | |
| next_embeds = self .llm .model .embed_tokens (next_token ) | |
| current_embeds = next_embeds | |
| # Combine all generated IDs | |
| if all_generated_ids : | |
| flat_ids = [] | |
| for t in all_generated_ids : | |
| if t .dim () == 0 : | |
| flat_ids .append (t .unsqueeze (0 )) | |
| elif t .dim () == 1 : | |
| flat_ids .append (t ) | |
| else : | |
| flat_ids .append (t .view (-1 )) | |
| generated_ids = torch .cat (flat_ids , dim =0 ).unsqueeze (0 ) | |
| else : | |
| generated_ids = torch .tensor ([[]], dtype =torch .long , device =llm_device ) | |
| # ── 4. Extract speaking text (strip tool call/result markup) ── | |
| speaking_text = generated_text | |
| # Remove tool call blocks | |
| while tool_call_start_token in speaking_text : | |
| tc_s = speaking_text .find (tool_call_start_token ) | |
| tc_e = speaking_text .find (tool_call_end_token ) | |
| if tc_e > tc_s : | |
| speaking_text = speaking_text [:tc_s ] + speaking_text [tc_e + len (tool_call_end_token ):] | |
| else : | |
| break | |
| # Remove tool result blocks | |
| while tool_result_start in speaking_text : | |
| tr_s = speaking_text .find (tool_result_start ) | |
| tr_e = speaking_text .find (tool_result_end ) | |
| if tr_e > tr_s : | |
| speaking_text = speaking_text [:tr_s ] + speaking_text [tr_e + len (tool_result_end ):] | |
| else : | |
| break | |
| speaking_text = speaking_text .strip () | |
| # ── 5. Speak: encode → mel → stream_decode → waveform ── | |
| response_embeds = self .llm .model .embed_tokens (generated_ids .to (llm_device )) | |
| mel , durations , _ , _ = self .audio_decoder ( | |
| response_embeds , | |
| speaker_embedding =speaker_embedding , | |
| ) | |
| mel_features = mel .transpose (1 , 2 ) | |
| if not hasattr (self , '_mel_to_hidden' ): | |
| self ._mel_to_hidden = nn .Linear (80 , self .config .hidden_size ).to (mel .device ) | |
| audio_features = self ._mel_to_hidden (mel_features ) | |
| waveform = self .waveform_decoder .stream_decode (audio_features ) | |
| return { | |
| 'waveform': waveform , | |
| 'text': generated_text , | |
| 'speaking_text': speaking_text , | |
| 'token_ids': generated_ids , | |
| 'mel': mel , | |
| 'tool_calls': tool_calls_made , | |
| } | |
| def merge_lora_weights (self ): | |
| """Merge LoRA weights into main weights for inference.""" | |
| if not self .lora_applied : | |
| return | |
| for module in self .modules (): | |
| if isinstance (module ,LoRALinear ): | |
| module .merge_lora_weights () | |
| logger .info ("LoRA weights merged into base model") | |
| def unmerge_lora_weights (self ): | |
| """Unmerge LoRA weights for continued training.""" | |
| if not self .lora_applied : | |
| return | |
| for module in self .modules (): | |
| if isinstance (module ,LoRALinear ): | |
| module .unmerge_lora_weights () | |
| logger .info ("LoRA weights unmerged") | |
| def save_pretrained ( | |
| self , | |
| path :str , | |
| optimizer =None , | |
| scheduler =None , | |
| global_step :int =0 , | |
| epoch :int =0 , | |
| best_loss :float =float ('inf'), | |
| sharded :bool =False , | |
| max_shard_size :int =2 *1024 *1024 *1024 , | |
| save_separately :bool =True , | |
| ): | |
| """ | |
| Save model and optionally training state for resuming. | |
| Args: | |
| path: Directory to save the model | |
| optimizer: Optional optimizer to save state | |
| scheduler: Optional scheduler to save state | |
| global_step: Current training step | |
| epoch: Current epoch | |
| best_loss: Best loss achieved so far | |
| sharded: If True, save model in multiple .safetensors files | |
| max_shard_size: Maximum size per shard in bytes (default 2GB) | |
| save_separately: If True, save each component as separate .safetensors files (default) | |
| This avoids safetensors issues with shared storage in LSTM weights | |
| """ | |
| os .makedirs (path ,exist_ok =True ) | |
| if save_separately : | |
| self ._save_components_safe (path ) | |
| elif sharded : | |
| self ._save_sharded (path ,max_shard_size ) | |
| else : | |
| self ._save_single_file_safe (path ) | |
| config_dict =self .config .to_dict () | |
| config_dict ['has_audio_encoder']=True | |
| config_dict ['has_audio_decoder']=True | |
| config_dict ['has_waveform_decoder']=hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None | |
| config_dict ['has_vision_encoder']=hasattr (self ,'vision_encoder')and self .vision_encoder is not None | |
| config_dict ['has_video_encoder']=hasattr (self ,'video_encoder')and self .video_encoder is not None | |
| config_dict ['has_generator']=hasattr (self ,'generator')and self .generator is not None | |
| config_dict ['has_video_generator']=hasattr (self ,'video_generator')and self .video_generator is not None | |
| config_dict ['has_cross_attention']=hasattr (self ,'cross_attention_layers')and self .cross_attention_layers is not None | |
| config_dict ['lora_applied']=self .lora_applied | |
| config_dict ['architecture_version']=2 | |
| config_dict ['auto_map']={ | |
| 'AutoConfig':'configuration_xoron.XoronConfig', | |
| 'AutoModel':'modeling_xoron.XoronModel', | |
| 'AutoModelForCausalLM':'modeling_xoron.XoronForCausalLM', | |
| } | |
| with open (os .path .join (path ,"config.json"),"w")as f : | |
| json .dump (config_dict ,f ,indent =2 ) | |
| self ._copy_huggingface_files (path ) | |
| if optimizer is not None or scheduler is not None : | |
| training_state ={ | |
| 'global_step':global_step , | |
| 'epoch':epoch , | |
| 'best_loss':best_loss , | |
| } | |
| if optimizer is not None : | |
| training_state ['optimizer_state_dict']=optimizer .state_dict () | |
| if scheduler is not None : | |
| training_state ['scheduler_state_dict']=scheduler .state_dict () | |
| torch .save (training_state ,os .path .join (path ,"training_state.pt")) | |
| logger .info (f"Training state saved (step {global_step }, epoch {epoch })") | |
| logger .info (f"Model saved to {path }") | |
| def _copy_huggingface_files (self ,path :str ): | |
| """ | |
| Build and copy HuggingFace custom code files for trust_remote_code support. | |
| This DYNAMICALLY BUILDS a self-contained modeling_xoron.py by combining | |
| all model components, so users can load from HuggingFace Hub with: | |
| model = AutoModel.from_pretrained("repo/model", trust_remote_code=True) | |
| WITHOUT needing to install the full Xoron-Dev package. | |
| Args: | |
| path: Directory to save the files | |
| """ | |
| import shutil | |
| current_dir =os .path .dirname (os .path .abspath (__file__ )) | |
| project_root =os .path .dirname (current_dir ) | |
| config_src =os .path .join (project_root ,'configuration_xoron.py') | |
| config_dst =os .path .join (path ,'configuration_xoron.py') | |
| if os .path .exists (config_src ): | |
| shutil .copy2 (config_src ,config_dst ) | |
| logger .info ("Copied configuration_xoron.py") | |
| modeling_dst =os .path .join (path ,'modeling_xoron.py') | |
| self ._build_self_contained_modeling_file (project_root ,modeling_dst ) | |
| logger .info ("HuggingFace custom code files ready") | |
| def _build_self_contained_modeling_file (self ,project_root :str ,output_path :str ): | |
| """ | |
| Build a self-contained modeling_xoron.py by combining all model components. | |
| This creates a single file with ALL model code embedded, removing internal | |
| imports so it works standalone on HuggingFace without the full package. | |
| """ | |
| import re | |
| component_files =[ | |
| "models/components/lora.py", | |
| "models/components/attention.py", | |
| "models/components/projectors.py", | |
| "models/components/moe.py", | |
| "models/encoders/vision.py", | |
| "models/encoders/video.py", | |
| "models/encoders/audio.py", | |
| "models/generators/image.py", | |
| "models/generators/video.py", | |
| "models/llm/moe_llama.py", | |
| "models/xoron.py", | |
| ] | |
| internal_import_patterns =[ | |
| r"^from config import.*$", | |
| r"^from config\..*import.*$", | |
| r"^from models\..*import.*$", | |
| r"^from models import.*$", | |
| ] | |
| def is_internal_import (line ): | |
| line =line .strip () | |
| for pattern in internal_import_patterns : | |
| if re .match (pattern ,line ): | |
| return True | |
| return False | |
| def is_module_level_import (line ): | |
| """Check if this is a module-level import (no indentation).""" | |
| stripped =line .strip () | |
| if line and not line [0 ].isspace (): | |
| return (stripped .startswith ("import ")or stripped .startswith ("from ")) | |
| return False | |
| def extract_code_body (content ): | |
| """Extract code body, removing module docstring and module-level imports only.""" | |
| lines =content .split ('\n') | |
| code_lines =[] | |
| i =0 | |
| in_multiline_import =False | |
| while i <len (lines )and not lines [i ].strip (): | |
| i +=1 | |
| if i <len (lines ): | |
| stripped =lines [i ].strip () | |
| if stripped .startswith ('"""')or stripped .startswith ("'''"): | |
| docstring_char =stripped [:3 ] | |
| if stripped .count (docstring_char )>=2 : | |
| i +=1 | |
| else : | |
| i +=1 | |
| while i <len (lines ): | |
| if docstring_char in lines [i ]: | |
| i +=1 | |
| break | |
| i +=1 | |
| for line in lines [i :]: | |
| stripped =line .strip () | |
| if not code_lines and not stripped : | |
| continue | |
| if in_multiline_import : | |
| if ')'in stripped : | |
| in_multiline_import =False | |
| continue | |
| if is_module_level_import (line ): | |
| if '('in stripped and ')'not in stripped : | |
| in_multiline_import =True | |
| continue | |
| if stripped .startswith ("logger = logging.getLogger")and not line [0 ].isspace (): | |
| continue | |
| code_lines .append (line ) | |
| while code_lines and not code_lines [-1 ].strip (): | |
| code_lines .pop () | |
| return '\n'.join (code_lines ) | |
| header ='''""" | |
| Xoron Model for HuggingFace Transformers - Self-Contained Implementation. | |
| AUTO-GENERATED FILE - Do not edit directly! | |
| This module provides a complete, self-contained HuggingFace-compatible model class | |
| for the Xoron multimodal model. All components are embedded directly in this file | |
| to enable loading via AutoModel with trust_remote_code=True WITHOUT requiring | |
| the full Xoron-Dev package to be installed. | |
| Usage: | |
| from transformers import AutoModel, AutoConfig | |
| config = AutoConfig.from_pretrained("your-repo/xoron-model", trust_remote_code=True) | |
| model = AutoModel.from_pretrained("your-repo/xoron-model", trust_remote_code=True) | |
| """ | |
| try: | |
| from safetensors.torch import save_file, load_file | |
| except ImportError: | |
| save_file, load_file = None, None | |
| try: | |
| from transformers.models.llama.modeling_llama import ( | |
| LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaMLP, | |
| LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv | |
| ) | |
| except ImportError: | |
| LlamaAttention = LlamaDecoderLayer = LlamaRMSNorm = LlamaMLP = None | |
| LlamaRotaryEmbedding = apply_rotary_pos_emb = repeat_kv = None | |
| try: | |
| from .configuration_xoron import XoronConfig | |
| except ImportError: | |
| from configuration_xoron import XoronConfig | |
| ''' | |
| all_code =[header ] | |
| for filepath in component_files : | |
| full_path =os .path .join (project_root ,filepath ) | |
| if not os .path .exists (full_path ): | |
| logger .warning (f"Component not found: {filepath }") | |
| continue | |
| with open (full_path ,'r',encoding ='utf-8')as f : | |
| content =f .read () | |
| code =extract_code_body (content ) | |
| if code .strip (): | |
| section_name =filepath .replace ('/','.').replace ('.py','').upper () | |
| section_header = f"""\n\n {'='*78}\n {section_name}\n {'='*78}\n\n""" | |
| all_code .append (section_header +code ) | |
| hf_wrapper =''' | |
| class XoronPreTrainedModel(PreTrainedModel): | |
| """Base class for Xoron models providing HuggingFace integration.""" | |
| config_class = XoronConfig | |
| base_model_prefix = "xoron" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["XoronMultimodalModel"] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| def _init_weights(self, module): | |
| std = 0.02 | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| class XoronModel(XoronPreTrainedModel): | |
| """Xoron Multimodal Model for HuggingFace.""" | |
| def __init__(self, config: XoronConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self._internal_model = None | |
| self._model_initialized = False | |
| def _ensure_model_initialized(self): | |
| """Lazily initialize the internal model to avoid meta device conflicts.""" | |
| if not self._model_initialized: | |
| self._internal_model = XoronMultimodalModel(self.config) | |
| self._model_initialized = True | |
| @property | |
| def internal_model(self): | |
| self._ensure_model_initialized() | |
| return self._internal_model | |
| @classmethod | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| """ | |
| Load pretrained Xoron model from HuggingFace Hub or local path. | |
| This override ensures proper initialization without meta device conflicts. | |
| """ | |
| kwargs.pop('device_map', None) | |
| config = kwargs.pop('config', None) | |
| if config is None: | |
| config = XoronConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
| model = cls(config) | |
| model._internal_model = XoronMultimodalModel(config) | |
| model._model_initialized = True | |
| import os | |
| from safetensors import safe_open | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| model_path = pretrained_model_name_or_path | |
| else: | |
| from huggingface_hub import snapshot_download | |
| model_path = snapshot_download(repo_id=pretrained_model_name_or_path) | |
| components_json = os.path.join(model_path, "components.json") | |
| if os.path.exists(components_json): | |
| with open(components_json, 'r') as f: | |
| manifest = json.load(f) | |
| component_map = { | |
| 'llm': model._internal_model.llm, | |
| 'vision_encoder': model._internal_model.vision_encoder, | |
| 'video_encoder': model._internal_model.video_encoder, | |
| 'audio_encoder': model._internal_model.audio_encoder, | |
| 'audio_decoder': model._internal_model.audio_decoder, | |
| 'projector': model._internal_model.projector, | |
| 'audio_projector': model._internal_model.audio_projector, | |
| } | |
| if model._internal_model.cross_attention_layers is not None: | |
| component_map['cross_attention'] = model._internal_model.cross_attention_layers | |
| if model._internal_model.generator is not None: | |
| component_map['generator'] = model._internal_model.generator | |
| if model._internal_model.video_generator is not None: | |
| component_map['video_generator'] = model._internal_model.video_generator | |
| if hasattr(model._internal_model, 'waveform_decoder') and model._internal_model.waveform_decoder is not None: | |
| component_map['waveform_decoder'] = model._internal_model.waveform_decoder | |
| for comp_name in manifest.get('components', []): | |
| if comp_name == 'modality_markers': | |
| continue | |
| comp_path = os.path.join(model_path, f"{comp_name}.safetensors") | |
| if os.path.exists(comp_path) and comp_name in component_map: | |
| component = component_map[comp_name] | |
| if component is not None: | |
| with safe_open(comp_path, framework="pt") as f: | |
| state_dict = {k: f.get_tensor(k) for k in f.keys()} | |
| if comp_name == 'llm': | |
| embed_key = 'model.embed_tokens.weight' | |
| lm_head_key = 'lm_head.weight' | |
| if embed_key in state_dict: | |
| saved_vocab_size = state_dict[embed_key].shape[0] | |
| hidden_size = state_dict[embed_key].shape[1] | |
| current_vocab_size = component.model.embed_tokens.weight.shape[0] | |
| if saved_vocab_size != current_vocab_size: | |
| logger.info(f"Resizing embeddings: {current_vocab_size} -> {saved_vocab_size}") | |
| new_embed = nn.Embedding(saved_vocab_size, hidden_size) | |
| new_embed.weight.data = state_dict[embed_key] | |
| component.model.embed_tokens = new_embed | |
| if lm_head_key in state_dict: | |
| new_lm_head = nn.Linear(hidden_size, saved_vocab_size, bias=False) | |
| new_lm_head.weight.data = state_dict[lm_head_key] | |
| component.lm_head = new_lm_head | |
| del state_dict[embed_key] | |
| if lm_head_key in state_dict: | |
| del state_dict[lm_head_key] | |
| component.load_state_dict(state_dict, strict=False) | |
| logger.info(f"Loaded {comp_name}") | |
| markers_path = os.path.join(model_path, "modality_markers.safetensors") | |
| if os.path.exists(markers_path): | |
| with safe_open(markers_path, framework="pt") as f: | |
| model._internal_model.image_start.data = f.get_tensor('image_start') | |
| model._internal_model.image_end.data = f.get_tensor('image_end') | |
| model._internal_model.video_start.data = f.get_tensor('video_start') | |
| model._internal_model.video_end.data = f.get_tensor('video_end') | |
| model._internal_model.audio_start.data = f.get_tensor('audio_start') | |
| model._internal_model.audio_end.data = f.get_tensor('audio_end') | |
| logger.info("Loaded modality markers") | |
| logger.info(f"Xoron model loaded from {pretrained_model_name_or_path}") | |
| return model | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| video_frames: Optional[torch.Tensor] = None, | |
| audio_features: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| self._ensure_model_initialized() | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self._internal_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| images=pixel_values, | |
| video=video_frames, | |
| audio=audio_features, | |
| labels=labels, | |
| ) | |
| if return_dict: | |
| return CausalLMOutputWithPast( | |
| loss=outputs.get("loss"), | |
| logits=outputs.get("logits"), | |
| past_key_values=outputs.get("past_key_values"), | |
| hidden_states=outputs.get("hidden_states"), | |
| attentions=outputs.get("attentions"), | |
| ) | |
| return (outputs.get("loss"), outputs.get("logits")) | |
| def generate_image(self, prompt_embeds: torch.Tensor, **kwargs): | |
| self._ensure_model_initialized() | |
| return self._internal_model.generate_image(prompt_embeds, **kwargs) | |
| def generate_video(self, prompt_embeds: torch.Tensor, **kwargs): | |
| self._ensure_model_initialized() | |
| return self._internal_model.generate_video(prompt_embeds, **kwargs) | |
| def generate_speech(self, text_embeds: torch.Tensor, **kwargs): | |
| self._ensure_model_initialized() | |
| return self._internal_model.generate_speech(text_embeds, **kwargs) | |
| class XoronForCausalLM(XoronModel): | |
| """Alias for XoronModel for compatibility.""" | |
| pass | |
| XoronConfig.register_for_auto_class() | |
| XoronModel.register_for_auto_class("AutoModel") | |
| XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM") | |
| ''' | |
| all_code .append (hf_wrapper ) | |
| final_content ='\n'.join (all_code ) | |
| with open (output_path ,'w',encoding ='utf-8')as f : | |
| f .write (final_content ) | |
| line_count =final_content .count ('\n') | |
| logger .info (f"Built self-contained modeling_xoron.py ({line_count :,} lines)") | |
| def _save_single_file_safe (self ,path :str ): | |
| """ | |
| Save model as single safetensors file with cloned tensors. | |
| Cloning breaks shared storage that causes safetensors errors. | |
| Args: | |
| path: Directory to save the model | |
| """ | |
| from safetensors .torch import save_file | |
| state_dict =self .state_dict () | |
| safe_state_dict ={} | |
| for key ,tensor in state_dict .items (): | |
| safe_state_dict [key ]=tensor .clone ().contiguous () | |
| save_file (safe_state_dict ,os .path .join (path ,"model.safetensors")) | |
| size_mb =sum (t .numel ()*t .element_size ()for t in safe_state_dict .values ())/(1024 *1024 ) | |
| logger .info (f"Saved model.safetensors ({size_mb :.1f} MB)") | |
| def _save_components_safe (self ,path :str ): | |
| """ | |
| Save model components as separate .safetensors files with cloned tensors. | |
| This is the default and most robust saving method that: | |
| 1. Handles LSTM weight sharing issues in safetensors | |
| 2. Allows surgical component loading/updates | |
| 3. Better for debugging and inspection | |
| Args: | |
| path: Directory to save component files | |
| """ | |
| from safetensors .torch import save_file | |
| os .makedirs (path ,exist_ok =True ) | |
| component_map ={ | |
| 'llm':self .llm , | |
| 'vision_encoder':self .vision_encoder , | |
| 'video_encoder':self .video_encoder , | |
| 'audio_encoder':self .audio_encoder , | |
| 'audio_decoder':self .audio_decoder , | |
| 'projector':self .projector , | |
| 'audio_projector':self .audio_projector , | |
| } | |
| if self .cross_attention_layers is not None : | |
| component_map ['cross_attention']=self .cross_attention_layers | |
| if self .generator is not None : | |
| component_map ['generator']=self .generator | |
| if self .video_generator is not None : | |
| component_map ['video_generator']=self .video_generator | |
| if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None : | |
| component_map ['waveform_decoder']=self .waveform_decoder | |
| saved_files =[] | |
| total_size =0 | |
| for comp_name ,component in component_map .items (): | |
| if component is None : | |
| continue | |
| comp_state =component .state_dict () | |
| if not comp_state : | |
| continue | |
| safe_comp_state ={} | |
| for key ,tensor in comp_state .items (): | |
| safe_comp_state [key ]=tensor .clone ().contiguous () | |
| comp_path =os .path .join (path ,f"{comp_name }.safetensors") | |
| save_file (safe_comp_state ,comp_path ) | |
| size_mb =sum (t .numel ()*t .element_size ()for t in safe_comp_state .values ())/(1024 *1024 ) | |
| total_size +=size_mb | |
| logger .info (f"Saved {comp_name }: {size_mb :.1f} MB") | |
| saved_files .append (comp_name ) | |
| markers ={ | |
| 'image_start':self .image_start .data .clone ().contiguous (), | |
| 'image_end':self .image_end .data .clone ().contiguous (), | |
| 'video_start':self .video_start .data .clone ().contiguous (), | |
| 'video_end':self .video_end .data .clone ().contiguous (), | |
| 'audio_start':self .audio_start .data .clone ().contiguous (), | |
| 'audio_end':self .audio_end .data .clone ().contiguous (), | |
| } | |
| save_file (markers ,os .path .join (path ,"modality_markers.safetensors")) | |
| logger .info ("Saved modality_markers") | |
| manifest ={ | |
| "components":saved_files +["modality_markers"], | |
| "save_format":"components", | |
| } | |
| with open (os .path .join (path ,"components.json"),"w")as f : | |
| json .dump (manifest ,f ,indent =2 ) | |
| weight_map ={} | |
| total_bytes =0 | |
| for comp_name ,component in component_map .items (): | |
| if component is None : | |
| continue | |
| comp_state =component .state_dict () | |
| if not comp_state : | |
| continue | |
| safetensor_file =f"{comp_name }.safetensors" | |
| for key in comp_state .keys (): | |
| full_key =f"{comp_name }.{key }" | |
| weight_map [full_key ]=safetensor_file | |
| total_bytes +=comp_state [key ].numel ()*comp_state [key ].element_size () | |
| marker_names =['image_start','image_end','video_start','video_end','audio_start','audio_end'] | |
| for marker_name in marker_names : | |
| weight_map [marker_name ]="modality_markers.safetensors" | |
| marker_tensor =getattr (self ,marker_name ) | |
| total_bytes +=marker_tensor .numel ()*marker_tensor .element_size () | |
| index ={ | |
| "metadata":{ | |
| "total_size":total_bytes , | |
| "format":"components", | |
| }, | |
| "weight_map":weight_map , | |
| } | |
| index_path =os .path .join (path ,"model.safetensors.index.json") | |
| with open (index_path ,"w")as f : | |
| json .dump (index ,f ,indent =2 ) | |
| logger .info ("Saved model.safetensors.index.json for HuggingFace compatibility") | |
| logger .info (f"Total size: {total_size :.1f} MB across {len (saved_files )} components") | |
| def _save_sharded (self ,path :str ,max_shard_size :int ): | |
| """ | |
| Save model weights in sharded .safetensors files. | |
| Components are surgically split across shards. | |
| Args: | |
| path: Directory to save shards | |
| max_shard_size: Maximum bytes per shard | |
| """ | |
| from safetensors .torch import save_file | |
| state_dict =self .state_dict () | |
| component_groups ={ | |
| 'llm':{}, | |
| 'vision_encoder':{}, | |
| 'video_encoder':{}, | |
| 'audio_encoder':{}, | |
| 'audio_decoder':{}, | |
| 'waveform_decoder':{}, | |
| 'generator':{}, | |
| 'video_generator':{}, | |
| 'projector':{}, | |
| 'audio_projector':{}, | |
| 'cross_attention_layers':{}, | |
| 'other':{}, | |
| } | |
| for key ,tensor in state_dict .items (): | |
| placed =False | |
| for comp_name in component_groups .keys (): | |
| if comp_name !='other'and key .startswith (comp_name ): | |
| component_groups [comp_name ][key ]=tensor | |
| placed =True | |
| break | |
| if not placed : | |
| component_groups ['other'][key ]=tensor | |
| shards =[] | |
| current_shard ={} | |
| current_size =0 | |
| shard_index_map ={} | |
| for comp_name ,comp_tensors in component_groups .items (): | |
| for key ,tensor in comp_tensors .items (): | |
| tensor_size =tensor .numel ()*tensor .element_size () | |
| if current_size +tensor_size >max_shard_size and current_shard : | |
| shards .append (current_shard ) | |
| current_shard ={} | |
| current_size =0 | |
| current_shard [key ]=tensor | |
| current_size +=tensor_size | |
| if current_shard : | |
| shards .append (current_shard ) | |
| total_shards =len (shards ) | |
| weight_map ={} | |
| for i ,shard in enumerate (shards ): | |
| shard_name =f"model-{i +1 :05d}-of-{total_shards :05d}.safetensors" | |
| shard_path =os .path .join (path ,shard_name ) | |
| shard_contiguous ={k :v .clone ().contiguous ()for k ,v in shard .items ()} | |
| save_file (shard_contiguous ,shard_path ) | |
| for key in shard .keys (): | |
| weight_map [key ]=shard_name | |
| shard_size_mb =sum (t .numel ()*t .element_size ()for t in shard .values ())/(1024 *1024 ) | |
| logger .info (f"Saved shard {i +1 }/{total_shards }: {shard_name } ({shard_size_mb :.1f} MB)") | |
| index ={ | |
| "metadata":{ | |
| "total_size":sum (t .numel ()*t .element_size ()for t in state_dict .values ()), | |
| "total_shards":total_shards , | |
| }, | |
| "weight_map":weight_map , | |
| } | |
| index_path =os .path .join (path ,"model.safetensors.index.json") | |
| with open (index_path ,"w")as f : | |
| json .dump (index ,f ,indent =2 ) | |
| logger .info ("Saved index: model.safetensors.index.json") | |
| def save_components_separately (self ,path :str ): | |
| """ | |
| Save model components as separate .safetensors files. | |
| Useful for surgical component updates and debugging. | |
| NOTE: This method now clones tensors to handle LSTM shared storage issues. | |
| Args: | |
| path: Directory to save component files | |
| """ | |
| from safetensors .torch import save_file | |
| os .makedirs (path ,exist_ok =True ) | |
| component_map ={ | |
| 'llm':self .llm , | |
| 'vision_encoder':self .vision_encoder , | |
| 'video_encoder':self .video_encoder , | |
| 'audio_encoder':self .audio_encoder , | |
| 'audio_decoder':self .audio_decoder , | |
| 'projector':self .projector , | |
| 'audio_projector':self .audio_projector , | |
| } | |
| if self .cross_attention_layers is not None : | |
| component_map ['cross_attention']=self .cross_attention_layers | |
| if self .generator is not None : | |
| component_map ['generator']=self .generator | |
| if self .video_generator is not None : | |
| component_map ['video_generator']=self .video_generator | |
| if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None : | |
| component_map ['waveform_decoder']=self .waveform_decoder | |
| saved_files =[] | |
| for comp_name ,component in component_map .items (): | |
| if component is None : | |
| continue | |
| comp_state =component .state_dict () | |
| if not comp_state : | |
| continue | |
| comp_state ={k :v .clone ().contiguous ()for k ,v in comp_state .items ()} | |
| comp_path =os .path .join (path ,f"{comp_name }.safetensors") | |
| save_file (comp_state ,comp_path ) | |
| size_mb =sum (t .numel ()*t .element_size ()for t in comp_state .values ())/(1024 *1024 ) | |
| logger .info (f"Saved {comp_name }: {size_mb :.1f} MB") | |
| saved_files .append (comp_name ) | |
| markers ={ | |
| 'image_start':self .image_start .data .clone ().contiguous (), | |
| 'image_end':self .image_end .data .clone ().contiguous (), | |
| 'video_start':self .video_start .data .clone ().contiguous (), | |
| 'video_end':self .video_end .data .clone ().contiguous (), | |
| 'audio_start':self .audio_start .data .clone ().contiguous (), | |
| 'audio_end':self .audio_end .data .clone ().contiguous (), | |
| } | |
| save_file (markers ,os .path .join (path ,"modality_markers.safetensors")) | |
| logger .info ("Saved modality_markers") | |
| manifest ={ | |
| "components":saved_files +["modality_markers"], | |
| "config":self .config .to_dict (), | |
| "lora_applied":self .lora_applied , | |
| } | |
| with open (os .path .join (path ,"components.json"),"w")as f : | |
| json .dump (manifest ,f ,indent =2 ) | |
| weight_map ={} | |
| total_bytes =0 | |
| for comp_name ,component in component_map .items (): | |
| if component is None : | |
| continue | |
| comp_state =component .state_dict () | |
| if not comp_state : | |
| continue | |
| safetensor_file =f"{comp_name }.safetensors" | |
| for key in comp_state .keys (): | |
| full_key =f"{comp_name }.{key }" | |
| weight_map [full_key ]=safetensor_file | |
| total_bytes +=comp_state [key ].numel ()*comp_state [key ].element_size () | |
| marker_names =['image_start','image_end','video_start','video_end','audio_start','audio_end'] | |
| for marker_name in marker_names : | |
| weight_map [marker_name ]="modality_markers.safetensors" | |
| marker_tensor =getattr (self ,marker_name ) | |
| total_bytes +=marker_tensor .numel ()*marker_tensor .element_size () | |
| index ={ | |
| "metadata":{ | |
| "total_size":total_bytes , | |
| "format":"components", | |
| }, | |
| "weight_map":weight_map , | |
| } | |
| index_path =os .path .join (path ,"model.safetensors.index.json") | |
| with open (index_path ,"w")as f : | |
| json .dump (index ,f ,indent =2 ) | |
| logger .info ("Saved model.safetensors.index.json for HuggingFace compatibility") | |
| logger .info (f"Components saved to {path }") | |
| def from_pretrained ( | |
| cls , | |
| path :str , | |
| device :str =None , | |
| device_map :Dict [str ,str ]=None , | |
| apply_lora :bool =True , | |
| strict :bool =False , | |
| )->'XoronMultimodalModel': | |
| """ | |
| Load a pretrained Xoron model from a checkpoint or final model directory. | |
| Args: | |
| path: Path to the saved model directory | |
| device: Device to load the model to (if not using device_map) | |
| device_map: Device map for model parallelism | |
| apply_lora: Whether to apply LoRA after loading | |
| strict: If False, allows loading weights even if architecture changed | |
| Returns: | |
| Loaded XoronMultimodalModel instance | |
| """ | |
| from safetensors import safe_open | |
| logger .info (f"Loading model from {path }...") | |
| config_path =os .path .join (path ,"config.json") | |
| if not os .path .exists (config_path ): | |
| raise FileNotFoundError (f"Config file not found at {config_path }") | |
| with open (config_path ,'r')as f : | |
| config_dict =json .load (f ) | |
| lora_was_applied =config_dict .pop ('lora_applied',False ) | |
| architecture_version =config_dict .pop ('architecture_version',1 ) | |
| has_waveform_decoder =config_dict .pop ('has_waveform_decoder',False ) | |
| has_vision_encoder =config_dict .pop ('has_vision_encoder',True ) | |
| has_video_encoder =config_dict .pop ('has_video_encoder',True ) | |
| has_generator =config_dict .pop ('has_generator',True ) | |
| has_video_generator =config_dict .pop ('has_video_generator',True ) | |
| has_cross_attention =config_dict .pop ('has_cross_attention',True ) | |
| config_dict .pop ('has_audio_encoder',None ) | |
| config_dict .pop ('has_audio_decoder',None ) | |
| logger .info (f"Saved model architecture (version {architecture_version }):") | |
| logger .info (f" - Waveform Decoder: {'✅'if has_waveform_decoder else '❌ (will init randomly)'}") | |
| logger .info (f" - Vision Encoder: {'✅'if has_vision_encoder else '❌'}") | |
| logger .info (f" - Video Encoder: {'✅'if has_video_encoder else '❌'}") | |
| logger .info (f" - Image Generator: {'✅'if has_generator else '❌'}") | |
| logger .info (f" - Video Generator: {'✅'if has_video_generator else '❌'}") | |
| logger .info (f" - Cross Attention: {'✅'if has_cross_attention else '❌'}") | |
| logger .info (f" - LoRA Applied: {'✅'if lora_was_applied else '❌'}") | |
| config =XoronConfig .from_dict (config_dict ) | |
| model =cls (config ,device_map =device_map ) | |
| if lora_was_applied: | |
| logger .info ("Checkpoint has LoRA weights. Applying LoRA structure before loading...") | |
| model .apply_lora () | |
| components_json =os .path .join (path ,"components.json") | |
| model_path =os .path .join (path ,"model.safetensors") | |
| if os .path .exists (components_json ): | |
| logger .info ("Loading from component-based format...") | |
| model ._load_components (path ,strict =strict ) | |
| model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights) | |
| elif os .path .exists (model_path ): | |
| logger .info ("Loading weights from safetensors...") | |
| if strict : | |
| load_model (model ,model_path ) | |
| else : | |
| checkpoint_state_dict ={} | |
| with safe_open (model_path ,framework ="pt",device ="cpu")as f : | |
| for key in f .keys (): | |
| checkpoint_state_dict [key ]=f .get_tensor (key ) | |
| model .load_state_dict (checkpoint_state_dict ,strict =False ) | |
| logger .info ("Loaded weights from checkpoint") | |
| model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights) | |
| else : | |
| pytorch_path =os .path .join (path ,"pytorch_model.bin") | |
| if os .path .exists (pytorch_path ): | |
| logger .info ("Loading weights from pytorch_model.bin...") | |
| checkpoint_state_dict =torch .load (pytorch_path ,map_location ='cpu') | |
| model .load_state_dict (checkpoint_state_dict ,strict =False ) | |
| logger .info ("Loaded weights from checkpoint") | |
| model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights) | |
| else : | |
| raise FileNotFoundError (f"No model weights found at {path }") | |
| if apply_lora and config .use_lora and not model .lora_applied : | |
| model .apply_lora () | |
| if device_map is not None : | |
| model .apply_model_parallel (device_map ) | |
| elif device is not None : | |
| model =model .to (device ) | |
| logger .info ("Model loaded successfully!") | |
| model ._print_stats () | |
| return model | |
| def _load_components (self ,path :str ,strict :bool =False ): | |
| """ | |
| Load model from component-based safetensors files. | |
| Args: | |
| path: Directory containing component files | |
| strict: If True, require exact match; if False, allow partial loading | |
| """ | |
| from safetensors import safe_open | |
| component_map ={ | |
| 'llm':self .llm , | |
| 'vision_encoder':self .vision_encoder , | |
| 'video_encoder':self .video_encoder , | |
| 'audio_encoder':self .audio_encoder , | |
| 'audio_decoder':self .audio_decoder , | |
| 'projector':self .projector , | |
| 'audio_projector':self .audio_projector , | |
| } | |
| if self .cross_attention_layers is not None : | |
| component_map ['cross_attention']=self .cross_attention_layers | |
| if self .generator is not None : | |
| component_map ['generator']=self .generator | |
| if self .video_generator is not None : | |
| component_map ['video_generator']=self .video_generator | |
| if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None : | |
| component_map ['waveform_decoder']=self .waveform_decoder | |
| for comp_name ,component in component_map .items (): | |
| if component is None : | |
| continue | |
| comp_path =os .path .join (path ,f"{comp_name }.safetensors") | |
| if not os .path .exists (comp_path ): | |
| continue | |
| try : | |
| checkpoint_state ={} | |
| with safe_open (comp_path ,framework ="pt",device ="cpu")as f : | |
| for key in f .keys (): | |
| checkpoint_state [key ]=f .get_tensor (key ) | |
| component .load_state_dict (checkpoint_state ,strict =strict ) | |
| size_mb =sum (t .numel ()*t .element_size ()for t in checkpoint_state .values ())/(1024 *1024 ) | |
| logger .info (f"Loaded {comp_name } ({size_mb :.1f} MB)") | |
| except Exception as e : | |
| logger .warning (f"Error loading {comp_name }: {e }") | |
| markers_path =os .path .join (path ,"modality_markers.safetensors") | |
| if os .path .exists (markers_path ): | |
| try : | |
| with safe_open (markers_path ,framework ="pt",device ="cpu")as f : | |
| self .image_start .data =f .get_tensor ('image_start') | |
| self .image_end .data =f .get_tensor ('image_end') | |
| self .video_start .data =f .get_tensor ('video_start') | |
| self .video_end .data =f .get_tensor ('video_end') | |
| self .audio_start .data =f .get_tensor ('audio_start') | |
| self .audio_end .data =f .get_tensor ('audio_end') | |
| logger .info ("Loaded modality_markers") | |
| except Exception as e : | |
| logger .warning (f"Error loading modality_markers: {e }") | |
| logger .info ("Components loaded successfully") | |
| def load_training_state (path :str )->Optional [Dict ]: | |
| """ | |
| Load training state from a checkpoint. | |
| Args: | |
| path: Path to the checkpoint directory | |
| Returns: | |
| Dictionary with training state or None if not found | |
| """ | |
| state_path =os .path .join (path ,"training_state.pt") | |
| if os .path .exists (state_path ): | |
| logger .info (f"Loading training state from {state_path }...") | |
| return torch .load (state_path ,map_location ='cpu') | |
| return None | |
| def freeze_components (self ,components :List [str ],hard_freeze :bool =True ): | |
| """ | |
| Freeze specific components of the model. | |
| IMPORTANT RULES: | |
| 1. LLM is NEVER frozen - it's trained from scratch and always needs full weight training | |
| 2. LoRA parameters are usually kept trainable, UNLESS hard_freeze=True | |
| Args: | |
| components: List of component group names to freeze. | |
| Valid groups: 'vision', 'video', 'audio', | |
| 'cross_attention', 'image_generation', 'video_generation', | |
| 'modality_markers' | |
| NOTE: 'llm' is NOT a valid group to freeze - will be ignored! | |
| hard_freeze: If True, completely freezes the component including its LoRA adapters. | |
| This prevents inactive components from updating via weight decay/momentum. | |
| """ | |
| if 'llm'in components : | |
| logger .warning ("Ignoring 'llm' in freeze list - LLM must always train (from scratch)") | |
| components =[c for c in components if c !='llm'] | |
| logger .info (f"Freezing components: {components } (hard_freeze={hard_freeze })") | |
| for group_name in components : | |
| if group_name not in COMPONENT_GROUPS : | |
| logger .warning (f" ⚠️ Unknown component group: {group_name }") | |
| continue | |
| for attr_name in COMPONENT_GROUPS [group_name ]: | |
| if hasattr (self ,attr_name ): | |
| component =getattr (self ,attr_name ) | |
| if component is not None : | |
| if isinstance (component ,nn .Parameter ): | |
| component .requires_grad =False | |
| elif isinstance (component ,nn .Module ): | |
| for name ,param in component .named_parameters (): | |
| path_lora ='lora_A'in name or 'lora_B'in name or 'magnitude'in name | |
| if hard_freeze or not path_lora : | |
| param .requires_grad =False | |
| logger .info (f"Frozen: {attr_name }") | |
| if self .lora_applied and not hard_freeze: | |
| enable_lora_training (self ) | |
| logger .info ("LoRA parameters remain trainable") | |
| self ._print_stats () | |
| def unfreeze_components (self ,components :List [str ]): | |
| """ | |
| Unfreeze specific components of the model. | |
| Args: | |
| components: List of component group names to unfreeze. | |
| """ | |
| logger .info (f"Unfreezing components: {components }") | |
| for group_name in components : | |
| if group_name not in COMPONENT_GROUPS : | |
| logger .warning (f" ⚠️ Unknown component group: {group_name }") | |
| continue | |
| for attr_name in COMPONENT_GROUPS [group_name ]: | |
| if hasattr (self ,attr_name ): | |
| component =getattr (self ,attr_name ) | |
| if component is not None : | |
| if isinstance (component ,nn .Parameter ): | |
| component .requires_grad =True | |
| elif isinstance (component ,nn .Module ): | |
| for param in component .parameters (): | |
| param .requires_grad =True | |
| logger .info (f"Unfrozen: {attr_name }") | |
| self ._print_stats () | |
| def freeze_all_except (self ,components :List [str ],hard_freeze :bool =True ): | |
| """ | |
| Freeze all components except the specified ones. | |
| NOTE: LLM is always kept trainable regardless of input - it's trained from scratch. | |
| Args: | |
| components: List of component group names to keep trainable. | |
| """ | |
| if 'llm'not in components : | |
| components =components +['llm'] | |
| all_groups =list (COMPONENT_GROUPS .keys ()) | |
| groups_to_freeze =[g for g in all_groups if g not in components ] | |
| self .freeze_components (groups_to_freeze ,hard_freeze =hard_freeze ) | |
| def get_trainable_component_names (self )->List [str ]: | |
| """Get list of component groups that have trainable parameters.""" | |
| trainable =[] | |
| for group_name ,attr_names in COMPONENT_GROUPS .items (): | |
| for attr_name in attr_names : | |
| if hasattr (self ,attr_name ): | |
| component =getattr (self ,attr_name ) | |
| if component is not None : | |
| if isinstance (component ,nn .Parameter ): | |
| if component .requires_grad : | |
| trainable .append (group_name ) | |
| break | |
| elif isinstance (component ,nn .Module ): | |
| if any (p .requires_grad for p in component .parameters ()): | |
| trainable .append (group_name ) | |
| break | |
| return trainable | |
| def get_frozen_component_names (self )->List [str ]: | |
| """Get list of component groups that are frozen (no trainable parameters).""" | |
| frozen =[] | |
| for group_name ,attr_names in COMPONENT_GROUPS .items (): | |
| has_component =False | |
| is_trainable =False | |
| for attr_name in attr_names : | |
| if hasattr (self ,attr_name ): | |
| component =getattr (self ,attr_name ) | |
| if component is not None : | |
| has_component =True | |
| if isinstance (component ,nn .Parameter ): | |
| if component .requires_grad : | |
| is_trainable =True | |
| break | |
| elif isinstance (component ,nn .Module ): | |
| if any (p .requires_grad for p in component .parameters ()): | |
| is_trainable =True | |
| break | |
| if has_component and not is_trainable : | |
| frozen .append (group_name ) | |
| return frozen | |
| def get_component_status (self )->tuple : | |
| """ | |
| Get tuple of (trainable_components, frozen_components) for display. | |
| Returns: | |
| tuple: (list of trainable component names, list of frozen component names) | |
| """ | |
| trainable =self .get_trainable_component_names () | |
| frozen =self .get_frozen_component_names () | |
| return trainable ,frozen | |
| class XoronPreTrainedModel(PreTrainedModel): | |
| """Base class for Xoron models providing HuggingFace integration.""" | |
| config_class = XoronConfig | |
| base_model_prefix = "xoron" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["XoronMultimodalModel"] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| def _init_weights(self, module): | |
| std = 0.02 | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| class XoronModel(XoronPreTrainedModel): | |
| """Xoron Multimodal Model for HuggingFace.""" | |
| def __init__(self, config: XoronConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self._internal_model = None | |
| self._model_initialized = False | |
| def _ensure_model_initialized(self): | |
| """Lazily initialize the internal model to avoid meta device conflicts.""" | |
| if not self._model_initialized: | |
| self._internal_model = XoronMultimodalModel(self.config) | |
| self._model_initialized = True | |
| def internal_model(self): | |
| self._ensure_model_initialized() | |
| return self._internal_model | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| """ | |
| Load pretrained Xoron model from HuggingFace Hub or local path. | |
| This override ensures proper initialization without meta device conflicts. | |
| """ | |
| kwargs.pop('device_map', None) | |
| config = kwargs.pop('config', None) | |
| if config is None: | |
| config = XoronConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
| model = cls(config) | |
| model._internal_model = XoronMultimodalModel(config) | |
| model._model_initialized = True | |
| import os | |
| from safetensors import safe_open | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| model_path = pretrained_model_name_or_path | |
| else: | |
| from huggingface_hub import snapshot_download | |
| model_path = snapshot_download(repo_id=pretrained_model_name_or_path) | |
| components_json = os.path.join(model_path, "components.json") | |
| if os.path.exists(components_json): | |
| with open(components_json, 'r') as f: | |
| manifest = json.load(f) | |
| component_map = { | |
| 'llm': model._internal_model.llm, | |
| 'vision_encoder': model._internal_model.vision_encoder, | |
| 'video_encoder': model._internal_model.video_encoder, | |
| 'audio_encoder': model._internal_model.audio_encoder, | |
| 'audio_decoder': model._internal_model.audio_decoder, | |
| 'projector': model._internal_model.projector, | |
| 'audio_projector': model._internal_model.audio_projector, | |
| } | |
| if model._internal_model.cross_attention_layers is not None: | |
| component_map['cross_attention'] = model._internal_model.cross_attention_layers | |
| if model._internal_model.generator is not None: | |
| component_map['generator'] = model._internal_model.generator | |
| if model._internal_model.video_generator is not None: | |
| component_map['video_generator'] = model._internal_model.video_generator | |
| if hasattr(model._internal_model, 'waveform_decoder') and model._internal_model.waveform_decoder is not None: | |
| component_map['waveform_decoder'] = model._internal_model.waveform_decoder | |
| for comp_name in manifest.get('components', []): | |
| if comp_name == 'modality_markers': | |
| continue | |
| comp_path = os.path.join(model_path, f"{comp_name}.safetensors") | |
| if os.path.exists(comp_path) and comp_name in component_map: | |
| component = component_map[comp_name] | |
| if component is not None: | |
| with safe_open(comp_path, framework="pt") as f: | |
| state_dict = {k: f.get_tensor(k) for k in f.keys()} | |
| if comp_name == 'llm': | |
| embed_key = 'model.embed_tokens.weight' | |
| lm_head_key = 'lm_head.weight' | |
| if embed_key in state_dict: | |
| saved_vocab_size = state_dict[embed_key].shape[0] | |
| hidden_size = state_dict[embed_key].shape[1] | |
| current_vocab_size = component.model.embed_tokens.weight.shape[0] | |
| if saved_vocab_size != current_vocab_size: | |
| logger.info(f"Resizing embeddings: {current_vocab_size} -> {saved_vocab_size}") | |
| new_embed = nn.Embedding(saved_vocab_size, hidden_size) | |
| new_embed.weight.data = state_dict[embed_key] | |
| component.model.embed_tokens = new_embed | |
| if lm_head_key in state_dict: | |
| new_lm_head = nn.Linear(hidden_size, saved_vocab_size, bias=False) | |
| new_lm_head.weight.data = state_dict[lm_head_key] | |
| component.lm_head = new_lm_head | |
| del state_dict[embed_key] | |
| if lm_head_key in state_dict: | |
| del state_dict[lm_head_key] | |
| component.load_state_dict(state_dict, strict=False) | |
| logger.info(f"Loaded {comp_name}") | |
| markers_path = os.path.join(model_path, "modality_markers.safetensors") | |
| if os.path.exists(markers_path): | |
| with safe_open(markers_path, framework="pt") as f: | |
| model._internal_model.image_start.data = f.get_tensor('image_start') | |
| model._internal_model.image_end.data = f.get_tensor('image_end') | |
| model._internal_model.video_start.data = f.get_tensor('video_start') | |
| model._internal_model.video_end.data = f.get_tensor('video_end') | |
| model._internal_model.audio_start.data = f.get_tensor('audio_start') | |
| model._internal_model.audio_end.data = f.get_tensor('audio_end') | |
| logger.info("Loaded modality markers") | |
| logger.info(f"Xoron model loaded from {pretrained_model_name_or_path}") | |
| return model | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| video_frames: Optional[torch.Tensor] = None, | |
| audio_features: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| self._ensure_model_initialized() | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self._internal_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| images=pixel_values, | |
| video=video_frames, | |
| audio=audio_features, | |
| labels=labels, | |
| ) | |
| if return_dict: | |
| return CausalLMOutputWithPast( | |
| loss=outputs.get("loss"), | |
| logits=outputs.get("logits"), | |
| past_key_values=outputs.get("past_key_values"), | |
| hidden_states=outputs.get("hidden_states"), | |
| attentions=outputs.get("attentions"), | |
| ) | |
| return (outputs.get("loss"), outputs.get("logits")) | |
| def generate_image(self, prompt_embeds: torch.Tensor, **kwargs): | |
| self._ensure_model_initialized() | |
| return self._internal_model.generate_image(prompt_embeds, **kwargs) | |
| def generate_video(self, prompt_embeds: torch.Tensor, **kwargs): | |
| self._ensure_model_initialized() | |
| return self._internal_model.generate_video(prompt_embeds, **kwargs) | |
| def generate_speech(self, text_embeds: torch.Tensor, **kwargs): | |
| self._ensure_model_initialized() | |
| return self._internal_model.generate_speech(text_embeds, **kwargs) | |
| class XoronForCausalLM(XoronModel): | |
| """Alias for XoronModel for compatibility.""" | |
| pass | |
| XoronConfig.register_for_auto_class() | |
| XoronModel.register_for_auto_class("AutoModel") | |
| XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM") | |