MOSS-VL-Base-0408 / modeling_moss_vl.py
CCCCyx's picture
Update modeling_moss_vl.py
fe6fdd6 verified
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import initialization as init
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import is_flash_attention_requested
from transformers.utils.output_capturing import OutputRecorder
from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionConfig
logger = logging.get_logger(__name__)
@dataclass
class MossVLModelOutputWithPast(ModelOutput):
"""
Output class for MossVL model with additional vision_token_info and rope_deltas fields.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`Cache`, *optional*):
Contains pre-computed hidden-states (key and values in the self-attention blocks and
cross-attention blocks) that can be used to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer).
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of attention weights.
vision_token_info (`List[dict]`, *optional*):
Information about vision tokens for each sample, used to correctly expand cross-attention masks.
This is cached during prefill and reused during decode to handle ViT padding correctly.
rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Position offset due to vision tokens. Used for fast position computation in decode stage.
rope_deltas = max_position - sequence_length
"""
last_hidden_state: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
vision_token_info: Optional[List[dict]] = None
rope_deltas: Optional[torch.LongTensor] = None
@dataclass
class MossVLCausalLMOutputWithPast(ModelOutput):
"""
Output class for MossVL causal language model with additional vision_token_info and rope_deltas fields.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head.
past_key_values (`Cache`, *optional*):
Contains pre-computed hidden-states for speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of hidden-states at each layer.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of attention weights.
vision_token_info (`List[dict]`, *optional*):
Information about vision tokens for each sample, cached for decode stage.
rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Position offset due to vision tokens. Used for fast position computation in decode stage.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
vision_token_info: Optional[List[dict]] = None
rope_deltas: Optional[torch.LongTensor] = None
# ==================== Vision Components (from Qwen3VL) ====================
class MossVLVisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
class MossVLVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
class MossVLVisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
# Keep dim / theta so that `_init_weights` can rebuild `inv_freq` after
# from_pretrained materializes the module (it is a non-persistent buffer
# and therefore never populated by the checkpoint).
self.dim = dim
self.theta = theta
inv_freq = self.compute_inv_freq()
self.register_buffer("inv_freq", inv_freq, persistent=False)
def compute_inv_freq(self) -> torch.Tensor:
return 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class MossVLVisionPatchMerger(nn.Module):
def __init__(self, config: MossVLVisionConfig, num_deepstack_features=0) -> None:
super().__init__()
# spatial_merge,维度变为原始的config.spatial_merge_size**2倍
base_hidden_size = config.hidden_size * (config.spatial_merge_size**2)
# 计算输入维度:spatial_merge 后的维度 * (1 + deepstack特征数)
self.input_hidden_size = base_hidden_size * (1 + num_deepstack_features)
# Use independent LayerNorms for each feature level
# Total features = 1 (last layer) + num_deepstack_features
num_features = 1 + num_deepstack_features
self.norms = nn.ModuleList([
nn.LayerNorm(config.hidden_size, eps=1e-6)
for _ in range(num_features)
])
self.hidden_size = config.hidden_size
self.linear_fc1 = nn.Linear(self.input_hidden_size, self.input_hidden_size)
self.act_fn = nn.GELU()
self.linear_fc2 = nn.Linear(self.input_hidden_size, config.out_hidden_size)
def forward(
self,
last_hidden_state: torch.Tensor,
deepstack_features: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
# 1. Collect all features: [last_hidden_state, deepstack_1, deepstack_2, ...]
# self.norms[0] corresponds to last_hidden_state
# self.norms[1:] corresponds to deepstack_features
if deepstack_features is None:
deepstack_features = []
all_inputs = [last_hidden_state] + deepstack_features
# 2. Apply Norm independently
outs = []
for i, feat in enumerate(all_inputs):
outs.append(self.norms[i](feat))
# 3. Concat once
x = torch.cat(outs, dim=-1)
# 做merge,维度变为原始的config.spatial_merge_size**2倍,len对应缩小为原来的1/config.spatial_merge_size**2
x = x.view(-1, self.input_hidden_size)
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return x
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class MossVLVisionAttention(nn.Module):
def __init__(self, config: MossVLVisionConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
if is_flash_attention_requested(self.config):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class MossVLVisionBlock(GradientCheckpointingLayer):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.attn = MossVLVisionAttention(config=config)
self.mlp = MossVLVisionMLP(config=config)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
# ==================== Text Components (from Qwen3 + Cross Attention) ====================
class MossVLTextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: MossVLTextConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
rope_parameters = getattr(config, "rope_parameters", None)
if rope_parameters is None:
rope_parameters = getattr(config, "rope_scaling", None) or {"rope_type": "default"}
self.rope_type = rope_parameters.get("rope_type", rope_parameters.get("type", "default"))
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
self.mrope_section = rope_parameters.get("mrope_section", [24, 20, 20])
@staticmethod
def compute_default_rope_parameters(
config: Optional[MossVLTextConfig] = None,
device: Optional[torch.device] = None,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, float]:
rope_parameters = getattr(config, "rope_parameters", None) or {}
base = rope_parameters.get("rope_theta", getattr(config, "rope_theta", 10000.0))
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
partial_rotary_factor = rope_parameters.get(
"partial_rotary_factor", getattr(config, "partial_rotary_factor", 1.0)
)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
"""
freqs_t = freqs[0] # just overwrite the first dimension T
for dim, offset in enumerate((1, 2), start=1): # H, W
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@use_kernel_forward_from_hub("RMSNorm")
class MossVLTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# self attention rotary position embedding
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# cross attention rotary position embedding
def apply_rotary_pos_emb_cross_attention(states, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
states_embed = (states * cos) + (rotate_half(states) * sin)
return states_embed
class MossVLTextSelfAttention(nn.Module):
"""Self attention for text decoder"""
def __init__(self, config: MossVLTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class MossVLTextCrossAttention(nn.Module):
"""Cross attention - for vision-text interaction"""
def __init__(self, config: MossVLTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
self.q_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: bool = None,
cache_position: Optional[torch.LongTensor] = None,
query_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
batch_size, seq_length, _ = hidden_states.size()
# Query from text hidden states
query_states = self.q_proj(hidden_states)
query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
# Key and Value from vision cross_attention_states
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = self.k_norm(key_states)
value_states = value_states.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Apply different RoPE for query (text position) and key (vision position)
if query_position_embeddings is not None:
cos, sin = query_position_embeddings
query_states = apply_rotary_pos_emb_cross_attention(query_states, cos, sin)
if vision_position_embeddings is not None:
vision_cos, vision_sin = vision_position_embeddings
key_states = apply_rotary_pos_emb_cross_attention(key_states, vision_cos, vision_sin)
if past_key_values is not None:
# if we have a new image + new tokens, we only computed key_states on that new image
# we still update the cross key states, past_image, new_image. And use it!
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_values.layers[self.layer_idx].keys,
past_key_values.layers[self.layer_idx].values,
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
if is_flash_attention_requested(self.config):
# Cross attention still relies on an explicit dense mask.
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
else:
attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class MossVLTextMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MossVLSelfAttentionDecoderLayer(GradientCheckpointingLayer):
"""Self-attention decoder layer"""
def __init__(self, config: MossVLTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = MossVLTextSelfAttention(config=config, layer_idx=layer_idx)
self.mlp = MossVLTextMLP(config)
self.input_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
vision_position_ids: Optional[torch.LongTensor] = None,
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, ...]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
"""Cross-attention decoder layer with tanh-gated attention and MLP"""
def __init__(self, config: MossVLTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.cross_attn = MossVLTextCrossAttention(config=config, layer_idx=layer_idx)
self.mlp = MossVLTextMLP(config)
self.input_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Gates for cross attention (single scalar value).
# Gate scalar = tanh(gate[0]), initialized to zero so tanh(0)=0 at start.
self.cross_attn_attn_gate = nn.Parameter(torch.zeros(1))
self.cross_attn_mlp_gate = nn.Parameter(torch.zeros(1))
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
vision_position_ids: Optional[torch.LongTensor] = None,
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, ...]:
# Cross Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights = self.cross_attn(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
attention_mask=cross_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
query_position_embeddings=position_embeddings,
vision_position_embeddings=vision_position_embeddings,
)
if full_text_row_masked_out_mask is not None:
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if full_text_row_masked_out_mask is not None:
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
@auto_docstring
class MossVLPreTrainedModel(PreTrainedModel):
config: MossVLConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MossVLSelfAttentionDecoderLayer", "MossVLCrossAttentionDecoderLayer", "MossVLVisionBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": [MossVLSelfAttentionDecoderLayer, MossVLCrossAttentionDecoderLayer],
"attentions": [
OutputRecorder(MossVLTextSelfAttention, index=1, layer_name="self_attn"), # self-attention layers
OutputRecorder(MossVLTextCrossAttention, index=1, layer_name="cross_attn"), # cross-attention layers
],
}
def _init_weights(self, module):
"""Initialize the weights.
"""
super()._init_weights(module)
if isinstance(module, MossVLVisionRotaryEmbedding):
init.copy_(module.inv_freq, module.compute_inv_freq())
class MossVLVisionModel(MossVLPreTrainedModel):
config: MossVLVisionConfig
_no_split_modules = ["MossVLVisionBlock"]
def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.patch_embed = MossVLVisionPatchEmbed(config=config)
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = MossVLVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([MossVLVisionBlock(config) for _ in range(config.depth)])
# DeepStack: 记录需要提取特征的层索引
self.deepstack_visual_indexes = config.deepstack_visual_indexes
num_deepstack_features = len(self.deepstack_visual_indexes)
# Merger: 输入维度 = hidden_size * (1 + num_deepstack_features)
self.merger = MossVLVisionPatchMerger(
config=config,
num_deepstack_features=num_deepstack_features
)
self.gradient_checkpointing = False
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size
max_hw = int(grid_thw[:, 1:].max().item())
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw:
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset : offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
embeddings = embeddings.flatten(1)
return embeddings
def fast_pos_embed_interpolate(self, grid_thw):
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
device = self.pos_embed.weight.device
dtype = self.pos_embed.weight.dtype
idx_parts = [[] for _ in range(4)]
weight_parts = [[] for _ in range(4)]
for t, h, w in zip(grid_ts, grid_hs, grid_ws):
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for i in range(4):
idx_parts[i].append(indices[i])
weight_parts[i].append(weights[i])
idx_tensor = torch.stack([torch.cat(parts) for parts in idx_parts]).to(dtype=torch.long)
weight_tensor = torch.stack([torch.cat(parts) for parts in weight_parts]).to(dtype=dtype)
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
patch_pos_embeds_permute = []
merge_size = self.config.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
def forward(
self,
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
**kwargs
) -> torch.Tensor:
"""
Args:
hidden_states: input tensor
grid_thw: [num_images, 3] tensor with (t, h, w) for each image
Returns:
hidden_states: [num_tokens, out_hidden_size] - packed hidden states
"""
hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
# DeepStack: 收集不同层的视觉特征
deepstack_features = []
for layer_idx, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
**kwargs,
)
# 如果当前层在 deepstack 索引中,保存特征
if layer_idx in self.deepstack_visual_indexes:
deepstack_features.append(hidden_states)
# Merger: 从 hidden_size * (1 + num_deepstack) 映射到 out_hidden_size
hidden_states = self.merger(hidden_states, deepstack_features)
return hidden_states
@auto_docstring(
custom_intro="""
The MossVL Text Model with self-attention and cross-attention layers for vision-language interaction.
"""
)
class MossVLTextModel(MossVLPreTrainedModel):
config: MossVLTextConfig
_no_split_modules = ["MossVLSelfAttentionDecoderLayer", "MossVLCrossAttentionDecoderLayer"]
def __init__(self, config: MossVLTextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# Store cross_attention_layers for use in forward pass
self.cross_attention_layers = config.cross_attention_layers
# Create layers: self-attention or cross-attention at specified indices
self.layers = nn.ModuleList()
for layer_idx in range(config.num_hidden_layers):
if layer_idx in config.cross_attention_layers:
# Cross attention layer
self.layers.append(
MossVLCrossAttentionDecoderLayer(config, layer_idx)
)
else:
# Self attention layer
self.layers.append(
MossVLSelfAttentionDecoderLayer(config, layer_idx)
)
self.norm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = MossVLTextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.post_init()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
vision_position_ids: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
"""
Args:
full_text_row_masked_out_mask (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
Mask for full text rows that should be masked out in attention computation.
cross_attention_states (`torch.Tensor`, *optional*):
Vision features to be used in cross-attention layers. Shape: `(batch_size, vision_seq_len, hidden_size)`.
cross_attention_mask (`torch.Tensor`, *optional*):
Attention mask for cross-attention between text and vision. Shape: `(batch_size, 1, text_seq_len, vision_seq_len)`.
vision_position_ids (`torch.LongTensor`, *optional*):
Position IDs for vision tokens used in cross-attention. Shape: `(batch_size, vision_seq_len)`.
cache_position (`torch.LongTensor`, *optional*):
Absolute cache positions for the current text tokens during incremental decoding.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache(config=self.config)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
attention_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
# Compute text position embeddings (for self-attention and cross-attention query)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Compute vision position embeddings (for cross-attention key/value) if needed
vision_position_embeddings = None
if cross_attention_states is not None:
if vision_position_ids is not None:
vision_position_embeddings = self.rotary_emb(cross_attention_states, vision_position_ids)
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states += (hidden_states,)
for idx, decoder_layer in enumerate(self.layers):
# For text-only path we should skip cross attention layers.
# Let's check if the layer is cross attention layer and if we have cross attention states
# or cached cross attention states.
is_cross_attention_layer = idx in self.cross_attention_layers
is_cross_attention_cache_empty = past_key_values is None or (
past_key_values is not None and past_key_values.get_seq_length(idx) == 0
)
if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
continue
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
vision_position_ids=vision_position_ids,
vision_position_embeddings=vision_position_embeddings,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states[:-1] + (hidden_states,)
if not return_dict:
outputs = (hidden_states, past_key_values)
if output_hidden_states:
outputs += (all_hidden_states,)
if output_attentions:
outputs += (all_attentions,)
return outputs
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
@auto_docstring(
custom_intro="""
The MossVL model which consists of a vision encoder (from Qwen3VL) and a language model with cross-attention layers.
"""
)
class MossVLModel(MossVLPreTrainedModel):
base_model_prefix = ""
config: MossVLConfig
_no_split_modules = ["MossVLSelfAttentionDecoderLayer", "MossVLCrossAttentionDecoderLayer", "MossVLVisionBlock"]
_checkpoint_conversion_mapping = {}
accepts_loss_kwargs = False
def __init__(self, config):
super().__init__(config)
self.visual = MossVLVisionModel._from_config(config.vision_config)
self.language_model = MossVLTextModel._from_config(config.text_config)
# Learnable Separator Token: inserted after each image/frame's vision tokens
# Initialized from LLM's separator_token_init_id embedding
self.separator_token = nn.Parameter(
torch.zeros(config.vision_config.out_hidden_size)
)
self.post_init()
def convert_packed_to_batch(
self,
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
media_nums_per_sample: Optional[List[int]],
) -> Tuple[torch.Tensor, List[dict]]:
"""
Convert packed vision tokens to batched format with separator tokens.
For each image: inserts 1 separator token after the vision tokens.
For each video: inserts 1 separator token after EACH frame's vision tokens.
Note: media_nums_per_sample counts each video as 1 media item,
but each frame in a video gets its own separator token.
"""
# Calculate number of tokens per media after spatial merge
tokens_per_media = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]) // (self.visual.spatial_merge_size ** 2)
hidden_size = hidden_states.shape[-1]
# If media_nums_per_sample is not provided, assume batch size = 1
if media_nums_per_sample is None:
batch_size = 1
media_nums_per_sample = [grid_thw.shape[0]]
else:
batch_size = len(media_nums_per_sample)
# Optimization for batch_size = 1 (common in inference)
if batch_size == 1:
# 1. Calculate total length (pure math, fast)
total_len = 0
for i in range(grid_thw.shape[0]):
num_tokens = tokens_per_media[i].item()
num_frames = grid_thw[i, 0].item()
total_len += num_tokens + num_frames # + separators
# 2. Handle Padding alignment
pad_multiple = self.config.vision_seq_pad_multiple
if total_len % pad_multiple != 0:
max_seq_len = (total_len + pad_multiple - 1) // pad_multiple * pad_multiple
else:
max_seq_len = total_len
# 3. Pre-allocate final tensor
batched_hidden_states = torch.zeros(
1, max_seq_len, hidden_size,
dtype=hidden_states.dtype,
device=hidden_states.device
)
# 4. Vectorized fill
sample_info = {
'medias': [],
'total_length': total_len,
'pad_start': total_len,
'pad_end': max_seq_len
}
token_offset = 0
current_seq_len = 0
separator_embedding = self.separator_token.to(hidden_states.dtype)
# Iterate through all medias in this single sample
for media_idx in range(grid_thw.shape[0]):
num_tokens = tokens_per_media[media_idx].item()
t, h, w = grid_thw[media_idx].tolist()
num_frames = t
tokens_per_frame = num_tokens // num_frames
# --- Vectorized processing start ---
# Extract vision tokens: (num_tokens, hidden)
media_vision_tokens = hidden_states[token_offset : token_offset + num_tokens]
# Reshape to (num_frames, tokens_per_frame, hidden)
media_vision_tokens = media_vision_tokens.view(num_frames, tokens_per_frame, hidden_size)
# Directly write to destination without creating intermediate large tensors
chunk_len = num_frames * (tokens_per_frame + 1)
# Get view of the target area: (num_frames, tokens_per_frame + 1, hidden)
target_view = batched_hidden_states[0, current_seq_len : current_seq_len + chunk_len]
target_view = target_view.view(num_frames, tokens_per_frame + 1, hidden_size)
# 1. Fill vision tokens
target_view[:, :tokens_per_frame].copy_(media_vision_tokens)
# 2. Fill separators (Broadcast assignment)
# separator_embedding is (hidden,), automatically broadcasts to (num_frames, hidden)
target_view[:, tokens_per_frame] = separator_embedding
# --- Vectorized processing end ---
sample_info['medias'].append({
'start': current_seq_len,
'end': current_seq_len + chunk_len,
'length': chunk_len,
'num_frames': num_frames,
'grid_h': h,
'grid_w': w,
'vision_tokens_per_frame': tokens_per_frame,
'has_separator': True,
})
current_seq_len += chunk_len
token_offset += num_tokens
vision_token_info = [sample_info]
return batched_hidden_states, vision_token_info
# Calculate tokens per sample including separator tokens
# For images: +1 separator per image
# For videos: +num_frames separators per video (one after each frame)
tokens_per_sample = []
media_idx = 0
for num_medias_in_sample in media_nums_per_sample:
sample_tokens = 0
for i in range(num_medias_in_sample):
num_tokens = tokens_per_media[media_idx + i].item()
num_frames = grid_thw[media_idx + i, 0].item()
sample_tokens += num_tokens + num_frames # +num_frames separator tokens
tokens_per_sample.append(sample_tokens)
media_idx += num_medias_in_sample
max_seq_len = max(tokens_per_sample)
pad_multiple = self.config.vision_seq_pad_multiple
if max_seq_len % pad_multiple != 0:
max_seq_len = (max_seq_len + pad_multiple - 1) // pad_multiple * pad_multiple
# Initialize batched output with zeros (for padding)
batched_hidden_states = torch.zeros(
batch_size, max_seq_len, hidden_size,
dtype=hidden_states.dtype,
device=hidden_states.device
)
# Get separator token (learnable parameter)
separator_embedding = self.separator_token.to(hidden_states.dtype)
# Track token positions for each sample
vision_token_info = []
# Split packed tensor and fill batched output
token_offset = 0
media_idx = 0
for sample_idx, num_medias_in_sample in enumerate(media_nums_per_sample):
sample_info = {
'medias': [], # List of dicts for each media in this sample
'total_length': tokens_per_sample[sample_idx],
'pad_start': tokens_per_sample[sample_idx],
'pad_end': max_seq_len
}
seq_offset = 0 # Offset within this sample's sequence
# Process each image/video in this sample
for _ in range(num_medias_in_sample):
num_tokens = tokens_per_media[media_idx].item()
t, h, w = grid_thw[media_idx].tolist()
num_frames = t
tokens_per_frame = num_tokens // num_frames
# Record start position for this media
media_start = seq_offset
# Vectorized handling of frames
# Extract vision tokens for this media: (num_tokens, hidden)
media_vision_tokens = hidden_states[token_offset : token_offset + num_tokens]
# Reshape to (num_frames, tokens_per_frame, hidden)
media_vision_tokens = media_vision_tokens.view(num_frames, tokens_per_frame, hidden_size)
# Create separators: (num_frames, 1, hidden)
separators = separator_embedding.view(1, 1, hidden_size).expand(num_frames, 1, hidden_size)
# Concatenate: (num_frames, tokens_per_frame + 1, hidden)
media_tokens_with_sep = torch.cat([media_vision_tokens, separators], dim=1)
# Flatten: (num_frames * (tokens_per_frame + 1), hidden)
media_tokens_with_sep = media_tokens_with_sep.view(-1, hidden_size)
# Assign to batched tensor
media_length_with_sep = media_tokens_with_sep.shape[0]
batched_hidden_states[sample_idx, seq_offset : seq_offset + media_length_with_sep] = media_tokens_with_sep
seq_offset += media_length_with_sep
# Total tokens for this media = vision_tokens + num_separators
media_length = num_tokens + num_frames
# Record this image/video's position within the sample
# Note: length now includes separator tokens
sample_info['medias'].append({
'start': media_start,
'end': media_start + media_length,
'length': media_length,
'num_frames': num_frames, # 1 for image, >1 for video
'grid_h': h,
'grid_w': w,
'vision_tokens_per_frame': tokens_per_frame, # Actual vision tokens per frame (excluding separator)
'has_separator': True, # Flag indicating separator tokens are included
})
token_offset += num_tokens
media_idx += 1
vision_token_info.append(sample_info)
return batched_hidden_states, vision_token_info
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def _expand_cross_attention_mask(
self,
cross_attention_mask: torch.Tensor,
vision_token_info: List[dict],
target_dtype: torch.dtype,
) -> torch.Tensor:
"""
Expand cross_attention_mask from (B, 1, T, N_frames) to (B, 1, T, N_tokens).
Args:
cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, num_frames)`):
Coarse attention mask where each frame corresponds to one column.
Can be bool (True=masked) or float (min_value=masked).
vision_token_info (`List[dict]`):
Precomputed token info that includes actual token counts after ViT padding.
Must be provided (either from prefill computation or from cache).
Each dict contains 'medias' list with 'length', 'num_frames', and 'vision_tokens_per_frame'.
target_dtype (`torch.dtype`):
Target dtype for the output mask (typically inputs_embeds.dtype).
Returns:
`torch.Tensor` of shape `(batch_size, 1, text_seq_len, total_vision_tokens)`:
Fine-grained attention mask where each vision token has its own column.
Masked positions have min_value, unmasked positions have 0.0.
Note:
- vision_token_info contains the actual token counts after ViT padding (pad to multiple of 8)
- Separator tokens are treated as part of the same frame, sharing the same mask
"""
if vision_token_info is None:
raise ValueError(
"vision_token_info must be provided to _expand_cross_attention_mask. "
"This should be cached from prefill stage or computed during current forward pass."
)
batch_size = cross_attention_mask.shape[0]
# Determine target vision length (should be consistent across batch, but take max to be safe)
max_vision_len = 0
if vision_token_info:
max_vision_len = max([info.get('pad_end', 0) for info in vision_token_info])
if max_vision_len == 0:
return None
# Convert bool mask to float mask if needed
if cross_attention_mask.dtype == torch.bool:
# True = masked, False = visible
# Convert to float: True -> min_value, False -> 0.0
min_value = torch.finfo(target_dtype).min
float_mask = torch.zeros_like(cross_attention_mask, dtype=target_dtype)
float_mask.masked_fill_(cross_attention_mask, min_value)
cross_attention_mask = float_mask
else:
# Already float, ensure it's the right dtype
cross_attention_mask = cross_attention_mask.to(dtype=target_dtype)
# Pre-allocate final mask with min_dtype (masked)
# This is memory efficient and handles padding automatically
min_dtype = torch.finfo(target_dtype).min
final_mask = torch.full(
(batch_size, 1, cross_attention_mask.shape[2], max_vision_len),
min_dtype,
dtype=target_dtype,
device=cross_attention_mask.device
)
for i in range(batch_size):
medias = vision_token_info[i]['medias']
if not medias:
continue
# Collect repetition counts for all frames in this sample
repeats_parts = []
for media in medias:
num_frames = media.get('num_frames', 1)
length = media['length']
has_separator = media.get('has_separator', False)
# Determine tokens per frame (including separator)
if has_separator:
vision_tokens_per_frame = media.get('vision_tokens_per_frame', (length // num_frames) - 1)
tokens_per_frame_with_sep = vision_tokens_per_frame + 1
else:
tokens_per_frame_with_sep = length // num_frames
# In convert_packed_to_batch we enforce strictly regular frames
# so we can assume all frames have the same number of tokens
repeats_parts.append(
torch.full(
(num_frames,),
tokens_per_frame_with_sep,
dtype=torch.long,
device=cross_attention_mask.device,
)
)
num_valid_frames = sum(part.numel() for part in repeats_parts)
if num_valid_frames == 0:
continue
# If cross_attention_mask has more frames (e.g. padded), slice it
# If it has fewer (shouldn't happen), slice repeats
valid_mask_frames = min(num_valid_frames, cross_attention_mask.shape[-1])
repeats_tensor = torch.cat(repeats_parts)
if valid_mask_frames < num_valid_frames:
repeats_tensor = repeats_tensor[:valid_mask_frames]
# Extract valid columns for this sample
# (1, text_len, valid_mask_frames)
source_mask = cross_attention_mask[i, :, :, :valid_mask_frames]
# Expand using repeat_interleave
# output shape: (1, text_len, sum(repeats))
expanded_mask = source_mask.repeat_interleave(repeats_tensor, dim=-1)
# Assign to final_mask
num_tokens = expanded_mask.shape[-1]
if num_tokens > max_vision_len:
num_tokens = max_vision_len
expanded_mask = expanded_mask[..., :num_tokens]
final_mask[i, :, :, :num_tokens] = expanded_mask
return final_mask
def compute_position_ids(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
rope_deltas: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
Compute 3D position IDs for text tokens with special handling for image tokens.
Rules:
- Regular text tokens: increment position (x, x, x) -> (x+1, x+1, x+1)
- Image token: gets (t, t, t) where t = previous_text_position + 1
- After processing vision tokens, next text token starts at max(vision_bottom_right) + 1
In decode stage, uses cached rope_deltas to quickly compute new positions.
Args:
input_ids: (batch_size, seq_len)
attention_mask: (batch_size, seq_len), optional
past_key_values: cache object used to infer decode offset from the current text cache length
Returns:
position_ids: (3, batch_size, seq_len)
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
image_token_id = self.config.image_token_id
# Decode stage: always advance positions from the current text cache length.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
if past_seen_tokens > 0:
position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
position_ids = position_ids + past_seen_tokens
if rope_deltas is not None:
position_ids = position_ids + rope_deltas.unsqueeze(1)
return position_ids.unsqueeze(0).expand(3, -1, -1)
# Prefill stage: compute full position_ids with image token awareness
# Vectorized implementation
# 1. Identify token types
is_image_token = (input_ids == image_token_id)
if attention_mask is not None:
is_padding = (attention_mask == 0)
else:
is_padding = torch.zeros_like(input_ids, dtype=torch.bool)
is_regular_token = ~(is_image_token | is_padding)
# 2. Calculate position increments
# Regular tokens increment position by 1
# Image tokens do not increment position (they reuse the "current" position counter)
# Padding tokens do not increment
# cumulative sum of regular tokens gives the position index
# We want 0-based index for the first regular token
# cumsum: [1, 2, 2, 3] -> positions: [0, 1, 2, 2]
# For image token at index i, we want count of regular tokens before i.
# This is exactly (cumsum - 1) if the token itself is regular? No.
# Let's use the logic: position[i] = sum(is_regular[:i])
# We can achieve this by cumsum(is_regular) - is_regular
cumulative_regular = is_regular_token.long().cumsum(dim=1)
# For regular token: position = cumsum - 1 (since it's inclusive) => 0, 1, 2...
# For image token: position = cumsum (since it's not included in cumsum, cumsum is count of prev regulars)
# Wait, if is_regular[i] is 0, cumsum[i] == cumsum[i-1].
# So for image token, position = cumsum[i] is correct.
# For regular token, position = cumsum[i] - 1 is correct.
# Combine: position = cumsum - is_regular.long()
base_position_ids = cumulative_regular - is_regular_token.long()
# Apply padding mask (set padding positions to 0)
base_position_ids = base_position_ids.masked_fill(is_padding, 0)
# Expand to 3D: (3, batch, seq_len)
position_ids = base_position_ids.unsqueeze(0).expand(3, -1, -1).clone()
return position_ids
def compute_vision_position_ids(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
vision_token_info: List[dict],
cross_attention_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute 3D position IDs for vision tokens (including separator tokens) and update text position_ids.
Vectorized implementation for improved efficiency.
Position encoding rules:
- For text: if not image token, increment position (t-1, t-1, t-1) -> (t, t, t) -> ...
- For vision: top-left is (t, t, t), increases towards bottom-right to (t, t+h-1, t+w-1)
- Separator Token after each frame: (x, x, x) where x = max(t+h-1, t+w-1) + 1 = max(t+h, t+w)
- Image token in text: also gets position (x, x, x) - same as separator
- Next text token after image: starts at (x+1, x+1, x+1)
Args:
input_ids: (batch_size, seq_len)
position_ids: (3, batch_size, seq_len) - will be updated in place
vision_token_info: metadata about vision tokens (now includes separator positions)
cross_attention_states: (batch_size, max_vision_seq_len, hidden_size)
attention_mask: (batch_size, seq_len), optional
Returns:
vision_pos_ids: (3, batch_size, max_vision_seq_len)
position_ids: (3, batch_size, seq_len) - updated
rope_deltas: (batch_size,) - position offset due to vision tokens
"""
batch_size, max_vision_seq_len, _ = cross_attention_states.shape
device = cross_attention_states.device
image_token_id = self.config.image_token_id
merge_size = self.visual.spatial_merge_size
# 1. Gather all frame metadata
# We need to flatten the nested vision_token_info structure to align with image tokens in input_ids
# Find all image tokens in text: (num_occurrences, 2) -> [batch_idx, seq_idx]
image_token_indices = (input_ids == image_token_id).nonzero()
# Flatten vision_token_info to parallel lists
# We assume the order of medias in vision_token_info matches the appearance of image tokens in input_ids
flat_eff_h_parts = []
flat_eff_w_parts = []
flat_vis_start_parts = []
# Processing metadata on CPU (fast enough for typical batch sizes)
for b_idx, info in enumerate(vision_token_info):
medias = info.get('medias', [])
for media in medias:
num_frames = media['num_frames']
h, w = media['grid_h'], media['grid_w']
eh, ew = h // merge_size, w // merge_size
start = media['start']
tok_per_frame = media['vision_tokens_per_frame']
stride = tok_per_frame + 1 # +1 for separator
frame_offsets = start + torch.arange(num_frames, device=device, dtype=torch.long) * stride
flat_vis_start_parts.append(frame_offsets)
flat_eff_h_parts.append(torch.full((num_frames,), eh, device=device, dtype=torch.long))
flat_eff_w_parts.append(torch.full((num_frames,), ew, device=device, dtype=torch.long))
# Pre-allocate output
vision_pos_ids = torch.zeros(
(3, batch_size, max_vision_seq_len),
dtype=torch.long,
device=device
)
# Handle case where no image tokens or info
if len(flat_eff_h_parts) == 0 or len(image_token_indices) == 0:
rope_deltas = position_ids.max(dim=0).values.max(dim=-1).values + 1 - input_ids.shape[1]
return vision_pos_ids, position_ids, rope_deltas
flat_eff_h = torch.cat(flat_eff_h_parts)
flat_eff_w = torch.cat(flat_eff_w_parts)
flat_vis_starts = torch.cat(flat_vis_start_parts)
# Align lengths (handle truncation if text has fewer tokens or vice versa)
num_matches = min(flat_eff_h.shape[0], image_token_indices.shape[0])
flat_eff_h = flat_eff_h[:num_matches]
flat_eff_w = flat_eff_w[:num_matches]
flat_vis_starts = flat_vis_starts[:num_matches]
# Get corresponding text positions
target_indices = image_token_indices[:num_matches]
batch_rows = target_indices[:, 0]
text_cols = target_indices[:, 1]
# 2. Compute Shifts and Update Position IDs
# Calculate max dimensions for each image token: separator_pos = t + max(h, w)
# Shift amount for subsequent tokens = max(h, w) + 1
max_hw = torch.maximum(flat_eff_h, flat_eff_w)
shifts = max_hw + 1
# Create a shift map to apply cumulative shifts
shift_map = torch.zeros((batch_size, input_ids.shape[1]), dtype=torch.long, device=device)
shift_map[batch_rows, text_cols] = shifts
# Calculate cumulative shifts along sequence
cum_shifts = shift_map.cumsum(dim=1)
# Calculate t_vals (start position for each vision grid)
# t_val = original_pos + shifts_before_this_image
# cum_shifts at image index includes the image's own shift, so we subtract it
orig_pos = position_ids[0, batch_rows, text_cols]
shifts_before = cum_shifts[batch_rows, text_cols] - shifts
t_vals = orig_pos + shifts_before
# Update text position_ids
# All tokens get shifted by cum_shifts
# Image tokens specifically need to be at t_val + max_hw (which is t_val + shift - 1)
# Our cum_shift update gives: orig_pos + shifts_before + shift = t_val + shift
# So we subtract 1 from image tokens
# Apply global shift
# Note: position_ids is (3, B, L), cum_shifts is (B, L). Expand to match.
new_pos_ids = position_ids + cum_shifts.unsqueeze(0)
# Correct image tokens (subtract 1)
# We can use boolean mask for efficient update
img_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
img_token_mask[batch_rows, text_cols] = True
new_pos_ids[:, img_token_mask] -= 1
# Ensure padding positions remain 0 (if attention_mask provided)
if attention_mask is not None:
# Assuming padding is 0 in attention_mask
padding_mask = (attention_mask == 0).unsqueeze(0)
new_pos_ids.masked_fill_(padding_mask, 0)
# Update position_ids in-place
position_ids.copy_(new_pos_ids)
# 3. Populate Vision Pos IDs
# Group frames by size (eff_h, eff_w) to vectorize grid generation
# This is efficient because typically there are few distinct aspect ratios
unique_shapes = torch.unique(torch.stack([flat_eff_h, flat_eff_w], dim=1), dim=0)
for shape in unique_shapes:
eh, ew = shape[0].item(), shape[1].item()
# Mask for frames of this shape
mask = (flat_eff_h == eh) & (flat_eff_w == ew)
sub_t_vals = t_vals[mask]
sub_batch_rows = batch_rows[mask]
sub_vis_starts = flat_vis_starts[mask]
num_frames_sub = sub_t_vals.shape[0]
if num_frames_sub == 0: continue
# Generate grids: (num_frames, eh, ew)
# y ranges 0..eh-1, x ranges 0..ew-1
# positions: t + y, t + x
y_grid = torch.arange(eh, device=device).view(1, eh, 1).expand(num_frames_sub, -1, ew)
x_grid = torch.arange(ew, device=device).view(1, 1, ew).expand(num_frames_sub, eh, -1)
t_grid = sub_t_vals.view(-1, 1, 1).expand(-1, eh, ew)
h_grid = t_grid + y_grid
w_grid = t_grid + x_grid
# Flatten to assign
flat_t = t_grid.reshape(-1)
flat_h = h_grid.reshape(-1)
flat_w = w_grid.reshape(-1)
# Calculate destination indices in vision_pos_ids
# (batch, seq_pos)
tokens_per_frame = eh * ew
# Offsets for each token in the frame 0..N-1
seq_offsets = torch.arange(tokens_per_frame, device=device).unsqueeze(0)
# Add start index: (num_frames, 1) + (1, tokens) -> (num_frames, tokens)
abs_seq_offsets = seq_offsets + sub_vis_starts.unsqueeze(1)
flat_seq_inds = abs_seq_offsets.reshape(-1)
flat_batch_inds = sub_batch_rows.unsqueeze(1).expand(-1, tokens_per_frame).reshape(-1)
# Clip to max_vision_seq_len
valid_mask = flat_seq_inds < max_vision_seq_len
if valid_mask.any():
final_b = flat_batch_inds[valid_mask]
final_s = flat_seq_inds[valid_mask]
vision_pos_ids[0, final_b, final_s] = flat_t[valid_mask]
vision_pos_ids[1, final_b, final_s] = flat_h[valid_mask]
vision_pos_ids[2, final_b, final_s] = flat_w[valid_mask]
# 4. Handle Separator Tokens
# Position: t_val + max(eh, ew)
sep_vals = t_vals + max_hw
# Index: start + tokens_per_frame = start + eh*ew
sep_indices = flat_vis_starts + (flat_eff_h * flat_eff_w)
valid_sep_mask = sep_indices < max_vision_seq_len
if valid_sep_mask.any():
final_b = batch_rows[valid_sep_mask]
final_s = sep_indices[valid_sep_mask]
vals = sep_vals[valid_sep_mask]
vision_pos_ids[0, final_b, final_s] = vals
vision_pos_ids[1, final_b, final_s] = vals
vision_pos_ids[2, final_b, final_s] = vals
# 5. Compute Rope Deltas
# rope_deltas[batch_idx] = max_pos + 1 - seq_len
# Use updated position_ids
# Max pos in each batch - take max across all 3 position dimensions
# position_ids shape: (3, batch_size, seq_len)
# We need rope_deltas shape: (batch_size,)
max_pos = position_ids.max(dim=0).values.max(dim=-1).values # (batch_size,)
rope_deltas = max_pos + 1 - input_ids.shape[1] # (batch_size,)
return vision_pos_ids, position_ids, rope_deltas
def get_vision_features(
self,
pixel_values: torch.FloatTensor,
grid_thw: Optional[torch.LongTensor] = None,
media_nums_per_sample: Optional[List[int]] = None
):
"""
Args:
pixel_values: vision pixel values (images and videos merged)
grid_thw: [num_media, 3] tensor with (t, h, w) for each media item
media_nums_per_sample: List indicating how many media items each sample has
Returns:
vision_embeds: [batch_size, max_seq_len, hidden_size]
vision_token_info: List[Dict] with media positions and padding info for each sample
"""
pixel_values = pixel_values.type(self.visual.dtype)
hidden_states = self.visual(
pixel_values,
grid_thw=grid_thw
)
vision_embeds, vision_token_info = self.convert_packed_to_batch(
hidden_states,
grid_thw,
media_nums_per_sample
)
return vision_embeds, vision_token_info
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
grid_thw: Optional[torch.LongTensor] = None,
media_nums_per_sample: Optional[List[int]] = None,
vision_position_ids: Optional[torch.LongTensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
vision_token_info: Optional[List[dict]] = None,
rope_deltas: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
"""
Args:
grid_thw (`torch.LongTensor` of shape `(num_media, 3)`, *optional*):
Grid size for each media item in (temporal, height, width) format. Each row contains `[t, h, w]`
representing the number of temporal, height, and width patches for a media item (image or video).
media_nums_per_sample (`List[int]`, *optional*):
List indicating how many media items each sample in the batch has. For example, `[2, 1, 3]` means
the first sample has 2 media items, the second has 1, and the third has 3.
vision_position_ids (`torch.LongTensor` of shape `(batch_size, vision_seq_len)`, *optional*):
Position IDs for vision tokens used in cross-attention. These are computed from text position IDs
based on the positions of image/video tokens in the input text.
cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
Attention mask for cross-attention between text and vision. Controls which vision tokens each text
token can attend to, enforcing causal visibility for video frames.
vision_token_info (`List[dict]`, *optional*):
Cached metadata describing how packed vision tokens were regrouped per sample. Reused in decode
to expand frame-level cross-attention masks to token-level masks without recomputing vision features.
rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Cached offsets between text sequence length and multimodal RoPE positions. Reused in decode to
reconstruct text position ids from the current cache length.
"""
cache_position = kwargs.pop("cache_position", None)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# Process vision features (images and videos are already merged by processor)
cross_attention_states = None
if pixel_values is not None:
# Determine batch size
batch_size = inputs_embeds.shape[0]
# Get default media_nums_per_sample if not provided
if media_nums_per_sample is None:
# Assume all media belong to first sample if batch_size=1, otherwise raise error
if batch_size == 1:
media_nums_per_sample = [grid_thw.shape[0]]
else:
raise ValueError("media_nums_per_sample must be provided when batch_size > 1")
# Process all vision inputs together through VIT
# pixel_values and grid_thw are already ordered by appearance in text
vision_embeds, vision_token_info = self.get_vision_features(
pixel_values, grid_thw, media_nums_per_sample
)
# vision_embeds: [batch_size, max_seq_len, hidden_size]
cross_attention_states = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
# Generate 3D position IDs for text if not provided
if position_ids is None:
# Compute position IDs with image token awareness
# In decode stage, this uses cached rope_deltas for fast computation
position_ids = self.compute_position_ids(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
rope_deltas=rope_deltas,
)
# Compute cross_attention_mask, vision_position_ids, and full_text_row_masked_out_mask
full_text_row_masked_out_mask = None
if cross_attention_mask is not None:
# Expand mask from frame-level to token-level
# The processor outputs coarse masks (bool or float) where each frame has one column,
# we need to expand to fine-grained masks where each vision token has its own column
# This function also converts bool to float with correct min/max values
cross_attention_mask = self._expand_cross_attention_mask(
cross_attention_mask,
vision_token_info,
target_dtype=inputs_embeds.dtype
)
# Handle full_text_row_masked_out_mask logic
if cross_attention_mask is not None:
negative_inf_value = torch.finfo(cross_attention_mask.dtype).min
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
)
cross_attention_mask = cross_attention_mask * full_text_row_masked_out_mask
if vision_position_ids is None and cross_attention_states is not None and input_ids is not None:
vision_position_ids, position_ids, rope_deltas = self.compute_vision_position_ids(
input_ids,
position_ids,
vision_token_info,
cross_attention_states,
attention_mask
)
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
vision_position_ids=vision_position_ids,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
if not return_dict:
last_hidden_state = outputs[0]
model_outputs = (
last_hidden_state,
outputs[1] if len(outputs) > 1 else past_key_values,
)
if output_hidden_states:
model_outputs += (outputs[2],)
if output_attentions:
attn_idx = 3 if output_hidden_states else 2
model_outputs += (outputs[attn_idx],)
model_outputs += (vision_token_info, rope_deltas)
return model_outputs
return MossVLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
vision_token_info=vision_token_info,
rope_deltas=rope_deltas,
)
@auto_docstring(
custom_intro="""
The MossVL model with a language modeling head on top, for conditional generation tasks.
Combines Qwen3VL vision encoder with LLM via cross-attention layers.
"""
)
class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
# transformers 5.x expects a dict[target, source]; MossVL does not tie
# lm_head to the embeddings (config.tie_word_embeddings is False), so the
# mapping is empty. The legacy list format ["lm_head.weight"] breaks
# save_pretrained in transformers>=5.
_tied_weights_keys: dict[str, str] = {}
config: MossVLConfig
_checkpoint_conversion_mapping = {}
accepts_loss_kwargs = False
def __init__(self, config):
super().__init__(config)
self.model = MossVLModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
def get_vision_features(
self,
pixel_values: torch.FloatTensor,
grid_thw: Optional[torch.LongTensor] = None,
media_nums_per_sample: Optional[List[int]] = None
):
"""
Get vision features for images and videos (merged).
Args:
pixel_values: vision pixel values (images and videos merged)
grid_thw: [num_media, 3] tensor with (t, h, w) for each media item
media_nums_per_sample: List indicating how many media items each sample has
Returns:
vision_embeds: [batch_size, max_seq_len, hidden_size]
vision_token_info: List[Dict] with media positions and padding info for each sample
"""
return self.model.get_vision_features(pixel_values, grid_thw, media_nums_per_sample)
@property
def language_model(self):
return self.model.language_model
@property
def visual(self):
return self.model.visual
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
grid_thw: Optional[torch.LongTensor] = None,
media_nums_per_sample: Optional[List[int]] = None,
vision_position_ids: Optional[torch.LongTensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
vision_token_info: Optional[List[dict]] = None,
rope_deltas: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
"""
Args:
grid_thw (`torch.LongTensor` of shape `(num_media, 3)`, *optional*):
Grid size for each media item in (temporal, height, width) format. Each row contains `[t, h, w]`
representing the number of temporal, height, and width patches for a media item (image or video).
media_nums_per_sample (`List[int]`, *optional*):
List indicating how many media items each sample in the batch has. For example, `[2, 1, 3]` means
the first sample has 2 media items, the second has 1, and the third has 3.
vision_position_ids (`torch.LongTensor` of shape `(batch_size, vision_seq_len)`, *optional*):
Position IDs for vision tokens used in cross-attention. These are computed from text position IDs
based on the positions of image/video tokens in the input text.
cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
Attention mask for cross-attention between text and vision. Controls which vision tokens each text
token can attend to, enforcing causal visibility for video frames.
vision_token_info (`List[dict]`, *optional*):
Cached metadata describing how packed vision tokens were regrouped per sample. Reused across decode
steps to expand cross-attention masks without re-running the vision encoder.
rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Cached multimodal RoPE offsets returned by the base model during prefill and reused during decode.
"""
cache_position = kwargs.pop("cache_position", None)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
grid_thw=grid_thw,
media_nums_per_sample=media_nums_per_sample,
position_ids=position_ids,
attention_mask=attention_mask,
vision_position_ids=vision_position_ids,
cross_attention_mask=cross_attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_token_info=vision_token_info,
rope_deltas=rope_deltas,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
if not return_dict:
output = (logits,)
output += outputs[1:]
return ((loss,) + output) if loss is not None else output
return MossVLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
vision_token_info=outputs.vision_token_info,
rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
position_ids=None,
use_cache=True,
pixel_values=None,
grid_thw=None,
media_nums_per_sample=None, # One video is one meida.
vision_position_ids=None,
vision_token_info=None,
rope_deltas=None,
cross_attention_mask=None,
**kwargs,
):
"""
Prepare inputs for generation.
Note: Currently only supports offline visual understanding, meaning all multimodal
content must be provided before generation starts. We don't support adding new
images/videos during generation (streaming mode).
Args:
media_nums_per_sample: One video counts as one media item (regardless of frame count)
"""
kwargs.pop("cache_position", None)
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
pixel_values=pixel_values,
grid_thw=grid_thw,
media_nums_per_sample=media_nums_per_sample,
use_cache=use_cache,
**kwargs,
)
model_input = model_inputs.get("input_ids")
if model_input is None:
model_input = model_inputs.get("inputs_embeds")
current_length = model_input.shape[1]
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# Let the model recompute multimodal position ids from the current cache length.
model_inputs["position_ids"] = None
model_inputs["vision_token_info"] = vision_token_info
model_inputs["rope_deltas"] = rope_deltas
# Handle cross attention mask
if cross_attention_mask is not None:
# Slice to the current text slice on text dimension (dim=2).
# Shape: [batch, 1, text_len, vision_len] -> [batch, 1, current_len, vision_len]
cross_attention_mask = cross_attention_mask[:, :, -current_length:, :]
model_inputs["cross_attention_mask"] = cross_attention_mask
# Vision inputs are only needed in prefill stage.
# In decode stage, vision features are retrieved from cross attention cache
if past_seen_tokens > 0:
model_inputs["pixel_values"] = None
model_inputs["grid_thw"] = None
model_inputs["media_nums_per_sample"] = None
model_inputs["vision_position_ids"] = None
else:
# In prefill stage, include all vision-related inputs
model_inputs["vision_position_ids"] = vision_position_ids
return model_inputs
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
"""
Update model kwargs for generation, extending cross_attention_mask for the newly generated token.
In offline mode (all multimodal content provided before generation):
- Each newly generated token should have the same cross_attention_mask pattern as the previous token
- This ensures all generated tokens can attend to all vision tokens that were visible before
"""
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
if cross_attention_mask_prev is not None:
model_kwargs["cross_attention_mask"] = cross_attention_mask_prev
if getattr(outputs, "vision_token_info", None) is not None:
model_kwargs["vision_token_info"] = outputs.vision_token_info
if getattr(outputs, "rope_deltas", None) is not None:
model_kwargs["rope_deltas"] = outputs.rope_deltas
return model_kwargs
__all__ = [
"MossVLVisionModel",
"MossVLForConditionalGeneration",
"MossVLModel",
"MossVLPreTrainedModel",
"MossVLTextModel",
]