Xoron-Dev-MultiMoe / modeling_xoron.py
Backup-bdg's picture
Update model weights after training (epoch 3, loss 2.9528)
f367b84 verified
"""
Xoron Model for HuggingFace Transformers - Self-Contained Implementation.
AUTO-GENERATED FILE - Do not edit directly!
This module provides a complete, self-contained HuggingFace-compatible model class
for the Xoron multimodal model. All components are embedded directly in this file
to enable loading via AutoModel with trust_remote_code=True WITHOUT requiring
the full Xoron-Dev package to be installed.
Usage:
from transformers import AutoModel, AutoConfig
config = AutoConfig.from_pretrained("your-repo/xoron-model", trust_remote_code=True)
model = AutoModel.from_pretrained("your-repo/xoron-model", trust_remote_code=True)
"""
import os
import math
import json
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, List, Union, Tuple, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from safetensors.torch import save_file, load_file
except ImportError:
save_file, load_file = None, None
from transformers import PreTrainedModel, LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
try:
from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaMLP,
LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
)
except ImportError:
LlamaAttention = LlamaDecoderLayer = LlamaRMSNorm = LlamaMLP = None
LlamaRotaryEmbedding = apply_rotary_pos_emb = repeat_kv = None
try:
from .configuration_xoron import XoronConfig
except ImportError:
from configuration_xoron import XoronConfig
logger = logging.getLogger(__name__)
==============================================================================
MODELS.COMPONENTS.LORA
==============================================================================
class LoRALinear (nn .Module ):
"""
SOTA LoRA layer with multiple variants.
Supports:
- Standard LoRA
- DoRA (Weight-Decomposed LoRA)
- rsLoRA (rank-stabilized scaling)
MEMORY OPTIMIZATION:
- Does NOT clone base weights - shares them with original module
- Only LoRA params (A, B, magnitude) consume additional memory
- Base weights are frozen and can be kept in lower precision
"""
def __init__ (
self ,
in_features :int ,
out_features :int ,
r :int =8 ,
lora_alpha :int =16 ,
lora_dropout :float =0.05 ,
merge_weights :bool =False ,
use_dora :bool =False ,
use_rslora :bool =True ,
base_layer :nn .Linear =None ,
):
super ().__init__ ()
self .r =r
self .lora_alpha =lora_alpha
self .merge_weights =merge_weights
self .merged =False
self .use_dora =use_dora
self .use_rslora =use_rslora
self .in_features =in_features
self .out_features =out_features
if base_layer is not None :
self .linear =base_layer
else :
self .linear =nn .Linear (in_features ,out_features ,bias =False )
if r >0 :
self .lora_A =nn .Parameter (torch .zeros (r ,in_features ))
self .lora_B =nn .Parameter (torch .zeros (out_features ,r ))
if use_rslora :
self .scaling =lora_alpha /math .sqrt (r )
else :
self .scaling =lora_alpha /r
self .lora_dropout =nn .Dropout (p =lora_dropout )if lora_dropout >0 else nn .Identity ()
nn .init .kaiming_uniform_ (self .lora_A ,a =math .sqrt (5 ))
nn .init .zeros_ (self .lora_B )
if use_dora :
self .magnitude =nn .Parameter (torch .ones (out_features ))
self .linear .weight .requires_grad =False
if hasattr (self .linear ,'bias')and self .linear .bias is not None :
self .linear .bias .requires_grad =False
def forward (self ,x :torch .Tensor )->torch .Tensor :
if self .r >0 and not self .merged :
lora_out =self .lora_dropout (x )@self .lora_A .T @self .lora_B .T *self .scaling
if self .use_dora :
weight =self .linear .weight +(self .lora_B @self .lora_A )*self .scaling
weight_norm =weight .norm (dim =1 ,keepdim =True )
weight_normalized =weight /(weight_norm +1e-6 )
result =F .linear (x ,weight_normalized *self .magnitude .unsqueeze (1 ))
else :
result =self .linear (x )+lora_out
else :
result =self .linear (x )
return result
def merge_lora_weights (self ):
"""Merge LoRA weights into the main weights for inference."""
if self .r >0 and not self .merged :
delta =(self .lora_B @self .lora_A )*self .scaling
if self .use_dora :
weight =self .linear .weight +delta
weight_norm =weight .norm (dim =1 ,keepdim =True )
self .linear .weight .data =(weight /(weight_norm +1e-6 ))*self .magnitude .unsqueeze (1 )
else :
self .linear .weight .data +=delta
self .merged =True
def unmerge_lora_weights (self ):
"""Unmerge LoRA weights for continued training."""
if self .r >0 and self .merged :
self .linear .weight .data -=(self .lora_B @self .lora_A )*self .scaling
self .merged =False
class LoRAConfig :
"""
Configuration for SOTA LoRA adaptation.
Supports multiple LoRA variants and configurations.
"""
def __init__ (
self ,
r :int =8 ,
lora_alpha :int =16 ,
lora_dropout :float =0.05 ,
target_modules :Optional [List [str ]]=None ,
enable_lora :bool =True ,
use_dora :bool =False ,
use_rslora :bool =True ,
lora_plus_lr_ratio :float =16.0 ,
):
self .r =r
self .lora_alpha =lora_alpha
self .lora_dropout =lora_dropout
self .target_modules =target_modules or [
'q_proj','k_proj','v_proj','o_proj',
'gate_proj','up_proj','down_proj',
]
self .enable_lora =enable_lora
self .use_dora =use_dora
self .use_rslora =use_rslora
self .lora_plus_lr_ratio =lora_plus_lr_ratio
def apply_lora_to_model (model :nn .Module ,lora_config :LoRAConfig )->nn .Module :
"""
Apply LoRA to specified modules in a model.
Returns the model with LoRA layers applied.
MEMORY OPTIMIZATION:
- Passes the original nn.Linear layer directly to LoRALinear
- This SHARES weights instead of cloning them (saves ~50% memory for target modules)
- Only LoRA parameters (A, B, magnitude) are newly allocated
For a 16GB model with 30% of weights in target modules:
- Old behavior: Clone ~5GB = 21GB total
- New behavior: Share weights = 16GB + ~50MB LoRA params
"""
if not lora_config .enable_lora :
return model
lora_layers_added =0
modules_to_replace =[]
total_base_params =0
for name ,module in model .named_modules ():
if not isinstance (module ,nn .Linear ):
continue
module_name =name .split ('.')[-1 ]
if module_name in lora_config .target_modules :
modules_to_replace .append ((name ,module ))
total_base_params +=module .weight .numel ()
for name ,module in modules_to_replace :
parts =name .split ('.')
attr_name =parts [-1 ]
parent_name ='.'.join (parts [:-1 ])
if parent_name :
parent =model .get_submodule (parent_name )
else :
parent =model
lora_layer =LoRALinear (
in_features =module .in_features ,
out_features =module .out_features ,
r =lora_config .r ,
lora_alpha =lora_config .lora_alpha ,
lora_dropout =lora_config .lora_dropout ,
use_dora =lora_config .use_dora ,
use_rslora =lora_config .use_rslora ,
base_layer =module ,
)
setattr (parent ,attr_name ,lora_layer )
lora_layers_added +=1
lora_params =lora_layers_added *(lora_config .r *(modules_to_replace [0 ][1 ].in_features +modules_to_replace [0 ][1 ].out_features ))if modules_to_replace else 0
base_mem_saved_mb =(total_base_params *2 )/(1024 *1024 )
lora_mem_added_mb =(lora_params *4 )/(1024 *1024 )
variant ="DoRA"if lora_config .use_dora else ("rsLoRA"if lora_config .use_rslora else "LoRA")
print (f"✅ {variant } applied to {lora_layers_added } layers (r={lora_config .r }, alpha={lora_config .lora_alpha })")
print (f" 💾 Memory optimization: {base_mem_saved_mb :.1f}MB base weights SHARED (not cloned)")
print (f" 📊 New LoRA params: ~{lora_mem_added_mb :.1f}MB (trainable)")
return model
def get_lora_parameters (model :nn .Module )->List [nn .Parameter ]:
"""
Get only the LoRA parameters from a model.
NOTE: This does NOT change requires_grad on any parameters!
It simply returns the LoRA params (lora_A, lora_B, magnitude).
Use this when you want to get LoRA params for separate optimizer groups
or for LoRA-only training mode.
"""
lora_params =[]
for name ,param in model .named_parameters ():
if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name :
lora_params .append (param )
return lora_params
def enable_lora_training (model :nn .Module )->List [nn .Parameter ]:
"""
Enable training for LoRA parameters (ensure requires_grad=True).
Returns list of LoRA parameters.
"""
lora_params =[]
for name ,param in model .named_parameters ():
if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name :
param .requires_grad =True
lora_params .append (param )
return lora_params
def freeze_non_lora_params (model :nn .Module )->int :
"""
Freeze all non-LoRA parameters and clear their gradients.
USE THIS ONLY FOR LORA-ONLY TRAINING MODE (train_lora_only=True).
For normal training with parallel fine-tuning (LoRA + full weights on
active components), use the model's freeze_components() method instead,
which respects the training mode flags (--text, --video, --image, --voice).
Returns:
Number of frozen parameters
"""
frozen_params =0
freed_memory =0
for name ,param in model .named_parameters ():
is_lora ='lora_A'in name or 'lora_B'in name or 'magnitude'in name
if not is_lora :
param .requires_grad =False
frozen_params +=param .numel ()
if param .grad is not None :
freed_memory +=param .grad .numel ()*param .grad .element_size ()
param .grad =None
print (f" ❄️ Frozen {frozen_params :,} non-LoRA parameters")
if freed_memory >0 :
print (f" 🧹 Freed {freed_memory /(1024 **2 ):.1f}MB of gradient memory")
return frozen_params
def get_lora_plus_param_groups (
model :nn .Module ,
base_lr :float ,
lr_ratio :float =16.0
)->List [Dict ]:
"""
Get parameter groups for LoRA+ training.
LoRA+ uses different learning rates for A and B matrices:
- B matrix: base_lr * lr_ratio (learns faster)
- A matrix: base_lr
This improves convergence and final performance.
"""
lora_a_params =[]
lora_b_params =[]
magnitude_params =[]
other_params =[]
for name ,param in model .named_parameters ():
if not param .requires_grad :
continue
if 'lora_A'in name :
lora_a_params .append (param )
elif 'lora_B'in name :
lora_b_params .append (param )
elif 'magnitude'in name :
magnitude_params .append (param )
else :
other_params .append (param )
param_groups =[]
if lora_a_params :
param_groups .append ({'params':lora_a_params ,'lr':base_lr ,'name':'lora_A'})
if lora_b_params :
param_groups .append ({'params':lora_b_params ,'lr':base_lr *lr_ratio ,'name':'lora_B'})
if magnitude_params :
param_groups .append ({'params':magnitude_params ,'lr':base_lr ,'name':'magnitude'})
if other_params :
param_groups .append ({'params':other_params ,'lr':base_lr ,'name':'other'})
return param_groups
def get_trainable_parameters (model :nn .Module ,train_lora_only :bool =False )->List [nn .Parameter ]:
"""Get trainable parameters, optionally only LoRA params."""
if train_lora_only :
return get_lora_parameters (model )
else :
return [p for p in model .parameters ()if p .requires_grad ]
def count_lora_parameters (model :nn .Module )->Tuple [int ,int ,float ]:
"""
Count LoRA parameters vs total parameters.
Returns:
(lora_params, total_params, percentage)
"""
lora_params =0
total_params =0
for name ,param in model .named_parameters ():
total_params +=param .numel ()
if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name :
lora_params +=param .numel ()
percentage =100.0 *lora_params /total_params if total_params >0 else 0.0
return lora_params ,total_params ,percentage
==============================================================================
MODELS.COMPONENTS.ATTENTION
==============================================================================
logger =logging .getLogger (__name__ )
def flash_attention_available ()->bool :
"""Check if Flash Attention (via SDPA) is available."""
try :
from torch .nn .functional import scaled_dot_product_attention
return True
except ImportError :
return False
def compute_qk_scale (head_dim :int )->float :
"""Compute the Q/K pre-scaling factor for FP16 stability.
By scaling both Q and K by head_dim^-0.25, the product Q@K^T
is effectively scaled by head_dim^-0.5 (the standard attention scaling).
This prevents overflow in FP16 when Q and K have large values.
"""
return head_dim **-0.25
class AttentionKVCache :
"""Pre-allocated KV Cache — static buffer with index-based filling.
Eliminates VRAM fragmentation from torch.cat during autoregressive generation.
Buffer is allocated once at first use and reused via slice assignment.
"""
__slots__ =('key_cache','value_cache','seen_tokens','_max_len')
def __init__ (self ,max_seq_len :int =131072 ):
self .key_cache :torch .Tensor =None
self .value_cache :torch .Tensor =None
self .seen_tokens :int =0
self ._max_len =max_seq_len
def _allocate (self ,batch :int ,heads :int ,head_dim :int ,device :torch .device ,dtype :torch .dtype ):
"""Allocate static buffer on first use."""
self .key_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype )
self .value_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype )
def update (
self ,
key_states :torch .Tensor ,
value_states :torch .Tensor ,
)->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Update cache with new key/value states using index-based filling.
Args:
key_states: New key states [batch, num_heads, seq_len, head_dim]
value_states: New value states [batch, num_heads, seq_len, head_dim]
Returns:
Updated key and value states including cache (views, no copy)
"""
batch ,heads ,new_len ,head_dim =key_states .shape
if self .key_cache is None :
self ._allocate (batch ,heads ,head_dim ,key_states .device ,key_states .dtype )
self .seen_tokens =0
if self .seen_tokens +new_len >self .key_cache .shape [2 ]:
new_max =max (self .key_cache .shape [2 ]*2 ,self .seen_tokens +new_len )
new_key =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype )
new_val =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype )
new_key [:,:,:self .seen_tokens ]=self .key_cache [:,:,:self .seen_tokens ]
new_val [:,:,:self .seen_tokens ]=self .value_cache [:,:,:self .seen_tokens ]
self .key_cache =new_key
self .value_cache =new_val
self .key_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=key_states
self .value_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=value_states
self .seen_tokens +=new_len
return self .key_cache [:,:,:self .seen_tokens ],self .value_cache [:,:,:self .seen_tokens ]
def get_seq_length (self )->int :
"""Get current sequence length in cache."""
return self .seen_tokens
def reset (self ):
"""Reset cache position without deallocating the buffer."""
self .seen_tokens =0
class FlashAttention (nn .Module ):
"""
SOTA Flash Attention with KV cache support and FP16-safe Q/K pre-scaling.
Uses PyTorch's scaled_dot_product_attention when available,
with fallback to standard attention. Supports:
- KV caching for efficient generation
- Causal masking
- Attention dropout
- Pre-scaled Q/K for FP16 stability
"""
def __init__ (
self ,
dropout :float =0.0 ,
causal :bool =False ,
head_dim :int =None ,
):
super ().__init__ ()
self .dropout =dropout
self .causal =causal
self ._flash_available =flash_attention_available ()
self ._head_dim =head_dim
self ._qk_scale =compute_qk_scale (head_dim )if head_dim else None
def forward (
self ,
query :torch .Tensor ,
key :torch .Tensor ,
value :torch .Tensor ,
attn_mask :torch .Tensor =None ,
is_causal :bool =None ,
past_key_value :Tuple [torch .Tensor ,torch .Tensor ]=None ,
use_cache :bool =False ,
output_attentions :bool =False ,
)->Tuple [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ],torch .Tensor ]:
"""
Forward pass with KV cache support.
Args:
query: Query tensor [batch, num_heads, seq_len, head_dim]
key: Key tensor [batch, num_heads, seq_len, head_dim]
value: Value tensor [batch, num_heads, seq_len, head_dim]
attn_mask: Optional attention mask
is_causal: Override causal setting
past_key_value: Optional tuple of (past_key, past_value) for KV cache
use_cache: Whether to return updated KV cache
output_attentions: Whether to return attention weights
Returns:
Tuple of (output, present_key_value, attention_weights)
"""
causal =is_causal if is_causal is not None else self .causal
batch_size ,num_heads ,seq_len ,head_dim =query .shape
qk_scale =self ._qk_scale if self ._qk_scale else compute_qk_scale (head_dim )
if past_key_value is not None :
past_key ,past_value =past_key_value
key =torch .cat ([past_key ,key ],dim =2 )
value =torch .cat ([past_value ,value ],dim =2 )
present_key_value =(key ,value )if use_cache else None
kv_seq_len =key .shape [2 ]
attn_weights =None
if self ._flash_available and not output_attentions :
query_scaled =query *qk_scale
key_scaled =key *qk_scale
dropout_p =self .dropout if self .training else 0.0
use_causal =causal and attn_mask is None and seq_len >1 and seq_len ==kv_seq_len
output =F .scaled_dot_product_attention (
query_scaled ,key_scaled ,value ,
attn_mask =attn_mask ,
dropout_p =dropout_p ,
is_causal =use_causal ,
scale =1.0 ,
)
else :
scale =1.0 /math .sqrt (head_dim )
attn_weights =torch .matmul (query ,key .transpose (-2 ,-1 ))*scale
if causal and attn_mask is None and seq_len >1 :
causal_mask =torch .triu (
torch .full ((seq_len ,kv_seq_len ),float ('-inf'),device =query .device ,dtype =query .dtype ),
diagonal =kv_seq_len -seq_len +1
)
attn_weights =attn_weights +causal_mask .unsqueeze (0 ).unsqueeze (0 )
if attn_mask is not None :
attn_weights =attn_weights +attn_mask
attn_weights =F .softmax (attn_weights ,dim =-1 ,dtype =query .dtype )
if self .training and self .dropout >0 :
attn_weights =F .dropout (attn_weights ,p =self .dropout )
output =torch .matmul (attn_weights ,value )
return output ,present_key_value ,attn_weights
class MultimodalCrossAttention (nn .Module ):
"""
SOTA Cross-attention layer for multimodal fusion with KV cache support.
Allows text to attend to image/video/audio features with:
- KV caching for efficient generation
- Gated residual connection for stable training
- Flash Attention support with pre-scaled Q/K for FP16 stability
- Optional attention weight output
"""
def __init__ (
self ,
hidden_size :int ,
num_heads :int =8 ,
dropout :float =0.1 ,
use_flash_attention :bool =True ,
gate_init :float =0.0 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .head_dim =hidden_size //num_heads
self .use_flash_attention =use_flash_attention and flash_attention_available ()
self .dropout_p =dropout
self .qk_scale =compute_qk_scale (self .head_dim )
self .q_proj =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .k_proj =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .v_proj =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .o_proj =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .dropout =nn .Dropout (dropout )
self .layer_norm =nn .LayerNorm (hidden_size )
self .gate =nn .Parameter (torch .tensor (gate_init ))
def forward (
self ,
text_hidden :torch .Tensor ,
modality_hidden :torch .Tensor ,
modality_mask :torch .Tensor =None ,
past_key_value :Tuple [torch .Tensor ,torch .Tensor ]=None ,
use_cache :bool =False ,
output_attentions :bool =False ,
)->Tuple [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ],torch .Tensor ]:
"""
Cross-attention: text attends to modality features with KV cache support.
Args:
text_hidden: Text hidden states [batch, text_len, hidden_size]
modality_hidden: Modality features [batch, modality_len, hidden_size]
modality_mask: Optional attention mask for modality
past_key_value: Optional cached (key, value) for this layer
use_cache: Whether to return updated KV cache
output_attentions: Whether to return attention weights
Returns:
Tuple of (output, present_key_value, attention_weights)
"""
batch_size ,text_len ,_ =text_hidden .shape
query =self .q_proj (text_hidden )
query =query .view (batch_size ,text_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 )
if past_key_value is not None :
key ,value =past_key_value
else :
modality_len =modality_hidden .shape [1 ]
key =self .k_proj (modality_hidden )
value =self .v_proj (modality_hidden )
key =key .view (batch_size ,modality_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 )
value =value .view (batch_size ,modality_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 )
present_key_value =(key ,value )if use_cache else None
attn_weights =None
if self .use_flash_attention and not output_attentions :
query_scaled =query *self .qk_scale
key_scaled =key *self .qk_scale
dropout_p =self .dropout_p if self .training else 0.0
attn_output =F .scaled_dot_product_attention (
query_scaled ,key_scaled ,value ,
attn_mask =modality_mask ,
dropout_p =dropout_p ,
is_causal =False ,
scale =1.0 ,
)
else :
scale =1.0 /math .sqrt (self .head_dim )
attn_weights =torch .matmul (query ,key .transpose (-2 ,-1 ))*scale
if modality_mask is not None :
attn_weights =attn_weights +modality_mask
attn_weights =F .softmax (attn_weights ,dim =-1 ,dtype =text_hidden .dtype )
if self .training and self .dropout_p >0 :
attn_weights =F .dropout (attn_weights ,p =self .dropout_p )
attn_output =torch .matmul (attn_weights ,value )
attn_output =attn_output .transpose (1 ,2 ).contiguous ().view (batch_size ,text_len ,self .hidden_size )
attn_output =self .o_proj (attn_output )
gate =torch .sigmoid (self .gate )
output =text_hidden +gate *self .dropout (attn_output )
output =self .layer_norm (output )
return output ,present_key_value ,attn_weights
@dataclass
class MultimodalFusionCache :
"""Cache for multimodal fusion layer KV states."""
image_kv :Tuple [torch .Tensor ,torch .Tensor ]=None
video_kv :Tuple [torch .Tensor ,torch .Tensor ]=None
audio_kv :Tuple [torch .Tensor ,torch .Tensor ]=None
class MultimodalFusionLayer (nn .Module ):
"""
SOTA Multimodal fusion layer with cross-attention for all modalities and KV cache support.
Features:
- Separate cross-attention for each modality (image, video, audio)
- KV caching for efficient generation
- Gated fusion MLP
- Flash Attention support
"""
def __init__ (
self ,
hidden_size :int ,
num_heads :int =8 ,
dropout :float =0.1 ,
use_flash_attention :bool =True ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .image_cross_attn =MultimodalCrossAttention (
hidden_size ,num_heads ,dropout ,use_flash_attention
)
self .video_cross_attn =MultimodalCrossAttention (
hidden_size ,num_heads ,dropout ,use_flash_attention
)
self .audio_cross_attn =MultimodalCrossAttention (
hidden_size ,num_heads ,dropout ,use_flash_attention
)
self .fusion_mlp =nn .Sequential (
nn .Linear (hidden_size ,hidden_size *4 ),
nn .GELU (),
nn .Dropout (dropout ),
nn .Linear (hidden_size *4 ,hidden_size ),
nn .Dropout (dropout ),
)
self .fusion_norm =nn .LayerNorm (hidden_size )
def forward (
self ,
text_hidden :torch .Tensor ,
image_hidden :torch .Tensor =None ,
video_hidden :torch .Tensor =None ,
audio_hidden :torch .Tensor =None ,
image_mask :torch .Tensor =None ,
video_mask :torch .Tensor =None ,
audio_mask :torch .Tensor =None ,
past_key_values :MultimodalFusionCache =None ,
use_cache :bool =False ,
)->Tuple [torch .Tensor ,MultimodalFusionCache ]:
"""
Fuse text with available modalities via cross-attention with KV cache support.
Args:
text_hidden: Text hidden states [batch, text_len, hidden_size]
image_hidden: Image features [batch, image_len, hidden_size]
video_hidden: Video features [batch, video_len, hidden_size]
audio_hidden: Audio features [batch, audio_len, hidden_size]
image_mask: Attention mask for image
video_mask: Attention mask for video
audio_mask: Attention mask for audio
past_key_values: Cached KV states from previous forward pass
use_cache: Whether to return updated KV cache
Returns:
Tuple of (output, present_key_values)
"""
present_key_values =MultimodalFusionCache ()if use_cache else None
past_image_kv =past_key_values .image_kv if past_key_values else None
past_video_kv =past_key_values .video_kv if past_key_values else None
past_audio_kv =past_key_values .audio_kv if past_key_values else None
if self ._has_content (image_hidden )or past_image_kv is not None :
try :
text_hidden ,image_kv ,_ =self .image_cross_attn (
text_hidden ,
image_hidden if image_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ),
image_mask ,
past_key_value =past_image_kv ,
use_cache =use_cache ,
)
if use_cache :
present_key_values .image_kv =image_kv
except Exception as e :
logger .debug (f"Image cross-attention skipped: {e }")
if self ._has_content (video_hidden )or past_video_kv is not None :
try :
text_hidden ,video_kv ,_ =self .video_cross_attn (
text_hidden ,
video_hidden if video_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ),
video_mask ,
past_key_value =past_video_kv ,
use_cache =use_cache ,
)
if use_cache :
present_key_values .video_kv =video_kv
except Exception as e :
logger .debug (f"Video cross-attention skipped: {e }")
if self ._has_content (audio_hidden )or past_audio_kv is not None :
try :
text_hidden ,audio_kv ,_ =self .audio_cross_attn (
text_hidden ,
audio_hidden if audio_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ),
audio_mask ,
past_key_value =past_audio_kv ,
use_cache =use_cache ,
)
if use_cache :
present_key_values .audio_kv =audio_kv
except Exception as e :
logger .debug (f"Audio cross-attention skipped: {e }")
residual =text_hidden
text_hidden =self .fusion_mlp (text_hidden )
text_hidden =self .fusion_norm (residual +text_hidden )
return text_hidden ,present_key_values
@staticmethod
def _has_content (tensor :torch .Tensor )->bool :
"""Check if tensor has meaningful content."""
if tensor is None :
return False
if not isinstance (tensor ,torch .Tensor ):
return False
try :
if tensor .numel ()==0 :
return False
return bool (tensor .any ())
except Exception :
return False
==============================================================================
MODELS.COMPONENTS.PROJECTORS
==============================================================================
def compute_2d_rope (height :int ,width :int ,dim :int ,device :torch .device ,dtype :torch .dtype ,base :float =10000.0 )->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Compute 2D Rotary Position Embeddings for spatial awareness.
Args:
height: Image height in patches
width: Image width in patches
dim: Embedding dimension (must be divisible by 4)
device: Target device
dtype: Target dtype
base: RoPE base frequency
Returns:
cos, sin: [height*width, dim] position embeddings
"""
assert dim %4 ==0 ,"dim must be divisible by 4 for 2D RoPE"
half_dim =dim //2
quarter_dim =dim //4
inv_freq =1.0 /(base **(torch .arange (0 ,quarter_dim ,device =device ,dtype =torch .float32 )/quarter_dim ))
y_pos =torch .arange (height ,device =device ,dtype =torch .float32 )
x_pos =torch .arange (width ,device =device ,dtype =torch .float32 )
y_emb =torch .outer (y_pos ,inv_freq )
x_emb =torch .outer (x_pos ,inv_freq )
y_emb =y_emb .unsqueeze (1 ).expand (-1 ,width ,-1 )
x_emb =x_emb .unsqueeze (0 ).expand (height ,-1 ,-1 )
emb =torch .cat ([y_emb ,y_emb ,x_emb ,x_emb ],dim =-1 )
emb =emb .reshape (height *width ,dim )
return emb .cos ().to (dtype ),emb .sin ().to (dtype )
def compute_3d_rope (
depth :int ,height :int ,width :int ,dim :int ,
device :torch .device ,dtype :torch .dtype ,base :float =10000.0
)->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Compute 3D Rotary Position Embeddings for video/temporal awareness.
Args:
depth: Temporal depth (number of frames)
height: Image height in patches
width: Image width in patches
dim: Embedding dimension (must be divisible by 6)
device: Target device
dtype: Target dtype
base: RoPE base frequency
Returns:
cos, sin: [depth*height*width, dim] position embeddings
"""
assert dim %6 ==0 ,"dim must be divisible by 6 for 3D RoPE"
sixth_dim =dim //6
inv_freq =1.0 /(base **(torch .arange (0 ,sixth_dim ,device =device ,dtype =torch .float32 )/sixth_dim ))
t_pos =torch .arange (depth ,device =device ,dtype =torch .float32 )
y_pos =torch .arange (height ,device =device ,dtype =torch .float32 )
x_pos =torch .arange (width ,device =device ,dtype =torch .float32 )
t_emb =torch .outer (t_pos ,inv_freq )
y_emb =torch .outer (y_pos ,inv_freq )
x_emb =torch .outer (x_pos ,inv_freq )
t_emb =t_emb .unsqueeze (1 ).unsqueeze (2 ).expand (-1 ,height ,width ,-1 )
y_emb =y_emb .unsqueeze (0 ).unsqueeze (2 ).expand (depth ,-1 ,width ,-1 )
x_emb =x_emb .unsqueeze (0 ).unsqueeze (1 ).expand (depth ,height ,-1 ,-1 )
emb =torch .cat ([t_emb ,t_emb ,y_emb ,y_emb ,x_emb ,x_emb ],dim =-1 )
emb =emb .reshape (depth *height *width ,dim )
return emb .cos ().to (dtype ),emb .sin ().to (dtype )
def apply_rope (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor :
"""Apply rotary position embeddings."""
x1 =x [...,:x .shape [-1 ]//2 ]
x2 =x [...,x .shape [-1 ]//2 :]
rotated =torch .cat ((-x2 ,x1 ),dim =-1 )
return x *cos +rotated *sin
class ResidualBottleneckBlock (nn .Module ):
"""
Residual Bottleneck Block for locality-enhanced feature extraction.
Preserves small-scale features (OCR, fine audio events) during compression.
"""
def __init__ (self ,in_channels :int ,out_channels :int ,bottleneck_ratio :float =0.25 ):
super ().__init__ ()
bottleneck_channels =int (out_channels *bottleneck_ratio )
self .conv1 =nn .Conv2d (in_channels ,bottleneck_channels ,1 ,bias =False )
self .bn1 =nn .BatchNorm2d (bottleneck_channels )
self .conv2 =nn .Conv2d (bottleneck_channels ,bottleneck_channels ,3 ,padding =1 ,bias =False )
self .bn2 =nn .BatchNorm2d (bottleneck_channels )
self .conv3 =nn .Conv2d (bottleneck_channels ,out_channels ,1 ,bias =False )
self .bn3 =nn .BatchNorm2d (out_channels )
self .shortcut =nn .Identity ()if in_channels ==out_channels else nn .Sequential (
nn .Conv2d (in_channels ,out_channels ,1 ,bias =False ),
nn .BatchNorm2d (out_channels ),
)
self .relu =nn .ReLU (inplace =True )
def forward (self ,x :torch .Tensor )->torch .Tensor :
identity =self .shortcut (x )
out =self .relu (self .bn1 (self .conv1 (x )))
out =self .relu (self .bn2 (self .conv2 (out )))
out =self .bn3 (self .conv3 (out ))
out =out +identity
out =self .relu (out )
return out
class LocalityEnhancedResNetAbstractor (nn .Module ):
"""
Locality-Enhanced ResNet Abstractor.
Upgrades the C-Abstractor with residual bottleneck blocks to preserve
small-scale features (OCR/fine audio events) during compression.
"""
def __init__ (
self ,
input_dim :int ,
output_dim :int ,
num_tokens :int =64 ,
num_blocks :int =3 ,
use_2d_rope :bool =True ,
):
super ().__init__ ()
self .num_tokens =num_tokens
self .use_2d_rope =use_2d_rope
self .input_proj =nn .Linear (input_dim ,output_dim )
self .blocks =nn .ModuleList ([
ResidualBottleneckBlock (output_dim ,output_dim )
for _ in range (num_blocks )
])
self .queries =nn .Parameter (torch .randn (1 ,num_tokens ,output_dim )*0.02 )
self .cross_attn =nn .MultiheadAttention (
embed_dim =output_dim ,
num_heads =8 ,
batch_first =True ,
dropout =0.1 ,
)
self .ff =nn .Sequential (
nn .LayerNorm (output_dim ),
nn .Linear (output_dim ,output_dim *4 ),
nn .GELU (),
nn .Linear (output_dim *4 ,output_dim ),
)
self .norm =nn .LayerNorm (output_dim )
print (f" 🏗️ LocalityEnhancedResNetAbstractor: {input_dim } -> {output_dim }, {num_tokens } tokens")
def forward (self ,features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor :
"""
Args:
features: [B, seq_len, input_dim] or [B, H, W, input_dim]
spatial_size: (H, W) if features are flattened
Returns:
abstracted: [B, num_tokens, output_dim]
"""
batch_size =features .shape [0 ]
x =self .input_proj (features )
if features .dim ()==3 :
seq_len =features .shape [1 ]
if spatial_size is None :
h =w =int (math .sqrt (seq_len ))
else :
h ,w =spatial_size
x =x .view (batch_size ,h ,w ,-1 )
else :
h ,w =features .shape [1 ],features .shape [2 ]
x =x .permute (0 ,3 ,1 ,2 )
for block in self .blocks :
x =block (x )
x =x .permute (0 ,2 ,3 ,1 )
x =x .reshape (batch_size ,h *w ,-1 )
if self .use_2d_rope :
cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype )
x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 ))
queries =self .queries .expand (batch_size ,-1 ,-1 )
abstracted ,_ =self .cross_attn (queries ,x ,x )
abstracted =abstracted +self .ff (abstracted )
return self .norm (abstracted )
class MultiScaleFeatureFusion (nn .Module ):
"""
Multi-Scale Feature Fusion (MSFF).
Extracts and weights features from multiple encoder depths (early, mid, late)
to capture both low-level textures and high-level semantics.
"""
def __init__ (
self ,
feature_dims :List [int ],
output_dim :int ,
num_scales :int =3 ,
):
super ().__init__ ()
self .num_scales =num_scales
self .scale_projs =nn .ModuleList ([
nn .Linear (dim ,output_dim )for dim in feature_dims
])
self .scale_weights =nn .Parameter (torch .ones (num_scales )/num_scales )
self .fusion =nn .Sequential (
nn .Linear (output_dim ,output_dim *2 ),
nn .GELU (),
nn .Linear (output_dim *2 ,output_dim ),
)
self .norm =nn .LayerNorm (output_dim )
print (f" 🔀 MultiScaleFeatureFusion: {feature_dims } -> {output_dim }")
def forward (self ,multi_scale_features :List [torch .Tensor ])->torch .Tensor :
"""
Args:
multi_scale_features: List of [B, seq_len, dim] features from different depths
Returns:
fused: [B, seq_len, output_dim]
"""
assert len (multi_scale_features )==self .num_scales
projected =[]
for i ,(features ,proj )in enumerate (zip (multi_scale_features ,self .scale_projs )):
projected .append (proj (features ))
weights =F .softmax (self .scale_weights ,dim =0 )
fused =sum (w *p for w ,p in zip (weights ,projected ))
fused =fused +self .fusion (fused )
return self .norm (fused )
class MultiScaleDeformableAttention (nn .Module ):
"""
Multi-Scale Deformable Attention.
Replaces fixed-grid cross-attention in Perceiver Resamplers,
allowing the projector to "look" at non-uniform regions of interest.
"""
def __init__ (
self ,
dim :int ,
num_heads :int =8 ,
num_levels :int =4 ,
num_points :int =4 ,
dropout :float =0.1 ,
):
super ().__init__ ()
self .dim =dim
self .num_heads =num_heads
self .num_levels =num_levels
self .num_points =num_points
self .head_dim =dim //num_heads
self .sampling_offsets =nn .Linear (dim ,num_heads *num_levels *num_points *2 )
self .attention_weights =nn .Linear (dim ,num_heads *num_levels *num_points )
self .value_proj =nn .Linear (dim ,dim )
self .output_proj =nn .Linear (dim ,dim )
self .dropout =nn .Dropout (dropout )
self ._reset_parameters ()
print (f" 🎯 MultiScaleDeformableAttention: {dim }d, {num_heads }H, {num_levels }L, {num_points }P")
def _reset_parameters (self ):
nn .init .constant_ (self .sampling_offsets .weight ,0.0 )
nn .init .constant_ (self .sampling_offsets .bias ,0.0 )
nn .init .xavier_uniform_ (self .attention_weights .weight )
nn .init .constant_ (self .attention_weights .bias ,0.0 )
nn .init .xavier_uniform_ (self .value_proj .weight )
nn .init .xavier_uniform_ (self .output_proj .weight )
def forward (
self ,
query :torch .Tensor ,
reference_points :torch .Tensor ,
input_flatten :torch .Tensor ,
input_spatial_shapes :torch .Tensor ,
)->torch .Tensor :
"""
Args:
query: [B, num_queries, dim]
reference_points: [B, num_queries, num_levels, 2] normalized reference points
input_flatten: [B, sum(H*W), dim] flattened multi-scale features
input_spatial_shapes: [num_levels, 2] spatial shapes of each level
Returns:
output: [B, num_queries, dim]
"""
batch_size ,num_queries ,_ =query .shape
offsets =self .sampling_offsets (query )
offsets =offsets .view (batch_size ,num_queries ,self .num_heads ,self .num_levels ,self .num_points ,2 )
attn_weights =self .attention_weights (query )
attn_weights =attn_weights .view (batch_size ,num_queries ,self .num_heads ,self .num_levels *self .num_points )
attn_weights =F .softmax (attn_weights ,dim =-1 )
attn_weights =attn_weights .view (batch_size ,num_queries ,self .num_heads ,self .num_levels ,self .num_points )
sampling_locations =reference_points .unsqueeze (2 ).unsqueeze (4 )+offsets *0.1
sampling_locations =sampling_locations .clamp (0 ,1 )
value =self .value_proj (input_flatten )
value =value .view (batch_size ,-1 ,self .num_heads ,self .head_dim )
output =torch .zeros (batch_size ,num_queries ,self .num_heads ,self .head_dim ,device =query .device ,dtype =query .dtype )
start_idx =0
for level_idx in range (self .num_levels ):
h ,w =input_spatial_shapes [level_idx ]
end_idx =start_idx +h *w
level_value =value [:,start_idx :end_idx ]
level_value =level_value .view (batch_size ,h ,w ,self .num_heads ,self .head_dim )
level_locs =sampling_locations [:,:,:,level_idx ]
level_weights =attn_weights [:,:,:,level_idx ]
for point_idx in range (self .num_points ):
loc =level_locs [:,:,:,point_idx ]
weight =level_weights [:,:,:,point_idx :point_idx +1 ]
y_idx =(loc [...,0 ]*(h -1 )).long ().clamp (0 ,h -1 )
x_idx =(loc [...,1 ]*(w -1 )).long ().clamp (0 ,w -1 )
for b in range (batch_size ):
for q in range (num_queries ):
for head in range (self .num_heads ):
y ,x =y_idx [b ,q ,head ].item (),x_idx [b ,q ,head ].item ()
output [b ,q ,head ]+=weight [b ,q ,head ]*level_value [b ,y ,x ,head ]
start_idx =end_idx
output =output .view (batch_size ,num_queries ,self .dim )
output =self .output_proj (output )
output =self .dropout (output )
return output
class DynamicTokenRouter (nn .Module ):
"""
Dynamic Token Router.
Implements a sparse gating mechanism to drop redundant "background" tokens,
drastically reducing KV-cache pressure for Ring Attention.
"""
def __init__ (
self ,
dim :int ,
num_tokens :int ,
keep_ratio :float =0.5 ,
temperature :float =1.0 ,
):
super ().__init__ ()
self .dim =dim
self .num_tokens =num_tokens
self .keep_ratio =keep_ratio
self .temperature =temperature
self .scorer =nn .Sequential (
nn .Linear (dim ,dim //2 ),
nn .GELU (),
nn .Linear (dim //2 ,1 ),
)
self .threshold =nn .Parameter (torch .tensor (0.0 ))
print (f" 🚦 DynamicTokenRouter: keep_ratio={keep_ratio }")
def forward (self ,tokens :torch .Tensor ,return_mask :bool =False )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]:
"""
Args:
tokens: [B, num_tokens, dim]
return_mask: Whether to return the selection mask
Returns:
selected_tokens: [B, num_kept, dim]
mask: [B, num_tokens] selection mask (if return_mask=True)
"""
batch_size ,num_tokens ,_ =tokens .shape
num_keep =max (1 ,int (num_tokens *self .keep_ratio ))
scores =self .scorer (tokens ).squeeze (-1 )
scores =scores /self .temperature
_ ,indices =torch .topk (scores ,num_keep ,dim =-1 )
indices =indices .sort (dim =-1 ).values
indices_expanded =indices .unsqueeze (-1 ).expand (-1 ,-1 ,self .dim )
selected_tokens =torch .gather (tokens ,1 ,indices_expanded )
if return_mask :
mask =torch .zeros (batch_size ,num_tokens ,device =tokens .device ,dtype =torch .bool )
mask .scatter_ (1 ,indices ,True )
return selected_tokens ,mask
return selected_tokens ,None
class PerceiverAttention (nn .Module ):
"""
Perceiver-style cross-attention for resampling with 2D/3D RoPE support.
"""
def __init__ (
self ,
dim :int ,
num_heads :int =8 ,
dim_head :int =64 ,
dropout :float =0.0 ,
use_rope :bool =True ,
):
super ().__init__ ()
inner_dim =dim_head *num_heads
self .num_heads =num_heads
self .dim_head =dim_head
self .inner_dim =inner_dim
self .scale =dim_head **-0.5
self .use_rope =use_rope
self .norm_latents =nn .LayerNorm (dim )
self .norm_context =nn .LayerNorm (dim )
self .to_q =nn .Linear (dim ,inner_dim ,bias =False )
self .to_kv =nn .Linear (dim ,inner_dim *2 ,bias =False )
self .to_out =nn .Sequential (
nn .Linear (inner_dim ,dim ),
nn .Dropout (dropout )
)
def forward (
self ,
latents :torch .Tensor ,
context :torch .Tensor ,
context_rope :Optional [Tuple [torch .Tensor ,torch .Tensor ]]=None ,
)->torch .Tensor :
"""
latents: [B, num_latents, dim] - learnable queries
context: [B, seq_len, dim] - input features to attend to
context_rope: Optional (cos, sin) for context positions
"""
latents =self .norm_latents (latents )
context =self .norm_context (context )
b ,n ,_ =latents .shape
ctx_len =context .shape [1 ]
h =self .num_heads
d =self .dim_head
q =self .to_q (latents )
kv =self .to_kv (context ).chunk (2 ,dim =-1 )
k ,v =kv
q =q .reshape (b ,n ,h ,d ).transpose (1 ,2 )
k =k .reshape (b ,ctx_len ,h ,d ).transpose (1 ,2 )
v =v .reshape (b ,ctx_len ,h ,d ).transpose (1 ,2 )
if self .use_rope and context_rope is not None :
cos ,sin =context_rope
cos =cos .unsqueeze (0 ).unsqueeze (0 )
sin =sin .unsqueeze (0 ).unsqueeze (0 )
k =apply_rope (k ,cos ,sin )
qk_scale =d **-0.25
out =F .scaled_dot_product_attention (
q *qk_scale ,k *qk_scale ,v ,
is_causal =False ,scale =1.0 ,
)
out =out .transpose (1 ,2 ).reshape (b ,n ,self .inner_dim )
return self .to_out (out )
class PerceiverResampler (nn .Module ):
"""
Perceiver Resampler with 2D/3D RoPE and Dynamic Token Routing.
"""
def __init__ (
self ,
input_dim :int ,
output_dim :int ,
num_latents :int =64 ,
num_heads :int =8 ,
num_layers :int =2 ,
dropout :float =0.0 ,
use_rope :bool =True ,
use_dynamic_routing :bool =False ,
routing_keep_ratio :float =0.5 ,
):
super ().__init__ ()
self .num_latents =num_latents
self .use_rope =use_rope
self .use_dynamic_routing =use_dynamic_routing
self .input_proj =nn .Linear (input_dim ,output_dim )if input_dim !=output_dim else nn .Identity ()
self .latents =nn .Parameter (torch .randn (1 ,num_latents ,output_dim )*0.02 )
self .layers =nn .ModuleList ([
nn .ModuleList ([
PerceiverAttention (output_dim ,num_heads ,output_dim //num_heads ,dropout ,use_rope ),
nn .Sequential (
nn .LayerNorm (output_dim ),
nn .Linear (output_dim ,output_dim *4 ),
nn .GELU (),
nn .Dropout (dropout ),
nn .Linear (output_dim *4 ,output_dim ),
nn .Dropout (dropout ),
)
])
for _ in range (num_layers )
])
if use_dynamic_routing :
self .token_router =DynamicTokenRouter (output_dim ,num_latents ,routing_keep_ratio )
else :
self .token_router =None
self .norm_out =nn .LayerNorm (output_dim )
def forward (
self ,
x :torch .Tensor ,
spatial_size :Optional [Tuple [int ,int ]]=None ,
temporal_size :Optional [int ]=None ,
)->torch .Tensor :
"""
x: [B, seq_len, input_dim] - input features
spatial_size: (H, W) for 2D RoPE
temporal_size: T for 3D RoPE (video)
returns: [B, num_latents, output_dim] - compressed features
"""
batch_size =x .shape [0 ]
x =self .input_proj (x )
context_rope =None
if self .use_rope and spatial_size is not None :
h ,w =spatial_size
if temporal_size is not None :
cos ,sin =compute_3d_rope (temporal_size ,h ,w ,x .shape [-1 ],x .device ,x .dtype )
else :
cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype )
context_rope =(cos ,sin )
latents =self .latents .expand (batch_size ,-1 ,-1 )
for attn ,ff in self .layers :
latents =latents +attn (latents ,x ,context_rope )
latents =latents +ff (latents )
latents =self .norm_out (latents )
if self .token_router is not None :
latents ,_ =self .token_router (latents )
return latents
class SpatialAwareProjector (nn .Module ):
"""
Spatial-aware projector with 2D RoPE.
"""
def __init__ (
self ,
vision_hidden_size :int ,
llm_hidden_size :int ,
num_tokens :int =64 ,
spatial_pool_size :int =8 ,
use_rope :bool =True ,
):
super ().__init__ ()
self .num_tokens =num_tokens
self .spatial_pool_size =spatial_pool_size
self .use_rope =use_rope
self .spatial_conv =nn .Sequential (
nn .Conv2d (vision_hidden_size ,llm_hidden_size ,3 ,padding =1 ),
nn .GELU (),
nn .Conv2d (llm_hidden_size ,llm_hidden_size ,3 ,padding =1 ),
nn .GELU (),
)
self .adaptive_pool =nn .AdaptiveAvgPool2d ((spatial_pool_size ,spatial_pool_size ))
self .proj =nn .Sequential (
nn .Linear (llm_hidden_size ,llm_hidden_size ),
nn .GELU (),
nn .Linear (llm_hidden_size ,llm_hidden_size ),
)
self .norm =nn .LayerNorm (llm_hidden_size )
def forward (self ,vision_features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor :
batch_size =vision_features .shape [0 ]
if vision_features .dim ()==3 :
seq_len =vision_features .shape [1 ]
if spatial_size is None :
h =w =int (math .sqrt (seq_len ))
else :
h ,w =spatial_size
vision_features =vision_features .view (batch_size ,h ,w ,-1 )
x =vision_features .permute (0 ,3 ,1 ,2 )
x =self .spatial_conv (x )
x =self .adaptive_pool (x )
x =x .flatten (2 ).transpose (1 ,2 )
if self .use_rope :
cos ,sin =compute_2d_rope (self .spatial_pool_size ,self .spatial_pool_size ,x .shape [-1 ],x .device ,x .dtype )
x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 ))
x =self .proj (x )
x =self .norm (x )
return x
class CAbstractor (nn .Module ):
"""
C-Abstractor: Compressed Abstraction for efficient multimodal fusion.
Now with 2D RoPE support.
"""
def __init__ (
self ,
vision_hidden_size :int ,
llm_hidden_size :int ,
num_tokens :int =64 ,
num_heads :int =8 ,
compression_ratio :int =4 ,
use_rope :bool =True ,
):
super ().__init__ ()
self .num_tokens =num_tokens
self .use_rope =use_rope
self .input_proj =nn .Linear (vision_hidden_size ,llm_hidden_size )
self .compress =nn .Sequential (
nn .Conv1d (llm_hidden_size ,llm_hidden_size ,kernel_size =compression_ratio ,stride =compression_ratio ),
nn .GELU (),
)
self .queries =nn .Parameter (torch .randn (1 ,num_tokens ,llm_hidden_size )*0.02 )
self .cross_attn =nn .MultiheadAttention (
embed_dim =llm_hidden_size ,
num_heads =num_heads ,
batch_first =True ,
dropout =0.1 ,
)
self .ff =nn .Sequential (
nn .LayerNorm (llm_hidden_size ),
nn .Linear (llm_hidden_size ,llm_hidden_size *4 ),
nn .GELU (),
nn .Linear (llm_hidden_size *4 ,llm_hidden_size ),
)
self .norm =nn .LayerNorm (llm_hidden_size )
def forward (self ,vision_features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor :
batch_size =vision_features .shape [0 ]
x =self .input_proj (vision_features )
if self .use_rope and spatial_size is not None :
h ,w =spatial_size
cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype )
x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 ))
x =x .transpose (1 ,2 )
x =self .compress (x )
x =x .transpose (1 ,2 )
queries =self .queries .expand (batch_size ,-1 ,-1 )
abstracted ,_ =self .cross_attn (queries ,x ,x )
abstracted =abstracted +self .ff (abstracted )
return self .norm (abstracted )
class MultimodalProjector (nn .Module ):
"""
SOTA Multimodal Projector with all advanced features.
Combines:
- Locality-Enhanced ResNet Abstractor
- Multi-Scale Feature Fusion
- Multi-Scale Deformable Attention
- Dynamic Token Router
- 2D/3D RoPE
- Perceiver Resampler
"""
def __init__ (
self ,
vision_hidden_size :int ,
llm_hidden_size :int ,
num_tokens :int =64 ,
projector_type :str ="perceiver",
num_heads :int =8 ,
num_layers :int =2 ,
use_rope :bool =True ,
use_dynamic_routing :bool =False ,
use_locality_enhanced :bool =False ,
use_msff :bool =False ,
use_deformable_attn :bool =False ,
):
super ().__init__ ()
self .num_tokens =num_tokens
self .projector_type =projector_type
self .use_rope =use_rope
if projector_type =="perceiver":
self .projector =PerceiverResampler (
input_dim =vision_hidden_size ,
output_dim =llm_hidden_size ,
num_latents =num_tokens ,
num_heads =num_heads ,
num_layers =num_layers ,
use_rope =use_rope ,
use_dynamic_routing =use_dynamic_routing ,
)
elif projector_type =="spatial":
self .projector =SpatialAwareProjector (
vision_hidden_size =vision_hidden_size ,
llm_hidden_size =llm_hidden_size ,
num_tokens =num_tokens ,
use_rope =use_rope ,
)
elif projector_type =="c_abstractor":
self .projector =CAbstractor (
vision_hidden_size =vision_hidden_size ,
llm_hidden_size =llm_hidden_size ,
num_tokens =num_tokens ,
num_heads =num_heads ,
use_rope =use_rope ,
)
elif projector_type =="locality_enhanced":
self .projector =LocalityEnhancedResNetAbstractor (
input_dim =vision_hidden_size ,
output_dim =llm_hidden_size ,
num_tokens =num_tokens ,
use_2d_rope =use_rope ,
)
else :
self .projector =nn .Sequential (
nn .Linear (vision_hidden_size ,llm_hidden_size ),
nn .GELU (),
nn .Linear (llm_hidden_size ,llm_hidden_size ),
)
self .query_tokens =nn .Parameter (torch .randn (1 ,num_tokens ,llm_hidden_size )*0.02 )
self .cross_attn =nn .MultiheadAttention (
embed_dim =llm_hidden_size ,
num_heads =num_heads ,
batch_first =True
)
self .norm =nn .LayerNorm (llm_hidden_size )
if use_msff :
self .msff =MultiScaleFeatureFusion (
feature_dims =[vision_hidden_size ]*3 ,
output_dim =vision_hidden_size ,
)
else :
self .msff =None
if use_deformable_attn :
self .deformable_attn =MultiScaleDeformableAttention (
dim =llm_hidden_size ,
num_heads =num_heads ,
)
else :
self .deformable_attn =None
if use_dynamic_routing and projector_type !="perceiver":
self .token_router =DynamicTokenRouter (llm_hidden_size ,num_tokens )
else :
self .token_router =None
def forward (
self ,
vision_features :torch .Tensor ,
multi_scale_features :Optional [List [torch .Tensor ]]=None ,
spatial_size :Optional [Tuple [int ,int ]]=None ,
temporal_size :Optional [int ]=None ,
)->torch .Tensor :
"""Project and resample vision features."""
if self .msff is not None and multi_scale_features is not None :
vision_features =self .msff (multi_scale_features )
if self .projector_type in ["perceiver"]:
output =self .projector (vision_features ,spatial_size ,temporal_size )
elif self .projector_type in ["spatial","c_abstractor","locality_enhanced"]:
output =self .projector (vision_features ,spatial_size )
else :
batch_size =vision_features .shape [0 ]
projected =self .projector (vision_features )
queries =self .query_tokens .expand (batch_size ,-1 ,-1 )
resampled ,_ =self .cross_attn (queries ,projected ,projected )
output =self .norm (resampled )
if self .token_router is not None :
output ,_ =self .token_router (output )
return output
==============================================================================
MODELS.COMPONENTS.MOE
==============================================================================
EPS =1e-5
class ExpertUtilizationTracker :
"""
Tracks expert utilization across MoE layers.
Attach to any MoE layer to log per-expert usage histograms.
Every `report_interval` steps, prints a report showing:
- Frequency of use per expert
- Cold experts (used < 1% of tokens)
- Count of experts offloaded to CPU (if ExpertOffloadManager is available)
Usage:
tracker = ExpertUtilizationTracker(num_experts=8, layer_name="layer.3.moe")
"""
def __init__ (
self ,
num_experts :int ,
layer_name :str ="moe",
report_interval :int =100 ,
cold_threshold_pct :float =1.0 ,
):
self .num_experts =num_experts
self .layer_name =layer_name
self .report_interval =report_interval
self .cold_threshold_pct =cold_threshold_pct
self ._counts =torch .zeros (num_experts ,dtype =torch .long )
self ._total_tokens =0
self ._step =0
self ._offload_manager =None
def link_offload_manager (self ,manager ):
"""Link an ExpertOffloadManager for cold-expert reporting."""
self ._offload_manager =manager
def record (self ,expert_indices :torch .Tensor ):
"""
Record expert selections from a forward pass.
Args:
expert_indices: [num_tokens, top_k] tensor of selected expert indices
"""
indices_flat =expert_indices .detach ().cpu ().reshape (-1 )
for idx in range (self .num_experts ):
self ._counts [idx ]+=(indices_flat ==idx ).sum ().item ()
self ._total_tokens +=expert_indices .shape [0 ]
def step (self ):
"""Advance step counter. Prints report and resets when interval is hit."""
self ._step +=1
if self ._step %self .report_interval ==0 :
self ._print_report ()
self ._reset ()
def _reset (self ):
"""Reset accumulators for next interval."""
self ._counts .zero_ ()
self ._total_tokens =0
def _print_report (self ):
"""Print expert utilization histogram."""
if self ._total_tokens ==0 :
return
freqs =self ._counts .float ()
total_assignments =freqs .sum ().item ()
if total_assignments ==0 :
return
pcts =(freqs /total_assignments *100 ).tolist ()
cold_experts =[i for i ,p in enumerate (pcts )if p <self .cold_threshold_pct ]
max_pct =max (pcts )if pcts else 0
bar_max =30
lines =[f"\n{'='*60 }"]
lines .append (f" Expert Utilization — {self .layer_name } (step {self ._step })")
lines .append (f" {self ._total_tokens :,} tokens, {int (total_assignments ):,} assignments")
lines .append (f"{'─'*60 }")
for i ,pct in enumerate (pcts ):
bar_len =int (pct /max_pct *bar_max )if max_pct >0 else 0
bar ="█"*bar_len
cold_tag =" ❄️"if pct <self .cold_threshold_pct else ""
lines .append (f" Expert {i :2d}{bar :<{bar_max }}│ {pct :5.1f}% ({int (self ._counts [i ]):>6d}){cold_tag }")
lines .append (f"{'─'*60 }")
if cold_experts :
lines .append (f" ❄️ Cold experts (<{self .cold_threshold_pct }%): {cold_experts }")
else :
lines .append (f" ✅ All experts active (no cold experts)")
if self ._offload_manager is not None :
status =self ._offload_manager .get_status ()
lines .append (f" 💾 Offloaded to CPU: {status ['cpu']}/{status ['total']}")
ideal_pct =100.0 /self .num_experts
balance =1.0 -(sum (abs (p -ideal_pct )for p in pcts )/(2 *100 ))
lines .append (f" ⚖️ Load balance score: {balance :.3f} (1.0 = perfect)")
lines .append (f"{'='*60 }")
print ("\n".join (lines ))
def get_stats (self )->dict :
"""Return current stats as a dict (for programmatic access)."""
total =self ._counts .sum ().item ()
if total ==0 :
pcts =[0.0 ]*self .num_experts
else :
pcts =(self ._counts .float ()/total *100 ).tolist ()
cold =[i for i ,p in enumerate (pcts )if p <self .cold_threshold_pct ]
ideal_pct =100.0 /self .num_experts
balance =1.0 -(sum (abs (p -ideal_pct )for p in pcts )/(2 *100 ))if total >0 else 0.0
return {
"step":self ._step ,
"layer_name":self .layer_name ,
"total_tokens":self ._total_tokens ,
"expert_counts":self ._counts .tolist (),
"expert_pcts":pcts ,
"cold_experts":cold ,
"balance_score":balance ,
}
def attach_utilization_trackers (
model :torch .nn .Module ,
report_interval :int =100 ,
)->list :
"""
Find all MoE layers in a model and attach ExpertUtilizationTrackers.
Returns list of trackers for manual step() calls in the training loop.
"""
trackers =[]
for name ,module in model .named_modules ():
if hasattr (module ,'experts')and hasattr (module ,'router'):
num_experts =len (module .experts )
tracker =ExpertUtilizationTracker (
num_experts =num_experts ,
layer_name =name ,
report_interval =report_interval ,
)
if hasattr (module ,'_expert_offload_manager'):
tracker .link_offload_manager (module ._expert_offload_manager )
module ._utilization_tracker =tracker
trackers .append (tracker )
if trackers :
print (f" 📊 Attached {len (trackers )} expert utilization trackers (report every {report_interval } steps)")
return trackers
class MoERouter (nn .Module ):
"""
SOTA Router for Mixture of Experts v2.0 - FP16 native.
Supports both traditional aux-loss routing and aux-lossless routing.
"""
def __init__ (self ,hidden_size :int ,num_experts :int ,top_k :int =2 ,
noise_std :float =0.01 ,capacity_factor :float =1.25 ,
aux_lossless :bool =True ):
super ().__init__ ()
self .num_experts =num_experts
self .top_k =top_k
self .noise_std =noise_std
self .capacity_factor =capacity_factor
self .hidden_size =hidden_size
self .aux_lossless =aux_lossless
self .input_norm =nn .LayerNorm (hidden_size ,eps =1e-5 )
self .gate =nn .Linear (hidden_size ,num_experts ,bias =False )
nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 )
if aux_lossless :
self .expert_bias =nn .Parameter (torch .zeros (num_experts ))
def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
batch_size ,seq_len ,hidden_dim =hidden_states .shape
hidden_flat =hidden_states .view (-1 ,hidden_dim )
hidden_norm =self .input_norm (hidden_flat )
router_logits =self .gate (hidden_norm )
if self .aux_lossless :
router_logits =router_logits +self .expert_bias
if self .training and self .noise_std >0 :
noise =torch .randn_like (router_logits )*self .noise_std
noisy_logits =router_logits +noise
else :
noisy_logits =router_logits
router_probs =F .softmax (noisy_logits ,dim =-1 ,dtype =hidden_states .dtype )
top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 )
prob_sum =top_k_probs .sum (dim =-1 ,keepdim =True ).clamp (min =EPS )
top_k_probs =top_k_probs /prob_sum
return top_k_probs ,top_k_indices ,router_logits
class MoEExpert (nn .Module ):
"""
Single expert FFN with SwiGLU activation - FP16 native.
"""
def __init__ (self ,hidden_size :int ,intermediate_size :int ,dropout :float =0.0 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .intermediate_size =intermediate_size
self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False )
self .act_fn =nn .SiLU ()
self .dropout =nn .Dropout (dropout )if dropout >0 else nn .Identity ()
self ._init_weights ()
def _init_weights (self ):
std =0.02
nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 )
def forward (self ,x :torch .Tensor )->torch .Tensor :
gate =self .act_fn (self .gate_proj (x ))
up =self .up_proj (x )
out =self .down_proj (gate *up )
return self .dropout (out )
class SharedExpert (nn .Module ):
"""
Isolated Shared Expert (v2.0) - FP16 native.
Always active, separate from routed experts.
The shared expert processes all tokens independently of routing decisions.
"""
def __init__ (self ,hidden_size :int ,intermediate_size :int ,dropout :float =0.0 ,
isolated :bool =True ):
super ().__init__ ()
self .hidden_size =hidden_size
self .intermediate_size =intermediate_size
self .isolated =isolated
self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False )
self .act_fn =nn .SiLU ()
self .dropout =nn .Dropout (dropout )if dropout >0 else nn .Identity ()
self .shared_gate =nn .Parameter (torch .ones (1 )*0.5 )
if isolated :
self .pre_norm =nn .LayerNorm (hidden_size ,eps =1e-5 )
self ._init_weights ()
def _init_weights (self ):
std =0.02
nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 )
def forward (self ,x :torch .Tensor )->torch .Tensor :
if self .isolated :
x =self .pre_norm (x )
gate =self .act_fn (self .gate_proj (x ))
up =self .up_proj (x )
out =self .down_proj (gate *up )
out =self .dropout (out )
return out *torch .sigmoid (self .shared_gate )
class MoELayer (nn .Module ):
"""
SOTA Mixture of Experts layer v2.0 - FP16 native.
Supports Aux-Lossless MoE with Isolated Shared Expert.
"""
def __init__ (
self ,
hidden_size :int ,
intermediate_size :int ,
num_experts :int =8 ,
num_experts_per_tok :int =2 ,
use_shared_expert :bool =True ,
shared_expert_intermediate_size :Optional [int ]=None ,
capacity_factor :float =1.25 ,
expert_dropout :float =0.0 ,
aux_lossless :bool =True ,
isolated_shared :bool =True ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_experts =num_experts
self .num_experts_per_tok =num_experts_per_tok
self .use_shared_expert =use_shared_expert
self .capacity_factor =capacity_factor
self .aux_lossless =aux_lossless
self .router =MoERouter (
hidden_size ,num_experts ,num_experts_per_tok ,
capacity_factor =capacity_factor ,aux_lossless =aux_lossless
)
self .experts =nn .ModuleList ([
MoEExpert (hidden_size ,intermediate_size ,expert_dropout )
for _ in range (num_experts )
])
if use_shared_expert :
shared_size =shared_expert_intermediate_size or intermediate_size
self .shared_expert =SharedExpert (
hidden_size ,shared_size ,expert_dropout ,isolated =isolated_shared
)
else :
self .shared_expert =None
def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
batch_size ,seq_len ,hidden_size =hidden_states .shape
hidden_flat =hidden_states .view (-1 ,hidden_size )
num_tokens =hidden_flat .shape [0 ]
top_k_probs ,top_k_indices ,router_logits =self .router (hidden_states )
if hasattr (self ,'_utilization_tracker'):
self ._utilization_tracker .record (top_k_indices )
final_output =torch .zeros_like (hidden_flat )
for expert_idx in range (self .num_experts ):
expert =self .experts [expert_idx ]
for k in range (self .num_experts_per_tok ):
mask =(top_k_indices [:,k ]==expert_idx )
if mask .any ():
expert_input =hidden_flat [mask ]
expert_output =expert (expert_input )
weight =top_k_probs [mask ,k :k +1 ]
final_output [mask ]=final_output [mask ]+weight *expert_output
if self .shared_expert is not None :
shared_output =self .shared_expert (hidden_flat )
final_output =final_output +shared_output
final_output =final_output .view (batch_size ,seq_len ,hidden_size )
aux_loss =self ._compute_aux_loss (router_logits ,top_k_indices ,num_tokens )
return final_output ,aux_loss
def _compute_aux_loss (self ,router_logits :torch .Tensor ,top_k_indices :torch .Tensor ,
num_tokens :int )->torch .Tensor :
device =router_logits .device
dtype =router_logits .dtype
if self .aux_lossless :
z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.0001
return z_loss
router_probs =F .softmax (router_logits ,dim =-1 ,dtype =dtype )
expert_mask =F .one_hot (top_k_indices ,self .num_experts ).to (dtype )
denominator =max (num_tokens *self .num_experts_per_tok ,1 )
tokens_per_expert =expert_mask .sum (dim =(0 ,1 ))/denominator
avg_probs =router_probs .mean (dim =0 )
load_balance_loss =self .num_experts *(tokens_per_expert *avg_probs ).sum ()
z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.001
router_probs_safe =router_probs .clamp (EPS ,1.0 -EPS )
log_probs =torch .log (router_probs_safe )
entropy =-(router_probs_safe *log_probs ).sum (dim =-1 ).mean ()
max_entropy =torch .log (torch .tensor (float (self .num_experts ),device =device ,dtype =dtype ))
entropy_loss =(max_entropy -entropy ).clamp (min =0.0 )*0.01
expert_usage =(tokens_per_expert >0.01 ).to (dtype ).mean ()
utilization_loss =(1.0 -expert_usage )*0.1
total_aux_loss =load_balance_loss +z_loss +entropy_loss +utilization_loss
return total_aux_loss
class ExpertChoiceMoELayer (nn .Module ):
"""
Expert Choice MoE - FP16 native.
"""
def __init__ (
self ,
hidden_size :int ,
intermediate_size :int ,
num_experts :int =8 ,
capacity_factor :float =1.0 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_experts =num_experts
self .capacity_factor =capacity_factor
self .input_norm =nn .LayerNorm (hidden_size ,eps =1e-5 )
self .gate =nn .Linear (hidden_size ,num_experts ,bias =False )
nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 )
self .experts =nn .ModuleList ([
MoEExpert (hidden_size ,intermediate_size )
for _ in range (num_experts )
])
def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
batch_size ,seq_len ,hidden_size =hidden_states .shape
hidden_flat =hidden_states .view (-1 ,hidden_size )
num_tokens =hidden_flat .shape [0 ]
hidden_norm =self .input_norm (hidden_flat )
router_logits =self .gate (hidden_norm )
router_probs =F .softmax (router_logits ,dim =0 ,dtype =hidden_states .dtype )
capacity =int (num_tokens *self .capacity_factor /self .num_experts )
capacity =max (capacity ,1 )
final_output =torch .zeros_like (hidden_flat )
token_counts =torch .zeros (num_tokens ,device =hidden_flat .device ,dtype =hidden_flat .dtype )
for expert_idx in range (self .num_experts ):
expert =self .experts [expert_idx ]
expert_probs =router_probs [:,expert_idx ]
top_probs ,top_indices =torch .topk (expert_probs ,min (capacity ,num_tokens ))
expert_input =hidden_flat [top_indices ]
expert_output =expert (expert_input )
final_output [top_indices ]=final_output [top_indices ]+top_probs .unsqueeze (-1 )*expert_output
token_counts [top_indices ]=token_counts [top_indices ]+top_probs
token_counts =token_counts .clamp (min =EPS )
final_output =final_output /token_counts .unsqueeze (-1 )
final_output =final_output .view (batch_size ,seq_len ,hidden_size )
aux_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.001
return final_output ,aux_loss
==============================================================================
MODELS.ENCODERS.VISION
==============================================================================
EPS =1e-5
class RoPE2DEncoder (nn .Module ):
"""
2D Rotary Position Embedding for vision encoder patches.
Matches the 2D-RoPE in image generator for seamless integration.
"""
def __init__ (self ,dim :int ,max_height :int =128 ,max_width :int =128 ,base :float =10000.0 ):
super ().__init__ ()
self .dim =dim
self .max_height =max_height
self .max_width =max_width
self .base =base
self .dim_x =dim //2
self .dim_y =dim -self .dim_x
inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x ))
inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y ))
self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False )
self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False )
def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]:
device =x .device
dtype =x .dtype
pos_x =torch .arange (width ,device =device ,dtype =torch .float32 )
pos_y =torch .arange (height ,device =device ,dtype =torch .float32 )
freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device ))
freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device ))
freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 )
freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 )
cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype )
sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype )
for y in range (height ):
for w in range (width ):
cos_2d [y ,w ,:self .dim_x ]=freqs_x [w ].cos ().to (dtype )
sin_2d [y ,w ,:self .dim_x ]=freqs_x [w ].sin ().to (dtype )
cos_2d [y ,w ,self .dim_x :]=freqs_y [y ].cos ().to (dtype )
sin_2d [y ,w ,self .dim_x :]=freqs_y [y ].sin ().to (dtype )
cos_2d =cos_2d .view (height *width ,self .dim )
sin_2d =sin_2d .view (height *width ,self .dim )
return cos_2d ,sin_2d
def apply_rope_2d_encoder (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor :
"""Apply 2D rotary position embedding to tensor."""
x1 =x [...,:x .shape [-1 ]//2 ]
x2 =x [...,x .shape [-1 ]//2 :]
rotated =torch .cat ((-x2 ,x1 ),dim =-1 )
return x *cos +rotated *sin
class TiTokTokenizer (nn .Module ):
"""
TiTok-style 1D Tokenizer for efficient visual representation.
Converts 2D patch grid to 1D token sequence with learnable compression.
"""
def __init__ (self ,hidden_size :int ,num_tokens :int =256 ,num_patches :int =576 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_tokens =num_tokens
self .num_patches =num_patches
self .compress =nn .Sequential (
nn .Linear (hidden_size ,hidden_size ),
nn .GELU (),
nn .Linear (hidden_size ,hidden_size ),
)
self .token_queries =nn .Parameter (torch .randn (1 ,num_tokens ,hidden_size )*0.02 )
self .compress_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =8 ,
batch_first =True ,
dropout =0.1 ,
)
self .compress_norm =nn .LayerNorm (hidden_size )
def forward (self ,x :torch .Tensor )->torch .Tensor :
"""
Compress patch features to TiTok-style 1D tokens.
Args:
x: [B, num_patches, hidden_size] patch features
Returns:
[B, num_tokens, hidden_size] compressed token features
"""
batch_size =x .shape [0 ]
queries =self .token_queries .expand (batch_size ,-1 ,-1 )
x_proj =self .compress (x )
tokens ,_ =self .compress_attn (queries ,x_proj ,x_proj )
tokens =self .compress_norm (queries +tokens )
return tokens
class DeepStack (nn .Module ):
"""
DeepStack: Fuses multi-level ViT features to capture fine-grained details and sharpen image-text alignment.
SOTA: Instead of using only the final layer features, DeepStack combines features from
multiple intermediate layers of the vision encoder, enabling:
- Better fine-grained detail capture (early layers have high-resolution features)
- Stronger image-text alignment (different layers capture different semantic levels)
- Improved generation quality for both understanding and generation tasks
Architecture:
- Collects features from selected layers (typically: early, middle, late)
- Projects each level to a common dimension
- Combines via learned weighted sum or attention
"""
def __init__ (self ,hidden_size :int ,num_layers :int =3 ,use_attention :bool =True ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_layers =num_layers
self .use_attention =use_attention
self .level_projs =nn .ModuleList ([
nn .Linear (hidden_size ,hidden_size )
for _ in range (num_layers )
])
self .level_norms =nn .ModuleList ([
nn .LayerNorm (hidden_size )
for _ in range (num_layers )
])
if use_attention :
self .fusion_query =nn .Parameter (torch .randn (1 ,1 ,hidden_size )*0.02 )
self .fusion_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =8 ,
batch_first =True ,
dropout =0.1 ,
)
self .fusion_norm =nn .LayerNorm (hidden_size )
else :
self .level_weights =nn .Parameter (torch .ones (num_layers )/num_layers )
self .output_proj =nn .Sequential (
nn .Linear (hidden_size ,hidden_size ),
nn .GELU (),
nn .Linear (hidden_size ,hidden_size ),
)
def forward (self ,multi_level_features :list )->torch .Tensor :
"""
Fuse multi-level features.
Args:
multi_level_features: List of [B, seq_len, hidden_size] features from different layers
Returns:
[B, seq_len, hidden_size] fused features
"""
if len (multi_level_features )!=self .num_layers :
multi_level_features =multi_level_features [-self .num_layers :]if len (multi_level_features )>self .num_layers else multi_level_features
batch_size ,seq_len ,_ =multi_level_features [0 ].shape
projected =[]
for i ,(feat ,proj ,norm )in enumerate (zip (multi_level_features ,self .level_projs ,self .level_norms )):
projected .append (norm (proj (feat )))
if self .use_attention :
stacked =torch .cat (projected ,dim =1 )
query =self .fusion_query .expand (batch_size ,seq_len ,-1 )
fused ,_ =self .fusion_attn (query ,stacked ,stacked )
fused =self .fusion_norm (query +fused )
else :
weights =F .softmax (self .level_weights ,dim =0 )
fused =sum (w *feat for w ,feat in zip (weights ,projected ))
return self .output_proj (fused )
class DualStreamEncoderAttention (nn .Module ):
"""
Symmetric Dual-Stream Self-Attention for vision encoding.
Matches the dual-stream architecture in image generator.
"""
def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .head_dim =hidden_size //num_heads
self .scale =self .head_dim **-0.5
self .to_qkv_a =nn .Linear (hidden_size ,hidden_size *3 ,bias =False )
self .to_qkv_b =nn .Linear (hidden_size ,hidden_size *3 ,bias =False )
self .to_out_a =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .to_out_b =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .norm_a =nn .LayerNorm (hidden_size )
self .norm_b =nn .LayerNorm (hidden_size )
self .rope_2d =RoPE2DEncoder (self .head_dim ,max_height ,max_width )
def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]:
batch_size ,seq_len ,_ =x_a .shape
x_a =self .norm_a (x_a )
x_b =self .norm_b (x_b )
qkv_a =self .to_qkv_a (x_a ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim )
qkv_b =self .to_qkv_b (x_b ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim )
q_a ,k_a ,v_a =qkv_a .unbind (dim =2 )
q_b ,k_b ,v_b =qkv_b .unbind (dim =2 )
cos ,sin =self .rope_2d (x_a ,height ,width )
cos =cos .unsqueeze (0 ).unsqueeze (0 )
sin =sin .unsqueeze (0 ).unsqueeze (0 )
q_a =q_a .transpose (1 ,2 )
k_a =k_a .transpose (1 ,2 )
v_a =v_a .transpose (1 ,2 )
q_b =q_b .transpose (1 ,2 )
k_b =k_b .transpose (1 ,2 )
v_b =v_b .transpose (1 ,2 )
q_a =apply_rope_2d_encoder (q_a ,cos ,sin )
k_a =apply_rope_2d_encoder (k_a ,cos ,sin )
q_b =apply_rope_2d_encoder (q_b ,cos ,sin )
k_b =apply_rope_2d_encoder (k_b ,cos ,sin )
k_combined =torch .cat ([k_a ,k_b ],dim =2 )
v_combined =torch .cat ([v_a ,v_b ],dim =2 )
attn_a =F .scaled_dot_product_attention (q_a ,k_combined ,v_combined )
attn_b =F .scaled_dot_product_attention (q_b ,k_combined ,v_combined )
attn_a =attn_a .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size )
attn_b =attn_b .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size )
out_a =self .to_out_a (attn_a )
out_b =self .to_out_b (attn_b )
return out_a ,out_b
class VisionEncoderBlock (nn .Module ):
"""Single block with dual-stream attention and FFN."""
def __init__ (self ,hidden_size :int ,num_heads :int =8 ,ff_mult :int =4 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .dual_attn =DualStreamEncoderAttention (hidden_size ,num_heads ,max_height ,max_width )
self .ffn_a =nn .Sequential (
nn .LayerNorm (hidden_size ),
nn .Linear (hidden_size ,hidden_size *ff_mult ),
nn .GELU (),
nn .Linear (hidden_size *ff_mult ,hidden_size ),
)
self .ffn_b =nn .Sequential (
nn .LayerNorm (hidden_size ),
nn .Linear (hidden_size ,hidden_size *ff_mult ),
nn .GELU (),
nn .Linear (hidden_size *ff_mult ,hidden_size ),
)
def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]:
attn_a ,attn_b =self .dual_attn (x_a ,x_b ,height ,width )
x_a =x_a +attn_a
x_b =x_b +attn_b
x_a =x_a +self .ffn_a (x_a )
x_b =x_b +self .ffn_b (x_b )
return x_a ,x_b
class VisionEncoder (nn .Module ):
"""
SOTA Vision Encoder with 2D-RoPE, TiTok tokenization, and Dual-Stream Attention.
Features:
- SigLIP 2 / CLIP backbone for robust visual features
- 2D-RoPE for flexible aspect ratios
- TiTok-style 1D tokenization for efficient representation
- Dual-stream attention for symmetric processing
- FP16-native numerical stability
"""
def __init__ (
self ,
model_name :str ="google/siglip-so400m-patch14-384",
freeze :bool =False ,
use_pooled_output :bool =False ,
use_dual_stream :bool =True ,
use_titok :bool =True ,
num_titok_tokens :int =256 ,
num_dual_stream_layers :int =2 ,
):
super ().__init__ ()
self .model_name =model_name
self .use_pooled_output =use_pooled_output
self .use_dual_stream =use_dual_stream
self .use_titok =use_titok
self ._is_siglip ="siglip"in model_name .lower ()
print (f"\n👁️ Loading Vision Encoder: {model_name }")
if self ._is_siglip :
self ._init_siglip (model_name ,freeze )
else :
self ._init_clip (model_name ,freeze )
self .rope_2d =RoPE2DEncoder (
dim =self .hidden_size ,
max_height =64 ,
max_width =64 ,
)
print (f" 📐 2D-RoPE: Flexible aspect ratio support")
if use_dual_stream :
patch_size =getattr (self .vision_model .config ,'patch_size',14 )
image_size =getattr (self .vision_model .config ,'image_size',384 )
max_patches =(image_size //patch_size )
self .dual_stream_layers =nn .ModuleList ([
VisionEncoderBlock (
hidden_size =self .hidden_size ,
num_heads =8 ,
ff_mult =4 ,
max_height =max_patches ,
max_width =max_patches ,
)
for _ in range (num_dual_stream_layers )
])
print (f" 🔄 Dual-Stream: {num_dual_stream_layers } layers")
else :
self .dual_stream_layers =None
if use_titok :
self .titok =TiTokTokenizer (
hidden_size =self .hidden_size ,
num_tokens =num_titok_tokens ,
num_patches =self .num_patches ,
)
print (f" 🎫 TiTok: {self .num_patches } patches -> {num_titok_tokens } tokens")
else :
self .titok =None
def _init_siglip (self ,model_name :str ,freeze :bool ):
"""Initialize SigLIP 2 vision encoder."""
try :
from transformers import SiglipVisionModel ,SiglipImageProcessor
self .vision_model =SiglipVisionModel .from_pretrained (model_name )
self .image_processor =SiglipImageProcessor .from_pretrained (model_name )
self .hidden_size =self .vision_model .config .hidden_size
print (f" 🎯 Using SigLIP 2 (recommended for MoE)")
print (f" ✅ Hidden size: {self .hidden_size }")
print (f" 📐 Native size: {self .vision_model .config .image_size } (multi-scale: 256-512px)")
print (f" 🔲 Patch size: {self .vision_model .config .patch_size }")
except ImportError :
print (" ⚠️ SigLIP not available, falling back to CLIP")
self ._is_siglip =False
self ._init_clip ("openai/clip-vit-large-patch14",freeze )
return
if freeze :
for param in self .vision_model .parameters ():
param .requires_grad =False
print (f" ❄️ Vision encoder backbone frozen")
else :
print (f" 🔥 Vision encoder backbone trainable")
def _init_clip (self ,model_name :str ,freeze :bool ):
"""Initialize CLIP vision encoder (legacy support)."""
from transformers import CLIPVisionModel ,CLIPImageProcessor
self .vision_model =CLIPVisionModel .from_pretrained (model_name )
self .image_processor =CLIPImageProcessor .from_pretrained (model_name )
self .hidden_size =self .vision_model .config .hidden_size
print (f" 📎 Using CLIP")
print (f" ✅ Hidden size: {self .hidden_size }")
if freeze :
for param in self .vision_model .parameters ():
param .requires_grad =False
print (f" ❄️ Vision encoder backbone frozen")
else :
print (f" 🔥 Vision encoder backbone trainable")
def forward (self ,pixel_values :torch .Tensor ,return_titok :bool =None )->torch .Tensor :
"""
Extract vision features from images with SOTA enhancements.
Args:
pixel_values: [B, C, H, W] tensor of images
return_titok: Override for TiTok output (None uses self.use_titok)
Returns:
[B, num_tokens, hidden_size] tensor (TiTok) or
[B, num_patches, hidden_size] tensor (standard) or
[B, hidden_size] if use_pooled_output=True
"""
outputs =self .vision_model (pixel_values =pixel_values )
features =outputs .last_hidden_state
if self .use_pooled_output :
if hasattr (outputs ,'pooler_output')and outputs .pooler_output is not None :
return outputs .pooler_output
else :
return features .mean (dim =1 )
batch_size ,num_patches ,hidden_size =features .shape
patch_size =getattr (self .vision_model .config ,'patch_size',14 )
image_size =getattr (self .vision_model .config ,'image_size',384 )
if num_patches ==(image_size //patch_size )**2 +1 :
cls_token =features [:,:1 ]
features =features [:,1 :]
num_patches =num_patches -1
has_cls =True
else :
cls_token =None
has_cls =False
height =width =int (math .sqrt (num_patches ))
if self .dual_stream_layers is not None :
x_a =features
x_b =features .clone ()
for layer in self .dual_stream_layers :
x_a ,x_b =layer (x_a ,x_b ,height ,width )
features =(x_a +x_b )/2
use_titok_now =return_titok if return_titok is not None else self .use_titok
if use_titok_now and self .titok is not None :
features =self .titok (features )
return features
def get_image_processor (self ):
"""Return the image processor for preprocessing."""
return self .image_processor
@property
def num_patches (self )->int :
"""Get number of patches for the vision model."""
config =self .vision_model .config
image_size =config .image_size
patch_size =config .patch_size
return (image_size //patch_size )**2
@property
def image_size (self )->int :
"""Get expected image size."""
return self .vision_model .config .image_size
@property
def output_tokens (self )->int :
"""Get number of output tokens (considering TiTok compression)."""
if self .use_titok and self .titok is not None :
return self .titok .num_tokens
return self .num_patches
SIGLIP_MODELS ={
"siglip-base":"google/siglip-base-patch16-224",
"siglip-base-384":"google/siglip-base-patch16-384",
"siglip-large":"google/siglip-large-patch16-256",
"siglip-large-384":"google/siglip-large-patch16-384",
"siglip-so400m":"google/siglip-so400m-patch14-384",
"siglip-so400m-224":"google/siglip-so400m-patch14-224",
"clip-base":"openai/clip-vit-base-patch16",
"clip-large":"openai/clip-vit-large-patch14",
}
def get_vision_encoder (
model_key :str ="siglip-so400m",
freeze :bool =False ,
use_dual_stream :bool =True ,
use_titok :bool =True ,
**kwargs
)->VisionEncoder :
"""
Get a vision encoder by key name with SOTA enhancements.
Args:
model_key: Key from SIGLIP_MODELS or full model name
freeze: Whether to freeze encoder backbone weights
use_dual_stream: Enable dual-stream attention
use_titok: Enable TiTok 1D tokenization
**kwargs: Additional arguments for VisionEncoder
Returns:
VisionEncoder instance
"""
model_name =SIGLIP_MODELS .get (model_key ,model_key )
return VisionEncoder (
model_name =model_name ,
freeze =freeze ,
use_dual_stream =use_dual_stream ,
use_titok =use_titok ,
**kwargs
)
==============================================================================
MODELS.ENCODERS.VIDEO
==============================================================================
EPS =1e-5
class TextTimestampAlignment (nn .Module ):
"""
Text-Timestamp Alignment: Precise timestamp-grounded event localization for stronger video temporal modeling.
SOTA: Moves beyond T-RoPE by explicitly aligning text descriptions with video timestamps,
enabling:
- Precise temporal localization of events described in text
- Better video captioning with accurate time references
- Improved video question-answering with temporal reasoning
- Enhanced video generation with temporal control
Architecture:
- Cross-attention between text features and frame-level video features
- Learnable timestamp embeddings for each frame
- Temporal alignment loss during training
"""
def __init__ (self ,hidden_size :int ,max_frames :int =64 ,num_heads :int =8 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .max_frames =max_frames
self .num_heads =num_heads
self .timestamp_embedding =nn .Embedding (max_frames ,hidden_size )
self .video_proj =nn .Linear (hidden_size ,hidden_size )
self .text_proj =nn .Linear (hidden_size ,hidden_size )
self .cross_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =num_heads ,
batch_first =True ,
dropout =0.1 ,
)
self .text_norm =nn .LayerNorm (hidden_size )
self .video_norm =nn .LayerNorm (hidden_size )
self .alignment_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //2 ),
nn .GELU (),
nn .Linear (hidden_size //2 ,1 ),
)
self .output_proj =nn .Linear (hidden_size ,hidden_size )
def forward (
self ,
video_features :torch .Tensor ,
text_features :torch .Tensor ,
num_frames :int ,
return_alignment_scores :bool =False ,
)->Tuple [torch .Tensor ,Optional [torch .Tensor ]]:
"""
Align text with video timestamps.
Args:
video_features: [B, T*H*W, hidden_size] video features
text_features: [B, text_len, hidden_size] text features
num_frames: Number of frames in the video
return_alignment_scores: Whether to return alignment scores for loss
Returns:
aligned_features: [B, T*H*W, hidden_size] timestamp-aligned video features
alignment_scores: Optional [B, text_len, T] alignment scores
"""
batch_size =video_features .shape [0 ]
total_tokens =video_features .shape [1 ]
spatial_tokens =total_tokens //num_frames
timestamp_ids =torch .arange (num_frames ,device =video_features .device )
timestamp_embeds =self .timestamp_embedding (timestamp_ids )
timestamp_embeds =timestamp_embeds .unsqueeze (1 ).expand (-1 ,spatial_tokens ,-1 )
timestamp_embeds =timestamp_embeds .reshape (1 ,total_tokens ,-1 )
timestamp_embeds =timestamp_embeds .expand (batch_size ,-1 ,-1 )
video_feat =self .video_norm (self .video_proj (video_features )+timestamp_embeds )
text_feat =self .text_norm (self .text_proj (text_features ))
aligned ,attn_weights =self .cross_attn (text_feat ,video_feat ,video_feat )
alignment_scores =None
if return_alignment_scores :
attn_reshaped =attn_weights .view (batch_size ,text_features .shape [1 ],num_frames ,spatial_tokens )
alignment_scores =attn_reshaped .mean (dim =-1 )
aligned_text =text_features +self .output_proj (aligned )
return aligned_text ,alignment_scores
class AlphaBlender (nn .Module ):
"""
AlphaBlender operator from VidTok for temporal blending.
Blends two inputs with a learnable or fixed alpha parameter.
"""
def __init__ (self ,alpha :float =0.55 ):
super ().__init__ ()
self .alpha =alpha
def forward (self ,x1 :torch .Tensor ,x2 :torch .Tensor )->torch .Tensor :
return self .alpha *x1 +(1 -self .alpha )*x2
class VidTokEncoder (nn .Module ):
"""
VidTok-style Video Encoder following Microsoft's VidTok architecture.
SOTA: Implements the VidTok encoder with:
- 3D convolutions for input and bottleneck (information fusion)
- 2D convolutions for spatial downsampling (efficiency)
- AlphaBlender + 1D convolutions for temporal downsampling
- Layer normalization for stability
Compresses video [B, C, T, H, W] -> latent [B, latent_dim, t, h, w]
"""
def __init__ (
self ,
in_channels :int =3 ,
latent_channels :int =4 ,
base_channels :int =64 ,
temporal_downsample :int =4 ,
spatial_downsample :int =8 ,
causal :bool =True ,
):
super ().__init__ ()
self .in_channels =in_channels
self .latent_channels =latent_channels
self .base_channels =base_channels
self .temporal_downsample =temporal_downsample
self .spatial_downsample =spatial_downsample
self .causal =causal
self .num_spatial_downs =int (math .log2 (spatial_downsample ))
self .num_temporal_downs =int (math .log2 (temporal_downsample ))
self .input_block =nn .Sequential (
nn .Conv3d (in_channels ,base_channels ,kernel_size =3 ,padding =1 ),
nn .GroupNorm (8 ,base_channels ),
nn .SiLU (),
)
self .spatial_down_blocks =nn .ModuleList ()
ch =base_channels
for i in range (self .num_spatial_downs ):
out_ch =min (ch *2 ,512 )
self .spatial_down_blocks .append (
self ._make_spatial_down_block (ch ,out_ch )
)
ch =out_ch
self .temporal_down_blocks =nn .ModuleList ()
for i in range (self .num_temporal_downs ):
self .temporal_down_blocks .append (
self ._make_temporal_down_block (ch )
)
self .bottleneck =nn .Sequential (
nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ),
nn .GroupNorm (8 ,ch ),
nn .SiLU (),
nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ),
nn .GroupNorm (8 ,ch ),
nn .SiLU (),
)
self .to_latent =nn .Conv3d (ch ,latent_channels ,kernel_size =1 )
print (f" 🎬 VidTokEncoder: {in_channels }ch -> {latent_channels }ch latent")
print (f" Spatial: {spatial_downsample }x down ({self .num_spatial_downs } stages)")
print (f" Temporal: {temporal_downsample }x down ({self .num_temporal_downs } stages)")
def _make_spatial_down_block (self ,in_ch :int ,out_ch :int )->nn .Module :
"""Create a spatial downsampling block using 2D convolutions."""
return nn .Sequential (
Rearrange3Dto2D (),
nn .Conv2d (in_ch ,out_ch ,kernel_size =3 ,stride =2 ,padding =1 ),
nn .GroupNorm (8 ,out_ch ),
nn .SiLU (),
nn .Conv2d (out_ch ,out_ch ,kernel_size =3 ,padding =1 ),
nn .GroupNorm (8 ,out_ch ),
nn .SiLU (),
Rearrange2Dto3D (),
)
def _make_temporal_down_block (self ,channels :int )->nn .Module :
"""Create a temporal downsampling block using AlphaBlender + 1D conv."""
return TemporalDownBlock (channels ,causal =self .causal )
def forward (self ,x :torch .Tensor )->torch .Tensor :
"""
Encode video to latent space.
Args:
x: [B, C, T, H, W] input video
Returns:
[B, latent_channels, t, h, w] latent representation
"""
B ,C ,T ,H ,W =x .shape
x =self .input_block (x )
for block in self .spatial_down_blocks :
if hasattr (block [0 ],'set_temporal_dim'):
block [0 ].set_temporal_dim (x .shape [2 ])
if hasattr (block [-1 ],'set_temporal_dim'):
block [-1 ].set_temporal_dim (x .shape [2 ])
x =block (x )
for block in self .temporal_down_blocks :
x =block (x )
x =self .bottleneck (x )
x =self .to_latent (x )
return x
class VidTokDecoder (nn .Module ):
"""
VidTok-style Video Decoder following Microsoft's VidTok architecture.
Reconstructs video from latent [B, latent_dim, t, h, w] -> [B, C, T, H, W]
"""
def __init__ (
self ,
out_channels :int =3 ,
latent_channels :int =4 ,
base_channels :int =64 ,
temporal_upsample :int =4 ,
spatial_upsample :int =8 ,
causal :bool =True ,
):
super ().__init__ ()
self .out_channels =out_channels
self .latent_channels =latent_channels
self .base_channels =base_channels
self .temporal_upsample =temporal_upsample
self .spatial_upsample =spatial_upsample
self .causal =causal
self .num_spatial_ups =int (math .log2 (spatial_upsample ))
self .num_temporal_ups =int (math .log2 (temporal_upsample ))
ch =min (base_channels *(2 **self .num_spatial_ups ),512 )
self .from_latent =nn .Conv3d (latent_channels ,ch ,kernel_size =1 )
self .bottleneck =nn .Sequential (
nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ),
nn .GroupNorm (8 ,ch ),
nn .SiLU (),
nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ),
nn .GroupNorm (8 ,ch ),
nn .SiLU (),
)
self .temporal_up_blocks =nn .ModuleList ()
for i in range (self .num_temporal_ups ):
self .temporal_up_blocks .append (
TemporalUpBlock (ch ,causal =self .causal )
)
self .spatial_up_blocks =nn .ModuleList ()
for i in range (self .num_spatial_ups ):
out_ch =max (ch //2 ,base_channels )
self .spatial_up_blocks .append (
self ._make_spatial_up_block (ch ,out_ch )
)
ch =out_ch
self .output_block =nn .Sequential (
nn .Conv3d (ch ,out_channels ,kernel_size =3 ,padding =1 ),
nn .Tanh (),
)
print (f" 🎬 VidTokDecoder: {latent_channels }ch latent -> {out_channels }ch")
def _make_spatial_up_block (self ,in_ch :int ,out_ch :int )->nn .Module :
"""Create a spatial upsampling block using 2D convolutions."""
return nn .Sequential (
Rearrange3Dto2D (),
nn .ConvTranspose2d (in_ch ,out_ch ,kernel_size =4 ,stride =2 ,padding =1 ),
nn .GroupNorm (8 ,out_ch ),
nn .SiLU (),
nn .Conv2d (out_ch ,out_ch ,kernel_size =3 ,padding =1 ),
nn .GroupNorm (8 ,out_ch ),
nn .SiLU (),
Rearrange2Dto3D (),
)
def forward (self ,z :torch .Tensor )->torch .Tensor :
"""
Decode latent to video.
Args:
z: [B, latent_channels, t, h, w] latent representation
Returns:
[B, C, T, H, W] reconstructed video
"""
x =self .from_latent (z )
x =self .bottleneck (x )
for block in self .temporal_up_blocks :
x =block (x )
for block in self .spatial_up_blocks :
x =block (x )
x =self .output_block (x )
return x
class Rearrange3Dto2D (nn .Module ):
"""Reshape [B, C, T, H, W] -> [B*T, C, H, W] for 2D operations."""
def __init__ (self ):
super ().__init__ ()
self .temporal_dim =None
def set_temporal_dim (self ,t :int ):
self .temporal_dim =t
def forward (self ,x :torch .Tensor )->torch .Tensor :
B ,C ,T ,H ,W =x .shape
self .temporal_dim =T
return x .permute (0 ,2 ,1 ,3 ,4 ).reshape (B *T ,C ,H ,W )
class Rearrange2Dto3D (nn .Module ):
"""Reshape [B*T, C, H, W] -> [B, C, T, H, W] after 2D operations."""
def __init__ (self ):
super ().__init__ ()
self .temporal_dim =None
def set_temporal_dim (self ,t :int ):
self .temporal_dim =t
def forward (self ,x :torch .Tensor )->torch .Tensor :
BT ,C ,H ,W =x .shape
T =self .temporal_dim if self .temporal_dim else 1
B =BT //T
return x .reshape (B ,T ,C ,H ,W ).permute (0 ,2 ,1 ,3 ,4 )
class TemporalDownBlock (nn .Module ):
"""Temporal downsampling using AlphaBlender + 1D conv (VidTok style)."""
def __init__ (self ,channels :int ,causal :bool =True ):
super ().__init__ ()
self .channels =channels
self .causal =causal
self .alpha_blender =AlphaBlender ()
padding =(1 ,0 )if causal else 1
self .temporal_conv =nn .Conv1d (channels ,channels ,kernel_size =2 ,stride =2 ,padding =0 )
self .norm =nn .GroupNorm (8 ,channels )
self .act =nn .SiLU ()
def forward (self ,x :torch .Tensor )->torch .Tensor :
"""
Args:
x: [B, C, T, H, W]
Returns:
[B, C, T//2, H, W]
"""
B ,C ,T ,H ,W =x .shape
x =x .permute (0 ,3 ,4 ,1 ,2 ).reshape (B *H *W ,C ,T )
x =self .temporal_conv (x )
x =self .norm (x .unsqueeze (-1 )).squeeze (-1 )
x =self .act (x )
T_new =x .shape [2 ]
x =x .reshape (B ,H ,W ,C ,T_new ).permute (0 ,3 ,4 ,1 ,2 )
return x
class TemporalUpBlock (nn .Module ):
"""Temporal upsampling using AlphaBlender + 1D conv (VidTok style)."""
def __init__ (self ,channels :int ,causal :bool =True ):
super ().__init__ ()
self .channels =channels
self .causal =causal
self .alpha_blender =AlphaBlender ()
self .temporal_conv =nn .ConvTranspose1d (channels ,channels ,kernel_size =2 ,stride =2 )
self .norm =nn .GroupNorm (8 ,channels )
self .act =nn .SiLU ()
def forward (self ,x :torch .Tensor )->torch .Tensor :
"""
Args:
x: [B, C, T, H, W]
Returns:
[B, C, T*2, H, W]
"""
B ,C ,T ,H ,W =x .shape
x =x .permute (0 ,3 ,4 ,1 ,2 ).reshape (B *H *W ,C ,T )
x =self .temporal_conv (x )
x =self .norm (x .unsqueeze (-1 )).squeeze (-1 )
x =self .act (x )
T_new =x .shape [2 ]
x =x .reshape (B ,H ,W ,C ,T_new ).permute (0 ,3 ,4 ,1 ,2 )
return x
class VidTokTokenizer (nn .Module ):
"""
VidTok-style Video Tokenizer (3D VAE) following Microsoft's VidTok architecture.
SOTA: Full encoder-decoder architecture for video compression to latent space.
- Efficient 2D+1D architecture (separates spatial and temporal processing)
- AlphaBlender for temporal blending
- Supports both continuous (KL) and discrete (FSQ) tokenization
- Causal mode for streaming/autoregressive applications
Compresses video [B, C, T, H, W] -> latent [B, latent_dim, t, h, w]
"""
def __init__ (
self ,
in_channels :int =3 ,
latent_channels :int =4 ,
base_channels :int =64 ,
temporal_compression :int =4 ,
spatial_compression :int =8 ,
causal :bool =True ,
use_fsq :bool =False ,
fsq_levels :int =8 ,
):
super ().__init__ ()
self .in_channels =in_channels
self .latent_channels =latent_channels
self .temporal_compression =temporal_compression
self .spatial_compression =spatial_compression
self .causal =causal
self .use_fsq =use_fsq
self .fsq_levels =fsq_levels
self .encoder =VidTokEncoder (
in_channels =in_channels ,
latent_channels =latent_channels *2 if not use_fsq else latent_channels ,
base_channels =base_channels ,
temporal_downsample =temporal_compression ,
spatial_downsample =spatial_compression ,
causal =causal ,
)
self .decoder =VidTokDecoder (
out_channels =in_channels ,
latent_channels =latent_channels ,
base_channels =base_channels ,
temporal_upsample =temporal_compression ,
spatial_upsample =spatial_compression ,
causal =causal ,
)
print (f" 🎬 VidTokTokenizer: {temporal_compression }x{spatial_compression }x{spatial_compression } compression")
print (f" Mode: {'FSQ (discrete)'if use_fsq else 'KL (continuous)'}, Causal: {causal }")
def encode (self ,x :torch .Tensor )->torch .Tensor :
"""Encode video to latent space."""
h =self .encoder (x )
if self .use_fsq :
return self ._fsq_quantize (h )
else :
mean ,logvar =h .chunk (2 ,dim =1 )
std =torch .exp (0.5 *logvar )
eps =torch .randn_like (std )
return mean +eps *std
def decode (self ,z :torch .Tensor )->torch .Tensor :
"""Decode latent to video."""
return self .decoder (z )
def _fsq_quantize (self ,z :torch .Tensor )->torch .Tensor :
"""Finite Scalar Quantization - quantize each channel independently."""
z =torch .tanh (z )
z =torch .round ((z +1 )*(self .fsq_levels -1 )/2 )*2 /(self .fsq_levels -1 )-1
return z
def forward (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Full forward pass: encode then decode.
Args:
x: [B, C, T, H, W] input video
Returns:
Tuple of (reconstructed video, latent representation)
"""
z =self .encode (x )
x_recon =self .decode (z )
return x_recon ,z
class RoPE3DEncoder (nn .Module ):
"""
3D Rotary Position Embedding for (x, y, t) dimensions.
Matches the 3D-RoPE in video generator for seamless integration.
"""
def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,max_frames :int =32 ,base :float =10000.0 ):
super ().__init__ ()
self .dim =dim
self .max_height =max_height
self .max_width =max_width
self .max_frames =max_frames
self .base =base
dim_per_axis =dim //3
self .dim_x =dim_per_axis
self .dim_y =dim_per_axis
self .dim_t =dim -2 *dim_per_axis
inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x ))
inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y ))
inv_freq_t =1.0 /(base **(torch .arange (0 ,self .dim_t ,2 ,dtype =torch .float32 )/self .dim_t ))
self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False )
self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False )
self .register_buffer ('inv_freq_t',inv_freq_t ,persistent =False )
def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int )->Tuple [torch .Tensor ,torch .Tensor ]:
device =x .device
dtype =x .dtype
pos_x =torch .arange (width ,device =device ,dtype =torch .float32 )
pos_y =torch .arange (height ,device =device ,dtype =torch .float32 )
pos_t =torch .arange (frames ,device =device ,dtype =torch .float32 )
freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device ))
freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device ))
freqs_t =torch .outer (pos_t ,self .inv_freq_t .to (device ))
freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 )
freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 )
freqs_t =torch .cat ([freqs_t ,freqs_t ],dim =-1 )
cos_x =freqs_x .cos ().to (dtype )
sin_x =freqs_x .sin ().to (dtype )
cos_y =freqs_y .cos ().to (dtype )
sin_y =freqs_y .sin ().to (dtype )
cos_t =freqs_t .cos ().to (dtype )
sin_t =freqs_t .sin ().to (dtype )
cos_3d =torch .zeros (frames ,height ,width ,self .dim ,device =device ,dtype =dtype )
sin_3d =torch .zeros (frames ,height ,width ,self .dim ,device =device ,dtype =dtype )
for t in range (frames ):
for y in range (height ):
for w in range (width ):
cos_3d [t ,y ,w ,:self .dim_x ]=cos_x [w ]
sin_3d [t ,y ,w ,:self .dim_x ]=sin_x [w ]
cos_3d [t ,y ,w ,self .dim_x :self .dim_x +self .dim_y ]=cos_y [y ]
sin_3d [t ,y ,w ,self .dim_x :self .dim_x +self .dim_y ]=sin_y [y ]
cos_3d [t ,y ,w ,self .dim_x +self .dim_y :]=cos_t [t ]
sin_3d [t ,y ,w ,self .dim_x +self .dim_y :]=sin_t [t ]
cos_3d =cos_3d .view (frames *height *width ,self .dim )
sin_3d =sin_3d .view (frames *height *width ,self .dim )
return cos_3d ,sin_3d
def apply_rope_3d_encoder (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor :
"""Apply 3D rotary position embedding to tensor."""
x1 =x [...,:x .shape [-1 ]//2 ]
x2 =x [...,x .shape [-1 ]//2 :]
rotated =torch .cat ((-x2 ,x1 ),dim =-1 )
return x *cos +rotated *sin
class TemporalExpertRouterEncoder (nn .Module ):
"""
Temporal-Aware Expert Router for video encoding.
Routes tokens based on temporal context and motion patterns.
"""
def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ):
super ().__init__ ()
self .num_experts =num_experts
self .top_k =top_k
self .temporal_proj =nn .Linear (hidden_size ,hidden_size )
self .gate =nn .Linear (hidden_size ,num_experts ,bias =False )
nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 )
def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->Tuple [torch .Tensor ,torch .Tensor ]:
if temporal_context is not None :
x =x +self .temporal_proj (temporal_context )
router_logits =self .gate (x )
router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype )
top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 )
top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS )
return top_k_probs ,top_k_indices
class VideoExpertEncoder (nn .Module ):
"""Single expert for video encoding with SwiGLU."""
def __init__ (self ,hidden_size :int ,intermediate_size :int ):
super ().__init__ ()
self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False )
self .act_fn =nn .SiLU ()
def forward (self ,x :torch .Tensor )->torch .Tensor :
return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x ))
class TemporalMoELayerEncoder (nn .Module ):
"""
Temporal-Aware MoE Layer for video encoding.
Uses motion-aware routing for expert selection.
"""
def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_experts =num_experts
self .top_k =top_k
self .router =TemporalExpertRouterEncoder (hidden_size ,num_experts ,top_k )
self .experts =nn .ModuleList ([
VideoExpertEncoder (hidden_size ,intermediate_size )
for _ in range (num_experts )
])
self .shared_expert =VideoExpertEncoder (hidden_size ,intermediate_size )
def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor :
batch_size ,seq_len ,hidden_size =x .shape
x_flat =x .view (-1 ,hidden_size )
top_k_probs ,top_k_indices =self .router (x_flat ,temporal_context .view (-1 ,hidden_size )if temporal_context is not None else None )
output =torch .zeros_like (x_flat )
for expert_idx in range (self .num_experts ):
expert =self .experts [expert_idx ]
for k in range (self .top_k ):
mask =(top_k_indices [:,k ]==expert_idx )
if mask .any ():
expert_input =x_flat [mask ]
expert_output =expert (expert_input )
weight =top_k_probs [mask ,k :k +1 ]
output [mask ]=output [mask ]+weight *expert_output
shared_output =self .shared_expert (x_flat )
output =output +shared_output
return output .view (batch_size ,seq_len ,hidden_size )
class Causal3DAttentionEncoder (nn .Module ):
"""
3D Causal Self-Attention with 3D-RoPE for video encoding.
Attends to all positions for encoding (non-causal during encoding).
"""
def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .head_dim =hidden_size //num_heads
self .scale =self .head_dim **-0.5
self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False )
self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .norm =nn .LayerNorm (hidden_size )
self .rope_3d =RoPE3DEncoder (self .head_dim ,max_height ,max_width ,max_frames )
def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =False )->torch .Tensor :
batch_size ,seq_len ,_ =x .shape
x_norm =self .norm (x )
qkv =self .to_qkv (x_norm ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim )
q ,k ,v =qkv .unbind (dim =2 )
cos ,sin =self .rope_3d (x ,height ,width ,frames )
cos =cos .unsqueeze (0 ).unsqueeze (2 )
sin =sin .unsqueeze (0 ).unsqueeze (2 )
q =q .transpose (1 ,2 )
k =k .transpose (1 ,2 )
v =v .transpose (1 ,2 )
q =apply_rope_3d_encoder (q ,cos ,sin )
k =apply_rope_3d_encoder (k ,cos ,sin )
if causal :
attn_output =F .scaled_dot_product_attention (q ,k ,v ,is_causal =True )
else :
attn_output =F .scaled_dot_product_attention (q ,k ,v )
attn_output =attn_output .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size )
return self .to_out (attn_output )
class VideoEncoderBlock (nn .Module ):
"""Single block with 3D causal attention and temporal MoE FFN."""
def __init__ (
self ,
hidden_size :int ,
num_heads :int =8 ,
num_experts :int =4 ,
max_frames :int =32 ,
max_height :int =64 ,
max_width :int =64 ,
):
super ().__init__ ()
self .attn =Causal3DAttentionEncoder (hidden_size ,num_heads ,max_frames ,max_height ,max_width )
self .moe =TemporalMoELayerEncoder (hidden_size ,hidden_size *4 ,num_experts )
self .norm =nn .LayerNorm (hidden_size )
def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =False )->torch .Tensor :
x =x +self .attn (x ,height ,width ,frames ,causal )
x =self .norm (x +self .moe (x ))
return x
class VideoTiTokTokenizer (nn .Module ):
"""
SOTA TiTok-style 1D Tokenizer for video features with temporal awareness.
This compresses encoded video features (from vision encoder) to a smaller
number of tokens, similar to how TiTokTokenizer works for images but with
proper temporal modeling.
SOTA Features:
- Multi-layer transformer with temporal-aware attention
- 3D positional encoding (spatial + temporal)
- Hierarchical compression: spatial first, then temporal
- Causal temporal attention for streaming compatibility
- Gated cross-attention for selective feature extraction
Note: This is different from VidTokTokenizer which is a 3D VAE for raw video compression.
This tokenizer operates on already-encoded features, not raw pixels.
Converts [B, T*H*W, hidden_size] -> [B, num_tokens, hidden_size]
"""
def __init__ (
self ,
hidden_size :int ,
num_tokens :int =64 ,
num_patches :int =576 ,
max_frames :int =32 ,
num_layers :int =2 ,
num_heads :int =8 ,
dropout :float =0.1 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_tokens =num_tokens
self .num_patches =num_patches
self .max_frames =max_frames
self .num_heads =num_heads
self .patches_per_frame =num_patches //max_frames if max_frames >0 else num_patches
self .spatial_size =int (self .patches_per_frame **0.5 )
self .temporal_pos =nn .Parameter (torch .randn (1 ,max_frames ,1 ,hidden_size )*0.02 )
self .spatial_pos =nn .Parameter (torch .randn (1 ,1 ,self .patches_per_frame ,hidden_size )*0.02 )
self .input_norm =nn .LayerNorm (hidden_size )
self .input_proj =nn .Linear (hidden_size ,hidden_size )
self .num_temporal_tokens =min (num_tokens //4 ,max_frames )
self .num_content_tokens =num_tokens -self .num_temporal_tokens
self .temporal_queries =nn .Parameter (torch .randn (1 ,self .num_temporal_tokens ,hidden_size )*0.02 )
self .content_queries =nn .Parameter (torch .randn (1 ,self .num_content_tokens ,hidden_size )*0.02 )
self .compress_layers =nn .ModuleList ()
for i in range (num_layers ):
self .compress_layers .append (nn .ModuleDict ({
'cross_attn':nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =num_heads ,
batch_first =True ,
dropout =dropout ,
),
'cross_gate':nn .Sequential (
nn .Linear (hidden_size ,hidden_size ),
nn .Sigmoid (),
),
'cross_norm':nn .LayerNorm (hidden_size ),
'self_attn':nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =num_heads ,
batch_first =True ,
dropout =dropout ,
),
'self_norm':nn .LayerNorm (hidden_size ),
'ffn':nn .Sequential (
nn .Linear (hidden_size ,hidden_size *4 ),
nn .GELU (),
nn .Dropout (dropout ),
nn .Linear (hidden_size *4 ,hidden_size ),
nn .Dropout (dropout ),
),
'ffn_norm':nn .LayerNorm (hidden_size ),
}))
self .fusion_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =num_heads ,
batch_first =True ,
dropout =dropout ,
)
self .fusion_norm =nn .LayerNorm (hidden_size )
self .output_proj =nn .Sequential (
nn .Linear (hidden_size ,hidden_size ),
nn .GELU (),
nn .Linear (hidden_size ,hidden_size ),
)
self .output_norm =nn .LayerNorm (hidden_size )
print (f" 🎬 VideoTiTokTokenizer: {num_patches } patches -> {num_tokens } tokens")
print (f" Temporal tokens: {self .num_temporal_tokens }, Content tokens: {self .num_content_tokens }")
print (f" Layers: {num_layers }, Heads: {num_heads }")
def _load_from_state_dict (self ,state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ):
"""Production-grade hook to handle dynamic frame counts and token counts when loading checkpoints."""
t_pos_key =prefix +'temporal_pos'
if t_pos_key in state_dict :
ckpt_pos =state_dict [t_pos_key ]
if ckpt_pos .shape !=self .temporal_pos .shape :
print (f" ⚠️ VideoTiTokTokenizer: Interpolating {t_pos_key } from {ckpt_pos .shape [1 ]} to {self .max_frames } frames.")
ckpt_pos =ckpt_pos .squeeze (2 ).transpose (1 ,2 )
resized =F .interpolate (ckpt_pos ,size =self .max_frames ,mode ='linear',align_corners =False )
state_dict [t_pos_key ]=resized .transpose (1 ,2 ).unsqueeze (2 )
t_query_key =prefix +'temporal_queries'
if t_query_key in state_dict :
ckpt_query =state_dict [t_query_key ]
if ckpt_query .shape !=self .temporal_queries .shape :
print (f" ⚠️ VideoTiTokTokenizer: Interpolating {t_query_key } from {ckpt_query .shape [1 ]} to {self .num_temporal_tokens } tokens.")
ckpt_query =ckpt_query .transpose (1 ,2 )
resized =F .interpolate (ckpt_query ,size =self .num_temporal_tokens ,mode ='linear',align_corners =False )
state_dict [t_query_key ]=resized .transpose (1 ,2 )
c_query_key =prefix +'content_queries'
if c_query_key in state_dict :
ckpt_query =state_dict [c_query_key ]
if ckpt_query .shape !=self .content_queries .shape :
print (f" ⚠️ VideoTiTokTokenizer: Interpolating {c_query_key } from {ckpt_query .shape [1 ]} to {self .num_content_tokens } tokens.")
ckpt_query =ckpt_query .transpose (1 ,2 )
resized =F .interpolate (ckpt_query ,size =self .num_content_tokens ,mode ='linear',align_corners =False )
state_dict [c_query_key ]=resized .transpose (1 ,2 )
super ()._load_from_state_dict (state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs )
def _add_3d_pos_encoding (self ,x :torch .Tensor ,num_frames :int ,patches_per_frame :int )->torch .Tensor :
"""Add 3D positional encoding (temporal + spatial)."""
B ,seq_len ,D =x .shape
x =x .reshape (B ,num_frames ,patches_per_frame ,D )
temporal_pos =self .temporal_pos [:,:num_frames ,:,:]
x =x +temporal_pos
spatial_pos =self .spatial_pos [:,:,:patches_per_frame ,:]
x =x +spatial_pos
return x .reshape (B ,seq_len ,D )
def forward (self ,x :torch .Tensor ,num_frames :int =None )->torch .Tensor :
"""
Compress video patch features to TiTok-style 1D tokens.
Args:
x: [B, T*H*W, hidden_size] video patch features (flattened spatial-temporal)
or [B, T, H*W, hidden_size] video patch features per frame
num_frames: Number of frames (optional, for temporal embedding)
Returns:
[B, num_tokens, hidden_size] compressed token features
"""
batch_size =x .shape [0 ]
if x .dim ()==4 :
B ,T ,HW ,D =x .shape
x =x .reshape (B ,T *HW ,D )
num_frames =T
patches_per_frame =HW
else :
seq_len =x .shape [1 ]
if num_frames is None :
num_frames =min (self .max_frames ,seq_len //self .patches_per_frame )
num_frames =max (1 ,num_frames )
patches_per_frame =seq_len //num_frames if num_frames >0 else seq_len
x =self .input_norm (x )
x =self .input_proj (x )
x =self ._add_3d_pos_encoding (x ,num_frames ,patches_per_frame )
temporal_queries =self .temporal_queries [:,:min (self .num_temporal_tokens ,num_frames ),:].expand (batch_size ,-1 ,-1 )
content_queries =self .content_queries .expand (batch_size ,-1 ,-1 )
queries =torch .cat ([temporal_queries ,content_queries ],dim =1 )
for layer in self .compress_layers :
cross_out ,_ =layer ['cross_attn'](queries ,x ,x )
gate =layer ['cross_gate'](queries )
queries =layer ['cross_norm'](queries +gate *cross_out )
self_out ,_ =layer ['self_attn'](queries ,queries ,queries )
queries =layer ['self_norm'](queries +self_out )
ffn_out =layer ['ffn'](queries )
queries =layer ['ffn_norm'](queries +ffn_out )
actual_temporal =temporal_queries .shape [1 ]
temporal_tokens =queries [:,:actual_temporal ,:]
content_tokens =queries [:,actual_temporal :,:]
fused ,_ =self .fusion_attn (content_tokens ,temporal_tokens ,temporal_tokens )
content_tokens =self .fusion_norm (content_tokens +fused )
tokens =torch .cat ([temporal_tokens ,content_tokens ],dim =1 )
if tokens .shape [1 ]<self .num_tokens :
pad_size =self .num_tokens -tokens .shape [1 ]
pad_tokens =self .content_queries [:,:pad_size ,:].expand (batch_size ,-1 ,-1 )
tokens =torch .cat ([tokens ,pad_tokens ],dim =1 )
elif tokens .shape [1 ]>self .num_tokens :
tokens =tokens [:,:self .num_tokens ,:]
tokens =self .output_proj (tokens )
tokens =self .output_norm (tokens )
return tokens
class VideoEncoder (nn .Module ):
"""
SOTA Video Encoder with 3D-RoPE, 3D Causal Attention, Temporal Expert Routing, and VidTokTokenizer.
Features:
- 3D-RoPE for flexible (x, y, t) positional encodings
- 3D Causal Attention for temporal understanding
- Temporal-Aware Expert Routing for motion patterns
- VidTokTokenizer for efficient 1D token compression (mirrors TiTokTokenizer for images)
- Integrated with vision encoder backbone
- FP16-native numerical stability
"""
def __init__ (
self ,
vision_encoder :VisionEncoder ,
max_frames :int =32 ,
num_encoder_layers :int =4 ,
num_experts :int =4 ,
use_3d_rope :bool =True ,
use_temporal_moe :bool =True ,
use_video_tokenizer :bool =True ,
num_video_tokens :int =64 ,
):
super ().__init__ ()
self .vision_encoder =vision_encoder
self .max_frames =max_frames
self .hidden_size =vision_encoder .hidden_size
self .use_3d_rope =use_3d_rope
self .use_temporal_moe =use_temporal_moe
self .use_video_tokenizer =use_video_tokenizer
self .image_size =getattr (vision_encoder ,'image_size',384 )
self .patch_size =getattr (vision_encoder .vision_model .config ,'patch_size',14 )
self .patches_per_side =self .image_size //self .patch_size
self .num_spatial_tokens =self .patches_per_side **2
if use_3d_rope :
self .rope_3d =RoPE3DEncoder (
dim =self .hidden_size ,
max_height =self .patches_per_side ,
max_width =self .patches_per_side ,
max_frames =max_frames ,
)
print (f" 📐 3D-RoPE: (x,y,t) position encoding")
else :
self .rope_3d =None
self .encoder_blocks =nn .ModuleList ([
VideoEncoderBlock (
hidden_size =self .hidden_size ,
num_heads =8 ,
num_experts =num_experts if use_temporal_moe else 1 ,
max_frames =max_frames ,
max_height =self .patches_per_side ,
max_width =self .patches_per_side ,
)
for _ in range (num_encoder_layers )
])
print (f" 🎬 3D Causal Transformer: {num_encoder_layers } layers")
if use_temporal_moe :
print (f" 🎯 Temporal MoE: {num_experts } experts per layer")
if use_video_tokenizer :
self .vidtok =VideoTiTokTokenizer (
hidden_size =self .hidden_size ,
num_tokens =num_video_tokens ,
num_patches =self .num_spatial_tokens *max_frames ,
max_frames =max_frames ,
)
self .video_tokenizer =self .vidtok
else :
self .vidtok =None
self .video_tokenizer =None
self .temporal_pool_query =nn .Parameter (torch .randn (1 ,1 ,self .hidden_size )*0.02 )
self .temporal_pool_attn =nn .MultiheadAttention (
embed_dim =self .hidden_size ,
num_heads =8 ,
batch_first =True ,
dropout =0.1 ,
)
self .temporal_pool_norm =nn .LayerNorm (self .hidden_size )
self .frame_pos_embed =nn .Parameter (torch .randn (1 ,max_frames ,self .hidden_size )*0.02 )
print (f" 🎬 Video encoder: max {max_frames } frames (multi-scale enabled)")
def _load_from_state_dict (self ,state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ):
"""Production-grade hook to handle dynamic frame counts when loading checkpoints.
Interpolates temporal embeddings if the checkpoint frames differ from max_frames.
"""
embed_key =prefix +'frame_pos_embed'
if embed_key in state_dict :
ckpt_embed =state_dict [embed_key ]
if ckpt_embed .shape !=self .frame_pos_embed .shape :
print (f" ⚠️ VideoEncoder: Interpolating {embed_key } from {ckpt_embed .shape [1 ]} to {self .max_frames } frames.")
ckpt_embed =ckpt_embed .transpose (1 ,2 )
resized =F .interpolate (ckpt_embed ,size =self .max_frames ,mode ='linear',align_corners =False )
state_dict [embed_key ]=resized .transpose (1 ,2 )
super ()._load_from_state_dict (state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs )
def _extract_frame_features (self ,frames :torch .Tensor )->torch .Tensor :
"""Extract per-frame features using vision encoder."""
batch_size ,num_frames =frames .shape [:2 ]
frames_flat =frames .view (-1 ,*frames .shape [2 :])
if frames_flat .shape [-1 ]!=self .image_size or frames_flat .shape [-2 ]!=self .image_size :
frames_flat =F .interpolate (
frames_flat ,
size =(self .image_size ,self .image_size ),
mode ='bilinear',
align_corners =False
)
if not any (p .requires_grad for p in self .vision_encoder .parameters ()):
with torch .no_grad ():
frame_features =self .vision_encoder (frames_flat ,return_titok =False )
else :
frame_features =self .vision_encoder (frames_flat ,return_titok =False )
return frame_features ,batch_size ,num_frames
def forward (
self ,
frames :torch .Tensor ,
return_all_frames :bool =False ,
causal :bool =False ,
return_tokens :bool =False ,
)->torch .Tensor :
"""
Process video frames with 3D-RoPE and Causal Attention.
Args:
frames: [B, T, C, H, W] tensor of video frames
return_all_frames: If True, return all frame features; else return pooled
causal: If True, use causal attention (for autoregressive)
return_tokens: If True, return VideoTokenizer compressed tokens
Returns:
If return_tokens: [B, num_tokens, hidden_size] video tokens
If return_all_frames: [B, T, hidden_size] per-frame features
Else: [B, hidden_size] pooled video representation
"""
frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames )
_ ,num_patches ,hidden_size =frame_features .shape
height =width =int (math .sqrt (num_patches ))
frame_features =frame_features .view (batch_size ,num_frames ,num_patches ,hidden_size )
frame_features =frame_features +self .frame_pos_embed [:,:num_frames ].unsqueeze (2 )
x =frame_features .view (batch_size ,num_frames *num_patches ,hidden_size )
for block in self .encoder_blocks :
x =block (x ,height ,width ,num_frames ,causal =causal )
if return_tokens and self .vidtok is not None :
return self .vidtok (x ,num_frames )
if return_all_frames :
x =x .view (batch_size ,num_frames ,num_patches ,hidden_size )
return x .mean (dim =2 )
else :
query =self .temporal_pool_query .expand (batch_size ,-1 ,-1 )
pooled ,_ =self .temporal_pool_attn (query ,x ,x )
pooled =self .temporal_pool_norm (query +pooled )
return pooled .squeeze (1 )
def encode_frames_separately (self ,frames :torch .Tensor )->torch .Tensor :
"""
Encode frames without temporal attention (for generation conditioning).
Args:
frames: [B, T, C, H, W] tensor of video frames
Returns:
[B, T, hidden_size] tensor of frame features
"""
frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames )
frame_features =frame_features .mean (dim =1 )
return frame_features .view (batch_size ,num_frames ,-1 )
def encode_with_spatial (self ,frames :torch .Tensor )->torch .Tensor :
"""
Encode frames preserving spatial structure (for video generation).
Args:
frames: [B, T, C, H, W] tensor of video frames
Returns:
[B, T, H, W, hidden_size] tensor of spatio-temporal features
"""
frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames )
_ ,num_patches ,hidden_size =frame_features .shape
height =width =int (math .sqrt (num_patches ))
frame_features =frame_features .view (batch_size ,num_frames ,height ,width ,hidden_size )
return frame_features
==============================================================================
MODELS.ENCODERS.AUDIO
==============================================================================
EPS =1e-5
class RawWaveformTokenizer (nn .Module ):
"""
Raw Waveform Tokenizer - directly tokenizes audio waveforms without mel spectrograms.
Uses multi-scale 1D convolutions to extract features at different temporal resolutions,
then combines them into a unified representation.
"""
def __init__ (
self ,
hidden_size :int =1024 ,
num_codebooks :int =8 ,
codebook_size :int =1024 ,
sample_rate :int =16000 ,
hop_length :int =320 ,
num_conv_layers :int =6 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_codebooks =num_codebooks
self .codebook_size =codebook_size
self .sample_rate =sample_rate
self .hop_length =hop_length
self .conv_layers =nn .ModuleList ()
in_channels =1
channels =[32 ,64 ,128 ,256 ,512 ,hidden_size ]
kernel_sizes =[7 ,5 ,5 ,3 ,3 ,3 ]
strides =[2 ,2 ,2 ,2 ,2 ,2 ]
for i in range (num_conv_layers ):
out_channels =channels [i ]if i <len (channels )else hidden_size
kernel_size =kernel_sizes [i ]if i <len (kernel_sizes )else 3
stride =strides [i ]if i <len (strides )else 2
self .conv_layers .append (nn .Sequential (
nn .Conv1d (in_channels ,out_channels ,kernel_size ,stride ,kernel_size //2 ),
nn .GroupNorm (8 if out_channels >=8 else 1 ,out_channels ),
nn .SiLU (),
))
in_channels =out_channels
self .codebooks =nn .ModuleList ([
nn .Embedding (codebook_size ,hidden_size )
for _ in range (num_codebooks )
])
self .commitment_weight =0.25
self .output_proj =nn .Linear (hidden_size ,hidden_size )
print (f" 🎵 RawWaveformTokenizer: {num_codebooks } codebooks x {codebook_size } codes")
def encode (self ,waveform :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Encode waveform to continuous features.
Args:
waveform: [B, T] or [B, 1, T] raw audio waveform
Returns:
features: [B, T', hidden_size] encoded features
indices: [B, T', num_codebooks] quantized indices
"""
if waveform .dim ()==2 :
waveform =waveform .unsqueeze (1 )
x =waveform
for conv in self .conv_layers :
x =conv (x )
x =x .transpose (1 ,2 )
return x ,None
def quantize (self ,features :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
"""
Residual Vector Quantization.
Args:
features: [B, T, hidden_size] continuous features
Returns:
quantized: [B, T, hidden_size] quantized features
indices: [B, T, num_codebooks] codebook indices
commitment_loss: scalar commitment loss
"""
batch_size ,seq_len ,_ =features .shape
residual =features
quantized =torch .zeros_like (features )
all_indices =[]
total_commitment_loss =0.0
for codebook in self .codebooks :
distances =torch .cdist (residual ,codebook .weight )
indices =distances .argmin (dim =-1 )
all_indices .append (indices )
quantized_step =codebook (indices )
quantized =quantized +residual +(quantized_step -residual ).detach ()
commitment_loss =F .mse_loss (residual .detach (),quantized_step )
total_commitment_loss =total_commitment_loss +commitment_loss
residual =residual -quantized_step .detach ()
indices =torch .stack (all_indices ,dim =-1 )
commitment_loss =total_commitment_loss *self .commitment_weight
return quantized ,indices ,commitment_loss
def forward (self ,waveform :torch .Tensor ,quantize :bool =False )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]:
"""
Forward pass.
Args:
waveform: [B, T] or [B, 1, T] raw audio
quantize: Whether to apply vector quantization
Returns:
features: [B, T', hidden_size] encoded features
commitment_loss: Optional commitment loss if quantize=True
"""
features ,_ =self .encode (waveform )
if quantize :
features ,indices ,commitment_loss =self .quantize (features )
features =self .output_proj (features )
return features ,commitment_loss
features =self .output_proj (features )
return features ,None
class SnakeActivation (nn .Module ):
"""
Snake activation function from BigVGAN.
x + (1/a) * sin^2(a * x)
Better than ReLU/SiLU for audio generation - preserves periodicity.
"""
def __init__ (self ,channels :int ,alpha :float =1.0 ):
super ().__init__ ()
self .alpha =nn .Parameter (torch .ones (1 ,channels ,1 )*alpha )
def forward (self ,x :torch .Tensor )->torch .Tensor :
return x +(1.0 /(self .alpha +1e-6 ))*torch .sin (self .alpha *x )**2
class ResidualBlock1D (nn .Module ):
"""Residual block with dilated convolutions for multi-receptive field."""
def __init__ (self ,channels :int ,kernel_size :int =3 ,dilation :int =1 ):
super ().__init__ ()
padding =(kernel_size *dilation -dilation )//2
self .conv1 =nn .utils .parametrizations .weight_norm (
nn .Conv1d (channels ,channels ,kernel_size ,padding =padding ,dilation =dilation )
)
self .conv2 =nn .utils .parametrizations .weight_norm (
nn .Conv1d (channels ,channels ,kernel_size ,padding =kernel_size //2 )
)
self .activation =SnakeActivation (channels )
def forward (self ,x :torch .Tensor )->torch .Tensor :
residual =x
x =self .activation (self .conv1 (x ))
x =self .activation (self .conv2 (x ))
return x +residual
class MultiReceptiveFieldFusion (nn .Module ):
"""
Multi-Receptive Field Fusion (MRF) from HiFi-GAN.
Processes input through multiple parallel residual stacks with different
kernel sizes and dilations, then sums results.
"""
def __init__ (self ,channels :int ,kernel_sizes :List [int ]=[3 ,7 ,11 ],
dilations :List [List [int ]]=[[1 ,3 ,5 ],[1 ,3 ,5 ],[1 ,3 ,5 ]]):
super ().__init__ ()
self .num_kernels =len (kernel_sizes )
self .resblocks =nn .ModuleList ()
for k ,d_list in zip (kernel_sizes ,dilations ):
blocks =nn .ModuleList ([
ResidualBlock1D (channels ,k ,d )for d in d_list
])
self .resblocks .append (blocks )
def forward (self ,x :torch .Tensor )->torch .Tensor :
out =None
for blocks in self .resblocks :
h =x
for block in blocks :
h =block (h )
out =h if out is None else out +h
return out /self .num_kernels
class RawWaveformDecoder (nn .Module ):
"""
SOTA Raw Waveform Decoder - BigVGAN/HiFi-GAN style architecture.
Converts features directly to playable audio waveform without external vocoder.
SOTA Features:
- Snake activation (BigVGAN) - preserves audio periodicity
- Multi-Receptive Field Fusion (HiFi-GAN) - captures patterns at multiple scales
- Weight normalization - stable training
- Efficient upsampling with careful kernel/stride ratios
- Anti-aliased resampling
- Streaming-capable architecture
Speed optimizations:
- Fewer layers with smarter architecture
- Fused operations where possible
- Efficient 256x total upsampling (vs 64x before)
"""
def __init__ (
self ,
hidden_size :int =1024 ,
sample_rate :int =16000 ,
upsample_rates :List [int ]=[8 ,8 ,2 ,2 ],
upsample_kernel_sizes :List [int ]=[16 ,16 ,4 ,4 ],
resblock_kernel_sizes :List [int ]=[3 ,7 ,11 ],
resblock_dilations :List [List [int ]]=[[1 ,3 ,5 ],[1 ,3 ,5 ],[1 ,3 ,5 ]],
initial_channels :int =512 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .sample_rate =sample_rate
self .num_upsamples =len (upsample_rates )
self .input_proj =nn .utils .parametrizations .weight_norm (
nn .Conv1d (hidden_size ,initial_channels ,kernel_size =7 ,padding =3 )
)
self .upsamplers =nn .ModuleList ()
self .mrf_blocks =nn .ModuleList ()
channels =initial_channels
for i ,(rate ,kernel )in enumerate (zip (upsample_rates ,upsample_kernel_sizes )):
self .upsamplers .append (
nn .utils .parametrizations .weight_norm (
nn .ConvTranspose1d (
channels ,channels //2 ,
kernel_size =kernel ,stride =rate ,
padding =(kernel -rate )//2
)
)
)
channels =channels //2
self .mrf_blocks .append (
MultiReceptiveFieldFusion (channels ,resblock_kernel_sizes ,resblock_dilations )
)
self .final_activation =SnakeActivation (channels )
self .output_conv =nn .utils .parametrizations .weight_norm (
nn .Conv1d (channels ,1 ,kernel_size =7 ,padding =3 )
)
self .upsample_factor =1
for rate in upsample_rates :
self .upsample_factor *=rate
print (f" 🔊 RawWaveformDecoder (SOTA BigVGAN-style):")
print (f" - Snake activation for audio periodicity")
print (f" - Multi-Receptive Field Fusion")
print (f" - {self .upsample_factor }x upsampling")
print (f" - Weight normalized layers")
def forward (
self ,
features :torch .Tensor ,
target_length :Optional [int ]=None ,
)->torch .Tensor :
"""
Decode features to raw waveform.
Args:
features: [B, T, hidden_size] encoded features
target_length: Optional target waveform length (for matching input length)
Returns:
waveform: [B, T_audio] raw audio waveform in [-1, 1]
"""
x =features .transpose (1 ,2 )
x =self .input_proj (x )
for upsample ,mrf in zip (self .upsamplers ,self .mrf_blocks ):
x =upsample (x )
x =mrf (x )
x =self .final_activation (x )
waveform =self .output_conv (x )
waveform =torch .tanh (waveform )
waveform =waveform .squeeze (1 )
if target_length is not None and waveform .shape [-1 ]!=target_length :
waveform =F .interpolate (
waveform .unsqueeze (1 ),
size =target_length ,
mode ='linear',
align_corners =False
).squeeze (1 )
return waveform
def decode_from_codes (
self ,
codes :torch .Tensor ,
codebooks :nn .ModuleList ,
target_length :Optional [int ]=None ,
)->torch .Tensor :
"""
Decode directly from codebook indices.
Args:
codes: [B, T, num_codebooks] codebook indices
codebooks: List of nn.Embedding codebooks from encoder
target_length: Optional target waveform length
Returns:
waveform: [B, T_audio] raw audio waveform
"""
features =torch .zeros (
codes .shape [0 ],codes .shape [1 ],codebooks [0 ].embedding_dim ,
device =codes .device ,dtype =codebooks [0 ].weight .dtype
)
for i ,codebook in enumerate (codebooks ):
features =features +codebook (codes [:,:,i ])
return self .forward (features ,target_length )
@torch .no_grad ()
def stream_decode (
self ,
features :torch .Tensor ,
chunk_size :int =10 ,
)->torch .Tensor :
"""
Streaming decode for real-time speech synthesis.
Processes features in chunks for low-latency output.
Args:
features: [B, T, hidden_size] encoded features
chunk_size: Number of feature frames per chunk
Yields:
waveform_chunk: [B, chunk_audio_len] audio chunk
"""
batch_size ,seq_len ,_ =features .shape
audio_chunks =[]
for start in range (0 ,seq_len ,chunk_size ):
end =min (start +chunk_size ,seq_len )
chunk =features [:,start :end ,:]
audio_chunk =self .forward (chunk )
audio_chunks .append (audio_chunk )
return torch .cat (audio_chunks ,dim =-1 )
class SpeakerEncoder (nn .Module ):
"""
Zero-Shot Speaker Encoder for speaker cloning.
Extracts speaker embeddings from reference audio that can be used
to clone the speaker's voice characteristics.
"""
def __init__ (
self ,
hidden_size :int =256 ,
output_size :int =256 ,
num_layers :int =3 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .output_size =output_size
self .frame_encoder =nn .Sequential (
nn .Conv1d (80 ,hidden_size ,5 ,1 ,2 ),
nn .ReLU (),
nn .GroupNorm (1 ,hidden_size ),
nn .Conv1d (hidden_size ,hidden_size ,5 ,1 ,2 ),
nn .ReLU (),
nn .GroupNorm (1 ,hidden_size ),
nn .Conv1d (hidden_size ,hidden_size ,5 ,1 ,2 ),
nn .ReLU (),
nn .GroupNorm (1 ,hidden_size ),
)
self .lstm =nn .LSTM (
hidden_size ,hidden_size ,
num_layers =num_layers ,
batch_first =True ,
bidirectional =True ,
)
self .attention =nn .Sequential (
nn .Linear (hidden_size *2 ,hidden_size ),
nn .Tanh (),
nn .Linear (hidden_size ,1 ),
)
self .output_proj =nn .Linear (hidden_size *2 ,output_size )
print (f" 👤 SpeakerEncoder: {hidden_size }d -> {output_size }d speaker embedding")
def forward (self ,mel_spectrogram :torch .Tensor )->torch .Tensor :
"""
Extract speaker embedding from mel spectrogram.
Args:
mel_spectrogram: [B, n_mels, T] mel spectrogram
Returns:
speaker_embedding: [B, output_size] speaker embedding
"""
x =self .frame_encoder (mel_spectrogram )
x =x .transpose (1 ,2 )
x ,_ =self .lstm (x )
attn_weights =self .attention (x )
attn_weights =F .softmax (attn_weights ,dim =1 )
x =(x *attn_weights ).sum (dim =1 )
speaker_embedding =self .output_proj (x )
speaker_embedding =F .normalize (speaker_embedding ,p =2 ,dim =-1 )
return speaker_embedding
class MonotonicAlignmentSearch (nn .Module ):
"""
Monotonic Alignment Search (MAS) for text-to-audio alignment.
Implements both:
1. Hard MAS for inference (dynamic programming)
2. Soft/Fluid MAS for training (differentiable)
"""
def __init__ (self ,hidden_size :int =1024 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .alignment_proj =nn .Sequential (
nn .Linear (hidden_size *2 ,hidden_size ),
nn .ReLU (),
nn .Linear (hidden_size ,1 ),
)
self .duration_predictor =nn .Sequential (
nn .Conv1d (hidden_size ,hidden_size ,3 ,padding =1 ),
nn .ReLU (),
nn .GroupNorm (1 ,hidden_size ),
nn .Conv1d (hidden_size ,hidden_size ,3 ,padding =1 ),
nn .ReLU (),
nn .GroupNorm (1 ,hidden_size ),
nn .Conv1d (hidden_size ,1 ,1 ),
)
@staticmethod
def hard_mas (log_probs :torch .Tensor )->torch .Tensor :
"""
Hard Monotonic Alignment Search using dynamic programming.
Args:
log_probs: [B, T_text, T_audio] log alignment probabilities
Returns:
alignment: [B, T_text, T_audio] hard alignment matrix
"""
batch_size ,text_len ,audio_len =log_probs .shape
device =log_probs .device
Q =torch .full ((batch_size ,text_len ,audio_len ),float ('-inf'),device =device )
Q [:,0 ,0 ]=log_probs [:,0 ,0 ]
for j in range (1 ,audio_len ):
Q [:,0 ,j ]=Q [:,0 ,j -1 ]+log_probs [:,0 ,j ]
for i in range (1 ,text_len ):
Q [:,i ,i ]=Q [:,i -1 ,i -1 ]+log_probs [:,i ,i ]
for j in range (i +1 ,audio_len ):
Q [:,i ,j ]=torch .max (Q [:,i -1 ,j -1 ],Q [:,i ,j -1 ])+log_probs [:,i ,j ]
alignment =torch .zeros_like (log_probs )
for b in range (batch_size ):
i ,j =text_len -1 ,audio_len -1
while i >=0 and j >=0 :
alignment [b ,i ,j ]=1
if i ==0 :
j -=1
elif j ==0 :
i -=1
elif Q [b ,i -1 ,j -1 ]>=Q [b ,i ,j -1 ]:
i -=1
j -=1
else :
j -=1
return alignment
def soft_mas (
self ,
text_hidden :torch .Tensor ,
audio_hidden :torch .Tensor ,
temperature :float =1.0 ,
)->torch .Tensor :
"""
Soft/Differentiable Monotonic Alignment Search.
Args:
text_hidden: [B, T_text, hidden_size] text features
audio_hidden: [B, T_audio, hidden_size] audio features
temperature: Softmax temperature
Returns:
soft_alignment: [B, T_text, T_audio] soft alignment matrix
"""
batch_size ,text_len ,_ =text_hidden .shape
audio_len =audio_hidden .shape [1 ]
text_expanded =text_hidden .unsqueeze (2 ).expand (-1 ,-1 ,audio_len ,-1 )
audio_expanded =audio_hidden .unsqueeze (1 ).expand (-1 ,text_len ,-1 ,-1 )
combined =torch .cat ([text_expanded ,audio_expanded ],dim =-1 )
logits =self .alignment_proj (combined ).squeeze (-1 )
logits =logits /temperature
position_bias =torch .arange (audio_len ,device =logits .device ).float ()
position_bias =position_bias .unsqueeze (0 ).unsqueeze (0 )
text_positions =torch .arange (text_len ,device =logits .device ).float ()
text_positions =text_positions .unsqueeze (0 ).unsqueeze (2 )
expected_pos =text_positions *(audio_len /text_len )
monotonic_bias =-0.1 *(position_bias -expected_pos ).abs ()
logits =logits +monotonic_bias
soft_alignment =F .softmax (logits ,dim =-1 )
return soft_alignment
def predict_durations (self ,text_hidden :torch .Tensor )->torch .Tensor :
"""
Predict durations for each text token.
Args:
text_hidden: [B, T_text, hidden_size] text features
Returns:
durations: [B, T_text] predicted durations
"""
x =text_hidden .transpose (1 ,2 )
durations =self .duration_predictor (x ).squeeze (1 )
durations =F .softplus (durations )
return durations
def forward (
self ,
text_hidden :torch .Tensor ,
audio_hidden :Optional [torch .Tensor ]=None ,
use_hard :bool =False ,
)->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Compute alignment and durations.
Args:
text_hidden: [B, T_text, hidden_size] text features
audio_hidden: [B, T_audio, hidden_size] audio features (optional for inference)
use_hard: Use hard MAS instead of soft
Returns:
alignment: [B, T_text, T_audio] alignment matrix
durations: [B, T_text] predicted durations
"""
durations =self .predict_durations (text_hidden )
if audio_hidden is None :
return None ,durations
if use_hard :
text_norm =F .normalize (text_hidden ,dim =-1 )
audio_norm =F .normalize (audio_hidden ,dim =-1 )
log_probs =torch .bmm (text_norm ,audio_norm .transpose (1 ,2 ))
alignment =self .hard_mas (log_probs )
else :
alignment =self .soft_mas (text_hidden ,audio_hidden )
return alignment ,durations
class RotaryMultiHeadLatentAttention (nn .Module ):
"""
Rotary Multi-Head Latent Attention (RMLA).
Combines:
- Multi-Head Latent Attention (MLA) for compressed KV cache
- Rotary Position Embeddings (RoPE) for position awareness
- Efficient attention computation
"""
def __init__ (
self ,
hidden_size :int =1024 ,
num_heads :int =16 ,
num_kv_heads :int =4 ,
head_dim :int =64 ,
kv_lora_rank :int =256 ,
max_position_embeddings :int =8192 ,
dropout :float =0.1 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .num_kv_heads =num_kv_heads
self .head_dim =head_dim
self .kv_lora_rank =kv_lora_rank
self .num_key_value_groups =num_heads //num_kv_heads
self .q_proj =nn .Linear (hidden_size ,num_heads *head_dim ,bias =False )
self .kv_a_proj =nn .Linear (hidden_size ,kv_lora_rank +head_dim ,bias =False )
self .kv_b_proj =nn .Linear (kv_lora_rank ,num_kv_heads *head_dim *2 ,bias =False )
self .kv_norm =nn .LayerNorm (kv_lora_rank )
self .o_proj =nn .Linear (num_heads *head_dim ,hidden_size ,bias =False )
self .rotary_emb =self ._create_rotary_embedding (head_dim ,max_position_embeddings )
self .dropout =nn .Dropout (dropout )
self .scale =head_dim **-0.5
def _create_rotary_embedding (self ,dim :int ,max_seq_len :int )->nn .Module :
"""Create rotary position embeddings."""
inv_freq =1.0 /(10000 **(torch .arange (0 ,dim ,2 ).float ()/dim ))
self .register_buffer ('inv_freq',inv_freq )
t =torch .arange (max_seq_len ).float ()
freqs =torch .einsum ('i,j->ij',t ,inv_freq )
emb =torch .cat ([freqs ,freqs ],dim =-1 )
self .register_buffer ('cos_cached',emb .cos ())
self .register_buffer ('sin_cached',emb .sin ())
return None
def _apply_rotary (self ,x :torch .Tensor ,seq_len :int )->torch .Tensor :
"""Apply rotary position embeddings."""
cos =self .cos_cached [:seq_len ].unsqueeze (0 ).unsqueeze (0 )
sin =self .sin_cached [:seq_len ].unsqueeze (0 ).unsqueeze (0 )
x1 ,x2 =x [...,:x .shape [-1 ]//2 ],x [...,x .shape [-1 ]//2 :]
rotated =torch .cat ([-x2 ,x1 ],dim =-1 )
return x *cos .to (x .dtype )+rotated *sin .to (x .dtype )
def forward (
self ,
hidden_states :torch .Tensor ,
attention_mask :Optional [torch .Tensor ]=None ,
past_key_value :Optional [Tuple [torch .Tensor ,torch .Tensor ]]=None ,
use_cache :bool =False ,
)->Tuple [torch .Tensor ,Optional [Tuple [torch .Tensor ,torch .Tensor ]]]:
"""
Forward pass with RMLA.
Args:
hidden_states: [B, T, hidden_size]
attention_mask: Optional attention mask
past_key_value: Optional cached KV states
use_cache: Whether to return updated cache
Returns:
output: [B, T, hidden_size]
present_key_value: Optional updated cache
"""
batch_size ,seq_len ,_ =hidden_states .shape
query =self .q_proj (hidden_states )
query =query .view (batch_size ,seq_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 )
kv_compressed =self .kv_a_proj (hidden_states )
kv_latent ,k_pe =kv_compressed .split ([self .kv_lora_rank ,self .head_dim ],dim =-1 )
kv_latent =self .kv_norm (kv_latent )
kv =self .kv_b_proj (kv_latent )
key ,value =kv .split (self .num_kv_heads *self .head_dim ,dim =-1 )
key =key .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 )
value =value .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 )
query =self ._apply_rotary (query ,seq_len )
key =self ._apply_rotary (key ,seq_len )
if past_key_value is not None :
past_key ,past_value =past_key_value
key =torch .cat ([past_key ,key ],dim =2 )
value =torch .cat ([past_value ,value ],dim =2 )
present_key_value =(key ,value )if use_cache else None
qk_scale =self .head_dim **-0.25
kv_len =key .shape [2 ]
use_causal =(attention_mask is None and seq_len >1 and seq_len ==kv_len )
dropout_p =self .dropout .p if self .training else 0.0
output =F .scaled_dot_product_attention (
query *qk_scale ,
key *qk_scale ,
value ,
attn_mask =attention_mask ,
is_causal =use_causal ,
dropout_p =dropout_p ,
scale =1.0 ,
enable_gqa =(self .num_key_value_groups >1 ),
)
output =output .transpose (1 ,2 ).contiguous ().view (batch_size ,seq_len ,-1 )
output =self .o_proj (output )
return output ,present_key_value
class InContextAudioPrompting (nn .Module ):
"""
In-Context Audio Prompting for conditioning generation on reference audio.
Allows the model to use a reference audio clip to guide the style,
speaker characteristics, and prosody of generated audio.
"""
def __init__ (
self ,
hidden_size :int =1024 ,
num_prompt_tokens :int =32 ,
num_heads :int =8 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_prompt_tokens =num_prompt_tokens
self .prompt_tokens =nn .Parameter (torch .randn (1 ,num_prompt_tokens ,hidden_size )*0.02 )
self .cross_attn =nn .MultiheadAttention (
hidden_size ,num_heads ,
dropout =0.1 ,
batch_first =True ,
)
self .prompt_encoder =nn .Sequential (
nn .Linear (hidden_size ,hidden_size ),
nn .SiLU (),
nn .Linear (hidden_size ,hidden_size ),
)
self .gate =nn .Parameter (torch .zeros (1 ))
self .norm =nn .LayerNorm (hidden_size )
def encode_prompt (self ,audio_features :torch .Tensor )->torch .Tensor :
"""
Encode reference audio into prompt tokens.
Args:
audio_features: [B, T, hidden_size] reference audio features
Returns:
prompt: [B, num_prompt_tokens, hidden_size] encoded prompt
"""
batch_size =audio_features .shape [0 ]
prompt =self .prompt_tokens .expand (batch_size ,-1 ,-1 )
prompt ,_ =self .cross_attn (prompt ,audio_features ,audio_features )
prompt =self .prompt_encoder (prompt )
return prompt
def forward (
self ,
hidden_states :torch .Tensor ,
prompt_features :Optional [torch .Tensor ]=None ,
audio_prompt :Optional [torch .Tensor ]=None ,
)->torch .Tensor :
"""
Apply in-context audio prompting.
Args:
hidden_states: [B, T, hidden_size] input features
prompt_features: [B, num_prompt_tokens, hidden_size] pre-encoded prompt
audio_prompt: [B, T_prompt, hidden_size] raw audio features to encode
Returns:
output: [B, T, hidden_size] conditioned features
"""
if prompt_features is None and audio_prompt is not None :
prompt_features =self .encode_prompt (audio_prompt )
if prompt_features is None :
return hidden_states
attended ,_ =self .cross_attn (hidden_states ,prompt_features ,prompt_features )
gate =torch .sigmoid (self .gate )
output =hidden_states +gate *attended
output =self .norm (output )
return output
class ConvolutionModule (nn .Module ):
"""Conformer convolution module with gating."""
def __init__ (self ,channels :int ,kernel_size :int =31 ,dropout :float =0.1 ):
super ().__init__ ()
self .layer_norm =nn .LayerNorm (channels )
self .pointwise_conv1 =nn .Conv1d (channels ,2 *channels ,kernel_size =1 )
self .depthwise_conv =nn .Conv1d (
channels ,channels ,kernel_size =kernel_size ,
padding =(kernel_size -1 )//2 ,groups =channels
)
self .batch_norm =nn .GroupNorm (1 ,channels )
self .pointwise_conv2 =nn .Conv1d (channels ,channels ,kernel_size =1 )
self .dropout =nn .Dropout (dropout )
def forward (self ,x :torch .Tensor )->torch .Tensor :
"""x: [B, T, C]"""
x =self .layer_norm (x )
x =x .transpose (1 ,2 )
x =self .pointwise_conv1 (x )
x =F .glu (x ,dim =1 )
x =self .depthwise_conv (x )
x =self .batch_norm (x )
x =F .silu (x )
x =self .pointwise_conv2 (x )
x =self .dropout (x )
return x .transpose (1 ,2 )
class ConformerBlock (nn .Module ):
"""Single Conformer block with RMLA, feed-forward, and convolution."""
def __init__ (
self ,
d_model :int ,
num_heads :int =8 ,
ff_expansion :int =4 ,
conv_kernel_size :int =31 ,
dropout :float =0.1 ,
use_rmla :bool =True ,
):
super ().__init__ ()
self .use_rmla =use_rmla
self .ff1_norm =nn .LayerNorm (d_model )
self .ff1 =nn .Sequential (
nn .Linear (d_model ,d_model *ff_expansion ),
nn .SiLU (),
nn .Dropout (dropout ),
nn .Linear (d_model *ff_expansion ,d_model ),
nn .Dropout (dropout )
)
if use_rmla :
self .attn =RotaryMultiHeadLatentAttention (
hidden_size =d_model ,
num_heads =num_heads ,
num_kv_heads =max (1 ,num_heads //4 ),
head_dim =d_model //num_heads ,
kv_lora_rank =d_model //4 ,
dropout =dropout ,
)
else :
self .attn_norm =nn .LayerNorm (d_model )
self .attn =nn .MultiheadAttention (d_model ,num_heads ,dropout =dropout ,batch_first =True )
self .attn_dropout =nn .Dropout (dropout )
self .conv =ConvolutionModule (d_model ,conv_kernel_size ,dropout )
self .ff2_norm =nn .LayerNorm (d_model )
self .ff2 =nn .Sequential (
nn .Linear (d_model ,d_model *ff_expansion ),
nn .SiLU (),
nn .Dropout (dropout ),
nn .Linear (d_model *ff_expansion ,d_model ),
nn .Dropout (dropout )
)
self .final_norm =nn .LayerNorm (d_model )
def forward (
self ,
x :torch .Tensor ,
mask :Optional [torch .Tensor ]=None ,
past_key_value :Optional [Tuple ]=None ,
use_cache :bool =False ,
)->Tuple [torch .Tensor ,Optional [Tuple ]]:
x =x +0.5 *self .ff1 (self .ff1_norm (x ))
if self .use_rmla :
attn_mask =None
if mask is not None :
attn_mask =mask .unsqueeze (1 ).unsqueeze (2 )
attn_mask =attn_mask .to (dtype =x .dtype )
attn_mask =attn_mask .masked_fill (attn_mask .bool (),float ('-inf'))
attn_out ,present_kv =self .attn (x ,attention_mask =attn_mask ,past_key_value =past_key_value ,use_cache =use_cache )
else :
attn_out ,_ =self .attn (self .attn_norm (x ),self .attn_norm (x ),self .attn_norm (x ),key_padding_mask =mask )
present_kv =None
x =x +self .attn_dropout (attn_out )
x =x +self .conv (x )
x =x +0.5 *self .ff2 (self .ff2_norm (x ))
return self .final_norm (x ),present_kv
class AudioEncoder (nn .Module ):
"""
SOTA Audio Encoder with Raw Waveform Tokenization, RMLA, and Voice Enhancement.
Features:
- Raw waveform tokenization (no mel spectrogram)
- Conformer blocks with RMLA
- Zero-shot speaker encoding
- In-context audio prompting
- Gradient checkpointing support for memory efficiency
Voice Enhancement Features (SOTA):
- Prosody-aware EoT Prediction (interruption detection)
- AVD Emotion Recognition (arousal/valence/dominance)
- Dynamic Latent Vocalizations (singing/rapping)
- Neural Sound Effects (beatboxing, breathing, expressions)
- Speculative Decoding (mid-stream token rewriting)
"""
def __init__ (
self ,
hidden_size :int =1024 ,
n_mels :int =80 ,
max_audio_length :int =3000 ,
num_layers :int =6 ,
num_heads :int =8 ,
dropout :float =0.1 ,
use_raw_waveform :bool =True ,
enable_eot :bool =True ,
enable_emotion :bool =True ,
enable_singing :bool =True ,
enable_effects :bool =True ,
enable_speculative :bool =True ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .max_audio_length =max_audio_length
self .use_raw_waveform =use_raw_waveform
self .gradient_checkpointing =False
self .enable_eot =enable_eot
self .enable_emotion =enable_emotion
self .enable_singing =enable_singing
self .enable_effects =enable_effects
self .enable_speculative =enable_speculative
if use_raw_waveform :
self .waveform_tokenizer =RawWaveformTokenizer (
hidden_size =hidden_size ,
num_codebooks =8 ,
codebook_size =1024 ,
)
else :
self .waveform_tokenizer =None
self .conv_subsample =nn .Sequential (
nn .Conv1d (n_mels ,hidden_size //2 ,kernel_size =3 ,stride =2 ,padding =1 ),
nn .GELU (),
nn .Conv1d (hidden_size //2 ,hidden_size ,kernel_size =3 ,stride =2 ,padding =1 ),
nn .GELU (),
)
self .speaker_encoder =SpeakerEncoder (
hidden_size =256 ,
output_size =hidden_size //4 ,
)
self .audio_prompting =InContextAudioPrompting (
hidden_size =hidden_size ,
num_prompt_tokens =32 ,
)
self .conformer_blocks =nn .ModuleList ([
ConformerBlock (
hidden_size ,num_heads ,
ff_expansion =4 ,
conv_kernel_size =31 ,
dropout =dropout ,
use_rmla =True ,
)
for _ in range (num_layers )
])
self .output_proj =nn .Linear (hidden_size ,hidden_size )
if enable_eot :
self .eot_predictor =ProsodyAwareEoTPredictor (hidden_size ,dropout =dropout )
else :
self .eot_predictor =None
if enable_emotion :
self .emotion_recognizer =AVDEmotionRecognizer (hidden_size ,dropout =dropout )
else :
self .emotion_recognizer =None
if enable_singing :
self .vocalizer =DynamicLatentVocalizer (hidden_size )
else :
self .vocalizer =None
if enable_effects :
self .effects_generator =NeuralSoundEffectGenerator (hidden_size )
else :
self .effects_generator =None
if enable_speculative :
self .speculative_decoder =SpeculativeAudioDecoder (hidden_size )
else :
self .speculative_decoder =None
print (f" 🎤 AudioEncoder (RMLA Conformer): {hidden_size }d, {num_layers } layers")
if use_raw_waveform :
print (f" - Raw Waveform Tokenizer enabled")
print (f" - Zero-Shot Speaker Encoder enabled")
print (f" - In-Context Audio Prompting enabled")
print (f" - EoT/Interruption Detection: {enable_eot }")
print (f" - Emotion Recognition (AVD): {enable_emotion }")
print (f" - Singing/Rapping (Vocalizer): {enable_singing }")
print (f" - Sound Effects Generator: {enable_effects }")
print (f" - Speculative Decoding: {enable_speculative }")
def gradient_checkpointing_enable (self ):
"""Enable gradient checkpointing to save memory during training."""
self .gradient_checkpointing =True
if hasattr (self ,'waveform_tokenizer')and self .waveform_tokenizer is not None :
if hasattr (self .waveform_tokenizer ,'gradient_checkpointing'):
self .waveform_tokenizer .gradient_checkpointing =True
if hasattr (self ,'speaker_encoder')and self .speaker_encoder is not None :
if hasattr (self .speaker_encoder ,'gradient_checkpointing'):
self .speaker_encoder .gradient_checkpointing =True
def gradient_checkpointing_disable (self ):
"""Disable gradient checkpointing."""
self .gradient_checkpointing =False
def forward (
self ,
audio_input :torch .Tensor ,
speaker_ref :Optional [torch .Tensor ]=None ,
audio_prompt :Optional [torch .Tensor ]=None ,
mask :Optional [torch .Tensor ]=None ,
return_eot :bool =False ,
return_emotion :bool =False ,
)->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [dict ]]:
"""
Process audio to features with optional voice enhancement outputs.
Args:
audio_input: [B, T] raw waveform or [B, n_mels, T] mel spectrogram
speaker_ref: [B, n_mels, T_ref] reference audio for speaker cloning
audio_prompt: [B, T_prompt, hidden_size] audio prompt features
mask: Optional attention mask
return_eot: Whether to return EoT/interruption predictions
return_emotion: Whether to return emotion/AVD predictions
Returns:
features: [B, T', hidden_size] audio features
speaker_embedding: [B, hidden_size//4] speaker embedding (if speaker_ref provided)
extras: dict with EoT/emotion predictions (if requested)
"""
commitment_loss =None
if self .use_raw_waveform and self .waveform_tokenizer is not None :
if audio_input .dim ()==3 and audio_input .shape [1 ]==1 :
audio_input =audio_input .squeeze (1 )
elif audio_input .dim ()==3 :
audio_input =audio_input .mean (dim =1 )
x ,commitment_loss =self .waveform_tokenizer (audio_input )
elif hasattr (self ,'conv_subsample')and self .conv_subsample is not None :
if audio_input .dim ()==2 :
audio_input =audio_input .unsqueeze (1 )
x =self .conv_subsample (audio_input )
x =x .transpose (1 ,2 )
else :
raise RuntimeError (
f"AudioEncoder: Incompatible configuration. "
f"use_raw_waveform={self .use_raw_waveform }, "
f"waveform_tokenizer={self .waveform_tokenizer is not None }, "
f"conv_subsample={hasattr (self ,'conv_subsample')and self .conv_subsample is not None }"
)
speaker_embedding =None
if speaker_ref is not None :
speaker_embedding =self .speaker_encoder (speaker_ref )
if audio_prompt is not None :
x =self .audio_prompting (x ,audio_prompt =audio_prompt )
if self .gradient_checkpointing and self .training :
from torch .utils .checkpoint import checkpoint
for block in self .conformer_blocks :
def create_custom_forward (module ):
def custom_forward (*inputs ):
return module (*inputs )
return custom_forward
x ,_ =checkpoint (create_custom_forward (block ),x ,mask ,use_reentrant =False )
else :
for block in self .conformer_blocks :
x ,_ =block (x ,mask )
x =self .output_proj (x )
extras ={}
if return_eot and self .eot_predictor is not None :
extras ["eot"]=self .eot_predictor (x ,mask )
if return_emotion and self .emotion_recognizer is not None :
extras ["emotion"]=self .emotion_recognizer (x ,mask )
return x ,speaker_embedding ,extras if extras else None
def detect_interruption (
self ,
audio_features :torch .Tensor ,
attention_mask :Optional [torch .Tensor ]=None ,
)->Optional [dict ]:
"""
Detect interruptions, backchannels, and turn-taking events.
Args:
audio_features: [B, T, hidden_size] encoded audio
attention_mask: [B, T] optional mask
Returns:
dict with eot_logits, event_logits, vad_logits, backoff_prob
"""
if self .eot_predictor is None :
return None
return self .eot_predictor (audio_features ,attention_mask )
def recognize_emotion (
self ,
audio_features :torch .Tensor ,
attention_mask :Optional [torch .Tensor ]=None ,
)->Optional [dict ]:
"""
Recognize emotion with AVD (arousal/valence/dominance) values.
Args:
audio_features: [B, T, hidden_size] encoded audio
attention_mask: [B, T] optional mask
Returns:
dict with emotion_logits, arousal, valence, dominance, response_mode
"""
if self .emotion_recognizer is None :
return None
return self .emotion_recognizer (audio_features ,attention_mask )
def generate_vocals (
self ,
text_features :torch .Tensor ,
style_id :Optional [torch .Tensor ]=None ,
mode_id :Optional [torch .Tensor ]=None ,
target_pitch :Optional [torch .Tensor ]=None ,
tempo_bpm :Optional [torch .Tensor ]=None ,
)->Optional [dict ]:
"""
Generate singing/rapping vocals from text/lyrics.
Args:
text_features: [B, T, hidden_size] text embeddings
style_id: [B] style indices (pop, rock, jazz, etc.)
mode_id: [B] mode indices (speak, sing, rap, hum, etc.)
target_pitch: [B, T] pitch targets
tempo_bpm: [B] tempo in BPM
Returns:
dict with vocal_features, pitch_logits, alignment, durations
"""
if self .vocalizer is None :
return None
return self .vocalizer (text_features ,style_id ,mode_id ,target_pitch ,tempo_bpm )
def generate_effects (
self ,
effect_ids :torch .Tensor ,
context :Optional [torch .Tensor ]=None ,
intensity :Optional [torch .Tensor ]=None ,
)->Optional [dict ]:
"""
Generate sound effects (beatbox, clicks, breathing, etc.).
Args:
effect_ids: [B] or [B, N] effect type indices
context: [B, T, hidden_size] optional context
intensity: [B] intensity values
Returns:
dict with effect_features, waveform, duration, intensity
"""
if self .effects_generator is None :
return None
return self .effects_generator (effect_ids ,context ,intensity )
def speculative_generate (
self ,
context :torch .Tensor ,
generate_draft :bool =True ,
verify_with :Optional [torch .Tensor ]=None ,
)->Optional [dict ]:
"""
Generate speculative draft tokens for mid-stream rewriting.
Args:
context: [B, T, hidden_size] current context
generate_draft: whether to generate new draft
verify_with: [B, T', hidden_size] new context to verify against
Returns:
dict with checkpoint, draft_tokens, confidence, accept_prob
"""
if self .speculative_decoder is None :
return None
return self .speculative_decoder (context ,generate_draft ,verify_with )
class VariancePredictor (nn .Module ):
"""Variance predictor for duration, pitch, and energy."""
def __init__ (self ,hidden_size :int ,kernel_size :int =3 ,dropout :float =0.1 ):
super ().__init__ ()
self .conv1 =nn .Conv1d (hidden_size ,hidden_size ,kernel_size ,padding =kernel_size //2 )
self .norm1 =nn .LayerNorm (hidden_size )
self .conv2 =nn .Conv1d (hidden_size ,hidden_size ,kernel_size ,padding =kernel_size //2 )
self .norm2 =nn .LayerNorm (hidden_size )
self .dropout =nn .Dropout (dropout )
self .linear =nn .Linear (hidden_size ,1 )
def forward (self ,x :torch .Tensor )->torch .Tensor :
"""x: [B, T, C] -> [B, T]"""
out =self .conv1 (x .transpose (1 ,2 )).transpose (1 ,2 )
out =F .relu (out )
out =self .norm1 (out )
out =self .dropout (out )
out =self .conv2 (out .transpose (1 ,2 )).transpose (1 ,2 )
out =F .relu (out )
out =self .norm2 (out )
out =self .dropout (out )
return self .linear (out ).squeeze (-1 )
class FFTBlock (nn .Module ):
"""FFT block for mel decoder."""
def __init__ (
self ,
hidden_size :int ,
num_heads :int =4 ,
ff_expansion :int =4 ,
kernel_size :int =9 ,
dropout :float =0.1 ,
):
super ().__init__ ()
self .attn =RotaryMultiHeadLatentAttention (
hidden_size =hidden_size ,
num_heads =num_heads ,
num_kv_heads =max (1 ,num_heads //2 ),
head_dim =hidden_size //num_heads ,
kv_lora_rank =hidden_size //4 ,
dropout =dropout ,
)
self .attn_norm =nn .LayerNorm (hidden_size )
self .attn_dropout =nn .Dropout (dropout )
self .ff_norm =nn .LayerNorm (hidden_size )
self .ff =nn .Sequential (
nn .Conv1d (hidden_size ,hidden_size *ff_expansion ,kernel_size ,padding =kernel_size //2 ),
nn .ReLU (),
nn .Conv1d (hidden_size *ff_expansion ,hidden_size ,kernel_size ,padding =kernel_size //2 ),
nn .Dropout (dropout )
)
def forward (self ,x :torch .Tensor )->torch .Tensor :
residual =x
x =self .attn_norm (x )
x ,_ =self .attn (x )
x =residual +self .attn_dropout (x )
residual =x
x =self .ff_norm (x )
x =self .ff (x .transpose (1 ,2 )).transpose (1 ,2 )
x =residual +x
return x
class AudioDecoder (nn .Module ):
"""
SOTA Audio Decoder with MAS, Zero-Shot Speaker Cloning, and Voice Enhancement Support.
Features:
- Monotonic Alignment Search for text-to-audio alignment
- Zero-shot speaker cloning via speaker embeddings
- In-context audio prompting
- Variance adaptor with duration, pitch, energy prediction
- RMLA-based FFT blocks
- Gradient checkpointing support for memory efficiency
Voice Enhancement Features (matching AudioEncoder):
- Emotion conditioning for emotional speech synthesis
- Singing/vocal style synthesis support
- Sound effect generation and integration
- Raw waveform output support (optional)
- Speculative decoding integration
"""
def __init__ (
self ,
hidden_size :int =1024 ,
n_mels :int =80 ,
max_audio_length :int =1000 ,
num_speakers :int =256 ,
num_decoder_layers :int =4 ,
dropout :float =0.1 ,
enable_emotion :bool =True ,
enable_singing :bool =True ,
enable_effects :bool =True ,
enable_raw_waveform :bool =True ,
enable_speculative :bool =True ,
num_emotions :int =10 ,
num_vocal_styles :int =8 ,
num_vocal_modes :int =6 ,
num_effect_types :int =20 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .n_mels =n_mels
self .max_audio_length =max_audio_length
self .gradient_checkpointing =False
self .enable_emotion =enable_emotion
self .enable_singing =enable_singing
self .enable_effects =enable_effects
self .enable_raw_waveform =enable_raw_waveform
self .enable_speculative =enable_speculative
self .mas =MonotonicAlignmentSearch (hidden_size )
self .speaker_embed =nn .Embedding (num_speakers ,hidden_size //4 )
self .speaker_proj =nn .Linear (hidden_size //4 ,hidden_size //4 )
self .audio_prompting =InContextAudioPrompting (
hidden_size =hidden_size ,
num_prompt_tokens =32 ,
)
if enable_emotion :
self .emotion_embed =nn .Embedding (num_emotions ,hidden_size //4 )
self .avd_proj =nn .Sequential (
nn .Linear (3 ,hidden_size //8 ),
nn .SiLU (),
nn .Linear (hidden_size //8 ,hidden_size //4 ),
)
self .emotion_cond_size =hidden_size //4
else :
self .emotion_embed =None
self .avd_proj =None
self .emotion_cond_size =0
if enable_singing :
self .vocal_style_embed =nn .Embedding (num_vocal_styles ,hidden_size //4 )
self .vocal_mode_embed =nn .Embedding (num_vocal_modes ,hidden_size //4 )
self .tempo_proj =nn .Sequential (
nn .Linear (1 ,hidden_size //8 ),
nn .SiLU (),
nn .Linear (hidden_size //8 ,hidden_size //4 ),
)
self .singing_cond_size =hidden_size //4
else :
self .vocal_style_embed =None
self .vocal_mode_embed =None
self .tempo_proj =None
self .singing_cond_size =0
if enable_effects :
self .effect_embed =nn .Embedding (num_effect_types ,hidden_size //4 )
self .effect_intensity_proj =nn .Sequential (
nn .Linear (1 ,hidden_size //8 ),
nn .SiLU (),
nn .Linear (hidden_size //8 ,hidden_size //4 ),
)
self .effect_cond_size =hidden_size //4
else :
self .effect_embed =None
self .effect_intensity_proj =None
self .effect_cond_size =0
total_cond_size =hidden_size //4
total_cond_size +=self .emotion_cond_size
total_cond_size +=self .singing_cond_size
total_cond_size +=self .effect_cond_size
self .input_proj =nn .Linear (hidden_size +total_cond_size ,hidden_size )
self .encoder_blocks =nn .ModuleList ([
FFTBlock (hidden_size ,num_heads =4 ,ff_expansion =4 ,dropout =dropout )
for _ in range (4 )
])
self .duration_predictor =VariancePredictor (hidden_size ,dropout =dropout )
self .pitch_predictor =VariancePredictor (hidden_size ,dropout =dropout )
self .energy_predictor =VariancePredictor (hidden_size ,dropout =dropout )
self .pitch_embed =nn .Conv1d (1 ,hidden_size ,kernel_size =9 ,padding =4 )
self .energy_embed =nn .Conv1d (1 ,hidden_size ,kernel_size =9 ,padding =4 )
self .decoder_blocks =nn .ModuleList ([
FFTBlock (hidden_size ,num_heads =4 ,ff_expansion =4 ,dropout =dropout )
for _ in range (num_decoder_layers )
])
self .mel_linear =nn .Linear (hidden_size ,n_mels )
self .postnet =nn .ModuleList ([
nn .Sequential (
nn .Conv1d (n_mels ,256 ,kernel_size =5 ,padding =2 ),
nn .GroupNorm (1 ,256 ),
nn .Tanh (),
),
nn .Sequential (
nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ),
nn .GroupNorm (1 ,256 ),
nn .Tanh (),
),
nn .Sequential (
nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ),
nn .GroupNorm (1 ,256 ),
nn .Tanh (),
),
nn .Sequential (
nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ),
nn .GroupNorm (1 ,256 ),
nn .Tanh (),
),
nn .Conv1d (256 ,n_mels ,kernel_size =5 ,padding =2 ),
])
if enable_raw_waveform :
self .waveform_decoder =RawWaveformDecoder (
hidden_size =hidden_size ,
sample_rate =16000 ,
)
else :
self .waveform_decoder =None
if enable_speculative :
self .speculative_decoder =SpeculativeAudioDecoder (
hidden_size =hidden_size ,
draft_length =10 ,
)
else :
self .speculative_decoder =None
print (f" 🔊 AudioDecoder (MAS + RMLA): {hidden_size }d -> {n_mels } mels")
print (f" - Monotonic Alignment Search enabled")
print (f" - Zero-Shot Speaker Cloning enabled")
print (f" - In-Context Audio Prompting enabled")
print (f" - Emotion Conditioning: {enable_emotion }")
print (f" - Singing/Vocal Styles: {enable_singing }")
print (f" - Sound Effects: {enable_effects }")
print (f" - Raw Waveform Output: {enable_raw_waveform }")
print (f" - Speculative Decoding: {enable_speculative }")
def gradient_checkpointing_enable (self ):
"""Enable gradient checkpointing to save memory during training."""
self .gradient_checkpointing =True
def gradient_checkpointing_disable (self ):
"""Disable gradient checkpointing."""
self .gradient_checkpointing =False
def forward (
self ,
text_embeds :torch .Tensor ,
target_length :Optional [int ]=None ,
speaker :Optional [torch .Tensor ]=None ,
speaker_embedding :Optional [torch .Tensor ]=None ,
audio_prompt :Optional [torch .Tensor ]=None ,
audio_features :Optional [torch .Tensor ]=None ,
duration_target :Optional [torch .Tensor ]=None ,
pitch_target :Optional [torch .Tensor ]=None ,
energy_target :Optional [torch .Tensor ]=None ,
use_mas :bool =True ,
emotion_id :Optional [torch .Tensor ]=None ,
avd_values :Optional [torch .Tensor ]=None ,
vocal_style_id :Optional [torch .Tensor ]=None ,
vocal_mode_id :Optional [torch .Tensor ]=None ,
tempo_bpm :Optional [torch .Tensor ]=None ,
effect_id :Optional [torch .Tensor ]=None ,
effect_intensity :Optional [torch .Tensor ]=None ,
output_waveform :bool =False ,
use_speculative :bool =False ,
)->Tuple [torch .Tensor ,torch .Tensor ,Optional [torch .Tensor ],Optional [dict ]]:
"""
Generate mel-spectrogram from text embeddings with voice enhancement support.
Args:
text_embeds: [B, T, hidden_size] text embeddings
target_length: target mel length (for training)
speaker: [B] speaker IDs (for multi-speaker)
speaker_embedding: [B, hidden_size//4] zero-shot speaker embedding
audio_prompt: [B, T_prompt, hidden_size] audio prompt features
audio_features: [B, T_audio, hidden_size] target audio features (for MAS training)
duration_target: [B, T] ground truth durations
pitch_target: [B, T'] ground truth pitch
energy_target: [B, T'] ground truth energy
use_mas: Whether to use MAS for alignment
Voice enhancement args:
emotion_id: [B] discrete emotion category (0-9)
avd_values: [B, 3] continuous arousal/valence/dominance values
vocal_style_id: [B] singing style (0-7: pop, rock, jazz, etc.)
vocal_mode_id: [B] vocal mode (0-5: speak, sing, rap, hum, whistle, chant)
tempo_bpm: [B] tempo in BPM for singing/rapping
effect_id: [B] sound effect type (0-19)
effect_intensity: [B] effect intensity (0-1)
output_waveform: Whether to also output raw waveform
use_speculative: Whether to use speculative decoding
Returns:
mel: [B, n_mels, T'] generated mel spectrogram
durations: [B, T] predicted durations
alignment: [B, T_text, T_audio] alignment matrix (if use_mas and audio_features provided)
extras: dict with optional outputs (waveform, speculative results)
"""
batch_size ,seq_len ,_ =text_embeds .shape
device =text_embeds .device
dtype =text_embeds .dtype
extras ={}
if speaker_embedding is not None :
spk_emb =self .speaker_proj (speaker_embedding )
elif speaker is not None :
spk_emb =self .speaker_embed (speaker )
else :
speaker =torch .zeros (batch_size ,dtype =torch .long ,device =device )
spk_emb =self .speaker_embed (speaker )
spk_emb =spk_emb .unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
cond_embeds =[spk_emb ]
if self .enable_emotion :
if emotion_id is not None :
emo_emb =self .emotion_embed (emotion_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
elif avd_values is not None :
emo_emb =self .avd_proj (avd_values .to (dtype )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 )
else :
neutral =torch .full ((batch_size ,),6 ,dtype =torch .long ,device =device )
emo_emb =self .emotion_embed (neutral ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
cond_embeds .append (emo_emb )
if self .enable_singing :
if vocal_style_id is not None :
style_emb =self .vocal_style_embed (vocal_style_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
else :
default_style =torch .zeros (batch_size ,dtype =torch .long ,device =device )
style_emb =self .vocal_style_embed (default_style ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
if vocal_mode_id is not None :
mode_emb =self .vocal_mode_embed (vocal_mode_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
else :
default_mode =torch .zeros (batch_size ,dtype =torch .long ,device =device )
mode_emb =self .vocal_mode_embed (default_mode ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
if tempo_bpm is not None :
tempo_norm =(tempo_bpm .float ()-60 )/120
tempo_emb =self .tempo_proj (tempo_norm .unsqueeze (-1 ).to (dtype )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 )
else :
tempo_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ,dtype =dtype )
singing_emb =style_emb +mode_emb +tempo_emb
cond_embeds .append (singing_emb )
if self .enable_effects :
if effect_id is not None :
eff_emb =self .effect_embed (effect_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype )
if effect_intensity is not None :
intensity_emb =self .effect_intensity_proj (effect_intensity .unsqueeze (-1 ).to (dtype ))
eff_emb =eff_emb *intensity_emb .unsqueeze (1 )
else :
eff_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ,dtype =dtype )
cond_embeds .append (eff_emb )
all_cond =torch .cat (cond_embeds ,dim =-1 )
x =torch .cat ([text_embeds ,all_cond ],dim =-1 )
x =self .input_proj (x )
if audio_prompt is not None :
x =self .audio_prompting (x ,audio_prompt =audio_prompt )
if self .gradient_checkpointing and self .training :
from torch .utils .checkpoint import checkpoint
for block in self .encoder_blocks :
def create_custom_forward (module ):
def custom_forward (*inputs ):
return module (*inputs )
return custom_forward
x =checkpoint (create_custom_forward (block ),x ,use_reentrant =False )
else :
for block in self .encoder_blocks :
x =block (x )
alignment =None
if use_mas and audio_features is not None :
alignment ,durations =self .mas (x ,audio_features ,use_hard =not self .training )
else :
_ ,durations =self .mas (x )
if duration_target is not None :
durations =duration_target
pitch_pred =self .pitch_predictor (x )
energy_pred =F .softplus (self .energy_predictor (x ))
MIN_MEL_LENGTH =1
if target_length is not None :
mel_length =max (MIN_MEL_LENGTH ,target_length )
else :
mel_length =int (durations .sum (dim =1 ).max ().item ())
mel_length =max (16 ,min (mel_length ,self .max_audio_length ))
x =F .interpolate (x .transpose (1 ,2 ),size =mel_length ,mode ='linear',align_corners =False ).transpose (1 ,2 )
pitch =pitch_target if pitch_target is not None else pitch_pred
energy =energy_target if energy_target is not None else energy_pred
pitch_up =F .interpolate (pitch .unsqueeze (1 ),size =mel_length ,mode ='linear',align_corners =False )
energy_up =F .interpolate (energy .unsqueeze (1 ),size =mel_length ,mode ='linear',align_corners =False )
pitch_emb =self .pitch_embed (pitch_up ).transpose (1 ,2 )
energy_emb =self .energy_embed (energy_up ).transpose (1 ,2 )
x =x +pitch_emb +energy_emb
if self .gradient_checkpointing and self .training :
from torch .utils .checkpoint import checkpoint
for block in self .decoder_blocks :
def create_custom_forward (module ):
def custom_forward (*inputs ):
return module (*inputs )
return custom_forward
x =checkpoint (create_custom_forward (block ),x ,use_reentrant =False )
else :
for block in self .decoder_blocks :
x =block (x )
mel =self .mel_linear (x ).transpose (1 ,2 )
mel_post =mel
for layer in self .postnet :
mel_post =layer (mel_post )
mel =mel +mel_post
if output_waveform and self .waveform_decoder is not None :
waveform =self .waveform_decoder (x )
extras ["waveform"]=waveform
if use_speculative and self .speculative_decoder is not None :
spec_results =self .speculative_decoder (x )
extras ["speculative"]=spec_results
return mel ,durations ,alignment ,extras if extras else None
class ProsodyAwareEoTPredictor (nn .Module ):
"""
Prosody-aware End-of-Turn (EoT) Prediction for real-time interruption detection.
Detects when a speaker is about to finish their turn, allowing the model to:
- Detect user interruptions (coughs, laughs, "uh-huh", etc.)
- Yield the floor when appropriate
- Adjust response mid-stream
Uses prosodic features (pitch, energy, rhythm) combined with semantic features.
"""
def __init__ (
self ,
hidden_size :int =1024 ,
num_eot_classes :int =5 ,
prosody_dim :int =128 ,
num_heads :int =4 ,
dropout :float =0.1 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_eot_classes =num_eot_classes
self .pitch_conv =nn .Sequential (
nn .Conv1d (1 ,prosody_dim //2 ,kernel_size =5 ,padding =2 ),
nn .SiLU (),
nn .Conv1d (prosody_dim //2 ,prosody_dim ,kernel_size =3 ,padding =1 ),
)
self .energy_conv =nn .Sequential (
nn .Conv1d (1 ,prosody_dim //2 ,kernel_size =5 ,padding =2 ),
nn .SiLU (),
nn .Conv1d (prosody_dim //2 ,prosody_dim ,kernel_size =3 ,padding =1 ),
)
self .vad_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //2 ),
nn .SiLU (),
nn .Linear (hidden_size //2 ,2 ),
)
self .event_classifier =nn .Sequential (
nn .Linear (hidden_size +prosody_dim *2 ,hidden_size ),
nn .SiLU (),
nn .Dropout (dropout ),
nn .Linear (hidden_size ,hidden_size //2 ),
nn .SiLU (),
nn .Linear (hidden_size //2 ,8 ),
)
self .temporal_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =num_heads ,
dropout =dropout ,
batch_first =True ,
)
self .eot_head =nn .Sequential (
nn .Linear (hidden_size +prosody_dim *2 ,hidden_size ),
nn .SiLU (),
nn .Dropout (dropout ),
nn .Linear (hidden_size ,num_eot_classes ),
)
self .backoff_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //4 ),
nn .SiLU (),
nn .Linear (hidden_size //4 ,1 ),
nn .Sigmoid (),
)
print (f" 🎙️ ProsodyAwareEoTPredictor: {num_eot_classes } turn states, {prosody_dim }d prosody")
def extract_prosody (self ,audio_features :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
"""Extract pitch and energy prosodic features."""
batch_size ,seq_len ,hidden =audio_features .shape
x =audio_features .transpose (1 ,2 )
pitch_proxy =x [:,:1 ,:]
energy_proxy =x .pow (2 ).mean (dim =1 ,keepdim =True )
pitch_features =self .pitch_conv (pitch_proxy ).transpose (1 ,2 )
energy_features =self .energy_conv (energy_proxy ).transpose (1 ,2 )
return pitch_features ,energy_features
def forward (
self ,
audio_features :torch .Tensor ,
attention_mask :Optional [torch .Tensor ]=None ,
)->dict :
"""
Predict end-of-turn and interruption events.
Args:
audio_features: [B, T, hidden_size] encoded audio
attention_mask: [B, T] optional mask
Returns:
dict with:
- eot_logits: [B, T, num_eot_classes] turn state predictions
- event_logits: [B, T, 8] interruption event predictions
- vad_logits: [B, T, 2] voice activity predictions
- backoff_prob: [B, T, 1] backoff probability
"""
batch_size ,seq_len ,_ =audio_features .shape
pitch_features ,energy_features =self .extract_prosody (audio_features )
if attention_mask is not None :
key_padding_mask =~attention_mask .bool ()
else :
key_padding_mask =None
contextualized ,_ =self .temporal_attn (
audio_features ,audio_features ,audio_features ,
key_padding_mask =key_padding_mask ,
)
combined =torch .cat ([contextualized ,pitch_features ,energy_features ],dim =-1 )
eot_logits =self .eot_head (combined )
event_logits =self .event_classifier (combined )
vad_logits =self .vad_head (contextualized )
backoff_prob =self .backoff_head (contextualized )
return {
"eot_logits":eot_logits ,
"event_logits":event_logits ,
"vad_logits":vad_logits ,
"backoff_prob":backoff_prob ,
}
class AVDEmotionRecognizer (nn .Module ):
"""
Continuous AVD (Arousal/Valence/Dominance) Emotion Recognition.
Predicts both discrete emotion categories and continuous AVD values
for nuanced emotion understanding and response adaptation.
"""
def __init__ (
self ,
hidden_size :int =1024 ,
num_emotions :int =10 ,
num_layers :int =2 ,
dropout :float =0.1 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_emotions =num_emotions
self .emotion_query =nn .Parameter (torch .randn (1 ,1 ,hidden_size ))
self .emotion_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =8 ,
dropout =dropout ,
batch_first =True ,
)
self .temporal_conv =nn .Sequential (
nn .Conv1d (hidden_size ,hidden_size ,kernel_size =5 ,padding =2 ,groups =8 ),
nn .SiLU (),
nn .Conv1d (hidden_size ,hidden_size ,kernel_size =3 ,padding =1 ),
)
self .emotion_classifier =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //2 ),
nn .SiLU (),
nn .Dropout (dropout ),
nn .Linear (hidden_size //2 ,num_emotions ),
)
self .arousal_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //4 ),
nn .SiLU (),
nn .Linear (hidden_size //4 ,1 ),
nn .Sigmoid (),
)
self .valence_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //4 ),
nn .SiLU (),
nn .Linear (hidden_size //4 ,1 ),
nn .Tanh (),
)
self .dominance_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //4 ),
nn .SiLU (),
nn .Linear (hidden_size //4 ,1 ),
nn .Sigmoid (),
)
self .response_adaptation =nn .Sequential (
nn .Linear (hidden_size +3 ,hidden_size //2 ),
nn .SiLU (),
nn .Linear (hidden_size //2 ,4 ),
)
print (f" 😊 AVDEmotionRecognizer: {num_emotions } emotions + continuous AVD")
def forward (
self ,
audio_features :torch .Tensor ,
attention_mask :Optional [torch .Tensor ]=None ,
)->dict :
"""
Recognize emotion from audio features.
Args:
audio_features: [B, T, hidden_size] encoded audio
attention_mask: [B, T] optional mask
Returns:
dict with:
- emotion_logits: [B, num_emotions] discrete emotion
- arousal: [B, 1] arousal value (0-1)
- valence: [B, 1] valence value (-1 to 1)
- dominance: [B, 1] dominance value (0-1)
- response_mode: [B, 4] response adaptation logits
"""
batch_size ,seq_len ,_ =audio_features .shape
x_conv =self .temporal_conv (audio_features .transpose (1 ,2 )).transpose (1 ,2 )
x =audio_features +x_conv
query =self .emotion_query .expand (batch_size ,-1 ,-1 )
if attention_mask is not None :
key_padding_mask =~attention_mask .bool ()
else :
key_padding_mask =None
emotion_context ,_ =self .emotion_attn (
query ,x ,x ,
key_padding_mask =key_padding_mask ,
)
emotion_vec =emotion_context .squeeze (1 )
emotion_logits =self .emotion_classifier (emotion_vec )
arousal =self .arousal_head (emotion_vec )
valence =self .valence_head (emotion_vec )
dominance =self .dominance_head (emotion_vec )
avd_concat =torch .cat ([emotion_vec ,arousal ,valence ,dominance ],dim =-1 )
response_mode =self .response_adaptation (avd_concat )
return {
"emotion_logits":emotion_logits ,
"arousal":arousal ,
"valence":valence ,
"dominance":dominance ,
"response_mode":response_mode ,
}
class DynamicLatentVocalizer (nn .Module ):
"""
Dynamic Latent Vocalizations for singing, rapping, humming, etc.
Extends speech synthesis to include:
- Singing with pitch control
- Rapping with rhythm control
- Humming, whistling, chanting
- Musical style transfer
"""
def __init__ (
self ,
hidden_size :int =1024 ,
num_styles :int =8 ,
num_vocal_modes :int =6 ,
pitch_bins :int =256 ,
tempo_range :Tuple [int ,int ]=(60 ,180 ),
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_styles =num_styles
self .num_vocal_modes =num_vocal_modes
self .pitch_bins =pitch_bins
self .tempo_range =tempo_range
self .style_embed =nn .Embedding (num_styles ,hidden_size //4 )
self .mode_embed =nn .Embedding (num_vocal_modes ,hidden_size //4 )
self .pitch_embed =nn .Embedding (pitch_bins ,hidden_size //4 )
self .pitch_predictor =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //2 ),
nn .SiLU (),
nn .Linear (hidden_size //2 ,pitch_bins ),
)
self .tempo_encoder =nn .Sequential (
nn .Linear (1 ,hidden_size //8 ),
nn .SiLU (),
nn .Linear (hidden_size //8 ,hidden_size //4 ),
)
self .rhythm_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =4 ,
dropout =0.1 ,
batch_first =True ,
)
self .style_transfer =nn .Sequential (
nn .Linear (hidden_size +hidden_size //2 ,hidden_size ),
nn .SiLU (),
nn .Linear (hidden_size ,hidden_size ),
)
self .lyrics_aligner =MonotonicAlignmentSearch (hidden_size )
self .output_proj =nn .Linear (hidden_size ,hidden_size )
print (f" 🎵 DynamicLatentVocalizer: {num_styles } styles, {num_vocal_modes } modes")
def forward (
self ,
text_features :torch .Tensor ,
style_id :Optional [torch .Tensor ]=None ,
mode_id :Optional [torch .Tensor ]=None ,
target_pitch :Optional [torch .Tensor ]=None ,
tempo_bpm :Optional [torch .Tensor ]=None ,
)->dict :
"""
Generate vocalization features for singing/rapping/etc.
Args:
text_features: [B, T, hidden_size] text/lyrics embeddings
style_id: [B] style indices (0-7)
mode_id: [B] vocal mode indices (0-5)
target_pitch: [B, T] optional pitch targets
tempo_bpm: [B] optional tempo in BPM
Returns:
dict with:
- vocal_features: [B, T', hidden_size] vocalization features
- pitch_logits: [B, T, pitch_bins] predicted pitch
- alignment: [B, T, T'] text-to-audio alignment
"""
batch_size ,seq_len ,_ =text_features .shape
device =text_features .device
if style_id is None :
style_id =torch .zeros (batch_size ,dtype =torch .long ,device =device )
if mode_id is None :
mode_id =torch .zeros (batch_size ,dtype =torch .long ,device =device )
style_emb =self .style_embed (style_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 )
mode_emb =self .mode_embed (mode_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 )
if tempo_bpm is not None :
tempo_norm =(tempo_bpm .float ()-self .tempo_range [0 ])/(self .tempo_range [1 ]-self .tempo_range [0 ])
tempo_emb =self .tempo_encoder (tempo_norm .unsqueeze (-1 )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 )
else :
tempo_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device )
pitch_logits =self .pitch_predictor (text_features )
if target_pitch is not None :
pitch_emb =self .pitch_embed (target_pitch )
else :
pitch_idx =pitch_logits .argmax (dim =-1 )
pitch_emb =self .pitch_embed (pitch_idx )
conditions =torch .cat ([style_emb ,mode_emb ,tempo_emb ,pitch_emb ],dim =-1 )
combined =torch .cat ([text_features ,conditions ],dim =-1 )
vocal_features =self .style_transfer (combined )
vocal_features ,_ =self .rhythm_attn (vocal_features ,vocal_features ,vocal_features )
alignment ,durations =self .lyrics_aligner (text_features )
vocal_features =self .output_proj (vocal_features )
return {
"vocal_features":vocal_features ,
"pitch_logits":pitch_logits ,
"alignment":alignment ,
"durations":durations ,
}
class NeuralSoundEffectGenerator (nn .Module ):
"""
Neural Style Transfer for Sound Effects and Non-verbal Vocalizations.
Generates:
- Beatboxing (kicks, snares, hi-hats)
- Vocal clicks, pops, tongue sounds
- Breathing, sighing, gasping
- Non-verbal expressions (hmm, aha, wow, etc.)
- Polyphonic ad-libs and harmonies
"""
def __init__ (
self ,
hidden_size :int =1024 ,
num_effect_types :int =20 ,
num_layers :int =3 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_effect_types =num_effect_types
self .effect_embed =nn .Embedding (num_effect_types ,hidden_size )
self .generator =nn .Sequential (
nn .Linear (hidden_size ,hidden_size *4 ),
nn .SiLU (),
nn .Unflatten (1 ,(hidden_size ,4 )),
nn .ConvTranspose1d (hidden_size ,hidden_size //2 ,4 ,2 ,1 ),
nn .SiLU (),
nn .ConvTranspose1d (hidden_size //2 ,hidden_size //4 ,4 ,2 ,1 ),
nn .SiLU (),
nn .ConvTranspose1d (hidden_size //4 ,hidden_size //8 ,4 ,2 ,1 ),
nn .SiLU (),
nn .ConvTranspose1d (hidden_size //8 ,1 ,4 ,2 ,1 ),
nn .Tanh (),
)
self .duration_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //4 ),
nn .SiLU (),
nn .Linear (hidden_size //4 ,1 ),
nn .Softplus (),
)
self .intensity_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //4 ),
nn .SiLU (),
nn .Linear (hidden_size //4 ,1 ),
nn .Sigmoid (),
)
self .blend_attn =nn .MultiheadAttention (
embed_dim =hidden_size ,
num_heads =4 ,
batch_first =True ,
)
print (f" 🥁 NeuralSoundEffectGenerator: {num_effect_types } effect types")
def forward (
self ,
effect_ids :torch .Tensor ,
context :Optional [torch .Tensor ]=None ,
intensity :Optional [torch .Tensor ]=None ,
)->dict :
"""
Generate sound effect features.
Args:
effect_ids: [B] or [B, N] effect type indices
context: [B, T, hidden_size] optional context features
intensity: [B] or [B, N] optional intensity values
Returns:
dict with:
- effect_features: [B, T', hidden_size] generated features
- waveform: [B, 1, samples] raw waveform (if generating directly)
- duration: [B, 1] predicted duration
"""
if effect_ids .dim ()==1 :
effect_ids =effect_ids .unsqueeze (1 )
batch_size ,num_effects =effect_ids .shape
device =effect_ids .device
effect_emb =self .effect_embed (effect_ids )
if num_effects >1 :
effect_emb ,_ =self .blend_attn (effect_emb ,effect_emb ,effect_emb )
effect_vec =effect_emb .mean (dim =1 )
if context is not None :
context_vec =context .mean (dim =1 )
effect_vec =effect_vec +context_vec
duration =self .duration_head (effect_vec )
pred_intensity =self .intensity_head (effect_vec )
if intensity is not None :
pred_intensity =intensity .unsqueeze (-1 )if intensity .dim ()==1 else intensity
effect_vec =effect_vec *pred_intensity
waveform =self .generator (effect_vec )
return {
"effect_features":effect_emb ,
"waveform":waveform ,
"duration":duration ,
"intensity":pred_intensity ,
}
class SpeculativeAudioDecoder (nn .Module ):
"""
Mid-stream Token Rewriting support for Speculative Decoding in audio.
Allows the model to:
- Generate draft audio tokens speculatively
- Accept/reject based on user feedback or context change
- Rollback and regenerate from checkpoints
- Smooth transitions during rewrites
"""
def __init__ (
self ,
hidden_size :int =1024 ,
draft_length :int =10 ,
num_heads :int =8 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .draft_length =draft_length
self .draft_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size ),
nn .SiLU (),
nn .Linear (hidden_size ,hidden_size ),
)
self .verify_head =nn .Sequential (
nn .Linear (hidden_size *2 ,hidden_size ),
nn .SiLU (),
nn .Linear (hidden_size ,1 ),
nn .Sigmoid (),
)
self .checkpoint_encoder =nn .GRU (
input_size =hidden_size ,
hidden_size =hidden_size ,
num_layers =1 ,
batch_first =True ,
)
self .smoother =nn .Sequential (
nn .Linear (hidden_size *2 ,hidden_size ),
nn .SiLU (),
nn .Linear (hidden_size ,hidden_size ),
)
self .confidence_head =nn .Sequential (
nn .Linear (hidden_size ,hidden_size //4 ),
nn .SiLU (),
nn .Linear (hidden_size //4 ,1 ),
nn .Sigmoid (),
)
print (f" ⚡ SpeculativeAudioDecoder: draft_length={draft_length }")
def generate_draft (
self ,
context :torch .Tensor ,
num_tokens :int =None ,
)->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Generate draft tokens speculatively.
Args:
context: [B, T, hidden_size] context features
num_tokens: number of draft tokens (default: self.draft_length)
Returns:
draft_tokens: [B, N, hidden_size] draft features
confidence: [B, N, 1] confidence per token
"""
if num_tokens is None :
num_tokens =self .draft_length
batch_size =context .shape [0 ]
device =context .device
seed =context [:,-1 :,:]
draft_tokens =[]
confidences =[]
current =seed
for _ in range (num_tokens ):
draft =self .draft_head (current )
conf =self .confidence_head (draft )
draft_tokens .append (draft )
confidences .append (conf )
current =draft
draft_tokens =torch .cat (draft_tokens ,dim =1 )
confidences =torch .cat (confidences ,dim =1 )
return draft_tokens ,confidences
def verify_draft (
self ,
draft_tokens :torch .Tensor ,
new_context :torch .Tensor ,
)->torch .Tensor :
"""
Verify if draft tokens should be accepted given new context.
Args:
draft_tokens: [B, N, hidden_size] draft features
new_context: [B, T, hidden_size] updated context
Returns:
accept_prob: [B, N, 1] probability to accept each token
"""
context_summary =new_context .mean (dim =1 ,keepdim =True ).expand (-1 ,draft_tokens .shape [1 ],-1 )
combined =torch .cat ([draft_tokens ,context_summary ],dim =-1 )
accept_prob =self .verify_head (combined )
return accept_prob
def create_checkpoint (self ,hidden_state :torch .Tensor )->torch .Tensor :
"""Save hidden state for potential rollback."""
_ ,checkpoint =self .checkpoint_encoder (hidden_state )
return checkpoint .squeeze (0 )
def smooth_transition (
self ,
old_features :torch .Tensor ,
new_features :torch .Tensor ,
)->torch .Tensor :
"""Create smooth transition between old and new features."""
combined =torch .cat ([old_features ,new_features ],dim =-1 )
return self .smoother (combined )
def forward (
self ,
context :torch .Tensor ,
generate_draft :bool =True ,
verify_with :Optional [torch .Tensor ]=None ,
)->dict :
"""
Full speculative decoding step.
Args:
context: [B, T, hidden_size] current context
generate_draft: whether to generate new draft
verify_with: [B, T', hidden_size] new context to verify against
Returns:
dict with draft tokens, confidence, verification results
"""
results ={}
results ["checkpoint"]=self .create_checkpoint (context )
if generate_draft :
draft ,confidence =self .generate_draft (context )
results ["draft_tokens"]=draft
results ["confidence"]=confidence
if verify_with is not None and "draft_tokens"in results :
accept_prob =self .verify_draft (results ["draft_tokens"],verify_with )
results ["accept_prob"]=accept_prob
return results
==============================================================================
MODELS.GENERATORS.IMAGE
==============================================================================
EPS =1e-5
class RoPE2D (nn .Module ):
"""
2D Rotary Position Embedding for flexible aspect ratios.
Encodes (x, y) spatial positions for patch-based DiT.
"""
def __init__ (self ,dim :int ,max_height :int =128 ,max_width :int =128 ,base :float =10000.0 ):
super ().__init__ ()
self .dim =dim
self .max_height =max_height
self .max_width =max_width
self .base =base
self .dim_x =dim //2
self .dim_y =dim -self .dim_x
inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x ))
inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y ))
self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False )
self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False )
def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]:
device =x .device
dtype =x .dtype
pos_x =torch .arange (width ,device =device ,dtype =torch .float32 )
pos_y =torch .arange (height ,device =device ,dtype =torch .float32 )
freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device ))
freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device ))
freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 )
freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 )
cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype )
sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype )
for y in range (height ):
for w in range (width ):
cos_2d [y ,w ,:self .dim_x ]=freqs_x [w ].cos ().to (dtype )
sin_2d [y ,w ,:self .dim_x ]=freqs_x [w ].sin ().to (dtype )
cos_2d [y ,w ,self .dim_x :]=freqs_y [y ].cos ().to (dtype )
sin_2d [y ,w ,self .dim_x :]=freqs_y [y ].sin ().to (dtype )
cos_2d =cos_2d .view (height *width ,self .dim )
sin_2d =sin_2d .view (height *width ,self .dim )
return cos_2d ,sin_2d
def apply_rope_2d (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor :
x1 =x [...,:x .shape [-1 ]//2 ]
x2 =x [...,x .shape [-1 ]//2 :]
rotated =torch .cat ((-x2 ,x1 ),dim =-1 )
return x *cos +rotated *sin
class ImageExpert (nn .Module ):
"""Single expert for DiT with SwiGLU activation."""
def __init__ (self ,hidden_size :int ,intermediate_size :int ):
super ().__init__ ()
self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False )
self .act_fn =nn .SiLU ()
def forward (self ,x :torch .Tensor )->torch .Tensor :
return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x ))
class ImageMoERouter (nn .Module ):
"""Router for Image MoE with spatial awareness."""
def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ):
super ().__init__ ()
self .num_experts =num_experts
self .top_k =top_k
self .norm =nn .LayerNorm (hidden_size )
self .gate =nn .Linear (hidden_size ,num_experts ,bias =False )
nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 )
def forward (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
x_norm =self .norm (x )
router_logits =self .gate (x_norm )
router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype )
top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 )
top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS )
return top_k_probs ,top_k_indices
class ImageMoELayer (nn .Module ):
"""MoE Layer for DiT with shared expert."""
def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_experts =num_experts
self .top_k =top_k
self .router =ImageMoERouter (hidden_size ,num_experts ,top_k )
self .experts =nn .ModuleList ([
ImageExpert (hidden_size ,intermediate_size )
for _ in range (num_experts )
])
self .shared_expert =ImageExpert (hidden_size ,intermediate_size )
def forward (self ,x :torch .Tensor )->torch .Tensor :
batch_size ,seq_len ,hidden_size =x .shape
x_flat =x .view (-1 ,hidden_size )
top_k_probs ,top_k_indices =self .router (x_flat )
output =torch .zeros_like (x_flat )
for expert_idx in range (self .num_experts ):
expert =self .experts [expert_idx ]
for k in range (self .top_k ):
mask =(top_k_indices [:,k ]==expert_idx )
if mask .any ():
expert_input =x_flat [mask ]
expert_output =expert (expert_input )
weight =top_k_probs [mask ,k :k +1 ]
output [mask ]=output [mask ]+weight *expert_output
shared_output =self .shared_expert (x_flat )
output =output +shared_output
return output .view (batch_size ,seq_len ,hidden_size )
class DualStreamSelfAttention (nn .Module ):
"""
Symmetric Dual-Stream Self-Attention (SD3/Flux-style).
Two parallel streams with cross-stream information exchange.
Uses Flash Attention 2.0 via SDPA for O(N) memory.
"""
def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .head_dim =hidden_size //num_heads
self .scale =self .head_dim **-0.5
self ._qk_scale =self .head_dim **-0.25
self .to_qkv_a =nn .Linear (hidden_size ,hidden_size *3 ,bias =False )
self .to_qkv_b =nn .Linear (hidden_size ,hidden_size *3 ,bias =False )
self .to_out_a =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .to_out_b =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .norm_a =nn .LayerNorm (hidden_size )
self .norm_b =nn .LayerNorm (hidden_size )
self .rope_2d =RoPE2D (self .head_dim ,max_height ,max_width )
def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]:
batch_size ,seq_len ,_ =x_a .shape
x_a =self .norm_a (x_a )
x_b =self .norm_b (x_b )
qkv_a =self .to_qkv_a (x_a ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim )
qkv_b =self .to_qkv_b (x_b ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim )
q_a ,k_a ,v_a =qkv_a .unbind (dim =2 )
q_b ,k_b ,v_b =qkv_b .unbind (dim =2 )
cos ,sin =self .rope_2d (x_a ,height ,width )
cos =cos .unsqueeze (0 ).unsqueeze (1 )
sin =sin .unsqueeze (0 ).unsqueeze (1 )
q_a =q_a .transpose (1 ,2 )
k_a =k_a .transpose (1 ,2 )
v_a =v_a .transpose (1 ,2 )
q_b =q_b .transpose (1 ,2 )
k_b =k_b .transpose (1 ,2 )
v_b =v_b .transpose (1 ,2 )
q_a =apply_rope_2d (q_a ,cos ,sin )
k_a =apply_rope_2d (k_a ,cos ,sin )
q_b =apply_rope_2d (q_b ,cos ,sin )
k_b =apply_rope_2d (k_b ,cos ,sin )
k_combined =torch .cat ([k_a ,k_b ],dim =2 )
v_combined =torch .cat ([v_a ,v_b ],dim =2 )
out_a =F .scaled_dot_product_attention (
q_a *self ._qk_scale ,k_combined *self ._qk_scale ,v_combined ,
is_causal =False ,scale =1.0 ,
)
out_b =F .scaled_dot_product_attention (
q_b *self ._qk_scale ,k_combined *self ._qk_scale ,v_combined ,
is_causal =False ,scale =1.0 ,
)
out_a =out_a .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size )
out_b =out_b .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size )
out_a =self .to_out_a (out_a )
out_b =self .to_out_b (out_b )
return out_a ,out_b
class CrossAttention (nn .Module ):
"""Cross-attention for text conditioning."""
def __init__ (self ,query_dim :int ,context_dim :int =None ,heads :int =8 ):
super ().__init__ ()
self .heads =heads
context_dim =context_dim or query_dim
self .head_dim =query_dim //heads
self .scale =self .head_dim **-0.5
self .norm =nn .LayerNorm (query_dim )
self .to_q =nn .Linear (query_dim ,query_dim ,bias =False )
self .to_k =nn .Linear (context_dim ,query_dim ,bias =False )
self .to_v =nn .Linear (context_dim ,query_dim ,bias =False )
self .to_out =nn .Linear (query_dim ,query_dim ,bias =False )
def forward (self ,x :torch .Tensor ,context :torch .Tensor )->torch .Tensor :
batch_size ,seq_len ,_ =x .shape
ctx_len =context .shape [1 ]
x =self .norm (x )
q =self .to_q (x ).reshape (batch_size ,seq_len ,self .heads ,self .head_dim ).transpose (1 ,2 )
k =self .to_k (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 )
v =self .to_v (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 )
qk_scale =self .head_dim **-0.25
out =F .scaled_dot_product_attention (
q *qk_scale ,k *qk_scale ,v ,
is_causal =False ,scale =1.0 ,
)
out =out .transpose (1 ,2 ).reshape (batch_size ,seq_len ,-1 )
out =self .to_out (out )
return out
class DiTBlock (nn .Module ):
"""
DiT Block with Dual-Stream Attention and MoE FFN.
"""
def __init__ (self ,hidden_size :int ,context_dim :int ,num_heads :int =8 ,num_experts :int =4 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .dual_attn =DualStreamSelfAttention (hidden_size ,num_heads ,max_height ,max_width )
self .cross_attn_a =CrossAttention (hidden_size ,context_dim ,num_heads )
self .cross_attn_b =CrossAttention (hidden_size ,context_dim ,num_heads )
self .moe_a =ImageMoELayer (hidden_size ,hidden_size *4 ,num_experts )
self .moe_b =ImageMoELayer (hidden_size ,hidden_size *4 ,num_experts )
self .adaLN_a =nn .Sequential (
nn .SiLU (),
nn .Linear (hidden_size ,hidden_size *6 ),
)
self .adaLN_b =nn .Sequential (
nn .SiLU (),
nn .Linear (hidden_size ,hidden_size *6 ),
)
self .norm1_a =nn .LayerNorm (hidden_size ,elementwise_affine =False )
self .norm1_b =nn .LayerNorm (hidden_size ,elementwise_affine =False )
self .norm2_a =nn .LayerNorm (hidden_size ,elementwise_affine =False )
self .norm2_b =nn .LayerNorm (hidden_size ,elementwise_affine =False )
def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,context :torch .Tensor ,t_emb :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]:
shift_a ,scale_a ,gate_a ,shift2_a ,scale2_a ,gate2_a =self .adaLN_a (t_emb ).chunk (6 ,dim =-1 )
shift_b ,scale_b ,gate_b ,shift2_b ,scale2_b ,gate2_b =self .adaLN_b (t_emb ).chunk (6 ,dim =-1 )
shift_a =shift_a .unsqueeze (1 )
scale_a =scale_a .unsqueeze (1 )
gate_a =gate_a .unsqueeze (1 )
shift2_a =shift2_a .unsqueeze (1 )
scale2_a =scale2_a .unsqueeze (1 )
gate2_a =gate2_a .unsqueeze (1 )
shift_b =shift_b .unsqueeze (1 )
scale_b =scale_b .unsqueeze (1 )
gate_b =gate_b .unsqueeze (1 )
shift2_b =shift2_b .unsqueeze (1 )
scale2_b =scale2_b .unsqueeze (1 )
gate2_b =gate2_b .unsqueeze (1 )
x_a_norm =self .norm1_a (x_a )*(1 +scale_a )+shift_a
x_b_norm =self .norm1_b (x_b )*(1 +scale_b )+shift_b
attn_out_a ,attn_out_b =self .dual_attn (x_a_norm ,x_b_norm ,height ,width )
x_a =x_a +gate_a *attn_out_a
x_b =x_b +gate_b *attn_out_b
x_a =x_a +self .cross_attn_a (x_a ,context )
x_b =x_b +self .cross_attn_b (x_b ,context )
x_a_norm =self .norm2_a (x_a )*(1 +scale2_a )+shift2_a
x_b_norm =self .norm2_b (x_b )*(1 +scale2_b )+shift2_b
x_a =x_a +gate2_a *self .moe_a (x_a_norm )
x_b =x_b +gate2_b *self .moe_b (x_b_norm )
return x_a ,x_b
class FlowMatchingScheduler :
"""Flow Matching scheduler for image generation."""
def __init__ (self ,num_steps :int =50 ,sigma_min :float =0.002 ):
self .num_steps =num_steps
self .sigma_min =sigma_min
self .timesteps =torch .linspace (1 ,0 ,num_steps +1 )
def get_velocity (self ,x_t :torch .Tensor ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor :
return x_0 -x_t
def step (self ,model_output :torch .Tensor ,t :torch .Tensor ,t_prev :torch .Tensor ,x_t :torch .Tensor )->torch .Tensor :
dt =t -t_prev
x_prev =x_t +model_output *dt .view (-1 ,1 ,1 ,1 )
return x_prev
def add_noise (self ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor :
noise =torch .randn_like (x_0 )
t =t .to (x_0 .dtype ).view (-1 ,1 ,1 ,1 )
x_t =t *noise +(1 -t )*x_0
return x_t
class PatchEmbed (nn .Module ):
"""Patch embedding for DiT."""
def __init__ (self ,patch_size :int =2 ,in_channels :int =4 ,hidden_size :int =512 ):
super ().__init__ ()
self .patch_size =patch_size
self .proj =nn .Conv2d (in_channels ,hidden_size ,kernel_size =patch_size ,stride =patch_size )
def forward (self ,x :torch .Tensor )->torch .Tensor :
x =self .proj (x )
x =x .flatten (2 ).transpose (1 ,2 )
return x
class UnpatchEmbed (nn .Module ):
"""Unpatch embedding to reconstruct image from patches."""
def __init__ (self ,patch_size :int =2 ,out_channels :int =4 ,hidden_size :int =512 ):
super ().__init__ ()
self .patch_size =patch_size
self .out_channels =out_channels
self .proj =nn .Linear (hidden_size ,patch_size *patch_size *out_channels )
def forward (self ,x :torch .Tensor ,height :int ,width :int )->torch .Tensor :
x =self .proj (x )
batch_size =x .shape [0 ]
x =x .reshape (batch_size ,height ,width ,self .patch_size ,self .patch_size ,self .out_channels )
x =x .permute (0 ,5 ,1 ,3 ,2 ,4 ).reshape (batch_size ,self .out_channels ,height *self .patch_size ,width *self .patch_size )
return x
class MoEDiT (nn .Module ):
"""
MoE Diffusion Transformer with Dual-Stream Attention.
"""
def __init__ (
self ,
in_channels :int =4 ,
out_channels :int =4 ,
hidden_size :int =512 ,
context_dim :int =1024 ,
num_layers :int =8 ,
num_heads :int =8 ,
num_experts :int =4 ,
patch_size :int =2 ,
max_image_size :int =64 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .patch_size =patch_size
max_patches =max_image_size //patch_size
self .time_embed =nn .Sequential (
nn .Linear (hidden_size ,hidden_size *4 ),
nn .SiLU (),
nn .Linear (hidden_size *4 ,hidden_size ),
)
self .patch_embed =PatchEmbed (patch_size ,in_channels ,hidden_size )
self .context_proj =nn .Linear (context_dim ,hidden_size )
self .blocks =nn .ModuleList ([
DiTBlock (hidden_size ,hidden_size ,num_heads ,num_experts ,max_patches ,max_patches )
for _ in range (num_layers )
])
self .final_norm =nn .LayerNorm (hidden_size )
self .unpatch_embed =UnpatchEmbed (patch_size ,out_channels ,hidden_size )
self .gradient_checkpointing =False
self ._init_weights ()
def _init_weights (self ):
nn .init .zeros_ (self .unpatch_embed .proj .weight )
nn .init .zeros_ (self .unpatch_embed .proj .bias )
def enable_gradient_checkpointing (self ):
"""Enable gradient checkpointing for memory efficiency."""
self .gradient_checkpointing =True
def forward (self ,x :torch .Tensor ,timesteps :torch .Tensor ,context :torch .Tensor ,mask :Optional [torch .Tensor ]=None )->torch .Tensor :
batch_size ,channels ,height ,width =x .shape
patch_height =height //self .patch_size
patch_width =width //self .patch_size
half_dim =self .hidden_size //2
t_emb =math .log (10000 )/(half_dim -1 )
t_emb =torch .exp (torch .arange (half_dim ,device =x .device ,dtype =x .dtype )*-t_emb )
t_emb =timesteps [:,None ].to (x .dtype )*t_emb [None ,:]
t_emb =torch .cat ([torch .sin (t_emb ),torch .cos (t_emb )],dim =-1 )
t_emb =self .time_embed (t_emb )
x_patches =self .patch_embed (x )
context_proj =self .context_proj (context )
x_a =x_patches
x_b =x_patches .clone ()
for block in self .blocks :
if self .gradient_checkpointing and self .training :
x_a ,x_b =torch .utils .checkpoint .checkpoint (
block ,x_a ,x_b ,context_proj ,t_emb ,patch_height ,patch_width ,
use_reentrant =False
)
else :
x_a ,x_b =block (x_a ,x_b ,context_proj ,t_emb ,patch_height ,patch_width )
x_combined =(x_a +x_b )/2
x_combined =self .final_norm (x_combined )
velocity =self .unpatch_embed (x_combined ,patch_height ,patch_width )
return velocity
class ImageVAE (nn .Module ):
"""Lightweight VAE for image encoding/decoding."""
def __init__ (self ,in_channels :int =3 ,latent_channels :int =4 ,base_channels :int =64 ):
super ().__init__ ()
self .encoder =nn .Sequential (
nn .Conv2d (in_channels ,base_channels ,3 ,padding =1 ),
nn .SiLU (),
nn .Conv2d (base_channels ,base_channels *2 ,3 ,stride =2 ,padding =1 ),
nn .SiLU (),
nn .Conv2d (base_channels *2 ,base_channels *4 ,3 ,stride =2 ,padding =1 ),
nn .SiLU (),
nn .Conv2d (base_channels *4 ,latent_channels *2 ,3 ,padding =1 ),
)
self .decoder =nn .Sequential (
nn .Conv2d (latent_channels ,base_channels *4 ,3 ,padding =1 ),
nn .SiLU (),
nn .Upsample (scale_factor =2 ,mode ='bilinear',align_corners =False ),
nn .Conv2d (base_channels *4 ,base_channels *2 ,3 ,padding =1 ),
nn .SiLU (),
nn .Upsample (scale_factor =2 ,mode ='bilinear',align_corners =False ),
nn .Conv2d (base_channels *2 ,base_channels ,3 ,padding =1 ),
nn .SiLU (),
nn .Conv2d (base_channels ,in_channels ,3 ,padding =1 ),
)
def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
h =self .encoder (x )
mean ,logvar =h .chunk (2 ,dim =1 )
logvar =torch .clamp (logvar ,-30 ,20 )
std =torch .exp (0.5 *logvar )
z =mean +std *torch .randn_like (std )
return z ,mean ,logvar
def decode (self ,z :torch .Tensor )->torch .Tensor :
return self .decoder (z )
class MobileDiffusionGenerator (nn .Module ):
"""
SOTA Image Diffusion with MoE-DiT, Flow Matching, 2D-RoPE, Dual-Stream.
Optimized for 2x T4 GPUs with FP16.
"""
def __init__ (
self ,
latent_channels :int =4 ,
base_channels :int =128 ,
context_dim :int =1024 ,
num_inference_steps :int =50 ,
image_size :int =256 ,
cfg_scale :float =7.5 ,
):
super ().__init__ ()
self .latent_channels =latent_channels
self .context_dim =context_dim
self .image_size =image_size
self .latent_size =image_size //4
self .num_inference_steps =num_inference_steps
self .cfg_scale =cfg_scale
self .vae_encoder =ImageVAE (3 ,latent_channels ,base_channels //2 )
self .vae_decoder =self .vae_encoder
self .unet =MoEDiT (
in_channels =latent_channels ,
out_channels =latent_channels ,
hidden_size =base_channels *4 ,
context_dim =context_dim ,
num_layers =8 ,
num_heads =8 ,
num_experts =4 ,
patch_size =2 ,
max_image_size =self .latent_size ,
)
self .scheduler =FlowMatchingScheduler (num_inference_steps )
def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
return self .vae_encoder .encode (x )
def decode (self ,z :torch .Tensor )->torch .Tensor :
return self .vae_decoder .decode (z )
def training_step (self ,images :torch .Tensor ,context :torch .Tensor ,mask :Optional [torch .Tensor ]=None )->dict :
device =images .device
dtype =images .dtype
batch_size =images .shape [0 ]
z ,mean ,logvar =self .encode (images *2 -1 )
del images
t =torch .rand (batch_size ,device =device ,dtype =dtype )
x_t =self .scheduler .add_noise (z ,t )
target_velocity =self .scheduler .get_velocity (x_t ,z ,t )
if self .training :
drop_mask =torch .rand (batch_size ,device =device )<0.1
drop_mask_expanded =drop_mask .view (batch_size ,1 ,1 ).expand_as (context )
null_ctx =torch .zeros_like (context )
context =torch .where (drop_mask_expanded ,null_ctx ,context )
del drop_mask ,drop_mask_expanded ,null_ctx
pred_velocity =self .unet (x_t ,(t *1000 ).to (dtype ),context ,mask )
del x_t ,context
flow_loss =F .mse_loss (pred_velocity ,target_velocity )
del pred_velocity ,target_velocity
kl_loss =-0.5 *torch .mean (1 +logvar -mean .pow (2 )-logvar .exp ())
del z ,mean ,logvar
total_loss =flow_loss +0.0001 *kl_loss
return {
'flow_loss':flow_loss ,
'kl_loss':kl_loss ,
'total_loss':total_loss ,
}
@torch .no_grad ()
def generate (self ,context :torch .Tensor ,guidance_scale :float =None ,num_steps :int =None ,init_latents :Optional [torch .Tensor ]=None ,mask :Optional [torch .Tensor ]=None ,masked_image_latents :Optional [torch .Tensor ]=None )->torch .Tensor :
device =context .device
batch_size =context .shape [0 ]
seq_len =context .shape [1 ]
guidance_scale =guidance_scale or self .cfg_scale
num_steps =num_steps or self .num_inference_steps
if init_latents is not None :
latents =init_latents
else :
latents =torch .randn (batch_size ,self .latent_channels ,self .latent_size ,self .latent_size ,device =device )
timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device )
if guidance_scale >1.0 :
null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype )
context =torch .cat ([null_ctx ,context ])
for i in range (num_steps ):
t =timesteps [i ]
t_prev =timesteps [i +1 ]
t_batch =t .expand (batch_size )*1000
if guidance_scale >1.0 :
latent_input =torch .cat ([latents ,latents ])
t_input =torch .cat ([t_batch ,t_batch ])
velocity_pred =self .unet (latent_input ,t_input ,context ,mask )
velocity_uncond ,velocity_cond =velocity_pred .chunk (2 )
velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond )
else :
velocity_pred =self .unet (latents ,t_batch ,context ,mask )
latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents )
if mask is not None and masked_image_latents is not None :
latents =masked_image_latents *mask +latents *(1 -mask )
images =self .decode (latents )
images =(images +1 )/2
return torch .clamp (images ,0 ,1 )
@torch .no_grad ()
def edit_image (self ,image :torch .Tensor ,context :torch .Tensor ,mask :torch .Tensor ,strength :float =0.8 ,guidance_scale :float =None )->torch .Tensor :
device =image .device
image_norm =image *2 -1
z ,_ ,_ =self .encode (image_norm )
mask_latent =F .interpolate (mask ,size =(self .latent_size ,self .latent_size ),mode ='nearest')
num_steps =int (self .num_inference_steps *strength )
t =torch .tensor ([strength ],device =device )
noisy_z =self .scheduler .add_noise (z ,t .expand (z .shape [0 ]))
return self .generate (
context ,
guidance_scale =guidance_scale ,
num_steps =num_steps ,
init_latents =noisy_z ,
mask =mask_latent ,
masked_image_latents =z ,
)
==============================================================================
MODELS.GENERATORS.VIDEO
==============================================================================
EPS =1e-5
class InterleavedMRoPE (nn .Module ):
"""
Interleaved Multi-dimensional Rotary Position Embedding (MRoPE).
SOTA: Full-frequency allocation over time, width, and height via robust positional embeddings.
Unlike separate spatial and temporal RoPE, Interleaved-MRoPE allocates frequencies across
all three dimensions jointly, enhancing long-horizon video reasoning.
Key advantages:
- Better temporal-spatial correlation modeling
- More robust for variable aspect ratios and frame counts
- Improved long-range video understanding
"""
def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,max_frames :int =64 ,base :float =10000.0 ):
super ().__init__ ()
self .dim =dim
self .max_height =max_height
self .max_width =max_width
self .max_frames =max_frames
self .base =base
self .dim_t =dim //3
self .dim_y =dim //3
self .dim_x =dim -self .dim_t -self .dim_y
inv_freq_t =1.0 /(base **(torch .arange (0 ,self .dim_t ,2 ,dtype =torch .float32 )/self .dim_t ))
inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y ))
inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x ))
self .register_buffer ('inv_freq_t',inv_freq_t ,persistent =False )
self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False )
self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False )
def forward (self ,x :torch .Tensor ,height :int ,width :int ,num_frames :int )->Tuple [torch .Tensor ,torch .Tensor ]:
"""
Compute interleaved 3D positional embeddings.
Args:
x: Input tensor for device/dtype reference
height: Spatial height
width: Spatial width
num_frames: Temporal frames
Returns:
cos, sin: [T * H * W, dim] positional embeddings
"""
device =x .device
dtype =x .dtype
pos_t =torch .arange (num_frames ,device =device ,dtype =torch .float32 )
pos_y =torch .arange (height ,device =device ,dtype =torch .float32 )
pos_x =torch .arange (width ,device =device ,dtype =torch .float32 )
freqs_t =torch .outer (pos_t ,self .inv_freq_t .to (device ))
freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device ))
freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device ))
freqs_t =torch .cat ([freqs_t ,freqs_t ],dim =-1 )
freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 )
freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 )
seq_len =num_frames *height *width
cos_3d =torch .zeros (num_frames ,height ,width ,self .dim ,device =device ,dtype =dtype )
sin_3d =torch .zeros (num_frames ,height ,width ,self .dim ,device =device ,dtype =dtype )
for t in range (num_frames ):
for h in range (height ):
for w in range (width ):
cos_3d [t ,h ,w ,:self .dim_t ]=freqs_t [t ].cos ().to (dtype )
sin_3d [t ,h ,w ,:self .dim_t ]=freqs_t [t ].sin ().to (dtype )
cos_3d [t ,h ,w ,self .dim_t :self .dim_t +self .dim_y ]=freqs_y [h ].cos ().to (dtype )
sin_3d [t ,h ,w ,self .dim_t :self .dim_t +self .dim_y ]=freqs_y [h ].sin ().to (dtype )
cos_3d [t ,h ,w ,self .dim_t +self .dim_y :]=freqs_x [w ].cos ().to (dtype )
sin_3d [t ,h ,w ,self .dim_t +self .dim_y :]=freqs_x [w ].sin ().to (dtype )
cos_3d =cos_3d .view (seq_len ,self .dim )
sin_3d =sin_3d .view (seq_len ,self .dim )
return cos_3d ,sin_3d
class RoPE2D (nn .Module ):
"""
2D Rotary Position Embedding for spatial dimensions (memory efficient).
Used for spatial attention in factorized video attention.
"""
def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,base :float =10000.0 ):
super ().__init__ ()
self .dim =dim
self .dim_x =dim //2
self .dim_y =dim -self .dim_x
inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x ))
inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y ))
self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False )
self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False )
def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]:
device =x .device
dtype =x .dtype
pos_x =torch .arange (width ,device =device ,dtype =torch .float32 )
pos_y =torch .arange (height ,device =device ,dtype =torch .float32 )
freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device ))
freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device ))
cos_x =torch .cat ([freqs_x .cos (),freqs_x .cos ()],dim =-1 )
sin_x =torch .cat ([freqs_x .sin (),freqs_x .sin ()],dim =-1 )
cos_y =torch .cat ([freqs_y .cos (),freqs_y .cos ()],dim =-1 )
sin_y =torch .cat ([freqs_y .sin (),freqs_y .sin ()],dim =-1 )
cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype )
sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype )
cos_2d [:,:,:self .dim_x ]=cos_x .unsqueeze (0 ).expand (height ,-1 ,-1 )
sin_2d [:,:,:self .dim_x ]=sin_x .unsqueeze (0 ).expand (height ,-1 ,-1 )
cos_2d [:,:,self .dim_x :]=cos_y .unsqueeze (1 ).expand (-1 ,width ,-1 )
sin_2d [:,:,self .dim_x :]=sin_y .unsqueeze (1 ).expand (-1 ,width ,-1 )
return cos_2d .view (height *width ,self .dim ).to (dtype ),sin_2d .view (height *width ,self .dim ).to (dtype )
class RoPE1D (nn .Module ):
"""
1D Rotary Position Embedding for temporal dimension.
Used for temporal attention in factorized video attention.
"""
def __init__ (self ,dim :int ,max_len :int =64 ,base :float =10000.0 ):
super ().__init__ ()
self .dim =dim
inv_freq =1.0 /(base **(torch .arange (0 ,dim ,2 ,dtype =torch .float32 )/dim ))
self .register_buffer ('inv_freq',inv_freq ,persistent =False )
def forward (self ,x :torch .Tensor ,seq_len :int )->Tuple [torch .Tensor ,torch .Tensor ]:
device =x .device
dtype =x .dtype
pos =torch .arange (seq_len ,device =device ,dtype =torch .float32 )
freqs =torch .outer (pos ,self .inv_freq .to (device ))
freqs =torch .cat ([freqs ,freqs ],dim =-1 )
return freqs .cos ().to (dtype ),freqs .sin ().to (dtype )
def apply_rope (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor :
"""Apply rotary position embedding."""
x1 =x [...,:x .shape [-1 ]//2 ]
x2 =x [...,x .shape [-1 ]//2 :]
rotated =torch .cat ((-x2 ,x1 ),dim =-1 )
return x *cos +rotated *sin
class TemporalExpertRouter (nn .Module ):
"""
Temporal-Aware Expert Router for video generation.
Routes tokens based on temporal context and motion patterns.
"""
def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ):
super ().__init__ ()
self .num_experts =num_experts
self .top_k =top_k
self .temporal_proj =nn .Linear (hidden_size ,hidden_size )
self .gate =nn .Linear (hidden_size ,num_experts ,bias =False )
nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 )
def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->Tuple [torch .Tensor ,torch .Tensor ]:
if temporal_context is not None :
x =x +self .temporal_proj (temporal_context )
router_logits =self .gate (x )
router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype )
top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 )
top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS )
return top_k_probs ,top_k_indices
class VideoExpert (nn .Module ):
"""Single expert for video processing with SwiGLU."""
def __init__ (self ,hidden_size :int ,intermediate_size :int ):
super ().__init__ ()
self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False )
self .act_fn =nn .SiLU ()
def forward (self ,x :torch .Tensor )->torch .Tensor :
return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x ))
class TemporalMoELayer (nn .Module ):
"""
Temporal-Aware MoE Layer for video generation.
Uses motion-aware routing for expert selection.
"""
def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_experts =num_experts
self .top_k =top_k
self .router =TemporalExpertRouter (hidden_size ,num_experts ,top_k )
self .experts =nn .ModuleList ([
VideoExpert (hidden_size ,intermediate_size )
for _ in range (num_experts )
])
self .shared_expert =VideoExpert (hidden_size ,intermediate_size )
def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor :
batch_size ,seq_len ,hidden_size =x .shape
x_flat =x .view (-1 ,hidden_size )
top_k_probs ,top_k_indices =self .router (x_flat ,temporal_context .view (-1 ,hidden_size )if temporal_context is not None else None )
output =torch .zeros_like (x_flat )
for expert_idx in range (self .num_experts ):
expert =self .experts [expert_idx ]
for k in range (self .top_k ):
mask =(top_k_indices [:,k ]==expert_idx )
if mask .any ():
expert_input =x_flat [mask ]
expert_output =expert (expert_input )
weight =top_k_probs [mask ,k :k +1 ]
output [mask ]=output [mask ]+weight *expert_output
shared_output =self .shared_expert (x_flat )
output =output +shared_output
return output .view (batch_size ,seq_len ,hidden_size )
class SpatialAttention (nn .Module ):
"""
Spatial self-attention: each frame attends only within itself.
Memory: O(T * (H*W)^2) instead of O((T*H*W)^2)
"""
def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .head_dim =hidden_size //num_heads
self .scale =self .head_dim **-0.5
self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False )
self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .rope_2d =RoPE2D (self .head_dim ,max_height ,max_width )
self .norm =nn .LayerNorm (hidden_size )
def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int )->torch .Tensor :
batch_size ,seq_len ,_ =x .shape
spatial_len =height *width
x =self .norm (x )
x =x .view (batch_size *frames ,spatial_len ,self .hidden_size )
qkv =self .to_qkv (x ).reshape (batch_size *frames ,spatial_len ,3 ,self .num_heads ,self .head_dim )
q ,k ,v =qkv .unbind (dim =2 )
cos ,sin =self .rope_2d (x ,height ,width )
cos =cos .unsqueeze (0 ).unsqueeze (1 )
sin =sin .unsqueeze (0 ).unsqueeze (1 )
q =q .transpose (1 ,2 )
k =k .transpose (1 ,2 )
v =v .transpose (1 ,2 )
q =apply_rope (q ,cos ,sin )
k =apply_rope (k ,cos ,sin )
qk_scale =self .head_dim **-0.25
out =F .scaled_dot_product_attention (
q *qk_scale ,k *qk_scale ,v ,
is_causal =False ,scale =1.0 ,
)
out =out .transpose (1 ,2 ).reshape (batch_size *frames ,spatial_len ,self .hidden_size )
out =self .to_out (out )
return out .view (batch_size ,seq_len ,self .hidden_size )
class TemporalAttention (nn .Module ):
"""
Temporal self-attention: each spatial position attends across time.
Memory: O(H*W * T^2) instead of O((T*H*W)^2)
"""
def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .head_dim =hidden_size //num_heads
self .scale =self .head_dim **-0.5
self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False )
self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False )
self .rope_1d =RoPE1D (self .head_dim ,max_frames )
self .norm =nn .LayerNorm (hidden_size )
def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =True )->torch .Tensor :
batch_size ,seq_len ,_ =x .shape
spatial_len =height *width
x =self .norm (x )
x =x .view (batch_size ,frames ,spatial_len ,self .hidden_size )
x =x .permute (0 ,2 ,1 ,3 ).reshape (batch_size *spatial_len ,frames ,self .hidden_size )
qkv =self .to_qkv (x ).reshape (batch_size *spatial_len ,frames ,3 ,self .num_heads ,self .head_dim )
q ,k ,v =qkv .unbind (dim =2 )
cos ,sin =self .rope_1d (x ,frames )
cos =cos .unsqueeze (0 ).unsqueeze (1 )
sin =sin .unsqueeze (0 ).unsqueeze (1 )
q =q .transpose (1 ,2 )
k =k .transpose (1 ,2 )
v =v .transpose (1 ,2 )
q =apply_rope (q ,cos ,sin )
k =apply_rope (k ,cos ,sin )
qk_scale =self .head_dim **-0.25
out =F .scaled_dot_product_attention (
q *qk_scale ,k *qk_scale ,v ,
is_causal =causal ,scale =1.0 ,
)
out =out .transpose (1 ,2 ).reshape (batch_size *spatial_len ,frames ,self .hidden_size )
out =out .view (batch_size ,spatial_len ,frames ,self .hidden_size )
out =out .permute (0 ,2 ,1 ,3 ).reshape (batch_size ,seq_len ,self .hidden_size )
out =self .to_out (out )
return out
class FactorizedSpatioTemporalAttention (nn .Module ):
"""
Factorized Spatial-Temporal Attention (like CogVideo, Open-Sora, SVD).
Instead of full 3D attention O((T*H*W)^2), uses:
1. Spatial attention per frame: O(T * (H*W)^2)
2. Temporal attention per position: O(H*W * T^2)
Total: O(T*(H*W)^2 + H*W*T^2) << O((T*H*W)^2)
For T=8, H=W=64:
- Full 3D: 32768^2 = 1B attention scores
- Factorized: 8*4096^2 + 4096*64 = 134M attention scores (7.5x less!)
"""
def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .spatial_attn =SpatialAttention (hidden_size ,num_heads ,max_height ,max_width )
self .temporal_attn =TemporalAttention (hidden_size ,num_heads ,max_frames )
def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =True )->torch .Tensor :
x =x +self .spatial_attn (x ,height ,width ,frames )
x =x +self .temporal_attn (x ,height ,width ,frames ,causal )
return x
class CrossAttention3D (nn .Module ):
"""Cross-attention for text-to-video conditioning."""
def __init__ (self ,query_dim :int ,context_dim :int =None ,heads :int =8 ):
super ().__init__ ()
self .heads =heads
context_dim =context_dim or query_dim
self .head_dim =query_dim //heads
self .scale =self .head_dim **-0.5
self .norm =nn .LayerNorm (query_dim )
self .to_q =nn .Linear (query_dim ,query_dim ,bias =False )
self .to_k =nn .Linear (context_dim ,query_dim ,bias =False )
self .to_v =nn .Linear (context_dim ,query_dim ,bias =False )
self .to_out =nn .Linear (query_dim ,query_dim ,bias =False )
def forward (self ,x :torch .Tensor ,context :torch .Tensor )->torch .Tensor :
batch_size ,seq_len ,_ =x .shape
ctx_len =context .shape [1 ]
x =self .norm (x )
q =self .to_q (x ).reshape (batch_size ,seq_len ,self .heads ,self .head_dim ).transpose (1 ,2 )
k =self .to_k (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 )
v =self .to_v (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 )
qk_scale =self .head_dim **-0.25
out =F .scaled_dot_product_attention (
q *qk_scale ,k *qk_scale ,v ,
is_causal =False ,scale =1.0 ,
)
out =out .transpose (1 ,2 ).reshape (batch_size ,seq_len ,-1 )
out =self .to_out (out )
return out
class Causal3DTransformerBlock (nn .Module ):
"""
3D Causal Transformer Block with Factorized Spatial-Temporal Attention.
Uses memory-efficient factorized attention instead of full 3D attention:
- Spatial: Each frame attends within itself O(T * (H*W)^2)
- Temporal: Each position attends across frames O(H*W * T^2)
This reduces memory from O((T*H*W)^2) to O(T*(H*W)^2 + H*W*T^2)
"""
def __init__ (self ,hidden_size :int ,context_dim :int ,num_heads :int =8 ,num_experts :int =4 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ):
super ().__init__ ()
self .self_attn =FactorizedSpatioTemporalAttention (hidden_size ,num_heads ,max_frames ,max_height ,max_width )
self .cross_attn =CrossAttention3D (hidden_size ,context_dim ,num_heads )
self .moe =TemporalMoELayer (hidden_size ,hidden_size *4 ,num_experts )
self .norm1 =nn .LayerNorm (hidden_size )
self .norm2 =nn .LayerNorm (hidden_size )
self .norm3 =nn .LayerNorm (hidden_size )
def forward (self ,x :torch .Tensor ,context :torch .Tensor ,height :int ,width :int ,frames :int ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor :
x =self .self_attn (self .norm1 (x ),height ,width ,frames ,causal =True )
x =x +self .cross_attn (self .norm2 (x ),context )
x =x +self .moe (self .norm3 (x ),temporal_context )
return x
class FlowMatchingScheduler :
"""
Flow Matching scheduler for video generation.
Uses optimal transport paths for superior generation quality.
"""
def __init__ (self ,num_steps :int =50 ,sigma_min :float =0.002 ):
self .num_steps =num_steps
self .sigma_min =sigma_min
self .timesteps =torch .linspace (1 ,0 ,num_steps +1 )
def get_velocity (self ,x_t :torch .Tensor ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor :
"""Compute target velocity for flow matching."""
return x_0 -x_t
def step (self ,model_output :torch .Tensor ,t :torch .Tensor ,t_prev :torch .Tensor ,x_t :torch .Tensor )->torch .Tensor :
"""Single step of flow matching ODE."""
dt =t -t_prev
x_prev =x_t +model_output *dt .view (-1 ,1 ,1 ,1 ,1 )
return x_prev
def add_noise (self ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor :
"""Add noise for training (linear interpolation)."""
noise =torch .randn_like (x_0 )
t =t .to (x_0 .dtype ).view (-1 ,1 ,1 ,1 ,1 )
x_t =t *noise +(1 -t )*x_0
return x_t
class VideoUNet3D (nn .Module ):
"""
3D U-Net for video generation with Factorized Spatial-Temporal Attention.
Uses memory-efficient factorized attention that processes spatial and temporal
dimensions separately, reducing memory from O((T*H*W)^2) to O(T*(H*W)^2 + H*W*T^2).
"""
def __init__ (
self ,
in_channels :int =4 ,
out_channels :int =4 ,
hidden_size :int =512 ,
context_dim :int =1024 ,
num_layers :int =4 ,
num_heads :int =8 ,
num_experts :int =4 ,
num_frames :int =16 ,
max_height :int =64 ,
max_width :int =64 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_frames =num_frames
self .time_embed =nn .Sequential (
nn .Linear (hidden_size ,hidden_size *4 ),
nn .SiLU (),
nn .Linear (hidden_size *4 ,hidden_size ),
)
self .input_proj =nn .Conv3d (in_channels ,hidden_size ,kernel_size =3 ,padding =1 )
self .transformer_blocks =nn .ModuleList ([
Causal3DTransformerBlock (hidden_size ,context_dim ,num_heads ,num_experts ,num_frames ,max_height ,max_width )
for _ in range (num_layers )
])
self .output_proj =nn .Sequential (
nn .GroupNorm (32 ,hidden_size ),
nn .SiLU (),
nn .Conv3d (hidden_size ,out_channels ,kernel_size =3 ,padding =1 ),
)
nn .init .zeros_ (self .output_proj [-1 ].weight )
nn .init .zeros_ (self .output_proj [-1 ].bias )
self .gradient_checkpointing =False
def enable_gradient_checkpointing (self ):
"""Enable gradient checkpointing for memory efficiency."""
self .gradient_checkpointing =True
def forward (self ,x :torch .Tensor ,timesteps :torch .Tensor ,context :torch .Tensor ,first_frame_latent :Optional [torch .Tensor ]=None )->torch .Tensor :
batch_size ,channels ,frames ,height ,width =x .shape
half_dim =self .hidden_size //2
t_emb =math .log (10000 )/(half_dim -1 )
t_emb =torch .exp (torch .arange (half_dim ,device =x .device ,dtype =x .dtype )*-t_emb )
t_emb =timesteps [:,None ].to (x .dtype )*t_emb [None ,:]
t_emb =torch .cat ([torch .sin (t_emb ),torch .cos (t_emb )],dim =-1 )
t_emb =self .time_embed (t_emb )
h =self .input_proj (x )
h =h .permute (0 ,2 ,3 ,4 ,1 ).reshape (batch_size ,frames *height *width ,self .hidden_size )
temporal_context =t_emb .unsqueeze (1 ).expand (-1 ,frames *height *width ,-1 )
for block in self .transformer_blocks :
if self .gradient_checkpointing and self .training :
h =torch .utils .checkpoint .checkpoint (
block ,h ,context ,height ,width ,frames ,temporal_context ,
use_reentrant =False
)
else :
h =block (h ,context ,height ,width ,frames ,temporal_context )
h =h .reshape (batch_size ,frames ,height ,width ,self .hidden_size ).permute (0 ,4 ,1 ,2 ,3 )
velocity =self .output_proj (h )
return velocity
class VideoVAE3D (nn .Module ):
"""
3D VAE for video encoding/decoding using VidTok architecture.
This replaces the simple placeholder with proper temporal+spatial compression
following Microsoft's VidTok architecture for high-quality video tokenization.
Features:
- Proper temporal compression (4x default)
- Proper spatial compression (8x default, same as image VAE)
- AlphaBlender for temporal blending
- Causal mode support for streaming
- Both KL (continuous) and FSQ (discrete) tokenization
Compression: [B, C, T, H, W] -> [B, latent_ch, T/4, H/8, W/8]
"""
def __init__ (
self ,
in_channels :int =3 ,
latent_channels :int =4 ,
base_channels :int =64 ,
temporal_compression :int =4 ,
spatial_compression :int =8 ,
causal :bool =True ,
use_fsq :bool =False ,
):
super ().__init__ ()
self .in_channels =in_channels
self .latent_channels =latent_channels
self .temporal_compression =temporal_compression
self .spatial_compression =spatial_compression
self .causal =causal
self .use_fsq =use_fsq
self .temporal_stages =int (math .log2 (temporal_compression ))
self .spatial_stages =int (math .log2 (spatial_compression ))
encoder_layers =[]
ch_in =in_channels
ch_out =base_channels
encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 ))
encoder_layers .append (nn .SiLU ())
for i in range (self .spatial_stages -self .temporal_stages ):
ch_in =ch_out
ch_out =min (ch_out *2 ,base_channels *8 )
encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,stride =(1 ,2 ,2 ),padding =1 ))
encoder_layers .append (nn .SiLU ())
for i in range (self .temporal_stages ):
ch_in =ch_out
ch_out =min (ch_out *2 ,base_channels *8 )
encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,stride =(2 ,2 ,2 ),padding =1 ))
encoder_layers .append (nn .SiLU ())
out_ch =latent_channels *2 if not use_fsq else latent_channels
encoder_layers .append (nn .Conv3d (ch_out ,out_ch ,3 ,padding =1 ))
self .encoder =nn .Sequential (*encoder_layers )
decoder_layers =[]
ch_in =latent_channels
ch_out =base_channels *(2 **min (self .spatial_stages ,3 ))
decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 ))
decoder_layers .append (nn .SiLU ())
for i in range (self .temporal_stages ):
ch_in =ch_out
ch_out =max (ch_out //2 ,base_channels )
decoder_layers .append (nn .Upsample (scale_factor =(2 ,2 ,2 ),mode ='trilinear',align_corners =False ))
decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 ))
decoder_layers .append (nn .SiLU ())
for i in range (self .spatial_stages -self .temporal_stages ):
ch_in =ch_out
ch_out =max (ch_out //2 ,base_channels )
decoder_layers .append (nn .Upsample (scale_factor =(1 ,2 ,2 ),mode ='trilinear',align_corners =False ))
decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 ))
decoder_layers .append (nn .SiLU ())
decoder_layers .append (nn .Conv3d (ch_out ,in_channels ,3 ,padding =1 ))
self .decoder =nn .Sequential (*decoder_layers )
print (f" 🎬 VideoVAE3D (VidTok): {temporal_compression }x{spatial_compression }x{spatial_compression } compression")
print (f" Temporal stages: {self .temporal_stages }, Spatial stages: {self .spatial_stages }")
print (f" Mode: {'FSQ (discrete)'if use_fsq else 'KL (continuous)'}, Causal: {causal }")
def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
"""
Encode video to latent space.
Args:
x: [B, C, T, H, W] video tensor, values in [0, 1] or [-1, 1]
Returns:
Tuple of (z, mean, logvar) where z is the sampled latent
"""
h =self .encoder (x )
if self .use_fsq :
z =self ._fsq_quantize (h )
return z ,z ,torch .zeros_like (z )
else :
mean ,logvar =h .chunk (2 ,dim =1 )
logvar =torch .clamp (logvar ,-30 ,20 )
std =torch .exp (0.5 *logvar )
z =mean +std *torch .randn_like (std )
return z ,mean ,logvar
def _fsq_quantize (self ,z :torch .Tensor ,levels :int =8 )->torch .Tensor :
"""Finite Scalar Quantization."""
z =torch .tanh (z )
z =torch .round ((z +1 )*(levels -1 )/2 )*2 /(levels -1 )-1
return z
def decode (self ,z :torch .Tensor )->torch .Tensor :
"""
Decode latent to video.
Args:
z: [B, latent_ch, t, h, w] latent tensor
Returns:
[B, C, T, H, W] reconstructed video
"""
return self .decoder (z )
class MobileVideoDiffusion (nn .Module ):
"""
SOTA Video Diffusion with Flow Matching, Factorized Attention, Temporal MoE.
Uses memory-efficient factorized spatial-temporal attention:
- Full 3D attention: O((T*H*W)^2) = 1B+ attention scores (OOM!)
- Factorized: O(T*(H*W)^2 + H*W*T^2) = ~134M scores (7.5x less memory)
Optimized for 2x T4 GPUs (15GB each) with FP16.
"""
def __init__ (
self ,
latent_channels :int =4 ,
base_channels :int =64 ,
context_dim :int =1024 ,
num_frames :int =16 ,
image_size :int =256 ,
num_inference_steps :int =50 ,
cfg_scale :float =7.5 ,
temporal_compression :int =4 ,
spatial_compression :int =8 ,
causal :bool =True ,
use_fsq :bool =False ,
):
super ().__init__ ()
self .latent_channels =latent_channels
self .context_dim =context_dim
self .num_frames =num_frames
self .image_size =image_size
self .temporal_compression =temporal_compression
self .spatial_compression =spatial_compression
self .latent_size =image_size //spatial_compression
self .latent_frames =num_frames //temporal_compression
self .num_inference_steps =num_inference_steps
self .cfg_scale =cfg_scale
self .vae =VideoVAE3D (
in_channels =3 ,
latent_channels =latent_channels ,
base_channels =base_channels ,
temporal_compression =temporal_compression ,
spatial_compression =spatial_compression ,
causal =causal ,
use_fsq =use_fsq ,
)
self .unet =VideoUNet3D (
in_channels =latent_channels ,
out_channels =latent_channels ,
hidden_size =base_channels *4 ,
context_dim =context_dim ,
num_layers =4 ,
num_heads =8 ,
num_experts =4 ,
num_frames =num_frames ,
max_height =self .latent_size ,
max_width =self .latent_size ,
)
self .scheduler =FlowMatchingScheduler (num_inference_steps )
def encode_video (self ,video :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
return self .vae .encode (video *2 -1 )
def decode_video (self ,z :torch .Tensor )->torch .Tensor :
return self .vae .decode (z )
def encode_image (self ,image :torch .Tensor )->torch .Tensor :
image_expanded =image .unsqueeze (2 )
z ,_ ,_ =self .vae .encode (image_expanded )
return z .squeeze (2 )
def training_step (self ,video :torch .Tensor ,context :torch .Tensor ,first_frame :Optional [torch .Tensor ]=None )->dict :
device =video .device
dtype =video .dtype
batch_size =video .shape [0 ]
z ,mean ,logvar =self .encode_video (video )
del video
t =torch .rand (batch_size ,device =device ,dtype =dtype )
x_t =self .scheduler .add_noise (z ,t )
target_velocity =self .scheduler .get_velocity (x_t ,z ,t )
if self .training :
drop_mask =torch .rand (batch_size ,device =device )<0.1
drop_mask_expanded =drop_mask .view (batch_size ,1 ,1 ).expand_as (context )
null_ctx =torch .zeros_like (context )
context =torch .where (drop_mask_expanded ,null_ctx ,context )
del drop_mask ,drop_mask_expanded ,null_ctx
pred_velocity =self .unet (x_t ,(t *1000 ).to (dtype ),context ,None )
del x_t ,context
flow_loss =F .mse_loss (pred_velocity ,target_velocity )
del pred_velocity ,target_velocity
kl_loss =-0.5 *torch .mean (1 +logvar -mean .pow (2 )-logvar .exp ())
temporal_loss =torch .tensor (0.0 ,device =device ,dtype =dtype )
if z .shape [2 ]>1 :
z_diff =z [:,:,1 :]-z [:,:,:-1 ]
temporal_loss =torch .mean (z_diff **2 )
del z_diff
del z ,mean ,logvar
total_loss =flow_loss +0.0001 *kl_loss +0.01 *temporal_loss
return {
'flow_loss':flow_loss ,
'kl_loss':kl_loss ,
'temporal_loss':temporal_loss ,
'total_loss':total_loss ,
}
@torch .no_grad ()
def generate_t2v (self ,context :torch .Tensor ,num_frames :int =None ,guidance_scale :float =None ,num_steps :int =None )->torch .Tensor :
device =context .device
batch_size =context .shape [0 ]
seq_len =context .shape [1 ]
num_frames =num_frames or self .num_frames
guidance_scale =guidance_scale or self .cfg_scale
num_steps =num_steps or self .num_inference_steps
latents =torch .randn (
batch_size ,self .latent_channels ,num_frames ,
self .latent_size ,self .latent_size ,device =device
)
timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device )
if guidance_scale >1.0 :
null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype )
context =torch .cat ([null_ctx ,context ])
for i in range (num_steps ):
t =timesteps [i ]
t_prev =timesteps [i +1 ]
t_batch =t .expand (batch_size )*1000
if guidance_scale >1.0 :
latent_input =torch .cat ([latents ,latents ])
t_input =torch .cat ([t_batch ,t_batch ])
velocity_pred =self .unet (latent_input ,t_input ,context ,None )
velocity_uncond ,velocity_cond =velocity_pred .chunk (2 )
velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond )
else :
velocity_pred =self .unet (latents ,t_batch ,context ,None )
latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents )
video =self .decode_video (latents )
return torch .clamp ((video +1 )/2 ,0 ,1 )
@torch .no_grad ()
def generate_i2v (self ,first_frame :torch .Tensor ,context :Optional [torch .Tensor ]=None ,num_frames :int =None ,guidance_scale :float =None ,num_steps :int =None )->torch .Tensor :
device =first_frame .device
batch_size =first_frame .shape [0 ]
num_frames =num_frames or self .num_frames
guidance_scale =guidance_scale or self .cfg_scale
num_steps =num_steps or self .num_inference_steps
first_frame_latent =self .encode_image (first_frame *2 -1 )
latents =torch .randn (
batch_size ,self .latent_channels ,num_frames ,
self .latent_size ,self .latent_size ,device =device
)
latents [:,:,0 ]=first_frame_latent
if context is None :
context =torch .zeros (batch_size ,77 ,self .context_dim ,device =device )
seq_len =context .shape [1 ]
timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device )
if guidance_scale >1.0 :
null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype )
context =torch .cat ([null_ctx ,context ])
for i in range (num_steps ):
t =timesteps [i ]
t_prev =timesteps [i +1 ]
t_batch =t .expand (batch_size )*1000
if guidance_scale >1.0 :
latent_input =torch .cat ([latents ,latents ])
t_input =torch .cat ([t_batch ,t_batch ])
velocity_pred =self .unet (latent_input ,t_input ,context ,None )
velocity_uncond ,velocity_cond =velocity_pred .chunk (2 )
velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond )
else :
velocity_pred =self .unet (latents ,t_batch ,context ,None )
latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents )
latents [:,:,0 ]=first_frame_latent
video =self .decode_video (latents )
return torch .clamp ((video +1 )/2 ,0 ,1 )
==============================================================================
MODELS.LLM.MOE_LLAMA
==============================================================================
EPS =1e-5
class YaRNRotaryEmbedding (nn .Module ):
"""
YaRN (Yet another RoPE extensioN) with LongRoPE-style improvements.
Supports up to 128K+ context with proper frequency scaling.
"""
def __init__ (
self ,
dim :int ,
max_position_embeddings :int =131072 ,
base :float =500000.0 ,
original_max_position_embeddings :int =8192 ,
beta_fast :float =32.0 ,
beta_slow :float =1.0 ,
mscale :float =1.0 ,
):
super ().__init__ ()
self .dim =dim
self .max_position_embeddings =max_position_embeddings
self .base =base
self .original_max_position =original_max_position_embeddings
self .beta_fast =beta_fast
self .beta_slow =beta_slow
self .mscale =mscale
self .scaling_factor =max_position_embeddings /original_max_position_embeddings
inv_freq =self ._compute_yarn_inv_freq ()
self .register_buffer ('inv_freq',inv_freq ,persistent =False )
def _compute_yarn_inv_freq (self )->torch .Tensor :
"""Compute YaRN-scaled inverse frequencies."""
pos_freqs =self .base **(torch .arange (0 ,self .dim ,2 ,dtype =torch .float32 )/self .dim )
inv_freq_extrapolation =1.0 /pos_freqs
inv_freq_interpolation =1.0 /(self .scaling_factor *pos_freqs )
low =max (math .floor (self .dim *math .log (self .original_max_position /(self .beta_fast *2 *math .pi ))/
(2 *math .log (self .base ))),0 )
high =min (math .ceil (self .dim *math .log (self .original_max_position /(self .beta_slow *2 *math .pi ))/
(2 *math .log (self .base ))),self .dim -1 )
inv_freq =torch .zeros (self .dim //2 ,dtype =torch .float32 )
for i in range (self .dim //2 ):
if i <low :
inv_freq [i ]=inv_freq_interpolation [i ]
elif i >high :
inv_freq [i ]=inv_freq_extrapolation [i ]
else :
smooth =(i -low )/max (high -low ,1 )
inv_freq [i ]=(1 -smooth )*inv_freq_interpolation [i ]+smooth *inv_freq_extrapolation [i ]
return inv_freq
def _get_mscale (self ,scale :float )->float :
"""Get attention scaling factor for YaRN."""
if scale <=1 :
return 1.0
return 0.1 *math .log (scale )+1.0
def forward (self ,x :torch .Tensor ,position_ids :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
device =x .device
inv_freq =self .inv_freq .to (device )
inv_freq_expanded =inv_freq [None ,:,None ].float ().expand (position_ids .shape [0 ],-1 ,1 )
position_ids_expanded =position_ids [:,None ,:].float ()
freqs =(inv_freq_expanded @position_ids_expanded ).transpose (1 ,2 )
emb =torch .cat ((freqs ,freqs ),dim =-1 )
mscale =self ._get_mscale (self .scaling_factor )*self .mscale
cos =emb .cos ().to (dtype =x .dtype )*mscale
sin =emb .sin ().to (dtype =x .dtype )*mscale
return cos ,sin
LlamaRotaryEmbedding =YaRNRotaryEmbedding
def rotate_half (x :torch .Tensor )->torch .Tensor :
x1 =x [...,:x .shape [-1 ]//2 ]
x2 =x [...,x .shape [-1 ]//2 :]
return torch .cat ((-x2 ,x1 ),dim =-1 )
def apply_rotary_pos_emb (
q :torch .Tensor ,
k :torch .Tensor ,
cos :torch .Tensor ,
sin :torch .Tensor ,
position_ids :Optional [torch .Tensor ]=None ,
unsqueeze_dim :int =1 ,
)->Tuple [torch .Tensor ,torch .Tensor ]:
cos =cos .unsqueeze (unsqueeze_dim )
sin =sin .unsqueeze (unsqueeze_dim )
q_embed =(q *cos )+(rotate_half (q )*sin )
k_embed =(k *cos )+(rotate_half (k )*sin )
return q_embed ,k_embed
class KVCache :
"""Pre-allocated KV Cache — static buffer with index-based filling.
Eliminates VRAM fragmentation from torch.cat during autoregressive generation.
Buffer is allocated once at first use and reused via slice assignment.
"""
__slots__ =('key_cache','value_cache','seen_tokens','_max_len')
def __init__ (
self ,
key_cache :torch .Tensor =None ,
value_cache :torch .Tensor =None ,
seen_tokens :int =0 ,
max_seq_len :int =131072 ,
):
self .key_cache =key_cache
self .value_cache =value_cache
self .seen_tokens =seen_tokens
self ._max_len =max_seq_len
def _allocate (self ,batch :int ,heads :int ,head_dim :int ,device :torch .device ,dtype :torch .dtype ):
"""Allocate static buffer on first use."""
self .key_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype )
self .value_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype )
def update (
self ,
key_states :torch .Tensor ,
value_states :torch .Tensor ,
chunk_size :Optional [int ]=None ,
)->Tuple [torch .Tensor ,torch .Tensor ]:
batch ,heads ,new_len ,head_dim =key_states .shape
if self .key_cache is None :
self ._allocate (batch ,heads ,head_dim ,key_states .device ,key_states .dtype )
self .seen_tokens =0
if chunk_size is not None and self .seen_tokens +new_len >chunk_size *2 :
keep =chunk_size
if self .seen_tokens >keep :
self .key_cache [:,:,:keep ]=self .key_cache [:,:,self .seen_tokens -keep :self .seen_tokens ].clone ()
self .value_cache [:,:,:keep ]=self .value_cache [:,:,self .seen_tokens -keep :self .seen_tokens ].clone ()
self .seen_tokens =keep
if self .seen_tokens +new_len >self .key_cache .shape [2 ]:
new_max =max (self .key_cache .shape [2 ]*2 ,self .seen_tokens +new_len )
new_key =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype )
new_val =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype )
new_key [:,:,:self .seen_tokens ]=self .key_cache [:,:,:self .seen_tokens ]
new_val [:,:,:self .seen_tokens ]=self .value_cache [:,:,:self .seen_tokens ]
self .key_cache =new_key
self .value_cache =new_val
self .key_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=key_states
self .value_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=value_states
self .seen_tokens +=new_len
return self .key_cache [:,:,:self .seen_tokens ],self .value_cache [:,:,:self .seen_tokens ]
def reset (self ):
"""Reset cache position without deallocating the buffer."""
self .seen_tokens =0
def ring_attention (
query :torch .Tensor ,
key :torch .Tensor ,
value :torch .Tensor ,
chunk_size :int =4096 ,
causal :bool =True ,
)->torch .Tensor :
"""
Ring Attention for distributed long-context processing.
Processes sequence in chunks with online softmax accumulation.
Args:
query: [batch, heads, seq_len, head_dim]
key: [batch, heads, kv_len, head_dim]
value: [batch, heads, kv_len, head_dim]
chunk_size: Size of each attention chunk
causal: Whether to apply causal masking
Returns:
Output tensor [batch, heads, seq_len, head_dim]
"""
batch_size ,num_heads ,seq_len ,head_dim =query .shape
kv_len =key .shape [2 ]
if seq_len <=chunk_size and kv_len <=chunk_size :
qk_scale =head_dim **-0.25
use_causal =causal and seq_len ==kv_len and seq_len >1
if use_causal :
return F .scaled_dot_product_attention (
query *qk_scale ,key *qk_scale ,value ,
is_causal =True ,scale =1.0 ,
)
elif causal and kv_len >seq_len :
causal_mask =torch .zeros (seq_len ,kv_len ,device =query .device ,dtype =query .dtype )
q_pos =torch .arange (seq_len ,device =query .device )+(kv_len -seq_len )
k_pos =torch .arange (kv_len ,device =query .device )
causal_mask =torch .where (k_pos .unsqueeze (0 )>q_pos .unsqueeze (1 ),float ('-inf'),0.0 )
return F .scaled_dot_product_attention (
query *qk_scale ,key *qk_scale ,value ,
attn_mask =causal_mask ,scale =1.0 ,
)
else :
return F .scaled_dot_product_attention (
query *qk_scale ,key *qk_scale ,value ,
is_causal =False ,scale =1.0 ,
)
scale =head_dim **-0.5
output =torch .zeros_like (query )
max_logits =torch .full ((batch_size ,num_heads ,seq_len ,1 ),float ('-inf'),device =query .device ,dtype =query .dtype )
sum_exp =torch .zeros ((batch_size ,num_heads ,seq_len ,1 ),device =query .device ,dtype =query .dtype )
if causal :
q_positions =torch .arange (seq_len ,device =query .device )
if kv_len >seq_len :
q_positions =q_positions +(kv_len -seq_len )
num_kv_chunks =(kv_len +chunk_size -1 )//chunk_size
for kv_idx in range (num_kv_chunks ):
kv_start =kv_idx *chunk_size
kv_end =min ((kv_idx +1 )*chunk_size ,kv_len )
key_chunk =key [:,:,kv_start :kv_end ,:]
value_chunk =value [:,:,kv_start :kv_end ,:]
attn_chunk =torch .matmul (query ,key_chunk .transpose (-1 ,-2 ))*scale
if causal :
k_positions =torch .arange (kv_start ,kv_end ,device =query .device )
causal_mask =k_positions .unsqueeze (0 )>q_positions .unsqueeze (1 )
attn_chunk =attn_chunk .masked_fill (causal_mask .unsqueeze (0 ).unsqueeze (0 ),float ('-inf'))
chunk_max =attn_chunk .max (dim =-1 ,keepdim =True )[0 ]
new_max =torch .maximum (max_logits ,chunk_max )
exp_weights =torch .exp (attn_chunk -new_max )
exp_sum_chunk =exp_weights .sum (dim =-1 ,keepdim =True )
correction =torch .exp (max_logits -new_max )
output =output *correction +torch .matmul (exp_weights ,value_chunk )
sum_exp =sum_exp *correction +exp_sum_chunk
max_logits =new_max
output =output /(sum_exp +EPS )
return output
class MultiHeadLatentAttention (nn .Module ):
"""
Multi-Head Latent Attention (MLA) from DeepSeek-V2.
Compresses KV cache using low-rank projections for memory efficiency.
"""
def __init__ (
self ,
hidden_size :int ,
num_heads :int ,
num_kv_heads :int =None ,
head_dim :int =None ,
kv_lora_rank :int =512 ,
q_lora_rank :int =0 ,
rope_theta :float =500000.0 ,
max_position_embeddings :int =131072 ,
use_ring_attention :bool =True ,
ring_chunk_size :int =4096 ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_heads =num_heads
self .num_kv_heads =num_kv_heads or num_heads
self .head_dim =head_dim or hidden_size //num_heads
self .kv_lora_rank =kv_lora_rank
self .q_lora_rank =q_lora_rank
self .use_ring_attention =use_ring_attention
self .ring_chunk_size =ring_chunk_size
self .num_key_value_groups =self .num_heads //self .num_kv_heads
self .scale =self .head_dim **-0.5
if q_lora_rank >0 :
self .q_a_proj =nn .Linear (hidden_size ,q_lora_rank ,bias =False )
self .q_b_proj =nn .Linear (q_lora_rank ,num_heads *self .head_dim ,bias =False )
self .q_a_layernorm =LlamaRMSNorm (q_lora_rank )
else :
self .q_proj =nn .Linear (hidden_size ,num_heads *self .head_dim ,bias =False )
self .kv_a_proj =nn .Linear (hidden_size ,kv_lora_rank +self .head_dim ,bias =False )
self .kv_b_proj =nn .Linear (kv_lora_rank ,self .num_kv_heads *self .head_dim *2 ,bias =False )
self .kv_a_layernorm =LlamaRMSNorm (kv_lora_rank )
self .o_proj =nn .Linear (num_heads *self .head_dim ,hidden_size ,bias =False )
self .rotary_emb =YaRNRotaryEmbedding (
dim =self .head_dim ,
max_position_embeddings =max_position_embeddings ,
base =rope_theta ,
)
self ._init_weights ()
def _init_weights (self ):
std =0.02
for name ,module in self .named_modules ():
if isinstance (module ,nn .Linear ):
nn .init .normal_ (module .weight ,mean =0.0 ,std =std )
def forward (
self ,
hidden_states :torch .Tensor ,
attention_mask :Optional [torch .Tensor ]=None ,
position_ids :Optional [torch .Tensor ]=None ,
past_key_value :Optional [KVCache ]=None ,
output_attentions :bool =False ,
use_cache :bool =False ,
)->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [KVCache ]]:
batch_size ,seq_len ,_ =hidden_states .shape
if self .q_lora_rank >0 :
q_compressed =self .q_a_layernorm (self .q_a_proj (hidden_states ))
query_states =self .q_b_proj (q_compressed )
else :
query_states =self .q_proj (hidden_states )
kv_compressed =self .kv_a_proj (hidden_states )
kv_latent ,k_pe =kv_compressed .split ([self .kv_lora_rank ,self .head_dim ],dim =-1 )
kv_latent =self .kv_a_layernorm (kv_latent )
kv_states =self .kv_b_proj (kv_latent )
query_states =query_states .view (batch_size ,seq_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 )
key_states ,value_states =kv_states .split (self .num_kv_heads *self .head_dim ,dim =-1 )
key_states =key_states .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 )
value_states =value_states .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 )
if position_ids is None :
position_ids =torch .arange (seq_len ,device =hidden_states .device ).unsqueeze (0 ).expand (batch_size ,-1 )
if past_key_value is not None and past_key_value .seen_tokens >0 :
position_ids =position_ids +past_key_value .seen_tokens
cos ,sin =self .rotary_emb (hidden_states ,position_ids )
query_states ,key_states =apply_rotary_pos_emb (query_states ,key_states ,cos ,sin )
if past_key_value is not None :
key_states ,value_states =past_key_value .update (
key_states ,value_states ,
self .ring_chunk_size if self .use_ring_attention else None
)
if self .use_ring_attention :
if self .num_key_value_groups >1 :
key_expanded =key_states .repeat_interleave (self .num_key_value_groups ,dim =1 )
value_expanded =value_states .repeat_interleave (self .num_key_value_groups ,dim =1 )
else :
key_expanded =key_states
value_expanded =value_states
attn_output =ring_attention (
query_states ,key_expanded ,value_expanded ,
chunk_size =self .ring_chunk_size ,
causal =True ,
)
else :
qk_scale =self .head_dim **-0.25
kv_len =key_states .shape [2 ]
use_causal =(attention_mask is None and seq_len >1 and seq_len ==kv_len )
attn_output =F .scaled_dot_product_attention (
query_states *qk_scale ,
key_states *qk_scale ,
value_states ,
attn_mask =attention_mask ,
is_causal =use_causal ,
scale =1.0 ,
enable_gqa =(self .num_key_value_groups >1 ),
)
attn_output =attn_output .transpose (1 ,2 ).contiguous ().view (batch_size ,seq_len ,-1 )
attn_output =self .o_proj (attn_output )
return attn_output ,None ,past_key_value if use_cache else None
class AuxLosslessMoERouter (nn .Module ):
"""
Aux-Lossless MoE Router with Shared Expert Isolation.
Eliminates auxiliary loss while maintaining load balance through architecture.
"""
def __init__ (
self ,
hidden_size :int ,
num_experts :int ,
top_k :int =2 ,
norm_topk_prob :bool =True ,
):
super ().__init__ ()
self .num_experts =num_experts
self .top_k =top_k
self .norm_topk_prob =norm_topk_prob
self .input_norm =LlamaRMSNorm (hidden_size )
self .gate =nn .Linear (hidden_size ,num_experts ,bias =False )
nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 )
self .expert_bias =nn .Parameter (torch .zeros (num_experts ))
# Deep experts gate (4 deep experts)
self .num_deep_experts = 4
self .deep_gate = nn .Linear (hidden_size , self .num_deep_experts , bias =False )
nn .init .normal_ (self .deep_gate .weight , mean =0.0 , std =0.01 )
self .deep_expert_bias = nn .Parameter (torch .zeros (self .num_deep_experts ))
def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
batch_size ,seq_len ,hidden_dim =hidden_states .shape
hidden_flat =hidden_states .view (-1 ,hidden_dim )
hidden_norm =self .input_norm (hidden_flat )
# Standard experts
router_logits_std =self .gate (hidden_norm )
biased_logits_std =router_logits_std +self .expert_bias
# Deep experts
router_logits_deep = self .deep_gate (hidden_norm )
biased_logits_deep = router_logits_deep + self .deep_expert_bias
# Concatenate: [batch*seq, num_experts + num_deep_experts]
router_logits = torch .cat ([biased_logits_std , biased_logits_deep ], dim =-1 )
router_probs =F .softmax (router_logits ,dim =-1 ,dtype =hidden_states .dtype )
top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 )
if self .norm_topk_prob :
top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS )
return top_k_probs ,top_k_indices ,router_logits
class MoEExpert (nn .Module ):
"""Single MoE Expert with SwiGLU activation."""
def __init__ (self ,hidden_size :int ,intermediate_size :int ):
super ().__init__ ()
self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False )
self .act_fn =nn .SiLU ()
self ._init_weights ()
def _init_weights (self ):
std =0.02
nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 )
def forward (self ,x :torch .Tensor )->torch .Tensor :
return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x ))
class DeepMoEExpert (nn .Module ):
"""Deep MoE Expert with multiple sequential SwiGLU transformations."""
def __init__ (self ,hidden_size :int ,intermediate_size :int ,depth :int =2 ):
super ().__init__ ()
self .depth = depth
self .gate_projs = nn .ModuleList ([nn .Linear (hidden_size if i == 0 else intermediate_size , intermediate_size , bias =False ) for i in range (depth )])
self .up_projs = nn .ModuleList ([nn .Linear (hidden_size if i == 0 else intermediate_size , intermediate_size , bias =False ) for i in range (depth )])
self .down_projs = nn .ModuleList ([nn .Linear (intermediate_size , intermediate_size if i < depth - 1 else hidden_size , bias =False ) for i in range (depth )])
self .act_fn = nn .SiLU ()
self ._init_weights ()
def _init_weights (self ):
std =0.02
for g , u , d in zip (self .gate_projs , self .up_projs , self .down_projs ):
nn .init .normal_ (g .weight ,mean =0.0 ,std =std )
nn .init .normal_ (u .weight ,mean =0.0 ,std =std )
nn .init .normal_ (d .weight ,mean =0.0 ,std =std *0.5 )
def forward (self ,x :torch .Tensor )->torch .Tensor :
for i in range (self .depth ):
# Optional residual connection if intermediate sizes match, but standard SwiGLU doesn't usually use them internally unless specified.
# We'll stick to sequential application as defined: Input -> SwiGLU -> SwiGLU ... -> DownProj
gate = self .act_fn (self .gate_projs [i ](x ))
up = self .up_projs [i ](x )
x = self .down_projs [i ](gate * up )
return x
class IsolatedSharedExpert (nn .Module ):
"""
Isolated Shared Expert that always processes all tokens.
Separate from routed experts to prevent competition.
"""
def __init__ (self ,hidden_size :int ,intermediate_size :int ):
super ().__init__ ()
self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False )
self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False )
self .act_fn =nn .SiLU ()
self ._init_weights ()
def _init_weights (self ):
std =0.02
nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std )
nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 )
def forward (self ,x :torch .Tensor )->torch .Tensor :
return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x ))
class AuxLosslessMoELayer (nn .Module ):
"""
Aux-Lossless MoE Layer with Isolated Shared Expert.
No auxiliary loss needed - load balance maintained through isolation.
"""
def __init__ (
self ,
hidden_size :int ,
intermediate_size :int ,
num_experts :int =8 ,
num_experts_per_tok :int =2 ,
shared_expert_intermediate_size :int =None ,
):
super ().__init__ ()
self .hidden_size =hidden_size
self .num_experts =num_experts
self .num_experts_per_tok =num_experts_per_tok
self .router =AuxLosslessMoERouter (hidden_size ,num_experts ,num_experts_per_tok )
self .experts =nn .ModuleList ([
MoEExpert (hidden_size ,intermediate_size )
for _ in range (num_experts )
])
# Deep Experts: Depths 2, 3, 4, 5
self .num_deep_experts = 4
self .deep_experts = nn .ModuleList ([
DeepMoEExpert (hidden_size , intermediate_size , depth =d )
for d in range (2 , 6 )
])
shared_size =shared_expert_intermediate_size or intermediate_size
self .shared_expert =IsolatedSharedExpert (hidden_size ,shared_size )
def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]:
batch_size ,seq_len ,hidden_size =hidden_states .shape
original_dtype =hidden_states .dtype
hidden_flat =hidden_states .view (-1 ,hidden_size )
num_tokens =hidden_flat .shape [0 ]
top_k_probs ,top_k_indices ,router_logits =self .router (hidden_states )
if hasattr (self ,'_utilization_tracker'):
self ._utilization_tracker .record (top_k_indices )
final_output =torch .zeros_like (hidden_flat )
total_experts = self .num_experts + self .num_deep_experts
for expert_idx in range (total_experts ):
# Determine which expert list to use
if expert_idx < self .num_experts :
expert =self .experts [expert_idx ]
else :
expert =self .deep_experts [expert_idx - self .num_experts ]
for k in range (self .num_experts_per_tok ):
mask =(top_k_indices [:,k ]==expert_idx )
if mask .any ():
expert_input =hidden_flat [mask ]
expert_output =expert (expert_input )
weight =top_k_probs [mask ,k :k +1 ]
weighted_output =(weight *expert_output ).to (original_dtype )
final_output [mask ]=final_output [mask ]+weighted_output
shared_output =self .shared_expert (hidden_flat )
final_output =final_output +shared_output .to (original_dtype )
final_output =final_output .view (batch_size ,seq_len ,hidden_size )
aux_loss =self ._compute_aux_loss (router_logits ,top_k_indices ,num_tokens )
return final_output ,aux_loss
def _compute_aux_loss (
self ,
router_logits :torch .Tensor ,
top_k_indices :torch .Tensor ,
num_tokens :int ,
)->torch .Tensor :
"""
Aux-lossless auxiliary loss.
Uses z-loss to keep router logits from growing unboundedly (FP16 stability),
plus a soft utilization penalty that activates only when experts go completely
cold. The expert_bias parameter handles routine load balancing.
"""
z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.0001
# Add penalty for choosing deep experts
# Depths are 2, 3, 4, 5 for indices (num_experts) to (num_experts + 3)
# Cost is roughly proportional to depth
deep_penalty = torch .tensor (0.0 , device =router_logits .device , dtype =router_logits .dtype )
# Calculate how often each deep expert was selected
# top_k_indices shape: [batch*seq, top_k]
for i in range (self .num_deep_experts ):
expert_idx = self .num_experts + i
depth = i + 2 # depths 2, 3, 4, 5
# Count how many times this deep expert was chosen in top-k
selection_count = (top_k_indices == expert_idx ).sum ()
# Simple penalty: deeper experts cost more
# Multiplied by a small scalar to act as a soft deterrent
# The model must truly need the depth to offset this loss increase
deep_penalty += selection_count .float () * depth * 0.00005
return z_loss + deep_penalty
expert_mask =F .one_hot (top_k_indices ,self .num_experts ).float ()
tokens_per_expert =expert_mask .sum (dim =(0 ,1 ))
fraction_used =(tokens_per_expert >0 ).float ().mean ()
utilization_loss =(1.0 -fraction_used )*0.01
return z_loss +utilization_loss
MoELayer =AuxLosslessMoELayer
class MoELlamaDecoderLayer (nn .Module ):
"""Decoder layer with MLA and Aux-Lossless MoE."""
def __init__ (self ,config ,layer_idx :int ,moe_config :dict =None ):
super ().__init__ ()
self .hidden_size =config .hidden_size
self .layer_idx =layer_idx
use_ring =getattr (config ,'use_ring_attention',True )
ring_chunk =getattr (config ,'ring_attention_chunk_size',4096 )
num_kv_heads =getattr (config ,'num_key_value_heads',config .num_attention_heads //4 )
self .self_attn =MultiHeadLatentAttention (
hidden_size =config .hidden_size ,
num_heads =config .num_attention_heads ,
num_kv_heads =num_kv_heads ,
rope_theta =getattr (config ,'rope_theta',500000.0 ),
max_position_embeddings =config .max_position_embeddings ,
use_ring_attention =use_ring ,
ring_chunk_size =ring_chunk ,
)
self .input_layernorm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps )
self .post_attention_layernorm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps )
self .use_moe =moe_config and moe_config .get ('use_moe',False )
moe_freq =moe_config .get ('moe_layer_freq',2 )if moe_config else 2
if self .use_moe and layer_idx %moe_freq ==(moe_freq -1 ):
self .mlp =AuxLosslessMoELayer (
hidden_size =config .hidden_size ,
intermediate_size =moe_config .get ('intermediate_size',config .intermediate_size ),
num_experts =moe_config .get ('num_experts',8 ),
num_experts_per_tok =moe_config .get ('num_experts_per_tok',2 ),
)
self .is_moe_layer =True
else :
self .mlp =MoEExpert (config .hidden_size ,config .intermediate_size )
self .is_moe_layer =False
def forward (
self ,
hidden_states :torch .Tensor ,
attention_mask :Optional [torch .Tensor ]=None ,
position_ids :Optional [torch .Tensor ]=None ,
past_key_value :Optional [KVCache ]=None ,
output_attentions :bool =False ,
use_cache :bool =False ,
)->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [KVCache ],Optional [torch .Tensor ]]:
residual =hidden_states
hidden_states =self .input_layernorm (hidden_states )
hidden_states ,_ ,present_key_value =self .self_attn (
hidden_states =hidden_states ,
attention_mask =attention_mask ,
position_ids =position_ids ,
past_key_value =past_key_value ,
output_attentions =output_attentions ,
use_cache =use_cache ,
)
hidden_states =residual +hidden_states
residual =hidden_states
hidden_states =self .post_attention_layernorm (hidden_states )
aux_loss =None
if self .is_moe_layer :
hidden_states ,aux_loss =self .mlp (hidden_states )
else :
hidden_states =self .mlp (hidden_states )
hidden_states =residual +hidden_states
return hidden_states ,None ,present_key_value ,aux_loss
@dataclass
class MoELlamaModelOutput :
last_hidden_state :torch .Tensor
past_key_values :Optional [List [KVCache ]]=None
hidden_states :Optional [Tuple [torch .Tensor ]]=None
attentions :Optional [Tuple [torch .Tensor ]]=None
aux_loss :Optional [torch .Tensor ]=None
class MoELlamaModel (nn .Module ):
"""MoE LLaMA Model with MLA and Ring Attention."""
def __init__ (self ,config ,moe_config :dict =None ):
super ().__init__ ()
self .config =config
self .moe_config =moe_config
self .gradient_checkpointing =False
self .embed_tokens =nn .Embedding (config .vocab_size ,config .hidden_size )
self .layers =nn .ModuleList ([
MoELlamaDecoderLayer (config ,layer_idx ,moe_config )
for layer_idx in range (config .num_hidden_layers )
])
self .norm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps )
self .num_moe_layers =sum (1 for layer in self .layers if layer .is_moe_layer )
# ── Coconut: Continuous Thought components ──
# Learned gate controls how much recurrent thought vs original input
# to retain at each thinking step. Sigmoid output in [0,1].
self .thought_gate = nn .Linear (config .hidden_size , 1 , bias =True )
nn .init .constant_ (self .thought_gate .bias , -2.0 ) # Initialize gate biased toward original (sigmoid(-2)≈0.12)
self .thought_layernorm = LlamaRMSNorm (config .hidden_size , eps =config .rms_norm_eps )
# Halt head: dynamically decides when to stop thinking
self .thought_halt_head = nn .Linear (config .hidden_size , 1 , bias =True )
nn .init .constant_ (self .thought_halt_head .bias , -2.0 ) # Biased toward continuing to think initially
# Fast Ponder Block for hyper-efficient 10x faster latent reasoning
# Bypasses O(N^2) attention, uses pure deep SwiGLU logic
self .fast_ponder_block = DeepMoEExpert (config .hidden_size , config .intermediate_size , depth =3 )
self ._init_weights ()
def _init_weights (self ):
nn .init .normal_ (self .embed_tokens .weight ,mean =0.0 ,std =0.02 )
def gradient_checkpointing_enable (self ):
"""Enable gradient checkpointing for memory efficiency."""
self .gradient_checkpointing =True
def gradient_checkpointing_disable (self ):
"""Disable gradient checkpointing."""
self .gradient_checkpointing =False
def forward (
self ,
input_ids :Optional [torch .Tensor ]=None ,
attention_mask :Optional [torch .Tensor ]=None ,
position_ids :Optional [torch .Tensor ]=None ,
inputs_embeds :Optional [torch .Tensor ]=None ,
past_key_values :Optional [List [KVCache ]]=None ,
use_cache :bool =False ,
output_attentions :bool =False ,
output_hidden_states :bool =False ,
return_dict :bool =True ,
cache_position :Optional [torch .Tensor ]=None ,
thinking_depth :int =0 ,
)->Union [Tuple ,MoELlamaModelOutput ]:
if inputs_embeds is None :
inputs_embeds =self .embed_tokens (input_ids )
hidden_states =inputs_embeds
batch_size ,seq_len =hidden_states .shape [:2 ]
if position_ids is None :
position_ids =torch .arange (seq_len ,device =hidden_states .device ).unsqueeze (0 ).expand (batch_size ,-1 )
if past_key_values is None :
past_key_values =[None ]*len (self .layers )
all_hidden_states =()if output_hidden_states else None
all_attentions =()if output_attentions else None
next_cache =[]if use_cache else None
total_aux_loss =torch .tensor (0.0 ,device =hidden_states .device ,dtype =hidden_states .dtype )
for idx ,layer in enumerate (self .layers ):
if output_hidden_states :
all_hidden_states =all_hidden_states +(hidden_states ,)
if self .gradient_checkpointing and self .training and not use_cache :
def create_custom_forward (module ):
def custom_forward (*inputs ):
return module (*inputs )
return custom_forward
layer_outputs =torch .utils .checkpoint .checkpoint (
create_custom_forward (layer ),
hidden_states ,
attention_mask ,
position_ids ,
past_key_values [idx ],
output_attentions ,
use_cache ,
use_reentrant =False ,
)
hidden_states ,attn_weights ,present_key_value ,aux_loss =layer_outputs
else :
hidden_states ,attn_weights ,present_key_value ,aux_loss =layer (
hidden_states =hidden_states ,
attention_mask =attention_mask ,
position_ids =position_ids ,
past_key_value =past_key_values [idx ],
output_attentions =output_attentions ,
use_cache =use_cache ,
)
if use_cache :
next_cache .append (present_key_value )
if aux_loss is not None :
total_aux_loss =total_aux_loss +aux_loss
if output_attentions and attn_weights is not None :
all_attentions =all_attentions +(attn_weights ,)
# ── Coconut: Continuous Thought Loop ──
# After the normal pass, loop hidden states back through the
# transformer layers for extra computation in latent space.
# No tokens are decoded — pure continuous reasoning.
if thinking_depth > 0 :
original_hidden = hidden_states .clone ()
thought_position_ids = torch .arange (
seq_len , device =hidden_states .device
).unsqueeze (0 ).expand (batch_size , -1 )
for thought_step in range (thinking_depth ):
# Check if we should halt thinking (only during inference or if forced)
# We evaluate the halt head on the *current* hidden state of the last token
halt_logits = self .thought_halt_head (hidden_states [:, -1:, :])
halt_prob = torch .sigmoid (halt_logits )
# If during generation we decide to stop, break early
if not self .training and (halt_prob > 0.5 ).all ():
break
# Normalize before processing
hidden_states = self .thought_layernorm (hidden_states )
# Run purely through the attention-free fast ponder block
# This achieves ~10x speedup by completely bypassing the O(N^2) self-attention stack
hidden_states = self .fast_ponder_block (hidden_states )
# Gated residual: blend thought with original
# gate ∈ [0,1], initialized small so early training
# stays close to original behavior
gate = torch .sigmoid (self .thought_gate (hidden_states ))
hidden_states = gate * hidden_states + (1.0 - gate ) * original_hidden
hidden_states =self .norm (hidden_states )
if output_hidden_states :
all_hidden_states =all_hidden_states +(hidden_states ,)
return MoELlamaModelOutput (
last_hidden_state =hidden_states ,
past_key_values =next_cache if use_cache else None ,
hidden_states =all_hidden_states ,
attentions =all_attentions ,
aux_loss =total_aux_loss ,
)
@dataclass
class CausalLMOutput :
loss :Optional [torch .Tensor ]=None
logits :torch .Tensor =None
past_key_values :Optional [List [KVCache ]]=None
hidden_states :Optional [Tuple [torch .Tensor ]]=None
attentions :Optional [Tuple [torch .Tensor ]]=None
aux_loss :Optional [torch .Tensor ]=None
class MoELlamaForCausalLM (nn .Module ):
"""MoE LLaMA for Causal Language Modeling with MLA and Ring Attention."""
def __init__ (self ,config ,moe_config :dict =None ):
super ().__init__ ()
self .config =config
self .moe_config =moe_config
self .model =MoELlamaModel (config ,moe_config )
self .lm_head =nn .Linear (config .hidden_size ,config .vocab_size ,bias =False )
if getattr (config ,'tie_word_embeddings',True ):
self .lm_head .weight =self .model .embed_tokens .weight
self .apply (self ._init_weights )
def _init_weights (self ,module ):
std =0.02
if isinstance (module ,nn .Linear ):
nn .init .normal_ (module .weight ,mean =0.0 ,std =std )
if module .bias is not None :
nn .init .zeros_ (module .bias )
elif isinstance (module ,nn .Embedding ):
nn .init .normal_ (module .weight ,mean =0.0 ,std =std )
def get_input_embeddings (self )->nn .Embedding :
return self .model .embed_tokens
def set_input_embeddings (self ,value :nn .Embedding ):
self .model .embed_tokens =value
def get_output_embeddings (self )->nn .Linear :
return self .lm_head
def set_output_embeddings (self ,new_embeddings :nn .Linear ):
self .lm_head =new_embeddings
def gradient_checkpointing_enable (self ):
"""Enable gradient checkpointing for memory efficiency."""
self .model .gradient_checkpointing_enable ()
def gradient_checkpointing_disable (self ):
"""Disable gradient checkpointing."""
self .model .gradient_checkpointing_disable ()
def prepare_inputs_for_generation (
self ,
input_ids :torch .Tensor ,
past_key_values :Optional [List [KVCache ]]=None ,
attention_mask :Optional [torch .Tensor ]=None ,
inputs_embeds :Optional [torch .Tensor ]=None ,
**kwargs ,
)->dict :
if past_key_values is not None :
input_ids =input_ids [:,-1 :]
position_ids =kwargs .get ("position_ids",None )
if attention_mask is not None and position_ids is None :
position_ids =attention_mask .long ().cumsum (-1 )-1
position_ids .masked_fill_ (attention_mask ==0 ,1 )
if past_key_values is not None :
position_ids =position_ids [:,-1 :]
return {
"input_ids":input_ids ,
"past_key_values":past_key_values ,
"use_cache":kwargs .get ("use_cache",True ),
"position_ids":position_ids ,
"attention_mask":attention_mask ,
}
def forward (
self ,
input_ids :Optional [torch .Tensor ]=None ,
attention_mask :Optional [torch .Tensor ]=None ,
position_ids :Optional [torch .Tensor ]=None ,
inputs_embeds :Optional [torch .Tensor ]=None ,
labels :Optional [torch .Tensor ]=None ,
past_key_values :Optional [List [KVCache ]]=None ,
use_cache :bool =False ,
output_attentions :bool =False ,
output_hidden_states :bool =False ,
return_dict :bool =True ,
cache_position :Optional [torch .Tensor ]=None ,
thinking_depth :int =0 ,
**kwargs ,
)->Union [Tuple ,CausalLMOutput ]:
outputs =self .model (
input_ids =input_ids ,
attention_mask =attention_mask ,
position_ids =position_ids ,
inputs_embeds =inputs_embeds ,
past_key_values =past_key_values ,
use_cache =use_cache ,
output_attentions =output_attentions ,
output_hidden_states =output_hidden_states ,
return_dict =True ,
cache_position =cache_position ,
thinking_depth =thinking_depth ,
)
hidden_states =outputs .last_hidden_state
aux_loss =outputs .aux_loss
logits =self .lm_head (hidden_states )
loss =None
if labels is not None :
shift_logits =logits [...,:-1 ,:].contiguous ()
shift_labels =labels [...,1 :].contiguous ()
if shift_labels .dtype !=torch .long :
shift_labels =shift_labels .long ()
valid_mask =(shift_labels !=-100 )
num_valid =valid_mask .sum ().item ()
if num_valid >0 :
loss_fct =nn .CrossEntropyLoss (ignore_index =-100 )
loss =loss_fct (
shift_logits .view (-1 ,shift_logits .size (-1 )),
shift_labels .view (-1 )
)
loss =torch .clamp (loss ,min =0.0 ,max =100.0 )
else :
loss =torch .tensor (0.0 ,device =logits .device ,dtype =logits .dtype ,requires_grad =True )
return CausalLMOutput (
loss =loss ,
logits =logits ,
past_key_values =outputs .past_key_values ,
hidden_states =outputs .hidden_states ,
attentions =outputs .attentions ,
aux_loss =aux_loss ,
)
@torch .no_grad ()
def generate (
self ,
input_ids :torch .Tensor ,
max_new_tokens :int =100 ,
temperature :float =1.0 ,
top_k :int =50 ,
top_p :float =0.9 ,
do_sample :bool =True ,
pad_token_id :Optional [int ]=None ,
eos_token_id :Optional [int ]=None ,
attention_mask :Optional [torch .Tensor ]=None ,
thinking_depth :int =0 ,
**kwargs ,
)->torch .Tensor :
batch_size =input_ids .shape [0 ]
device =input_ids .device
past_key_values =None
is_prefill =True # Deep thinking only on first pass (full context)
if attention_mask is None :
attention_mask =torch .ones_like (input_ids )
for _ in range (max_new_tokens ):
model_inputs =self .prepare_inputs_for_generation (
input_ids ,
past_key_values =past_key_values ,
attention_mask =attention_mask ,
)
# Apply thinking depth only on prefill, not per-token steps
current_depth = thinking_depth if is_prefill else 0
outputs =self .forward (**model_inputs ,use_cache =True ,return_dict =True ,thinking_depth =current_depth )
is_prefill =False
next_token_logits =outputs .logits [:,-1 ,:]
if temperature !=1.0 :
next_token_logits =next_token_logits /temperature
if do_sample :
if top_k >0 :
indices_to_remove =next_token_logits <torch .topk (next_token_logits ,top_k )[0 ][...,-1 ,None ]
next_token_logits [indices_to_remove ]=float ('-inf')
if top_p <1.0 :
sorted_logits ,sorted_indices =torch .sort (next_token_logits ,descending =True )
cumulative_probs =torch .cumsum (F .softmax (sorted_logits ,dim =-1 ),dim =-1 )
sorted_indices_to_remove =cumulative_probs >top_p
sorted_indices_to_remove [...,1 :]=sorted_indices_to_remove [...,:-1 ].clone ()
sorted_indices_to_remove [...,0 ]=0
indices_to_remove =sorted_indices_to_remove .scatter (1 ,sorted_indices ,sorted_indices_to_remove )
next_token_logits [indices_to_remove ]=float ('-inf')
probs =F .softmax (next_token_logits ,dim =-1 )
next_tokens =torch .multinomial (probs ,num_samples =1 ).squeeze (-1 )
else :
next_tokens =torch .argmax (next_token_logits ,dim =-1 )
input_ids =torch .cat ([input_ids ,next_tokens .unsqueeze (-1 )],dim =-1 )
attention_mask =torch .cat ([attention_mask ,torch .ones ((batch_size ,1 ),device =device )],dim =-1 )
past_key_values =outputs .past_key_values
if eos_token_id is not None and (next_tokens ==eos_token_id ).all ():
break
return input_ids
==============================================================================
MODELS.XORON
==============================================================================
logger =logging .getLogger (__name__ )
MAX_HIDDEN =10000.0
def safe_clamp_tensor (x :torch .Tensor ,max_val :float =MAX_HIDDEN )->torch .Tensor :
"""Clamp tensor values for FP16 safety, handling NaN/Inf properly.
WARNING: Only use for linear/hidden states, NOT for attention scores before softmax!
For attention scores, use a max of ~11.0 to prevent exp() overflow.
CRITICAL: torch.clamp does NOT fix NaN! clamp(nan, -10, 10) = nan
Must use nan_to_num first.
"""
if x is None or x .numel ()==0 :
return x
x =torch .nan_to_num (x ,nan =0.0 ,posinf =max_val ,neginf =-max_val )
return x .clamp (-max_val ,max_val )
COMPONENT_GROUPS ={
'vision':['vision_encoder','projector'],
'video':['video_encoder'],
'audio':['audio_encoder','audio_decoder','audio_projector','waveform_decoder'],
'speech':['waveform_decoder'],
'llm':['llm'],
'cross_attention':['cross_attention_layers'],
'image_generation':['generator'],
'video_generation':['video_generator'],
'modality_markers':['image_start','image_end','video_start','video_end','audio_start','audio_end'],
}
class MultimodalModelOutput (dict ):
"""Output class for multimodal model."""
def __getattr__ (self ,name ):
try :
return self [name ]
except KeyError :
raise AttributeError (f"'{type (self ).__name__ }' has no attribute '{name }'")
def __setattr__ (self ,name ,value ):
self [name ]=value
class XoronMultimodalModel (nn .Module ):
"""
Xoron-Dev: Complete multimodal model with:
- Image/video understanding (CLIP)
- Text generation (MoE LLM)
- Image/video generation (MobileDiffusion)
- Voice understanding and generation (ASR/TTS)
- Cross-attention for multimodal fusion
- LoRA support for efficient fine-tuning
- Flash Attention for faster training
- Model Parallelism support for multi-GPU training
"""
def __init__ (self ,config :XoronConfig ,device_map :Dict [str ,str ]=None ):
super ().__init__ ()
self .config =config
self .device_map =device_map
if device_map is not None :
device_values =[v for v in device_map .values ()if isinstance (v ,str )]
self ._model_parallel =len (set (device_values ))>1
else :
self ._model_parallel =False
logger .info ("Initializing Xoron-Dev Multimodal Model Build")
if self ._model_parallel :
logger .info (" ⚡ Model Parallelism: ENABLED")
self .vision_encoder =VisionEncoder (config .vision_model_name ,freeze =config .freeze_vision )
self .video_encoder =VideoEncoder (self .vision_encoder ,max_frames =config .video_max_frames )
logger .info ("Building SOTA Audio Encoder...")
self .audio_encoder =AudioEncoder (
hidden_size =config .hidden_size ,
n_mels =80 ,
max_audio_length =3000 ,
use_raw_waveform =getattr (config ,'use_raw_waveform',True ),
)
logger .info ("Building SOTA Audio Decoder...")
self .audio_decoder =AudioDecoder (
hidden_size =config .hidden_size ,
n_mels =80 ,
max_audio_length =1000 ,
)
logger .info ("Building Raw Waveform Decoder (Speech-to-Speech)...")
self .waveform_decoder =RawWaveformDecoder (
hidden_size =config .hidden_size ,
sample_rate =getattr (config ,'audio_sample_rate',16000 ),
)
llm_config =LlamaConfig (
vocab_size =config .vocab_size ,
hidden_size =config .hidden_size ,
intermediate_size =config .intermediate_size ,
num_hidden_layers =config .num_layers ,
num_attention_heads =config .num_heads ,
max_position_embeddings =config .max_position_embeddings ,
rms_norm_eps =1e-6 ,
tie_word_embeddings =getattr (config ,'tie_word_embeddings',True ),
pad_token_id =0 ,
)
llm_config .use_flash_attention =config .use_flash_attention
llm_config .use_ring_attention =getattr (config ,'use_ring_attention',True )
llm_config .ring_attention_chunk_size =getattr (config ,'ring_attention_chunk_size',4096 )
moe_config ={
'use_moe':config .use_moe ,
'num_experts':config .num_experts ,
'num_experts_per_tok':config .num_experts_per_tok ,
'moe_layer_freq':config .moe_layer_freq ,
'intermediate_size':config .intermediate_size ,
}
logger .info (f"Building LLM Core: {config .hidden_size }d, {config .num_layers }L")
logger .info (f" 📏 Context: {config .max_position_embeddings //1024 }K positions")
if config .use_ring_attention :
logger .info (f" 🔄 Ring Attention Enabled (chunk size: {config .ring_attention_chunk_size })")
logger .info (f" 🎯 MoE: {config .num_experts } experts, top-{config .num_experts_per_tok }")
self .llm =MoELlamaForCausalLM (llm_config ,moe_config )
logger .info (f" ✅ MoE layers initialized: {self .llm .model .num_moe_layers }/{config .num_layers }")
self .projector =MultimodalProjector (
self .vision_encoder .hidden_size ,
config .hidden_size ,
config .num_vision_tokens
)
logger .info (f" 🔗 Projector initialized: {self .vision_encoder .hidden_size } -> {config .hidden_size }")
self .audio_projector =nn .Linear (config .hidden_size ,config .hidden_size )
self .image_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 )
self .image_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 )
self .video_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 )
self .video_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 )
self .audio_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 )
self .audio_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 )
self .cross_attention_layers =None
if config .use_cross_attention :
logger .info (f"Building Cross-Attention Fusion ({config .cross_attention_layers } layers)...")
self .cross_attention_layers =nn .ModuleList ([
MultimodalFusionLayer (
hidden_size =config .hidden_size ,
num_heads =config .cross_attention_heads ,
dropout =config .cross_attention_dropout ,
use_flash_attention =config .use_flash_attention ,
)
for _ in range (config .cross_attention_layers )
])
logger .info (f" ✅ Cross-attention: {config .cross_attention_layers } layers, {config .cross_attention_heads } heads")
self .generator =None
if config .enable_generation :
logger .info ("Building MobileDiffusion Generators (Image & Video)...")
self .generator =MobileDiffusionGenerator (
latent_channels =config .generation_latent_channels ,
base_channels =config .generation_base_channels ,
context_dim =config .hidden_size ,
num_inference_steps =config .generation_inference_steps ,
image_size =config .image_max_size ,
)
self .video_generator =None
if config .enable_generation :
self .video_generator =MobileVideoDiffusion (
latent_channels =config .generation_latent_channels ,
base_channels =config .generation_base_channels //2 ,
context_dim =config .hidden_size ,
num_frames =config .video_max_frames ,
image_size =config .video_max_size ,
num_inference_steps =config .generation_inference_steps ,
)
self .num_vision_tokens =config .num_vision_tokens
self .video_max_frames =config .video_max_frames
self .lora_applied =False
self ._print_stats ()
logger .info ("Xoron-Dev Multimodal Model Build Complete")
def apply_model_parallel (self ,device_map :Dict [str ,str ]):
"""Apply Model Parallelism by sharding components across devices.
Trained components get their layers split across all training GPUs.
Frozen components go to CPU. Small components (projectors, markers)
go to the primary GPU.
"""
self .device_map =device_map
training_gpus = device_map .get ('training_gpus', ['cuda:0'])
primary = device_map .get ('primary', 'cuda:0')
if len (training_gpus ) <= 1 and not any (v == 'cpu' for v in device_map .values () if isinstance (v, str)):
logger .info (" ℹ️ Single device - no model parallelism needed")
return self
self ._model_parallel = True
logger .info ("Applying Model Parallelism (layer sharding)...")
def _shard_module (module, name, gpus):
"""Shard a module's sub-layers across GPUs."""
# Find shardable sub-layers (nn.ModuleList children)
layer_lists = []
for attr_name in dir (module):
attr = getattr (module, attr_name, None)
if isinstance (attr, nn .ModuleList) and len (attr) > 0:
layer_lists .append ((attr_name, attr))
if layer_lists:
# Shard the largest ModuleList across GPUs
layer_lists .sort (key=lambda x: len (x[1]), reverse=True)
list_name, layers = layer_lists [0]
for i, layer in enumerate (layers):
target_gpu = gpus [i % len (gpus)]
layer .to (target_gpu)
# Put remaining params on primary GPU
for param_name, param in module .named_parameters ():
if not any (f'{list_name}.' in param_name for _ in [1]):
param .data = param .data .to (gpus [0])
logger .info (f" ✅ {name}: {len(layers)} layers sharded across {gpus}")
else:
# No layers to shard — put whole module on first GPU
module .to (gpus [0])
logger .info (f" ✅ {name} -> {gpus[0]}")
# Map component names to actual attributes
component_attrs = {
'vision_encoder': 'vision_encoder',
'video_encoder': 'video_encoder',
'audio_encoder': 'audio_encoder',
'audio_decoder': 'audio_decoder',
'waveform_decoder': 'waveform_decoder',
'projector': 'projector',
'audio_projector': 'audio_projector',
'llm': 'llm',
'cross_attention': 'cross_attention_layers',
'generator': 'generator',
'video_generator': 'video_generator',
}
for comp_name, attr_name in component_attrs .items ():
comp = getattr (self, attr_name, None)
if comp is None:
continue
target = device_map .get (comp_name, 'cpu')
if target == 'cpu':
comp .to ('cpu')
logger .info (f" ❄️ {comp_name} -> cpu (frozen)")
else:
# Shard across all training GPUs
_shard_module (comp, comp_name, training_gpus)
# Modality markers → primary GPU
marker_device = device_map .get ('modality_markers', primary)
if marker_device != 'cpu':
marker_device = primary
for marker_name in ['image_start', 'image_end', 'video_start', 'video_end', 'audio_start', 'audio_end']:
marker = getattr (self, marker_name, None)
if marker is not None:
setattr (self, marker_name, nn .Parameter (marker .data .to (marker_device)))
logger .info (f" ✅ Modality markers -> {marker_device}")
logger .info ("Model Parallelism applied successfully!")
return self
def get_llm_device (self ):
"""Get the device where LLM is located."""
if self .device_map is not None :
return torch .device (self .device_map ['llm'])
return next (self .llm .parameters ()).device
def generate (self ,*args ,**kwargs ):
"""
Delegates generation to the internal LLM.
This allows the model to be treated as a causal LM in many pipelines.
"""
return self .llm .generate (*args ,**kwargs )
def get_encoder_device (self ):
"""Get the device where encoders are located."""
if self .device_map is not None :
return torch .device (self .device_map ['vision_encoder'])
return next (self .vision_encoder .parameters ()).device
def apply_lora (self ):
"""
Apply LoRA to the LLM and optionally cross-attention layers.
MEMORY OPTIMIZATION:
- LoRA layers share base weights (no cloning)
- Base weights in LoRA layers are frozen (requires_grad=False)
- LoRA params (A, B, magnitude) are always trainable
NOTE: This does NOT freeze other components!
Component freezing is handled separately by freeze_components() based on
training mode (--text, --video, --image, --voice flags).
This allows PARALLEL FINE-TUNING:
- LoRA adapters on LLM for efficient adaptation
- Full weight training on active components (vision, audio, etc.)
"""
if self .lora_applied :
logger .warning ("LoRA already applied")
return
if not self .config .use_lora :
logger .info ("LoRA disabled in config")
return
lora_config =LoRAConfig (
r =self .config .lora_r ,
lora_alpha =self .config .lora_alpha ,
lora_dropout =self .config .lora_dropout ,
target_modules =list (self .config .lora_target_modules ),
enable_lora =True ,
)
logger .info ("Applying LoRA to LLM Core...")
self .llm =apply_lora_to_model (self .llm ,lora_config )
if self .cross_attention_layers is not None :
logger .info ("Applying LoRA to cross-attention layers...")
cross_attn_lora_config =LoRAConfig (
r =lora_config .r ,
lora_alpha =lora_config .lora_alpha ,
lora_dropout =lora_config .lora_dropout ,
target_modules =['q_proj','k_proj','v_proj','o_proj'],
enable_lora =True ,
)
for i ,layer in enumerate (self .cross_attention_layers ):
self .cross_attention_layers [i ]=apply_lora_to_model (layer ,cross_attn_lora_config )
self .lora_applied =True
self ._print_stats ()
def get_trainable_params (self ):
"""
Get trainable parameters, respecting LoRA settings and component freezing.
If train_lora_only=True and LoRA is applied:
- Freezes all non-LoRA params
- Returns only LoRA params
Otherwise:
- Returns all params with requires_grad=True
- This includes both LoRA params AND unfrozen component weights
- Allows parallel fine-tuning: LoRA + full weights on active components
"""
if self .config .train_lora_only and self .lora_applied :
freeze_non_lora_params (self )
return get_lora_parameters (self )
return [p for p in self .parameters ()if p .requires_grad ]
def _print_stats (self ):
total =sum (p .numel ()for p in self .parameters ())
trainable =sum (p .numel ()for p in self .parameters ()if p .requires_grad )
logger .info ("Model Statistics:")
logger .info (f" Total parameters: {total /1e6 :.1f}M")
logger .info (f" Trainable parameters: {trainable /1e6 :.1f}M")
if self .lora_applied :
lora_params =sum (p .numel ()for n ,p in self .named_parameters ()if 'lora_'in n )
logger .info (f" LoRA parameters: {lora_params /1e6 :.2f}M")
def encode_image (self ,pixel_values :torch .Tensor )->torch .Tensor :
encoder_device =self .get_encoder_device ()
pixel_values =pixel_values .to (encoder_device )
vision_features =self .vision_encoder (pixel_values )
projected =self .projector (vision_features )
llm_device =self .get_llm_device ()
return projected .to (llm_device )
def encode_video (self ,video_frames :torch .Tensor )->torch .Tensor :
encoder_device =self .get_encoder_device ()
video_frames =video_frames .to (encoder_device )
video_features =self .video_encoder (video_frames )
projected =self .projector (video_features )
llm_device =self .get_llm_device ()
return projected .to (llm_device )
def encode_audio (self ,audio_features :torch .Tensor )->torch .Tensor :
encoder_device =self .get_encoder_device ()
audio_features =audio_features .to (encoder_device )
audio_embeds =self .audio_encoder (audio_features )
projected =self .audio_projector (audio_embeds )
llm_device =self .get_llm_device ()
return projected .to (llm_device )
def get_text_embeddings (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None )->torch .Tensor :
llm_device =self .get_llm_device ()
input_ids =input_ids .to (llm_device )
embeddings =self .llm .model .embed_tokens (input_ids )
return embeddings
def _apply_cross_attention (
self ,
text_embeds :torch .Tensor ,
image_embeds :torch .Tensor =None ,
video_embeds :torch .Tensor =None ,
audio_embeds :torch .Tensor =None ,
)->torch .Tensor :
if self .cross_attention_layers is None :
return text_embeds
for fusion_layer in self .cross_attention_layers :
text_embeds ,_ =fusion_layer (
text_hidden =text_embeds ,
image_hidden =image_embeds ,
video_hidden =video_embeds ,
audio_hidden =audio_embeds ,
use_cache =False ,
)
return text_embeds
def forward (
self ,
input_ids :torch .Tensor ,
attention_mask :torch .Tensor =None ,
pixel_values :torch .Tensor =None ,
video_frames :torch .Tensor =None ,
audio_features :torch .Tensor =None ,
labels :torch .Tensor =None ,
):
"""Forward pass - FP16 native."""
batch_size =input_ids .shape [0 ]
llm_device =self .get_llm_device ()
input_ids_llm =input_ids .to (llm_device )
text_embeds =self .llm .model .embed_tokens (input_ids_llm )
text_embeds =safe_clamp_tensor (text_embeds )
device =text_embeds .device
if attention_mask is not None :
attention_mask =attention_mask .to (device )
if labels is not None :
labels =labels .to (device )
image_embeds_for_cross =None
video_embeds_for_cross =None
audio_embeds_for_cross =None
def has_content (tensor ):
if tensor is None :
return False
if not isinstance (tensor ,torch .Tensor ):
return False
try :
if tensor .numel ()==0 :
return False
return bool (tensor .any ())
except Exception :
return False
if has_content (pixel_values ):
try :
image_embeds =self .encode_image (pixel_values )
image_embeds =safe_clamp_tensor (image_embeds )
image_embeds_for_cross =image_embeds
image_start =self .image_start .expand (batch_size ,-1 ,-1 )
image_end =self .image_end .expand (batch_size ,-1 ,-1 )
image_embeds =torch .cat ([image_start ,image_embeds ,image_end ],dim =1 )
text_embeds =torch .cat ([image_embeds ,text_embeds ],dim =1 )
text_embeds =safe_clamp_tensor (text_embeds )
if attention_mask is not None :
image_mask =torch .ones (batch_size ,image_embeds .shape [1 ],device =device )
attention_mask =torch .cat ([image_mask ,attention_mask ],dim =1 )
if labels is not None :
image_labels =torch .full ((batch_size ,image_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype )
labels =torch .cat ([image_labels ,labels ],dim =1 )
except Exception as e :
logger .debug (f"Image encoding skipped: {e }")
if has_content (video_frames ):
try :
video_embeds =self .encode_video (video_frames )
video_embeds =safe_clamp_tensor (video_embeds )
video_embeds_for_cross =video_embeds
video_start =self .video_start .expand (batch_size ,-1 ,-1 )
video_end =self .video_end .expand (batch_size ,-1 ,-1 )
video_embeds =torch .cat ([video_start ,video_embeds ,video_end ],dim =1 )
text_embeds =torch .cat ([video_embeds ,text_embeds ],dim =1 )
text_embeds =safe_clamp_tensor (text_embeds )
if attention_mask is not None :
video_mask =torch .ones (batch_size ,video_embeds .shape [1 ],device =device )
attention_mask =torch .cat ([video_mask ,attention_mask ],dim =1 )
if labels is not None :
video_labels =torch .full ((batch_size ,video_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype )
labels =torch .cat ([video_labels ,labels ],dim =1 )
except Exception as e :
logger .debug (f"Video encoding skipped: {e }")
if has_content (audio_features ):
try :
audio_embeds =self .encode_audio (audio_features )
audio_embeds =safe_clamp_tensor (audio_embeds )
audio_embeds_for_cross =audio_embeds
audio_start =self .audio_start .expand (batch_size ,-1 ,-1 )
audio_end =self .audio_end .expand (batch_size ,-1 ,-1 )
audio_embeds =torch .cat ([audio_start ,audio_embeds ,audio_end ],dim =1 )
text_embeds =torch .cat ([audio_embeds ,text_embeds ],dim =1 )
text_embeds =safe_clamp_tensor (text_embeds )
if attention_mask is not None :
audio_mask =torch .ones (batch_size ,audio_embeds .shape [1 ],device =device )
attention_mask =torch .cat ([audio_mask ,attention_mask ],dim =1 )
if labels is not None :
audio_labels =torch .full ((batch_size ,audio_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype )
labels =torch .cat ([audio_labels ,labels ],dim =1 )
except Exception as e :
logger .debug (f"Audio encoding skipped: {e }")
if self .cross_attention_layers is not None :
try :
text_embeds =self ._apply_cross_attention (
text_embeds ,
image_embeds =image_embeds_for_cross ,
video_embeds =video_embeds_for_cross ,
audio_embeds =audio_embeds_for_cross ,
)
text_embeds =safe_clamp_tensor (text_embeds )
except Exception as e :
logger .debug (f"Cross-attention skipped: {e }")
text_embeds =safe_clamp_tensor (text_embeds )
outputs =self .llm (inputs_embeds =text_embeds ,attention_mask =attention_mask ,labels =labels )
return MultimodalModelOutput (
loss =outputs .loss if hasattr (outputs ,'loss')else None ,
logits =outputs .logits if hasattr (outputs ,'logits')else None ,
aux_loss =outputs .aux_loss if hasattr (outputs ,'aux_loss')else None ,
)
@torch .no_grad ()
def generate_image (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None ):
"""Generate image from text."""
if self .generator is None :
raise ValueError ("Image generator not enabled")
context =self .get_text_embeddings (input_ids ,attention_mask )
images =self .generator .generate (context )
return images
@torch .no_grad ()
def generate_video (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None ,
first_frame :torch .Tensor =None ,num_frames :int =None ):
"""Generate video from text (T2V) or from image (I2V)."""
if self .video_generator is None :
raise ValueError ("Video generator not enabled")
context =self .get_text_embeddings (input_ids ,attention_mask )
context =context .mean (dim =1 )
if first_frame is not None :
video =self .video_generator .generate_i2v (first_frame ,context ,num_frames )
else :
video =self .video_generator .generate_t2v (context ,num_frames )
return video
@torch .no_grad ()
def generate_speech (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None ):
"""Generate speech (mel-spectrogram) from text (TTS)."""
text_embeds =self .get_text_embeddings (input_ids ,attention_mask )
mel ,durations ,_ ,_ =self .audio_decoder (text_embeds )
return mel ,durations
@torch .no_grad ()
def speak (
self ,
input_ids :torch .Tensor ,
attention_mask :torch .Tensor =None ,
speaker_embedding :torch .Tensor =None ,
return_mel :bool =False ,
)->torch .Tensor :
"""
Generate playable audio waveform from text (Speech-to-Speech TTS).
This is the main method for making the model talk. It converts text
directly to audio waveform without needing an external vocoder.
Args:
input_ids: [B, T] tokenized text input
attention_mask: [B, T] attention mask
speaker_embedding: [B, D] optional speaker embedding for voice cloning
return_mel: If True, also return intermediate mel spectrogram
Returns:
waveform: [B, T_audio] raw audio waveform in [-1, 1] range at 16kHz
Can be played directly or saved as WAV file
mel (optional): [B, 80, T_mel] mel spectrogram if return_mel=True
"""
text_embeds =self .get_text_embeddings (input_ids ,attention_mask )
mel ,durations ,_ ,_ =self .audio_decoder (
text_embeds ,
speaker_embedding =speaker_embedding ,
)
mel_features =mel .transpose (1 ,2 )
if not hasattr (self ,'_mel_to_hidden'):
self ._mel_to_hidden =nn .Linear (80 ,self .config .hidden_size ).to (mel .device )
audio_features =self ._mel_to_hidden (mel_features )
waveform =self .waveform_decoder (audio_features )
if return_mel :
return waveform ,mel
return waveform
@torch .no_grad ()
def listen (self ,audio_waveform :torch .Tensor )->torch .Tensor :
"""
Transcribe audio to text embeddings (Speech-to-Speech ASR).
This is the listening component - converts speech to embeddings
that can be fed to the LLM for understanding.
Args:
audio_waveform: [B, T_audio] raw audio waveform
Returns:
audio_embeds: [B, T, hidden_size] encoded audio features
"""
return self .encode_audio (audio_waveform )
@torch .no_grad ()
def listen_and_respond (
self ,
audio_waveform :torch .Tensor ,
tokenizer =None ,
max_new_tokens :int =512 ,
speaker_embedding :torch .Tensor =None ,
temperature :float =0.7 ,
top_p :float =0.9 ,
tool_executor =None ,
available_tools :list =None ,
system_prompt :str =None ,
max_tool_calls :int =5 ,
) -> Dict [str ,Any ]:
"""
Agentic Speech-to-Speech: Listen, think, use tools, speak back.
This is the full agentic pipeline for live voice conversations.
The model can detect when the user is asking for actions (e.g.
"write me a Python script") and execute tools mid-generation.
Pipeline:
1. Encode input audio → audio embeddings (ASR)
2. Build context (system prompt with tools + audio embeddings)
3. Generate tokens, watching for <|tool_call|> sequences
4. When tool call detected: parse, execute, inject result, resume
5. Synthesize final spoken response from non-tool text
Args:
audio_waveform: [B, T_audio] input audio waveform
tokenizer: Tokenizer for decoding tokens to text (required for tools)
max_new_tokens: Maximum total tokens to generate
speaker_embedding: [B, D] optional speaker embedding for voice cloning
temperature: Sampling temperature
top_p: Nucleus sampling probability
tool_executor: Callable(tool_name, args_dict) -> str result.
If None, tool calls are detected but not executed.
available_tools: List of tool definition dicts for system prompt.
system_prompt: Optional system prompt override.
max_tool_calls: Maximum number of tool calls per response (safety limit).
Returns:
Dict with:
'waveform': [B, T_response] audio waveform tensor (in-memory, no file I/O)
'text': str full response text (excluding tool call markup)
'token_ids': [B, T_tokens] all generated token IDs
'mel': [B, 80, T_mel] intermediate mel spectrogram
'tool_calls': List[Dict] executed tool calls and their results
'speaking_text': str clean text that was spoken (no tool markup)
"""
import re
import json as _json
device = audio_waveform .device
batch_size = audio_waveform .shape [0 ]
llm_device = self .get_llm_device ()
# ── 1. Listen: encode input audio ──
audio_embeds = self .encode_audio (audio_waveform )
# Wrap with start/end markers
audio_start = self .audio_start .expand (batch_size , -1 , -1 ).to (llm_device )
audio_end = self .audio_end .expand (batch_size , -1 , -1 ).to (llm_device )
audio_embeds = audio_embeds .to (llm_device )
# ── 2. Build context with system prompt + tools ──
context_parts = []
if tokenizer is not None and (system_prompt or tool_executor):
sys_text = system_prompt or "You are Xoron, an intelligent voice assistant. You can use tools to help the user."
if tool_executor and hasattr (tool_executor , 'get_tool_prompt' ):
sys_text = sys_text + "\n\n" + tool_executor .get_tool_prompt ()
elif available_tools :
from utils .tool_executor import format_tools_for_prompt
sys_text = sys_text + "\n\n" + format_tools_for_prompt (available_tools )
# Encode system prompt and prepend
sys_str = "<|system|>" + sys_text + "<|/system|>"
sys_token_ids = tokenizer .encode (sys_str , return_tensors ="pt" ).to (llm_device )
sys_embeds = self .llm .model .embed_tokens (sys_token_ids )
context_parts .append (sys_embeds .squeeze (0 ) if sys_embeds .dim () == 3 else sys_embeds )
# Audio context
context_parts .extend ([audio_start , audio_embeds , audio_end ])
# Assistant generation prompt
if tokenizer is not None :
asst_str = "<|assistant|>"
asst_ids = tokenizer .encode (asst_str , return_tensors ="pt" ).to (llm_device )
asst_embeds = self .llm .model .embed_tokens (asst_ids )
context_parts .append (asst_embeds .squeeze (0 ) if asst_embeds .dim () == 3 else asst_embeds )
input_embeds = torch .cat (context_parts , dim =1 )
# ── 3. Agentic generation loop with tool call detection ──
tool_call_start_token = "<|tool_call|>"
tool_call_end_token = "<|/tool_call|>"
fn_name_start = "<|function_name|>"
fn_name_end = "<|/function_name|>"
fn_args_start = "<|function_args|>"
fn_args_end = "<|/function_args|>"
tool_result_start = "<|tool_result|>"
tool_result_end = "<|/tool_result|>"
eos_token = "<|eos|>"
all_generated_ids = []
tool_calls_made = []
num_tool_calls = 0
generated_text = ""
total_tokens = 0
# Use standard generation if no tool executor
if tool_executor is None or tokenizer is None :
gen_kwargs = {
'inputs_embeds': input_embeds ,
'max_new_tokens': max_new_tokens ,
'do_sample': True ,
'temperature': temperature ,
'top_p': top_p ,
'use_cache': True ,
}
generated_ids = self .llm .generate (**gen_kwargs )
all_generated_ids = [generated_ids ]
if tokenizer is not None :
generated_text = tokenizer .batch_decode (generated_ids , skip_special_tokens =True )[0 ]
else :
# Token-by-token generation with tool call detection
current_embeds = input_embeds
past_key_values = None
in_tool_call = False
tool_call_buffer = ""
while total_tokens < max_new_tokens :
outputs = self .llm (
inputs_embeds =current_embeds ,
past_key_values =past_key_values ,
use_cache =True ,
)
past_key_values = outputs .past_key_values
logits = outputs .logits [:, -1 :, :]
# Sample next token
if temperature > 0 :
logits = logits / temperature
if top_p < 1.0 :
sorted_logits , sorted_indices = torch .sort (logits , descending =True , dim =-1 )
cumulative_probs = torch .cumsum (F .softmax (sorted_logits , dim =-1 ), dim =-1 )
sorted_mask = cumulative_probs - F .softmax (sorted_logits , dim =-1 ) >= top_p
sorted_logits [sorted_mask ] = float ('-inf' )
logits .scatter_ (-1 , sorted_indices , sorted_logits )
probs = F .softmax (logits , dim =-1 )
next_token = torch .multinomial (probs .squeeze (1 ), num_samples =1 )
else :
next_token = logits .argmax (dim =-1 )
total_tokens += 1
all_generated_ids .append (next_token )
# Decode the token
token_text = tokenizer .decode (next_token [0 ], skip_special_tokens =False )
generated_text = generated_text + token_text
# Check for EOS
if eos_token in token_text or next_token .item () == tokenizer .eos_token_id :
break
# ── Tool call detection ──
if tool_call_start_token in generated_text and not in_tool_call :
in_tool_call = True
# Extract everything after the tool_call_start
tc_start_idx = generated_text .rfind (tool_call_start_token )
tool_call_buffer = generated_text [tc_start_idx :]
if in_tool_call :
tool_call_buffer = tool_call_buffer + token_text if tool_call_buffer else generated_text
# Check if we have a complete tool call
if tool_call_end_token in tool_call_buffer :
in_tool_call = False
num_tool_calls += 1
# Parse the tool call
tool_name = ""
tool_args = {}
try :
# Extract function name
name_start = tool_call_buffer .find (fn_name_start ) + len (fn_name_start )
name_end = tool_call_buffer .find (fn_name_end )
if name_start > 0 and name_end > 0 :
tool_name = tool_call_buffer [name_start :name_end ].strip ()
# Extract arguments
args_start = tool_call_buffer .find (fn_args_start ) + len (fn_args_start )
args_end = tool_call_buffer .find (fn_args_end )
if args_start > 0 and args_end > 0 :
args_str = tool_call_buffer [args_start :args_end ].strip ()
try :
import json as _json
tool_args = _json .loads (args_str )
except Exception :
tool_args = {"raw": args_str }
except Exception :
pass
# Execute the tool
tool_result = "[error]: Failed to parse tool call"
if tool_name :
tool_result = tool_executor (tool_name , tool_args )
tool_calls_made .append ({
"name": tool_name ,
"arguments": tool_args ,
"result": tool_result ,
})
# Inject tool result back into generation context
result_str = tool_result_start + tool_result + tool_result_end
result_ids = tokenizer .encode (result_str , return_tensors ="pt" ).to (llm_device )
result_embeds = self .llm .model .embed_tokens (result_ids )
current_embeds = result_embeds
past_key_values = None # Reset KV cache to include result
all_generated_ids .append (result_ids .squeeze (0 ))
generated_text = generated_text + result_str
tool_call_buffer = ""
if num_tool_calls >= max_tool_calls :
break
continue
# Prepare next input
next_embeds = self .llm .model .embed_tokens (next_token )
current_embeds = next_embeds
# Combine all generated IDs
if all_generated_ids :
flat_ids = []
for t in all_generated_ids :
if t .dim () == 0 :
flat_ids .append (t .unsqueeze (0 ))
elif t .dim () == 1 :
flat_ids .append (t )
else :
flat_ids .append (t .view (-1 ))
generated_ids = torch .cat (flat_ids , dim =0 ).unsqueeze (0 )
else :
generated_ids = torch .tensor ([[]], dtype =torch .long , device =llm_device )
# ── 4. Extract speaking text (strip tool call/result markup) ──
speaking_text = generated_text
# Remove tool call blocks
while tool_call_start_token in speaking_text :
tc_s = speaking_text .find (tool_call_start_token )
tc_e = speaking_text .find (tool_call_end_token )
if tc_e > tc_s :
speaking_text = speaking_text [:tc_s ] + speaking_text [tc_e + len (tool_call_end_token ):]
else :
break
# Remove tool result blocks
while tool_result_start in speaking_text :
tr_s = speaking_text .find (tool_result_start )
tr_e = speaking_text .find (tool_result_end )
if tr_e > tr_s :
speaking_text = speaking_text [:tr_s ] + speaking_text [tr_e + len (tool_result_end ):]
else :
break
speaking_text = speaking_text .strip ()
# ── 5. Speak: encode → mel → stream_decode → waveform ──
response_embeds = self .llm .model .embed_tokens (generated_ids .to (llm_device ))
mel , durations , _ , _ = self .audio_decoder (
response_embeds ,
speaker_embedding =speaker_embedding ,
)
mel_features = mel .transpose (1 , 2 )
if not hasattr (self , '_mel_to_hidden' ):
self ._mel_to_hidden = nn .Linear (80 , self .config .hidden_size ).to (mel .device )
audio_features = self ._mel_to_hidden (mel_features )
waveform = self .waveform_decoder .stream_decode (audio_features )
return {
'waveform': waveform ,
'text': generated_text ,
'speaking_text': speaking_text ,
'token_ids': generated_ids ,
'mel': mel ,
'tool_calls': tool_calls_made ,
}
def merge_lora_weights (self ):
"""Merge LoRA weights into main weights for inference."""
if not self .lora_applied :
return
for module in self .modules ():
if isinstance (module ,LoRALinear ):
module .merge_lora_weights ()
logger .info ("LoRA weights merged into base model")
def unmerge_lora_weights (self ):
"""Unmerge LoRA weights for continued training."""
if not self .lora_applied :
return
for module in self .modules ():
if isinstance (module ,LoRALinear ):
module .unmerge_lora_weights ()
logger .info ("LoRA weights unmerged")
def save_pretrained (
self ,
path :str ,
optimizer =None ,
scheduler =None ,
global_step :int =0 ,
epoch :int =0 ,
best_loss :float =float ('inf'),
sharded :bool =False ,
max_shard_size :int =2 *1024 *1024 *1024 ,
save_separately :bool =True ,
):
"""
Save model and optionally training state for resuming.
Args:
path: Directory to save the model
optimizer: Optional optimizer to save state
scheduler: Optional scheduler to save state
global_step: Current training step
epoch: Current epoch
best_loss: Best loss achieved so far
sharded: If True, save model in multiple .safetensors files
max_shard_size: Maximum size per shard in bytes (default 2GB)
save_separately: If True, save each component as separate .safetensors files (default)
This avoids safetensors issues with shared storage in LSTM weights
"""
os .makedirs (path ,exist_ok =True )
if save_separately :
self ._save_components_safe (path )
elif sharded :
self ._save_sharded (path ,max_shard_size )
else :
self ._save_single_file_safe (path )
config_dict =self .config .to_dict ()
config_dict ['has_audio_encoder']=True
config_dict ['has_audio_decoder']=True
config_dict ['has_waveform_decoder']=hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None
config_dict ['has_vision_encoder']=hasattr (self ,'vision_encoder')and self .vision_encoder is not None
config_dict ['has_video_encoder']=hasattr (self ,'video_encoder')and self .video_encoder is not None
config_dict ['has_generator']=hasattr (self ,'generator')and self .generator is not None
config_dict ['has_video_generator']=hasattr (self ,'video_generator')and self .video_generator is not None
config_dict ['has_cross_attention']=hasattr (self ,'cross_attention_layers')and self .cross_attention_layers is not None
config_dict ['lora_applied']=self .lora_applied
config_dict ['architecture_version']=2
config_dict ['auto_map']={
'AutoConfig':'configuration_xoron.XoronConfig',
'AutoModel':'modeling_xoron.XoronModel',
'AutoModelForCausalLM':'modeling_xoron.XoronForCausalLM',
}
with open (os .path .join (path ,"config.json"),"w")as f :
json .dump (config_dict ,f ,indent =2 )
self ._copy_huggingface_files (path )
if optimizer is not None or scheduler is not None :
training_state ={
'global_step':global_step ,
'epoch':epoch ,
'best_loss':best_loss ,
}
if optimizer is not None :
training_state ['optimizer_state_dict']=optimizer .state_dict ()
if scheduler is not None :
training_state ['scheduler_state_dict']=scheduler .state_dict ()
torch .save (training_state ,os .path .join (path ,"training_state.pt"))
logger .info (f"Training state saved (step {global_step }, epoch {epoch })")
logger .info (f"Model saved to {path }")
def _copy_huggingface_files (self ,path :str ):
"""
Build and copy HuggingFace custom code files for trust_remote_code support.
This DYNAMICALLY BUILDS a self-contained modeling_xoron.py by combining
all model components, so users can load from HuggingFace Hub with:
model = AutoModel.from_pretrained("repo/model", trust_remote_code=True)
WITHOUT needing to install the full Xoron-Dev package.
Args:
path: Directory to save the files
"""
import shutil
current_dir =os .path .dirname (os .path .abspath (__file__ ))
project_root =os .path .dirname (current_dir )
config_src =os .path .join (project_root ,'configuration_xoron.py')
config_dst =os .path .join (path ,'configuration_xoron.py')
if os .path .exists (config_src ):
shutil .copy2 (config_src ,config_dst )
logger .info ("Copied configuration_xoron.py")
modeling_dst =os .path .join (path ,'modeling_xoron.py')
self ._build_self_contained_modeling_file (project_root ,modeling_dst )
logger .info ("HuggingFace custom code files ready")
def _build_self_contained_modeling_file (self ,project_root :str ,output_path :str ):
"""
Build a self-contained modeling_xoron.py by combining all model components.
This creates a single file with ALL model code embedded, removing internal
imports so it works standalone on HuggingFace without the full package.
"""
import re
component_files =[
"models/components/lora.py",
"models/components/attention.py",
"models/components/projectors.py",
"models/components/moe.py",
"models/encoders/vision.py",
"models/encoders/video.py",
"models/encoders/audio.py",
"models/generators/image.py",
"models/generators/video.py",
"models/llm/moe_llama.py",
"models/xoron.py",
]
internal_import_patterns =[
r"^from config import.*$",
r"^from config\..*import.*$",
r"^from models\..*import.*$",
r"^from models import.*$",
]
def is_internal_import (line ):
line =line .strip ()
for pattern in internal_import_patterns :
if re .match (pattern ,line ):
return True
return False
def is_module_level_import (line ):
"""Check if this is a module-level import (no indentation)."""
stripped =line .strip ()
if line and not line [0 ].isspace ():
return (stripped .startswith ("import ")or stripped .startswith ("from "))
return False
def extract_code_body (content ):
"""Extract code body, removing module docstring and module-level imports only."""
lines =content .split ('\n')
code_lines =[]
i =0
in_multiline_import =False
while i <len (lines )and not lines [i ].strip ():
i +=1
if i <len (lines ):
stripped =lines [i ].strip ()
if stripped .startswith ('"""')or stripped .startswith ("'''"):
docstring_char =stripped [:3 ]
if stripped .count (docstring_char )>=2 :
i +=1
else :
i +=1
while i <len (lines ):
if docstring_char in lines [i ]:
i +=1
break
i +=1
for line in lines [i :]:
stripped =line .strip ()
if not code_lines and not stripped :
continue
if in_multiline_import :
if ')'in stripped :
in_multiline_import =False
continue
if is_module_level_import (line ):
if '('in stripped and ')'not in stripped :
in_multiline_import =True
continue
if stripped .startswith ("logger = logging.getLogger")and not line [0 ].isspace ():
continue
code_lines .append (line )
while code_lines and not code_lines [-1 ].strip ():
code_lines .pop ()
return '\n'.join (code_lines )
header ='''"""
Xoron Model for HuggingFace Transformers - Self-Contained Implementation.
AUTO-GENERATED FILE - Do not edit directly!
This module provides a complete, self-contained HuggingFace-compatible model class
for the Xoron multimodal model. All components are embedded directly in this file
to enable loading via AutoModel with trust_remote_code=True WITHOUT requiring
the full Xoron-Dev package to be installed.
Usage:
from transformers import AutoModel, AutoConfig
config = AutoConfig.from_pretrained("your-repo/xoron-model", trust_remote_code=True)
model = AutoModel.from_pretrained("your-repo/xoron-model", trust_remote_code=True)
"""
try:
from safetensors.torch import save_file, load_file
except ImportError:
save_file, load_file = None, None
try:
from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaMLP,
LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
)
except ImportError:
LlamaAttention = LlamaDecoderLayer = LlamaRMSNorm = LlamaMLP = None
LlamaRotaryEmbedding = apply_rotary_pos_emb = repeat_kv = None
try:
from .configuration_xoron import XoronConfig
except ImportError:
from configuration_xoron import XoronConfig
'''
all_code =[header ]
for filepath in component_files :
full_path =os .path .join (project_root ,filepath )
if not os .path .exists (full_path ):
logger .warning (f"Component not found: {filepath }")
continue
with open (full_path ,'r',encoding ='utf-8')as f :
content =f .read ()
code =extract_code_body (content )
if code .strip ():
section_name =filepath .replace ('/','.').replace ('.py','').upper ()
section_header = f"""\n\n {'='*78}\n {section_name}\n {'='*78}\n\n"""
all_code .append (section_header +code )
hf_wrapper ='''
class XoronPreTrainedModel(PreTrainedModel):
"""Base class for Xoron models providing HuggingFace integration."""
config_class = XoronConfig
base_model_prefix = "xoron"
supports_gradient_checkpointing = True
_no_split_modules = ["XoronMultimodalModel"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class XoronModel(XoronPreTrainedModel):
"""Xoron Multimodal Model for HuggingFace."""
def __init__(self, config: XoronConfig):
super().__init__(config)
self.config = config
self._internal_model = None
self._model_initialized = False
def _ensure_model_initialized(self):
"""Lazily initialize the internal model to avoid meta device conflicts."""
if not self._model_initialized:
self._internal_model = XoronMultimodalModel(self.config)
self._model_initialized = True
@property
def internal_model(self):
self._ensure_model_initialized()
return self._internal_model
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load pretrained Xoron model from HuggingFace Hub or local path.
This override ensures proper initialization without meta device conflicts.
"""
kwargs.pop('device_map', None)
config = kwargs.pop('config', None)
if config is None:
config = XoronConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
model._internal_model = XoronMultimodalModel(config)
model._model_initialized = True
import os
from safetensors import safe_open
if os.path.isdir(pretrained_model_name_or_path):
model_path = pretrained_model_name_or_path
else:
from huggingface_hub import snapshot_download
model_path = snapshot_download(repo_id=pretrained_model_name_or_path)
components_json = os.path.join(model_path, "components.json")
if os.path.exists(components_json):
with open(components_json, 'r') as f:
manifest = json.load(f)
component_map = {
'llm': model._internal_model.llm,
'vision_encoder': model._internal_model.vision_encoder,
'video_encoder': model._internal_model.video_encoder,
'audio_encoder': model._internal_model.audio_encoder,
'audio_decoder': model._internal_model.audio_decoder,
'projector': model._internal_model.projector,
'audio_projector': model._internal_model.audio_projector,
}
if model._internal_model.cross_attention_layers is not None:
component_map['cross_attention'] = model._internal_model.cross_attention_layers
if model._internal_model.generator is not None:
component_map['generator'] = model._internal_model.generator
if model._internal_model.video_generator is not None:
component_map['video_generator'] = model._internal_model.video_generator
if hasattr(model._internal_model, 'waveform_decoder') and model._internal_model.waveform_decoder is not None:
component_map['waveform_decoder'] = model._internal_model.waveform_decoder
for comp_name in manifest.get('components', []):
if comp_name == 'modality_markers':
continue
comp_path = os.path.join(model_path, f"{comp_name}.safetensors")
if os.path.exists(comp_path) and comp_name in component_map:
component = component_map[comp_name]
if component is not None:
with safe_open(comp_path, framework="pt") as f:
state_dict = {k: f.get_tensor(k) for k in f.keys()}
if comp_name == 'llm':
embed_key = 'model.embed_tokens.weight'
lm_head_key = 'lm_head.weight'
if embed_key in state_dict:
saved_vocab_size = state_dict[embed_key].shape[0]
hidden_size = state_dict[embed_key].shape[1]
current_vocab_size = component.model.embed_tokens.weight.shape[0]
if saved_vocab_size != current_vocab_size:
logger.info(f"Resizing embeddings: {current_vocab_size} -> {saved_vocab_size}")
new_embed = nn.Embedding(saved_vocab_size, hidden_size)
new_embed.weight.data = state_dict[embed_key]
component.model.embed_tokens = new_embed
if lm_head_key in state_dict:
new_lm_head = nn.Linear(hidden_size, saved_vocab_size, bias=False)
new_lm_head.weight.data = state_dict[lm_head_key]
component.lm_head = new_lm_head
del state_dict[embed_key]
if lm_head_key in state_dict:
del state_dict[lm_head_key]
component.load_state_dict(state_dict, strict=False)
logger.info(f"Loaded {comp_name}")
markers_path = os.path.join(model_path, "modality_markers.safetensors")
if os.path.exists(markers_path):
with safe_open(markers_path, framework="pt") as f:
model._internal_model.image_start.data = f.get_tensor('image_start')
model._internal_model.image_end.data = f.get_tensor('image_end')
model._internal_model.video_start.data = f.get_tensor('video_start')
model._internal_model.video_end.data = f.get_tensor('video_end')
model._internal_model.audio_start.data = f.get_tensor('audio_start')
model._internal_model.audio_end.data = f.get_tensor('audio_end')
logger.info("Loaded modality markers")
logger.info(f"Xoron model loaded from {pretrained_model_name_or_path}")
return model
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
video_frames: Optional[torch.Tensor] = None,
audio_features: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
self._ensure_model_initialized()
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self._internal_model(
input_ids=input_ids,
attention_mask=attention_mask,
images=pixel_values,
video=video_frames,
audio=audio_features,
labels=labels,
)
if return_dict:
return CausalLMOutputWithPast(
loss=outputs.get("loss"),
logits=outputs.get("logits"),
past_key_values=outputs.get("past_key_values"),
hidden_states=outputs.get("hidden_states"),
attentions=outputs.get("attentions"),
)
return (outputs.get("loss"), outputs.get("logits"))
def generate_image(self, prompt_embeds: torch.Tensor, **kwargs):
self._ensure_model_initialized()
return self._internal_model.generate_image(prompt_embeds, **kwargs)
def generate_video(self, prompt_embeds: torch.Tensor, **kwargs):
self._ensure_model_initialized()
return self._internal_model.generate_video(prompt_embeds, **kwargs)
def generate_speech(self, text_embeds: torch.Tensor, **kwargs):
self._ensure_model_initialized()
return self._internal_model.generate_speech(text_embeds, **kwargs)
class XoronForCausalLM(XoronModel):
"""Alias for XoronModel for compatibility."""
pass
XoronConfig.register_for_auto_class()
XoronModel.register_for_auto_class("AutoModel")
XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
'''
all_code .append (hf_wrapper )
final_content ='\n'.join (all_code )
with open (output_path ,'w',encoding ='utf-8')as f :
f .write (final_content )
line_count =final_content .count ('\n')
logger .info (f"Built self-contained modeling_xoron.py ({line_count :,} lines)")
def _save_single_file_safe (self ,path :str ):
"""
Save model as single safetensors file with cloned tensors.
Cloning breaks shared storage that causes safetensors errors.
Args:
path: Directory to save the model
"""
from safetensors .torch import save_file
state_dict =self .state_dict ()
safe_state_dict ={}
for key ,tensor in state_dict .items ():
safe_state_dict [key ]=tensor .clone ().contiguous ()
save_file (safe_state_dict ,os .path .join (path ,"model.safetensors"))
size_mb =sum (t .numel ()*t .element_size ()for t in safe_state_dict .values ())/(1024 *1024 )
logger .info (f"Saved model.safetensors ({size_mb :.1f} MB)")
def _save_components_safe (self ,path :str ):
"""
Save model components as separate .safetensors files with cloned tensors.
This is the default and most robust saving method that:
1. Handles LSTM weight sharing issues in safetensors
2. Allows surgical component loading/updates
3. Better for debugging and inspection
Args:
path: Directory to save component files
"""
from safetensors .torch import save_file
os .makedirs (path ,exist_ok =True )
component_map ={
'llm':self .llm ,
'vision_encoder':self .vision_encoder ,
'video_encoder':self .video_encoder ,
'audio_encoder':self .audio_encoder ,
'audio_decoder':self .audio_decoder ,
'projector':self .projector ,
'audio_projector':self .audio_projector ,
}
if self .cross_attention_layers is not None :
component_map ['cross_attention']=self .cross_attention_layers
if self .generator is not None :
component_map ['generator']=self .generator
if self .video_generator is not None :
component_map ['video_generator']=self .video_generator
if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None :
component_map ['waveform_decoder']=self .waveform_decoder
saved_files =[]
total_size =0
for comp_name ,component in component_map .items ():
if component is None :
continue
comp_state =component .state_dict ()
if not comp_state :
continue
safe_comp_state ={}
for key ,tensor in comp_state .items ():
safe_comp_state [key ]=tensor .clone ().contiguous ()
comp_path =os .path .join (path ,f"{comp_name }.safetensors")
save_file (safe_comp_state ,comp_path )
size_mb =sum (t .numel ()*t .element_size ()for t in safe_comp_state .values ())/(1024 *1024 )
total_size +=size_mb
logger .info (f"Saved {comp_name }: {size_mb :.1f} MB")
saved_files .append (comp_name )
markers ={
'image_start':self .image_start .data .clone ().contiguous (),
'image_end':self .image_end .data .clone ().contiguous (),
'video_start':self .video_start .data .clone ().contiguous (),
'video_end':self .video_end .data .clone ().contiguous (),
'audio_start':self .audio_start .data .clone ().contiguous (),
'audio_end':self .audio_end .data .clone ().contiguous (),
}
save_file (markers ,os .path .join (path ,"modality_markers.safetensors"))
logger .info ("Saved modality_markers")
manifest ={
"components":saved_files +["modality_markers"],
"save_format":"components",
}
with open (os .path .join (path ,"components.json"),"w")as f :
json .dump (manifest ,f ,indent =2 )
weight_map ={}
total_bytes =0
for comp_name ,component in component_map .items ():
if component is None :
continue
comp_state =component .state_dict ()
if not comp_state :
continue
safetensor_file =f"{comp_name }.safetensors"
for key in comp_state .keys ():
full_key =f"{comp_name }.{key }"
weight_map [full_key ]=safetensor_file
total_bytes +=comp_state [key ].numel ()*comp_state [key ].element_size ()
marker_names =['image_start','image_end','video_start','video_end','audio_start','audio_end']
for marker_name in marker_names :
weight_map [marker_name ]="modality_markers.safetensors"
marker_tensor =getattr (self ,marker_name )
total_bytes +=marker_tensor .numel ()*marker_tensor .element_size ()
index ={
"metadata":{
"total_size":total_bytes ,
"format":"components",
},
"weight_map":weight_map ,
}
index_path =os .path .join (path ,"model.safetensors.index.json")
with open (index_path ,"w")as f :
json .dump (index ,f ,indent =2 )
logger .info ("Saved model.safetensors.index.json for HuggingFace compatibility")
logger .info (f"Total size: {total_size :.1f} MB across {len (saved_files )} components")
def _save_sharded (self ,path :str ,max_shard_size :int ):
"""
Save model weights in sharded .safetensors files.
Components are surgically split across shards.
Args:
path: Directory to save shards
max_shard_size: Maximum bytes per shard
"""
from safetensors .torch import save_file
state_dict =self .state_dict ()
component_groups ={
'llm':{},
'vision_encoder':{},
'video_encoder':{},
'audio_encoder':{},
'audio_decoder':{},
'waveform_decoder':{},
'generator':{},
'video_generator':{},
'projector':{},
'audio_projector':{},
'cross_attention_layers':{},
'other':{},
}
for key ,tensor in state_dict .items ():
placed =False
for comp_name in component_groups .keys ():
if comp_name !='other'and key .startswith (comp_name ):
component_groups [comp_name ][key ]=tensor
placed =True
break
if not placed :
component_groups ['other'][key ]=tensor
shards =[]
current_shard ={}
current_size =0
shard_index_map ={}
for comp_name ,comp_tensors in component_groups .items ():
for key ,tensor in comp_tensors .items ():
tensor_size =tensor .numel ()*tensor .element_size ()
if current_size +tensor_size >max_shard_size and current_shard :
shards .append (current_shard )
current_shard ={}
current_size =0
current_shard [key ]=tensor
current_size +=tensor_size
if current_shard :
shards .append (current_shard )
total_shards =len (shards )
weight_map ={}
for i ,shard in enumerate (shards ):
shard_name =f"model-{i +1 :05d}-of-{total_shards :05d}.safetensors"
shard_path =os .path .join (path ,shard_name )
shard_contiguous ={k :v .clone ().contiguous ()for k ,v in shard .items ()}
save_file (shard_contiguous ,shard_path )
for key in shard .keys ():
weight_map [key ]=shard_name
shard_size_mb =sum (t .numel ()*t .element_size ()for t in shard .values ())/(1024 *1024 )
logger .info (f"Saved shard {i +1 }/{total_shards }: {shard_name } ({shard_size_mb :.1f} MB)")
index ={
"metadata":{
"total_size":sum (t .numel ()*t .element_size ()for t in state_dict .values ()),
"total_shards":total_shards ,
},
"weight_map":weight_map ,
}
index_path =os .path .join (path ,"model.safetensors.index.json")
with open (index_path ,"w")as f :
json .dump (index ,f ,indent =2 )
logger .info ("Saved index: model.safetensors.index.json")
def save_components_separately (self ,path :str ):
"""
Save model components as separate .safetensors files.
Useful for surgical component updates and debugging.
NOTE: This method now clones tensors to handle LSTM shared storage issues.
Args:
path: Directory to save component files
"""
from safetensors .torch import save_file
os .makedirs (path ,exist_ok =True )
component_map ={
'llm':self .llm ,
'vision_encoder':self .vision_encoder ,
'video_encoder':self .video_encoder ,
'audio_encoder':self .audio_encoder ,
'audio_decoder':self .audio_decoder ,
'projector':self .projector ,
'audio_projector':self .audio_projector ,
}
if self .cross_attention_layers is not None :
component_map ['cross_attention']=self .cross_attention_layers
if self .generator is not None :
component_map ['generator']=self .generator
if self .video_generator is not None :
component_map ['video_generator']=self .video_generator
if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None :
component_map ['waveform_decoder']=self .waveform_decoder
saved_files =[]
for comp_name ,component in component_map .items ():
if component is None :
continue
comp_state =component .state_dict ()
if not comp_state :
continue
comp_state ={k :v .clone ().contiguous ()for k ,v in comp_state .items ()}
comp_path =os .path .join (path ,f"{comp_name }.safetensors")
save_file (comp_state ,comp_path )
size_mb =sum (t .numel ()*t .element_size ()for t in comp_state .values ())/(1024 *1024 )
logger .info (f"Saved {comp_name }: {size_mb :.1f} MB")
saved_files .append (comp_name )
markers ={
'image_start':self .image_start .data .clone ().contiguous (),
'image_end':self .image_end .data .clone ().contiguous (),
'video_start':self .video_start .data .clone ().contiguous (),
'video_end':self .video_end .data .clone ().contiguous (),
'audio_start':self .audio_start .data .clone ().contiguous (),
'audio_end':self .audio_end .data .clone ().contiguous (),
}
save_file (markers ,os .path .join (path ,"modality_markers.safetensors"))
logger .info ("Saved modality_markers")
manifest ={
"components":saved_files +["modality_markers"],
"config":self .config .to_dict (),
"lora_applied":self .lora_applied ,
}
with open (os .path .join (path ,"components.json"),"w")as f :
json .dump (manifest ,f ,indent =2 )
weight_map ={}
total_bytes =0
for comp_name ,component in component_map .items ():
if component is None :
continue
comp_state =component .state_dict ()
if not comp_state :
continue
safetensor_file =f"{comp_name }.safetensors"
for key in comp_state .keys ():
full_key =f"{comp_name }.{key }"
weight_map [full_key ]=safetensor_file
total_bytes +=comp_state [key ].numel ()*comp_state [key ].element_size ()
marker_names =['image_start','image_end','video_start','video_end','audio_start','audio_end']
for marker_name in marker_names :
weight_map [marker_name ]="modality_markers.safetensors"
marker_tensor =getattr (self ,marker_name )
total_bytes +=marker_tensor .numel ()*marker_tensor .element_size ()
index ={
"metadata":{
"total_size":total_bytes ,
"format":"components",
},
"weight_map":weight_map ,
}
index_path =os .path .join (path ,"model.safetensors.index.json")
with open (index_path ,"w")as f :
json .dump (index ,f ,indent =2 )
logger .info ("Saved model.safetensors.index.json for HuggingFace compatibility")
logger .info (f"Components saved to {path }")
@classmethod
def from_pretrained (
cls ,
path :str ,
device :str =None ,
device_map :Dict [str ,str ]=None ,
apply_lora :bool =True ,
strict :bool =False ,
)->'XoronMultimodalModel':
"""
Load a pretrained Xoron model from a checkpoint or final model directory.
Args:
path: Path to the saved model directory
device: Device to load the model to (if not using device_map)
device_map: Device map for model parallelism
apply_lora: Whether to apply LoRA after loading
strict: If False, allows loading weights even if architecture changed
Returns:
Loaded XoronMultimodalModel instance
"""
from safetensors import safe_open
logger .info (f"Loading model from {path }...")
config_path =os .path .join (path ,"config.json")
if not os .path .exists (config_path ):
raise FileNotFoundError (f"Config file not found at {config_path }")
with open (config_path ,'r')as f :
config_dict =json .load (f )
lora_was_applied =config_dict .pop ('lora_applied',False )
architecture_version =config_dict .pop ('architecture_version',1 )
has_waveform_decoder =config_dict .pop ('has_waveform_decoder',False )
has_vision_encoder =config_dict .pop ('has_vision_encoder',True )
has_video_encoder =config_dict .pop ('has_video_encoder',True )
has_generator =config_dict .pop ('has_generator',True )
has_video_generator =config_dict .pop ('has_video_generator',True )
has_cross_attention =config_dict .pop ('has_cross_attention',True )
config_dict .pop ('has_audio_encoder',None )
config_dict .pop ('has_audio_decoder',None )
logger .info (f"Saved model architecture (version {architecture_version }):")
logger .info (f" - Waveform Decoder: {'✅'if has_waveform_decoder else '❌ (will init randomly)'}")
logger .info (f" - Vision Encoder: {'✅'if has_vision_encoder else '❌'}")
logger .info (f" - Video Encoder: {'✅'if has_video_encoder else '❌'}")
logger .info (f" - Image Generator: {'✅'if has_generator else '❌'}")
logger .info (f" - Video Generator: {'✅'if has_video_generator else '❌'}")
logger .info (f" - Cross Attention: {'✅'if has_cross_attention else '❌'}")
logger .info (f" - LoRA Applied: {'✅'if lora_was_applied else '❌'}")
config =XoronConfig .from_dict (config_dict )
model =cls (config ,device_map =device_map )
if lora_was_applied:
logger .info ("Checkpoint has LoRA weights. Applying LoRA structure before loading...")
model .apply_lora ()
components_json =os .path .join (path ,"components.json")
model_path =os .path .join (path ,"model.safetensors")
if os .path .exists (components_json ):
logger .info ("Loading from component-based format...")
model ._load_components (path ,strict =strict )
model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
elif os .path .exists (model_path ):
logger .info ("Loading weights from safetensors...")
if strict :
load_model (model ,model_path )
else :
checkpoint_state_dict ={}
with safe_open (model_path ,framework ="pt",device ="cpu")as f :
for key in f .keys ():
checkpoint_state_dict [key ]=f .get_tensor (key )
model .load_state_dict (checkpoint_state_dict ,strict =False )
logger .info ("Loaded weights from checkpoint")
model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
else :
pytorch_path =os .path .join (path ,"pytorch_model.bin")
if os .path .exists (pytorch_path ):
logger .info ("Loading weights from pytorch_model.bin...")
checkpoint_state_dict =torch .load (pytorch_path ,map_location ='cpu')
model .load_state_dict (checkpoint_state_dict ,strict =False )
logger .info ("Loaded weights from checkpoint")
model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
else :
raise FileNotFoundError (f"No model weights found at {path }")
if apply_lora and config .use_lora and not model .lora_applied :
model .apply_lora ()
if device_map is not None :
model .apply_model_parallel (device_map )
elif device is not None :
model =model .to (device )
logger .info ("Model loaded successfully!")
model ._print_stats ()
return model
def _load_components (self ,path :str ,strict :bool =False ):
"""
Load model from component-based safetensors files.
Args:
path: Directory containing component files
strict: If True, require exact match; if False, allow partial loading
"""
from safetensors import safe_open
component_map ={
'llm':self .llm ,
'vision_encoder':self .vision_encoder ,
'video_encoder':self .video_encoder ,
'audio_encoder':self .audio_encoder ,
'audio_decoder':self .audio_decoder ,
'projector':self .projector ,
'audio_projector':self .audio_projector ,
}
if self .cross_attention_layers is not None :
component_map ['cross_attention']=self .cross_attention_layers
if self .generator is not None :
component_map ['generator']=self .generator
if self .video_generator is not None :
component_map ['video_generator']=self .video_generator
if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None :
component_map ['waveform_decoder']=self .waveform_decoder
for comp_name ,component in component_map .items ():
if component is None :
continue
comp_path =os .path .join (path ,f"{comp_name }.safetensors")
if not os .path .exists (comp_path ):
continue
try :
checkpoint_state ={}
with safe_open (comp_path ,framework ="pt",device ="cpu")as f :
for key in f .keys ():
checkpoint_state [key ]=f .get_tensor (key )
component .load_state_dict (checkpoint_state ,strict =strict )
size_mb =sum (t .numel ()*t .element_size ()for t in checkpoint_state .values ())/(1024 *1024 )
logger .info (f"Loaded {comp_name } ({size_mb :.1f} MB)")
except Exception as e :
logger .warning (f"Error loading {comp_name }: {e }")
markers_path =os .path .join (path ,"modality_markers.safetensors")
if os .path .exists (markers_path ):
try :
with safe_open (markers_path ,framework ="pt",device ="cpu")as f :
self .image_start .data =f .get_tensor ('image_start')
self .image_end .data =f .get_tensor ('image_end')
self .video_start .data =f .get_tensor ('video_start')
self .video_end .data =f .get_tensor ('video_end')
self .audio_start .data =f .get_tensor ('audio_start')
self .audio_end .data =f .get_tensor ('audio_end')
logger .info ("Loaded modality_markers")
except Exception as e :
logger .warning (f"Error loading modality_markers: {e }")
logger .info ("Components loaded successfully")
@staticmethod
def load_training_state (path :str )->Optional [Dict ]:
"""
Load training state from a checkpoint.
Args:
path: Path to the checkpoint directory
Returns:
Dictionary with training state or None if not found
"""
state_path =os .path .join (path ,"training_state.pt")
if os .path .exists (state_path ):
logger .info (f"Loading training state from {state_path }...")
return torch .load (state_path ,map_location ='cpu')
return None
def freeze_components (self ,components :List [str ],hard_freeze :bool =True ):
"""
Freeze specific components of the model.
IMPORTANT RULES:
1. LLM is NEVER frozen - it's trained from scratch and always needs full weight training
2. LoRA parameters are usually kept trainable, UNLESS hard_freeze=True
Args:
components: List of component group names to freeze.
Valid groups: 'vision', 'video', 'audio',
'cross_attention', 'image_generation', 'video_generation',
'modality_markers'
NOTE: 'llm' is NOT a valid group to freeze - will be ignored!
hard_freeze: If True, completely freezes the component including its LoRA adapters.
This prevents inactive components from updating via weight decay/momentum.
"""
if 'llm'in components :
logger .warning ("Ignoring 'llm' in freeze list - LLM must always train (from scratch)")
components =[c for c in components if c !='llm']
logger .info (f"Freezing components: {components } (hard_freeze={hard_freeze })")
for group_name in components :
if group_name not in COMPONENT_GROUPS :
logger .warning (f" ⚠️ Unknown component group: {group_name }")
continue
for attr_name in COMPONENT_GROUPS [group_name ]:
if hasattr (self ,attr_name ):
component =getattr (self ,attr_name )
if component is not None :
if isinstance (component ,nn .Parameter ):
component .requires_grad =False
elif isinstance (component ,nn .Module ):
for name ,param in component .named_parameters ():
path_lora ='lora_A'in name or 'lora_B'in name or 'magnitude'in name
if hard_freeze or not path_lora :
param .requires_grad =False
logger .info (f"Frozen: {attr_name }")
if self .lora_applied and not hard_freeze:
enable_lora_training (self )
logger .info ("LoRA parameters remain trainable")
self ._print_stats ()
def unfreeze_components (self ,components :List [str ]):
"""
Unfreeze specific components of the model.
Args:
components: List of component group names to unfreeze.
"""
logger .info (f"Unfreezing components: {components }")
for group_name in components :
if group_name not in COMPONENT_GROUPS :
logger .warning (f" ⚠️ Unknown component group: {group_name }")
continue
for attr_name in COMPONENT_GROUPS [group_name ]:
if hasattr (self ,attr_name ):
component =getattr (self ,attr_name )
if component is not None :
if isinstance (component ,nn .Parameter ):
component .requires_grad =True
elif isinstance (component ,nn .Module ):
for param in component .parameters ():
param .requires_grad =True
logger .info (f"Unfrozen: {attr_name }")
self ._print_stats ()
def freeze_all_except (self ,components :List [str ],hard_freeze :bool =True ):
"""
Freeze all components except the specified ones.
NOTE: LLM is always kept trainable regardless of input - it's trained from scratch.
Args:
components: List of component group names to keep trainable.
"""
if 'llm'not in components :
components =components +['llm']
all_groups =list (COMPONENT_GROUPS .keys ())
groups_to_freeze =[g for g in all_groups if g not in components ]
self .freeze_components (groups_to_freeze ,hard_freeze =hard_freeze )
def get_trainable_component_names (self )->List [str ]:
"""Get list of component groups that have trainable parameters."""
trainable =[]
for group_name ,attr_names in COMPONENT_GROUPS .items ():
for attr_name in attr_names :
if hasattr (self ,attr_name ):
component =getattr (self ,attr_name )
if component is not None :
if isinstance (component ,nn .Parameter ):
if component .requires_grad :
trainable .append (group_name )
break
elif isinstance (component ,nn .Module ):
if any (p .requires_grad for p in component .parameters ()):
trainable .append (group_name )
break
return trainable
def get_frozen_component_names (self )->List [str ]:
"""Get list of component groups that are frozen (no trainable parameters)."""
frozen =[]
for group_name ,attr_names in COMPONENT_GROUPS .items ():
has_component =False
is_trainable =False
for attr_name in attr_names :
if hasattr (self ,attr_name ):
component =getattr (self ,attr_name )
if component is not None :
has_component =True
if isinstance (component ,nn .Parameter ):
if component .requires_grad :
is_trainable =True
break
elif isinstance (component ,nn .Module ):
if any (p .requires_grad for p in component .parameters ()):
is_trainable =True
break
if has_component and not is_trainable :
frozen .append (group_name )
return frozen
def get_component_status (self )->tuple :
"""
Get tuple of (trainable_components, frozen_components) for display.
Returns:
tuple: (list of trainable component names, list of frozen component names)
"""
trainable =self .get_trainable_component_names ()
frozen =self .get_frozen_component_names ()
return trainable ,frozen
class XoronPreTrainedModel(PreTrainedModel):
"""Base class for Xoron models providing HuggingFace integration."""
config_class = XoronConfig
base_model_prefix = "xoron"
supports_gradient_checkpointing = True
_no_split_modules = ["XoronMultimodalModel"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class XoronModel(XoronPreTrainedModel):
"""Xoron Multimodal Model for HuggingFace."""
def __init__(self, config: XoronConfig):
super().__init__(config)
self.config = config
self._internal_model = None
self._model_initialized = False
def _ensure_model_initialized(self):
"""Lazily initialize the internal model to avoid meta device conflicts."""
if not self._model_initialized:
self._internal_model = XoronMultimodalModel(self.config)
self._model_initialized = True
@property
def internal_model(self):
self._ensure_model_initialized()
return self._internal_model
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load pretrained Xoron model from HuggingFace Hub or local path.
This override ensures proper initialization without meta device conflicts.
"""
kwargs.pop('device_map', None)
config = kwargs.pop('config', None)
if config is None:
config = XoronConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
model._internal_model = XoronMultimodalModel(config)
model._model_initialized = True
import os
from safetensors import safe_open
if os.path.isdir(pretrained_model_name_or_path):
model_path = pretrained_model_name_or_path
else:
from huggingface_hub import snapshot_download
model_path = snapshot_download(repo_id=pretrained_model_name_or_path)
components_json = os.path.join(model_path, "components.json")
if os.path.exists(components_json):
with open(components_json, 'r') as f:
manifest = json.load(f)
component_map = {
'llm': model._internal_model.llm,
'vision_encoder': model._internal_model.vision_encoder,
'video_encoder': model._internal_model.video_encoder,
'audio_encoder': model._internal_model.audio_encoder,
'audio_decoder': model._internal_model.audio_decoder,
'projector': model._internal_model.projector,
'audio_projector': model._internal_model.audio_projector,
}
if model._internal_model.cross_attention_layers is not None:
component_map['cross_attention'] = model._internal_model.cross_attention_layers
if model._internal_model.generator is not None:
component_map['generator'] = model._internal_model.generator
if model._internal_model.video_generator is not None:
component_map['video_generator'] = model._internal_model.video_generator
if hasattr(model._internal_model, 'waveform_decoder') and model._internal_model.waveform_decoder is not None:
component_map['waveform_decoder'] = model._internal_model.waveform_decoder
for comp_name in manifest.get('components', []):
if comp_name == 'modality_markers':
continue
comp_path = os.path.join(model_path, f"{comp_name}.safetensors")
if os.path.exists(comp_path) and comp_name in component_map:
component = component_map[comp_name]
if component is not None:
with safe_open(comp_path, framework="pt") as f:
state_dict = {k: f.get_tensor(k) for k in f.keys()}
if comp_name == 'llm':
embed_key = 'model.embed_tokens.weight'
lm_head_key = 'lm_head.weight'
if embed_key in state_dict:
saved_vocab_size = state_dict[embed_key].shape[0]
hidden_size = state_dict[embed_key].shape[1]
current_vocab_size = component.model.embed_tokens.weight.shape[0]
if saved_vocab_size != current_vocab_size:
logger.info(f"Resizing embeddings: {current_vocab_size} -> {saved_vocab_size}")
new_embed = nn.Embedding(saved_vocab_size, hidden_size)
new_embed.weight.data = state_dict[embed_key]
component.model.embed_tokens = new_embed
if lm_head_key in state_dict:
new_lm_head = nn.Linear(hidden_size, saved_vocab_size, bias=False)
new_lm_head.weight.data = state_dict[lm_head_key]
component.lm_head = new_lm_head
del state_dict[embed_key]
if lm_head_key in state_dict:
del state_dict[lm_head_key]
component.load_state_dict(state_dict, strict=False)
logger.info(f"Loaded {comp_name}")
markers_path = os.path.join(model_path, "modality_markers.safetensors")
if os.path.exists(markers_path):
with safe_open(markers_path, framework="pt") as f:
model._internal_model.image_start.data = f.get_tensor('image_start')
model._internal_model.image_end.data = f.get_tensor('image_end')
model._internal_model.video_start.data = f.get_tensor('video_start')
model._internal_model.video_end.data = f.get_tensor('video_end')
model._internal_model.audio_start.data = f.get_tensor('audio_start')
model._internal_model.audio_end.data = f.get_tensor('audio_end')
logger.info("Loaded modality markers")
logger.info(f"Xoron model loaded from {pretrained_model_name_or_path}")
return model
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
video_frames: Optional[torch.Tensor] = None,
audio_features: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
self._ensure_model_initialized()
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self._internal_model(
input_ids=input_ids,
attention_mask=attention_mask,
images=pixel_values,
video=video_frames,
audio=audio_features,
labels=labels,
)
if return_dict:
return CausalLMOutputWithPast(
loss=outputs.get("loss"),
logits=outputs.get("logits"),
past_key_values=outputs.get("past_key_values"),
hidden_states=outputs.get("hidden_states"),
attentions=outputs.get("attentions"),
)
return (outputs.get("loss"), outputs.get("logits"))
def generate_image(self, prompt_embeds: torch.Tensor, **kwargs):
self._ensure_model_initialized()
return self._internal_model.generate_image(prompt_embeds, **kwargs)
def generate_video(self, prompt_embeds: torch.Tensor, **kwargs):
self._ensure_model_initialized()
return self._internal_model.generate_video(prompt_embeds, **kwargs)
def generate_speech(self, text_embeds: torch.Tensor, **kwargs):
self._ensure_model_initialized()
return self._internal_model.generate_speech(text_embeds, **kwargs)
class XoronForCausalLM(XoronModel):
"""Alias for XoronModel for compatibility."""
pass
XoronConfig.register_for_auto_class()
XoronModel.register_for_auto_class("AutoModel")
XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")