| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention""" |
|
|
| from dataclasses import dataclass |
| from typing import Any, Callable, Optional, Union, Tuple, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import initialization as init |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.integrations import use_kernel_forward_from_hub |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging |
| from transformers.utils.deprecation import deprecate_kwarg |
| from transformers.utils.generic import is_flash_attention_requested |
| from transformers.utils.output_capturing import OutputRecorder |
|
|
| from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionConfig |
|
|
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class MossVLModelOutputWithPast(ModelOutput): |
| """ |
| Output class for MossVL model with additional vision_token_info and rope_deltas fields. |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| past_key_values (`Cache`, *optional*): |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and |
| cross-attention blocks) that can be used to speed up sequential decoding. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer). |
| attentions (`tuple(torch.FloatTensor)`, *optional*): |
| Tuple of `torch.FloatTensor` (one for each layer) of attention weights. |
| vision_token_info (`List[dict]`, *optional*): |
| Information about vision tokens for each sample, used to correctly expand cross-attention masks. |
| This is cached during prefill and reused during decode to handle ViT padding correctly. |
| rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Position offset due to vision tokens. Used for fast position computation in decode stage. |
| rope_deltas = max_position - sequence_length |
| """ |
| |
| last_hidden_state: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| attentions: Optional[tuple[torch.FloatTensor]] = None |
| vision_token_info: Optional[List[dict]] = None |
| rope_deltas: Optional[torch.LongTensor] = None |
|
|
|
|
| @dataclass |
| class MossVLCausalLMOutputWithPast(ModelOutput): |
| """ |
| Output class for MossVL causal language model with additional vision_token_info and rope_deltas fields. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head. |
| past_key_values (`Cache`, *optional*): |
| Contains pre-computed hidden-states for speed up sequential decoding. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*): |
| Tuple of hidden-states at each layer. |
| attentions (`tuple(torch.FloatTensor)`, *optional*): |
| Tuple of attention weights. |
| vision_token_info (`List[dict]`, *optional*): |
| Information about vision tokens for each sample, cached for decode stage. |
| rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Position offset due to vision tokens. Used for fast position computation in decode stage. |
| """ |
| |
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| attentions: Optional[tuple[torch.FloatTensor]] = None |
| vision_token_info: Optional[List[dict]] = None |
| rope_deltas: Optional[torch.LongTensor] = None |
|
|
|
|
| |
|
|
| class MossVLVisionMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) |
| self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_state): |
| return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) |
|
|
|
|
| class MossVLVisionPatchEmbed(nn.Module): |
| def __init__(self, config) -> None: |
| super().__init__() |
| self.patch_size = config.patch_size |
| self.temporal_patch_size = config.temporal_patch_size |
| self.in_channels = config.in_channels |
| self.embed_dim = config.hidden_size |
|
|
| kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] |
| self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| target_dtype = self.proj.weight.dtype |
| hidden_states = hidden_states.view( |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size |
| ) |
| hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) |
| return hidden_states |
|
|
|
|
| class MossVLVisionRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, dim: int, theta: float = 10000.0) -> None: |
| super().__init__() |
| |
| |
| |
| self.dim = dim |
| self.theta = theta |
| inv_freq = self.compute_inv_freq() |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def compute_inv_freq(self) -> torch.Tensor: |
| return 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)) |
|
|
| def forward(self, seqlen: int) -> torch.Tensor: |
| seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(seq, self.inv_freq) |
| return freqs |
|
|
|
|
| class MossVLVisionPatchMerger(nn.Module): |
| def __init__(self, config: MossVLVisionConfig, num_deepstack_features=0) -> None: |
| super().__init__() |
| |
| base_hidden_size = config.hidden_size * (config.spatial_merge_size**2) |
| |
| self.input_hidden_size = base_hidden_size * (1 + num_deepstack_features) |
| |
| |
| |
| num_features = 1 + num_deepstack_features |
| self.norms = nn.ModuleList([ |
| nn.LayerNorm(config.hidden_size, eps=1e-6) |
| for _ in range(num_features) |
| ]) |
| |
| self.hidden_size = config.hidden_size |
| |
| self.linear_fc1 = nn.Linear(self.input_hidden_size, self.input_hidden_size) |
| self.act_fn = nn.GELU() |
| self.linear_fc2 = nn.Linear(self.input_hidden_size, config.out_hidden_size) |
|
|
| def forward( |
| self, |
| last_hidden_state: torch.Tensor, |
| deepstack_features: Optional[List[torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| |
| |
| |
| if deepstack_features is None: |
| deepstack_features = [] |
| all_inputs = [last_hidden_state] + deepstack_features |
| |
| |
| outs = [] |
| for i, feat in enumerate(all_inputs): |
| outs.append(self.norms[i](feat)) |
| |
| |
| x = torch.cat(outs, dim=-1) |
|
|
| |
| x = x.view(-1, self.input_hidden_size) |
| x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) |
| return x |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb_vision( |
| q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| orig_q_dtype = q.dtype |
| orig_k_dtype = k.dtype |
| q, k = q.float(), k.float() |
| cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| q_embed = q_embed.to(orig_q_dtype) |
| k_embed = k_embed.to(orig_k_dtype) |
| return q_embed, k_embed |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class MossVLVisionAttention(nn.Module): |
| def __init__(self, config: MossVLVisionConfig) -> None: |
| super().__init__() |
| self.dim = config.hidden_size |
| self.num_heads = config.num_heads |
| self.head_dim = self.dim // self.num_heads |
| self.num_key_value_groups = 1 |
| self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) |
| self.proj = nn.Linear(self.dim, self.dim) |
| self.scaling = self.head_dim**-0.5 |
| self.config = config |
| self.attention_dropout = 0.0 |
| self.is_causal = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| rotary_pos_emb: Optional[torch.Tensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| query_states, key_states, value_states = ( |
| self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| ) |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) |
|
|
| query_states = query_states.transpose(0, 1).unsqueeze(0) |
| key_states = key_states.transpose(0, 1).unsqueeze(0) |
| value_states = value_states.transpose(0, 1).unsqueeze(0) |
|
|
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( |
| self.config._attn_implementation, eager_attention_forward |
| ) |
|
|
| if is_flash_attention_requested(self.config): |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() |
| attn_output, _ = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask=None, |
| scaling=self.scaling, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| cu_seq_lens_q=cu_seqlens, |
| cu_seq_lens_k=cu_seqlens, |
| max_length_q=max_seqlen, |
| max_length_k=max_seqlen, |
| is_causal=False, |
| **kwargs, |
| ) |
| else: |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
| splits = [ |
| torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) |
| ] |
|
|
| attn_outputs = [ |
| attention_interface( |
| self, |
| q, |
| k, |
| v, |
| attention_mask=None, |
| scaling=self.scaling, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| is_causal=False, |
| **kwargs, |
| )[0] |
| for q, k, v in zip(*splits) |
| ] |
| attn_output = torch.cat(attn_outputs, dim=1) |
|
|
| attn_output = attn_output.reshape(seq_length, -1).contiguous() |
| attn_output = self.proj(attn_output) |
| return attn_output |
|
|
|
|
| class MossVLVisionBlock(GradientCheckpointingLayer): |
| def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
| super().__init__() |
| self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) |
| self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) |
| self.attn = MossVLVisionAttention(config=config) |
| self.mlp = MossVLVisionMLP(config=config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| rotary_pos_emb: Optional[torch.Tensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| hidden_states = hidden_states + self.attn( |
| self.norm1(hidden_states), |
| cu_seqlens=cu_seqlens, |
| rotary_pos_emb=rotary_pos_emb, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
| return hidden_states |
|
|
|
|
|
|
| |
|
|
| class MossVLTextRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: MossVLTextConfig, device=None): |
| super().__init__() |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| rope_parameters = getattr(config, "rope_parameters", None) |
| if rope_parameters is None: |
| rope_parameters = getattr(config, "rope_scaling", None) or {"rope_type": "default"} |
|
|
| self.rope_type = rope_parameters.get("rope_type", rope_parameters.get("type", "default")) |
| rope_init_fn: Callable = self.compute_default_rope_parameters |
| if self.rope_type != "default": |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) |
|
|
| self.mrope_section = rope_parameters.get("mrope_section", [24, 20, 20]) |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: Optional[MossVLTextConfig] = None, |
| device: Optional[torch.device] = None, |
| seq_len: Optional[int] = None, |
| ) -> tuple[torch.Tensor, float]: |
| rope_parameters = getattr(config, "rope_parameters", None) or {} |
| base = rope_parameters.get("rope_theta", getattr(config, "rope_theta", 10000.0)) |
| head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
| partial_rotary_factor = rope_parameters.get( |
| "partial_rotary_factor", getattr(config, "partial_rotary_factor", 1.0) |
| ) |
| dim = int(head_dim * partial_rotary_factor) |
|
|
| attention_factor = 1.0 |
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) |
| ) |
| return inv_freq, attention_factor |
|
|
| def apply_interleaved_mrope(self, freqs, mrope_section): |
| """Apply interleaved MRoPE to 3D rotary embeddings. |
| Reorganizes frequency layout from chunked [TTT...HHH...WWW] to |
| interleaved [THTHWHTHW...TT], preserving frequency continuity. |
| args: |
| x: (3, bs, seq_len, head_dim // 2) |
| mrope_section: (3,) |
| returns: |
| x_t: (bs, seq_len, head_dim // 2) |
| """ |
| freqs_t = freqs[0] |
| for dim, offset in enumerate((1, 2), start=1): |
| length = mrope_section[dim] * 3 |
| idx = slice(offset, length, 3) |
| freqs_t[..., idx] = freqs[dim, ..., idx] |
| return freqs_t |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| if position_ids.ndim == 2: |
| position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) |
| |
| inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) |
| position_ids_expanded = position_ids[:, :, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(2, 3) |
| freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class MossVLTextRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
| |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
| |
| def apply_rotary_pos_emb_cross_attention(states, cos, sin, position_ids=None, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| states_embed = (states * cos) + (rotate_half(states) * sin) |
| return states_embed |
|
|
|
|
| class MossVLTextSelfAttention(nn.Module): |
| """Self attention for text decoder""" |
|
|
| def __init__(self, config: MossVLTextConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
|
|
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| self.q_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = False, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_values is not None: |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) |
|
|
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( |
| self.config._attn_implementation, eager_attention_forward |
| ) |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class MossVLTextCrossAttention(nn.Module): |
| """Cross attention - for vision-text interaction""" |
|
|
| def __init__(self, config: MossVLTextConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_heads = config.num_attention_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = False |
|
|
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) |
|
|
| self.q_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = MossVLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cross_attention_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: bool = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| query_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| batch_size, seq_length, _ = hidden_states.size() |
| |
| |
| query_states = self.q_proj(hidden_states) |
| query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
| query_states = self.q_norm(query_states) |
|
|
| if cross_attention_states is not None: |
| |
| key_states = self.k_proj(cross_attention_states) |
| value_states = self.v_proj(cross_attention_states) |
| |
| key_states = key_states.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| key_states = self.k_norm(key_states) |
| value_states = value_states.view(batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| if query_position_embeddings is not None: |
| cos, sin = query_position_embeddings |
| query_states = apply_rotary_pos_emb_cross_attention(query_states, cos, sin) |
| |
| if vision_position_embeddings is not None: |
| vision_cos, vision_sin = vision_position_embeddings |
| key_states = apply_rotary_pos_emb_cross_attention(key_states, vision_cos, vision_sin) |
|
|
|
|
| if past_key_values is not None: |
| |
| |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) |
| |
| elif cache_position[0] != 0: |
| key_states, value_states = ( |
| past_key_values.layers[self.layer_idx].keys, |
| past_key_values.layers[self.layer_idx].values, |
| ) |
| else: |
| raise ValueError( |
| "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" |
| ) |
|
|
| if is_flash_attention_requested(self.config): |
| |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] |
| else: |
| attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( |
| self.config._attn_implementation, eager_attention_forward |
| ) |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class MossVLTextMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| class MossVLSelfAttentionDecoderLayer(GradientCheckpointingLayer): |
| """Self-attention decoder layer""" |
|
|
| def __init__(self, config: MossVLTextConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.layer_idx = layer_idx |
|
|
| self.self_attn = MossVLTextSelfAttention(config=config, layer_idx=layer_idx) |
| self.mlp = MossVLTextMLP(config) |
| self.input_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| cross_attention_states: Optional[torch.Tensor] = None, |
| cross_attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| vision_position_ids: Optional[torch.LongTensor] = None, |
| vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| output_attentions: bool = False, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor, ...]: |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
| if output_attentions: |
| outputs += (attn_weights,) |
| return outputs |
|
|
|
|
| class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer): |
| """Cross-attention decoder layer with tanh-gated attention and MLP""" |
|
|
| def __init__(self, config: MossVLTextConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.layer_idx = layer_idx |
|
|
| self.cross_attn = MossVLTextCrossAttention(config=config, layer_idx=layer_idx) |
| self.mlp = MossVLTextMLP(config) |
| |
| self.input_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
| |
| |
| self.cross_attn_attn_gate = nn.Parameter(torch.zeros(1)) |
| self.cross_attn_mlp_gate = nn.Parameter(torch.zeros(1)) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| cross_attention_states: Optional[torch.Tensor] = None, |
| cross_attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| vision_position_ids: Optional[torch.LongTensor] = None, |
| vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| output_attentions: bool = False, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor, ...]: |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| hidden_states, attn_weights = self.cross_attn( |
| hidden_states=hidden_states, |
| cross_attention_states=cross_attention_states, |
| attention_mask=cross_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| query_position_embeddings=position_embeddings, |
| vision_position_embeddings=vision_position_embeddings, |
| ) |
| if full_text_row_masked_out_mask is not None: |
| hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states |
|
|
| hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| if full_text_row_masked_out_mask is not None: |
| hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states |
|
|
| hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states |
|
|
| outputs = (hidden_states,) |
| if output_attentions: |
| outputs += (attn_weights,) |
| return outputs |
|
|
|
|
|
|
|
|
| @auto_docstring |
| class MossVLPreTrainedModel(PreTrainedModel): |
| config: MossVLConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["MossVLSelfAttentionDecoderLayer", "MossVLCrossAttentionDecoderLayer", "MossVLVisionBlock"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": [MossVLSelfAttentionDecoderLayer, MossVLCrossAttentionDecoderLayer], |
| "attentions": [ |
| OutputRecorder(MossVLTextSelfAttention, index=1, layer_name="self_attn"), |
| OutputRecorder(MossVLTextCrossAttention, index=1, layer_name="cross_attn"), |
| ], |
| } |
|
|
| def _init_weights(self, module): |
| """Initialize the weights. |
| """ |
| super()._init_weights(module) |
| if isinstance(module, MossVLVisionRotaryEmbedding): |
| init.copy_(module.inv_freq, module.compute_inv_freq()) |
|
|
|
|
|
|
|
|
|
|
| class MossVLVisionModel(MossVLPreTrainedModel): |
| config: MossVLVisionConfig |
| _no_split_modules = ["MossVLVisionBlock"] |
|
|
| def __init__(self, config, *inputs, **kwargs) -> None: |
| super().__init__(config, *inputs, **kwargs) |
| self.spatial_merge_size = config.spatial_merge_size |
| self.patch_size = config.patch_size |
| self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size |
|
|
| self.patch_embed = MossVLVisionPatchEmbed(config=config) |
| self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) |
| self.num_grid_per_side = int(config.num_position_embeddings**0.5) |
|
|
| head_dim = config.hidden_size // config.num_heads |
| self.rotary_pos_emb = MossVLVisionRotaryEmbedding(head_dim // 2) |
|
|
| self.blocks = nn.ModuleList([MossVLVisionBlock(config) for _ in range(config.depth)]) |
| |
| |
| self.deepstack_visual_indexes = config.deepstack_visual_indexes |
| num_deepstack_features = len(self.deepstack_visual_indexes) |
| |
| |
| self.merger = MossVLVisionPatchMerger( |
| config=config, |
| num_deepstack_features=num_deepstack_features |
| ) |
|
|
| self.gradient_checkpointing = False |
|
|
| def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: |
| merge_size = self.spatial_merge_size |
| max_hw = int(grid_thw[:, 1:].max().item()) |
| freq_table = self.rotary_pos_emb(max_hw) |
| device = freq_table.device |
|
|
| total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) |
| pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) |
|
|
| offset = 0 |
| for num_frames, height, width in grid_thw: |
| merged_h, merged_w = height // merge_size, width // merge_size |
|
|
| block_rows = torch.arange(merged_h, device=device) |
| block_cols = torch.arange(merged_w, device=device) |
| intra_row = torch.arange(merge_size, device=device) |
| intra_col = torch.arange(merge_size, device=device) |
|
|
| row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] |
| col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] |
|
|
| row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) |
| col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) |
|
|
| coords = torch.stack((row_idx, col_idx), dim=-1) |
|
|
| if num_frames > 1: |
| coords = coords.repeat(num_frames, 1) |
|
|
| num_tokens = coords.shape[0] |
| pos_ids[offset : offset + num_tokens] = coords |
| offset += num_tokens |
|
|
| embeddings = freq_table[pos_ids] |
| embeddings = embeddings.flatten(1) |
| return embeddings |
|
|
| def fast_pos_embed_interpolate(self, grid_thw): |
| grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] |
| device = self.pos_embed.weight.device |
| dtype = self.pos_embed.weight.dtype |
|
|
| idx_parts = [[] for _ in range(4)] |
| weight_parts = [[] for _ in range(4)] |
|
|
| for t, h, w in zip(grid_ts, grid_hs, grid_ws): |
| h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device) |
| w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device) |
|
|
| h_idxs_floor = h_idxs.int() |
| w_idxs_floor = w_idxs.int() |
| h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) |
| w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) |
|
|
| dh = h_idxs - h_idxs_floor |
| dw = w_idxs - w_idxs_floor |
|
|
| base_h = h_idxs_floor * self.num_grid_per_side |
| base_h_ceil = h_idxs_ceil * self.num_grid_per_side |
|
|
| indices = [ |
| (base_h[None].T + w_idxs_floor[None]).flatten(), |
| (base_h[None].T + w_idxs_ceil[None]).flatten(), |
| (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), |
| (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), |
| ] |
|
|
| weights = [ |
| ((1 - dh)[None].T * (1 - dw)[None]).flatten(), |
| ((1 - dh)[None].T * dw[None]).flatten(), |
| (dh[None].T * (1 - dw)[None]).flatten(), |
| (dh[None].T * dw[None]).flatten(), |
| ] |
|
|
| for i in range(4): |
| idx_parts[i].append(indices[i]) |
| weight_parts[i].append(weights[i]) |
|
|
| idx_tensor = torch.stack([torch.cat(parts) for parts in idx_parts]).to(dtype=torch.long) |
| weight_tensor = torch.stack([torch.cat(parts) for parts in weight_parts]).to(dtype=dtype) |
| pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] |
| patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] |
|
|
| patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) |
|
|
| patch_pos_embeds_permute = [] |
| merge_size = self.config.spatial_merge_size |
| for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): |
| pos_embed = pos_embed.repeat(t, 1) |
| pos_embed = ( |
| pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) |
| .permute(0, 1, 3, 2, 4, 5) |
| .flatten(0, 4) |
| ) |
| patch_pos_embeds_permute.append(pos_embed) |
| patch_pos_embeds = torch.cat(patch_pos_embeds_permute) |
| return patch_pos_embeds |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| grid_thw: torch.Tensor, |
| **kwargs |
| ) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states: input tensor |
| grid_thw: [num_images, 3] tensor with (t, h, w) for each image |
| Returns: |
| hidden_states: [num_tokens, out_hidden_size] - packed hidden states |
| """ |
| hidden_states = self.patch_embed(hidden_states) |
|
|
| pos_embeds = self.fast_pos_embed_interpolate(grid_thw) |
| hidden_states = hidden_states + pos_embeds |
|
|
| rotary_pos_emb = self.rot_pos_emb(grid_thw) |
|
|
| seq_len, _ = hidden_states.size() |
| hidden_states = hidden_states.reshape(seq_len, -1) |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| position_embeddings = (emb.cos(), emb.sin()) |
|
|
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
| |
| deepstack_features = [] |
| for layer_idx, blk in enumerate(self.blocks): |
| hidden_states = blk( |
| hidden_states, |
| cu_seqlens=cu_seqlens, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| |
| if layer_idx in self.deepstack_visual_indexes: |
| deepstack_features.append(hidden_states) |
|
|
| |
| hidden_states = self.merger(hidden_states, deepstack_features) |
|
|
| return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
| @auto_docstring( |
| custom_intro=""" |
| The MossVL Text Model with self-attention and cross-attention layers for vision-language interaction. |
| """ |
| ) |
| class MossVLTextModel(MossVLPreTrainedModel): |
| config: MossVLTextConfig |
| _no_split_modules = ["MossVLSelfAttentionDecoderLayer", "MossVLCrossAttentionDecoderLayer"] |
|
|
|
|
| def __init__(self, config: MossVLTextConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| |
| |
| self.cross_attention_layers = config.cross_attention_layers |
| |
| |
| self.layers = nn.ModuleList() |
| for layer_idx in range(config.num_hidden_layers): |
| if layer_idx in config.cross_attention_layers: |
| |
| self.layers.append( |
| MossVLCrossAttentionDecoderLayer(config, layer_idx) |
| ) |
| else: |
| |
| self.layers.append( |
| MossVLSelfAttentionDecoderLayer(config, layer_idx) |
| ) |
| |
| self.norm = MossVLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = MossVLTextRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| self.post_init() |
|
|
|
|
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cross_attention_states: Optional[torch.Tensor] = None, |
| cross_attention_mask: Optional[torch.Tensor] = None, |
| vision_position_ids: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Union[tuple, BaseModelOutputWithPast]: |
| """ |
| Args: |
| full_text_row_masked_out_mask (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): |
| Mask for full text rows that should be masked out in attention computation. |
| cross_attention_states (`torch.Tensor`, *optional*): |
| Vision features to be used in cross-attention layers. Shape: `(batch_size, vision_seq_len, hidden_size)`. |
| cross_attention_mask (`torch.Tensor`, *optional*): |
| Attention mask for cross-attention between text and vision. Shape: `(batch_size, 1, text_seq_len, vision_seq_len)`. |
| vision_position_ids (`torch.LongTensor`, *optional*): |
| Position IDs for vision tokens used in cross-attention. Shape: `(batch_size, vision_seq_len)`. |
| cache_position (`torch.LongTensor`, *optional*): |
| Absolute cache positions for the current text tokens during incremental decoding. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if use_cache and past_key_values is None and not torch.jit.is_tracing(): |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| attention_mask = create_causal_mask( |
| config=self.config, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| |
| vision_position_embeddings = None |
|
|
| if cross_attention_states is not None: |
| if vision_position_ids is not None: |
| vision_position_embeddings = self.rotary_emb(cross_attention_states, vision_position_ids) |
|
|
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| for idx, decoder_layer in enumerate(self.layers): |
| |
| |
| |
| is_cross_attention_layer = idx in self.cross_attention_layers |
| is_cross_attention_cache_empty = past_key_values is None or ( |
| past_key_values is not None and past_key_values.get_seq_length(idx) == 0 |
| ) |
|
|
| if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: |
| continue |
|
|
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| full_text_row_masked_out_mask=full_text_row_masked_out_mask, |
| past_key_values=past_key_values, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| cross_attention_states=cross_attention_states, |
| cross_attention_mask=cross_attention_mask, |
| vision_position_ids=vision_position_ids, |
| vision_position_embeddings=vision_position_embeddings, |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions += (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| hidden_states = self.norm(hidden_states) |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states[:-1] + (hidden_states,) |
|
|
| if not return_dict: |
| outputs = (hidden_states, past_key_values) |
| if output_hidden_states: |
| outputs += (all_hidden_states,) |
| if output_attentions: |
| outputs += (all_attentions,) |
| return outputs |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
|
|
|
|
| @auto_docstring( |
| custom_intro=""" |
| The MossVL model which consists of a vision encoder (from Qwen3VL) and a language model with cross-attention layers. |
| """ |
| ) |
| class MossVLModel(MossVLPreTrainedModel): |
| base_model_prefix = "" |
| config: MossVLConfig |
| _no_split_modules = ["MossVLSelfAttentionDecoderLayer", "MossVLCrossAttentionDecoderLayer", "MossVLVisionBlock"] |
| _checkpoint_conversion_mapping = {} |
| accepts_loss_kwargs = False |
| def __init__(self, config): |
| super().__init__(config) |
| self.visual = MossVLVisionModel._from_config(config.vision_config) |
| self.language_model = MossVLTextModel._from_config(config.text_config) |
|
|
| |
| |
| self.separator_token = nn.Parameter( |
| torch.zeros(config.vision_config.out_hidden_size) |
| ) |
|
|
| self.post_init() |
| |
|
|
|
|
| def convert_packed_to_batch( |
| self, |
| hidden_states: torch.Tensor, |
| grid_thw: torch.Tensor, |
| media_nums_per_sample: Optional[List[int]], |
| ) -> Tuple[torch.Tensor, List[dict]]: |
| """ |
| Convert packed vision tokens to batched format with separator tokens. |
| |
| For each image: inserts 1 separator token after the vision tokens. |
| For each video: inserts 1 separator token after EACH frame's vision tokens. |
| |
| Note: media_nums_per_sample counts each video as 1 media item, |
| but each frame in a video gets its own separator token. |
| """ |
| |
| |
| tokens_per_media = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]) // (self.visual.spatial_merge_size ** 2) |
| hidden_size = hidden_states.shape[-1] |
| |
| |
| if media_nums_per_sample is None: |
| batch_size = 1 |
| media_nums_per_sample = [grid_thw.shape[0]] |
| else: |
| batch_size = len(media_nums_per_sample) |
| |
| |
| if batch_size == 1: |
| |
| total_len = 0 |
| for i in range(grid_thw.shape[0]): |
| num_tokens = tokens_per_media[i].item() |
| num_frames = grid_thw[i, 0].item() |
| total_len += num_tokens + num_frames |
| |
| |
| pad_multiple = self.config.vision_seq_pad_multiple |
| if total_len % pad_multiple != 0: |
| max_seq_len = (total_len + pad_multiple - 1) // pad_multiple * pad_multiple |
| else: |
| max_seq_len = total_len |
| |
| |
| batched_hidden_states = torch.zeros( |
| 1, max_seq_len, hidden_size, |
| dtype=hidden_states.dtype, |
| device=hidden_states.device |
| ) |
| |
| |
| sample_info = { |
| 'medias': [], |
| 'total_length': total_len, |
| 'pad_start': total_len, |
| 'pad_end': max_seq_len |
| } |
| |
| token_offset = 0 |
| current_seq_len = 0 |
| separator_embedding = self.separator_token.to(hidden_states.dtype) |
| |
| |
| for media_idx in range(grid_thw.shape[0]): |
| num_tokens = tokens_per_media[media_idx].item() |
| t, h, w = grid_thw[media_idx].tolist() |
| num_frames = t |
| tokens_per_frame = num_tokens // num_frames |
| |
| |
| |
| media_vision_tokens = hidden_states[token_offset : token_offset + num_tokens] |
| |
| |
| media_vision_tokens = media_vision_tokens.view(num_frames, tokens_per_frame, hidden_size) |
| |
| |
| chunk_len = num_frames * (tokens_per_frame + 1) |
| |
| |
| target_view = batched_hidden_states[0, current_seq_len : current_seq_len + chunk_len] |
| target_view = target_view.view(num_frames, tokens_per_frame + 1, hidden_size) |
| |
| |
| target_view[:, :tokens_per_frame].copy_(media_vision_tokens) |
| |
| |
| |
| target_view[:, tokens_per_frame] = separator_embedding |
| |
| |
| |
| sample_info['medias'].append({ |
| 'start': current_seq_len, |
| 'end': current_seq_len + chunk_len, |
| 'length': chunk_len, |
| 'num_frames': num_frames, |
| 'grid_h': h, |
| 'grid_w': w, |
| 'vision_tokens_per_frame': tokens_per_frame, |
| 'has_separator': True, |
| }) |
| |
| current_seq_len += chunk_len |
| token_offset += num_tokens |
| |
| vision_token_info = [sample_info] |
| |
| return batched_hidden_states, vision_token_info |
|
|
| |
| |
| |
| tokens_per_sample = [] |
| media_idx = 0 |
| for num_medias_in_sample in media_nums_per_sample: |
| sample_tokens = 0 |
| for i in range(num_medias_in_sample): |
| num_tokens = tokens_per_media[media_idx + i].item() |
| num_frames = grid_thw[media_idx + i, 0].item() |
| sample_tokens += num_tokens + num_frames |
| tokens_per_sample.append(sample_tokens) |
| media_idx += num_medias_in_sample |
| |
| max_seq_len = max(tokens_per_sample) |
| pad_multiple = self.config.vision_seq_pad_multiple |
| if max_seq_len % pad_multiple != 0: |
| max_seq_len = (max_seq_len + pad_multiple - 1) // pad_multiple * pad_multiple |
| |
| |
| batched_hidden_states = torch.zeros( |
| batch_size, max_seq_len, hidden_size, |
| dtype=hidden_states.dtype, |
| device=hidden_states.device |
| ) |
| |
| |
| separator_embedding = self.separator_token.to(hidden_states.dtype) |
| |
| |
| vision_token_info = [] |
| |
| |
| token_offset = 0 |
| media_idx = 0 |
| |
| for sample_idx, num_medias_in_sample in enumerate(media_nums_per_sample): |
| sample_info = { |
| 'medias': [], |
| 'total_length': tokens_per_sample[sample_idx], |
| 'pad_start': tokens_per_sample[sample_idx], |
| 'pad_end': max_seq_len |
| } |
| |
| seq_offset = 0 |
| |
| |
| for _ in range(num_medias_in_sample): |
| num_tokens = tokens_per_media[media_idx].item() |
| |
| t, h, w = grid_thw[media_idx].tolist() |
| num_frames = t |
| tokens_per_frame = num_tokens // num_frames |
| |
| |
| media_start = seq_offset |
| |
| |
| |
| media_vision_tokens = hidden_states[token_offset : token_offset + num_tokens] |
| |
| |
| media_vision_tokens = media_vision_tokens.view(num_frames, tokens_per_frame, hidden_size) |
| |
| |
| separators = separator_embedding.view(1, 1, hidden_size).expand(num_frames, 1, hidden_size) |
| |
| |
| media_tokens_with_sep = torch.cat([media_vision_tokens, separators], dim=1) |
| |
| |
| media_tokens_with_sep = media_tokens_with_sep.view(-1, hidden_size) |
| |
| |
| media_length_with_sep = media_tokens_with_sep.shape[0] |
| batched_hidden_states[sample_idx, seq_offset : seq_offset + media_length_with_sep] = media_tokens_with_sep |
| |
| seq_offset += media_length_with_sep |
| |
| |
| media_length = num_tokens + num_frames |
| |
| |
| |
| sample_info['medias'].append({ |
| 'start': media_start, |
| 'end': media_start + media_length, |
| 'length': media_length, |
| 'num_frames': num_frames, |
| 'grid_h': h, |
| 'grid_w': w, |
| 'vision_tokens_per_frame': tokens_per_frame, |
| 'has_separator': True, |
| }) |
| |
| token_offset += num_tokens |
| media_idx += 1 |
| |
| vision_token_info.append(sample_info) |
| |
| return batched_hidden_states, vision_token_info |
|
|
| def get_input_embeddings(self): |
| return self.language_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.language_model.set_input_embeddings(value) |
|
|
| def set_decoder(self, decoder): |
| self.language_model = decoder |
|
|
| def get_decoder(self): |
| return self.language_model |
|
|
| def _expand_cross_attention_mask( |
| self, |
| cross_attention_mask: torch.Tensor, |
| vision_token_info: List[dict], |
| target_dtype: torch.dtype, |
| ) -> torch.Tensor: |
| """ |
| Expand cross_attention_mask from (B, 1, T, N_frames) to (B, 1, T, N_tokens). |
| |
| Args: |
| cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, num_frames)`): |
| Coarse attention mask where each frame corresponds to one column. |
| Can be bool (True=masked) or float (min_value=masked). |
| vision_token_info (`List[dict]`): |
| Precomputed token info that includes actual token counts after ViT padding. |
| Must be provided (either from prefill computation or from cache). |
| Each dict contains 'medias' list with 'length', 'num_frames', and 'vision_tokens_per_frame'. |
| target_dtype (`torch.dtype`): |
| Target dtype for the output mask (typically inputs_embeds.dtype). |
| |
| Returns: |
| `torch.Tensor` of shape `(batch_size, 1, text_seq_len, total_vision_tokens)`: |
| Fine-grained attention mask where each vision token has its own column. |
| Masked positions have min_value, unmasked positions have 0.0. |
| |
| Note: |
| - vision_token_info contains the actual token counts after ViT padding (pad to multiple of 8) |
| - Separator tokens are treated as part of the same frame, sharing the same mask |
| """ |
| if vision_token_info is None: |
| raise ValueError( |
| "vision_token_info must be provided to _expand_cross_attention_mask. " |
| "This should be cached from prefill stage or computed during current forward pass." |
| ) |
| |
| batch_size = cross_attention_mask.shape[0] |
| |
| |
| max_vision_len = 0 |
| if vision_token_info: |
| max_vision_len = max([info.get('pad_end', 0) for info in vision_token_info]) |
|
|
| if max_vision_len == 0: |
| return None |
|
|
| |
| if cross_attention_mask.dtype == torch.bool: |
| |
| |
| min_value = torch.finfo(target_dtype).min |
| float_mask = torch.zeros_like(cross_attention_mask, dtype=target_dtype) |
| float_mask.masked_fill_(cross_attention_mask, min_value) |
| cross_attention_mask = float_mask |
| else: |
| |
| cross_attention_mask = cross_attention_mask.to(dtype=target_dtype) |
|
|
| |
| |
| min_dtype = torch.finfo(target_dtype).min |
| final_mask = torch.full( |
| (batch_size, 1, cross_attention_mask.shape[2], max_vision_len), |
| min_dtype, |
| dtype=target_dtype, |
| device=cross_attention_mask.device |
| ) |
| |
| for i in range(batch_size): |
| medias = vision_token_info[i]['medias'] |
| if not medias: |
| continue |
| |
| |
| repeats_parts = [] |
| for media in medias: |
| num_frames = media.get('num_frames', 1) |
| length = media['length'] |
| has_separator = media.get('has_separator', False) |
| |
| |
| if has_separator: |
| vision_tokens_per_frame = media.get('vision_tokens_per_frame', (length // num_frames) - 1) |
| tokens_per_frame_with_sep = vision_tokens_per_frame + 1 |
| else: |
| tokens_per_frame_with_sep = length // num_frames |
| |
| |
| |
| repeats_parts.append( |
| torch.full( |
| (num_frames,), |
| tokens_per_frame_with_sep, |
| dtype=torch.long, |
| device=cross_attention_mask.device, |
| ) |
| ) |
| |
| num_valid_frames = sum(part.numel() for part in repeats_parts) |
| if num_valid_frames == 0: |
| continue |
| |
| |
| |
| valid_mask_frames = min(num_valid_frames, cross_attention_mask.shape[-1]) |
| repeats_tensor = torch.cat(repeats_parts) |
| if valid_mask_frames < num_valid_frames: |
| repeats_tensor = repeats_tensor[:valid_mask_frames] |
| |
| |
| |
| source_mask = cross_attention_mask[i, :, :, :valid_mask_frames] |
| |
| |
| |
| expanded_mask = source_mask.repeat_interleave(repeats_tensor, dim=-1) |
| |
| |
| num_tokens = expanded_mask.shape[-1] |
| if num_tokens > max_vision_len: |
| num_tokens = max_vision_len |
| expanded_mask = expanded_mask[..., :num_tokens] |
| |
| final_mask[i, :, :, :num_tokens] = expanded_mask |
| |
| return final_mask |
|
|
| def compute_position_ids( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Cache] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Compute 3D position IDs for text tokens with special handling for image tokens. |
| |
| Rules: |
| - Regular text tokens: increment position (x, x, x) -> (x+1, x+1, x+1) |
| - Image token: gets (t, t, t) where t = previous_text_position + 1 |
| - After processing vision tokens, next text token starts at max(vision_bottom_right) + 1 |
| |
| In decode stage, uses cached rope_deltas to quickly compute new positions. |
| |
| Args: |
| input_ids: (batch_size, seq_len) |
| attention_mask: (batch_size, seq_len), optional |
| past_key_values: cache object used to infer decode offset from the current text cache length |
| |
| Returns: |
| position_ids: (3, batch_size, seq_len) |
| """ |
| batch_size, seq_len = input_ids.shape |
| device = input_ids.device |
| image_token_id = self.config.image_token_id |
| |
| |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| if past_seen_tokens > 0: |
| position_ids = torch.arange(seq_len, device=device, dtype=torch.long) |
| position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
| position_ids = position_ids + past_seen_tokens |
|
|
| if rope_deltas is not None: |
| position_ids = position_ids + rope_deltas.unsqueeze(1) |
|
|
| return position_ids.unsqueeze(0).expand(3, -1, -1) |
| |
| |
| |
| |
| |
| is_image_token = (input_ids == image_token_id) |
| if attention_mask is not None: |
| is_padding = (attention_mask == 0) |
| else: |
| is_padding = torch.zeros_like(input_ids, dtype=torch.bool) |
| |
| is_regular_token = ~(is_image_token | is_padding) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cumulative_regular = is_regular_token.long().cumsum(dim=1) |
| |
| |
| |
| |
| |
| |
| |
| |
| base_position_ids = cumulative_regular - is_regular_token.long() |
| |
| |
| base_position_ids = base_position_ids.masked_fill(is_padding, 0) |
| |
| |
| position_ids = base_position_ids.unsqueeze(0).expand(3, -1, -1).clone() |
| |
| return position_ids |
|
|
| def compute_vision_position_ids( |
| self, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor, |
| vision_token_info: List[dict], |
| cross_attention_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Compute 3D position IDs for vision tokens (including separator tokens) and update text position_ids. |
| Vectorized implementation for improved efficiency. |
| |
| Position encoding rules: |
| - For text: if not image token, increment position (t-1, t-1, t-1) -> (t, t, t) -> ... |
| - For vision: top-left is (t, t, t), increases towards bottom-right to (t, t+h-1, t+w-1) |
| - Separator Token after each frame: (x, x, x) where x = max(t+h-1, t+w-1) + 1 = max(t+h, t+w) |
| - Image token in text: also gets position (x, x, x) - same as separator |
| - Next text token after image: starts at (x+1, x+1, x+1) |
| |
| Args: |
| input_ids: (batch_size, seq_len) |
| position_ids: (3, batch_size, seq_len) - will be updated in place |
| vision_token_info: metadata about vision tokens (now includes separator positions) |
| cross_attention_states: (batch_size, max_vision_seq_len, hidden_size) |
| attention_mask: (batch_size, seq_len), optional |
| |
| Returns: |
| vision_pos_ids: (3, batch_size, max_vision_seq_len) |
| position_ids: (3, batch_size, seq_len) - updated |
| rope_deltas: (batch_size,) - position offset due to vision tokens |
| """ |
| batch_size, max_vision_seq_len, _ = cross_attention_states.shape |
| device = cross_attention_states.device |
| image_token_id = self.config.image_token_id |
| merge_size = self.visual.spatial_merge_size |
|
|
| |
| |
| |
| |
| image_token_indices = (input_ids == image_token_id).nonzero() |
| |
| |
| |
| flat_eff_h_parts = [] |
| flat_eff_w_parts = [] |
| flat_vis_start_parts = [] |
|
|
| |
| for b_idx, info in enumerate(vision_token_info): |
| medias = info.get('medias', []) |
| for media in medias: |
| num_frames = media['num_frames'] |
| h, w = media['grid_h'], media['grid_w'] |
| eh, ew = h // merge_size, w // merge_size |
| start = media['start'] |
| tok_per_frame = media['vision_tokens_per_frame'] |
| stride = tok_per_frame + 1 |
|
|
| frame_offsets = start + torch.arange(num_frames, device=device, dtype=torch.long) * stride |
| flat_vis_start_parts.append(frame_offsets) |
| flat_eff_h_parts.append(torch.full((num_frames,), eh, device=device, dtype=torch.long)) |
| flat_eff_w_parts.append(torch.full((num_frames,), ew, device=device, dtype=torch.long)) |
|
|
| |
| vision_pos_ids = torch.zeros( |
| (3, batch_size, max_vision_seq_len), |
| dtype=torch.long, |
| device=device |
| ) |
|
|
| |
| if len(flat_eff_h_parts) == 0 or len(image_token_indices) == 0: |
| rope_deltas = position_ids.max(dim=0).values.max(dim=-1).values + 1 - input_ids.shape[1] |
| return vision_pos_ids, position_ids, rope_deltas |
|
|
| flat_eff_h = torch.cat(flat_eff_h_parts) |
| flat_eff_w = torch.cat(flat_eff_w_parts) |
| flat_vis_starts = torch.cat(flat_vis_start_parts) |
|
|
| |
| num_matches = min(flat_eff_h.shape[0], image_token_indices.shape[0]) |
| flat_eff_h = flat_eff_h[:num_matches] |
| flat_eff_w = flat_eff_w[:num_matches] |
| flat_vis_starts = flat_vis_starts[:num_matches] |
| |
| |
| target_indices = image_token_indices[:num_matches] |
| batch_rows = target_indices[:, 0] |
| text_cols = target_indices[:, 1] |
| |
| |
| |
| |
| |
| max_hw = torch.maximum(flat_eff_h, flat_eff_w) |
| shifts = max_hw + 1 |
| |
| |
| shift_map = torch.zeros((batch_size, input_ids.shape[1]), dtype=torch.long, device=device) |
| shift_map[batch_rows, text_cols] = shifts |
| |
| |
| cum_shifts = shift_map.cumsum(dim=1) |
| |
| |
| |
| |
| orig_pos = position_ids[0, batch_rows, text_cols] |
| shifts_before = cum_shifts[batch_rows, text_cols] - shifts |
| t_vals = orig_pos + shifts_before |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| new_pos_ids = position_ids + cum_shifts.unsqueeze(0) |
| |
| |
| |
| img_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) |
| img_token_mask[batch_rows, text_cols] = True |
| new_pos_ids[:, img_token_mask] -= 1 |
| |
| |
| if attention_mask is not None: |
| |
| padding_mask = (attention_mask == 0).unsqueeze(0) |
| new_pos_ids.masked_fill_(padding_mask, 0) |
| |
| |
| position_ids.copy_(new_pos_ids) |
| |
| |
| |
| |
| unique_shapes = torch.unique(torch.stack([flat_eff_h, flat_eff_w], dim=1), dim=0) |
| |
| for shape in unique_shapes: |
| eh, ew = shape[0].item(), shape[1].item() |
| |
| |
| mask = (flat_eff_h == eh) & (flat_eff_w == ew) |
| |
| sub_t_vals = t_vals[mask] |
| sub_batch_rows = batch_rows[mask] |
| sub_vis_starts = flat_vis_starts[mask] |
| |
| num_frames_sub = sub_t_vals.shape[0] |
| if num_frames_sub == 0: continue |
| |
| |
| |
| |
| |
| y_grid = torch.arange(eh, device=device).view(1, eh, 1).expand(num_frames_sub, -1, ew) |
| x_grid = torch.arange(ew, device=device).view(1, 1, ew).expand(num_frames_sub, eh, -1) |
| t_grid = sub_t_vals.view(-1, 1, 1).expand(-1, eh, ew) |
| |
| h_grid = t_grid + y_grid |
| w_grid = t_grid + x_grid |
| |
| |
| flat_t = t_grid.reshape(-1) |
| flat_h = h_grid.reshape(-1) |
| flat_w = w_grid.reshape(-1) |
| |
| |
| |
| tokens_per_frame = eh * ew |
| |
| |
| seq_offsets = torch.arange(tokens_per_frame, device=device).unsqueeze(0) |
| |
| abs_seq_offsets = seq_offsets + sub_vis_starts.unsqueeze(1) |
| |
| flat_seq_inds = abs_seq_offsets.reshape(-1) |
| flat_batch_inds = sub_batch_rows.unsqueeze(1).expand(-1, tokens_per_frame).reshape(-1) |
| |
| |
| valid_mask = flat_seq_inds < max_vision_seq_len |
| |
| if valid_mask.any(): |
| final_b = flat_batch_inds[valid_mask] |
| final_s = flat_seq_inds[valid_mask] |
| |
| vision_pos_ids[0, final_b, final_s] = flat_t[valid_mask] |
| vision_pos_ids[1, final_b, final_s] = flat_h[valid_mask] |
| vision_pos_ids[2, final_b, final_s] = flat_w[valid_mask] |
| |
| |
| |
| sep_vals = t_vals + max_hw |
| |
| sep_indices = flat_vis_starts + (flat_eff_h * flat_eff_w) |
| |
| valid_sep_mask = sep_indices < max_vision_seq_len |
| |
| if valid_sep_mask.any(): |
| final_b = batch_rows[valid_sep_mask] |
| final_s = sep_indices[valid_sep_mask] |
| vals = sep_vals[valid_sep_mask] |
| |
| vision_pos_ids[0, final_b, final_s] = vals |
| vision_pos_ids[1, final_b, final_s] = vals |
| vision_pos_ids[2, final_b, final_s] = vals |
| |
| |
| |
| |
| |
| |
| |
| |
| max_pos = position_ids.max(dim=0).values.max(dim=-1).values |
| rope_deltas = max_pos + 1 - input_ids.shape[1] |
| |
| return vision_pos_ids, position_ids, rope_deltas |
|
|
| def get_vision_features( |
| self, |
| pixel_values: torch.FloatTensor, |
| grid_thw: Optional[torch.LongTensor] = None, |
| media_nums_per_sample: Optional[List[int]] = None |
| ): |
| """ |
| Args: |
| pixel_values: vision pixel values (images and videos merged) |
| grid_thw: [num_media, 3] tensor with (t, h, w) for each media item |
| media_nums_per_sample: List indicating how many media items each sample has |
| Returns: |
| vision_embeds: [batch_size, max_seq_len, hidden_size] |
| vision_token_info: List[Dict] with media positions and padding info for each sample |
| """ |
| pixel_values = pixel_values.type(self.visual.dtype) |
| hidden_states = self.visual( |
| pixel_values, |
| grid_thw=grid_thw |
| ) |
| vision_embeds, vision_token_info = self.convert_packed_to_batch( |
| hidden_states, |
| grid_thw, |
| media_nums_per_sample |
| ) |
| return vision_embeds, vision_token_info |
|
|
|
|
|
|
| @auto_docstring |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| grid_thw: Optional[torch.LongTensor] = None, |
| media_nums_per_sample: Optional[List[int]] = None, |
| vision_position_ids: Optional[torch.LongTensor] = None, |
| cross_attention_mask: Optional[torch.Tensor] = None, |
| vision_token_info: Optional[List[dict]] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Union[tuple, BaseModelOutputWithPast]: |
| """ |
| Args: |
| grid_thw (`torch.LongTensor` of shape `(num_media, 3)`, *optional*): |
| Grid size for each media item in (temporal, height, width) format. Each row contains `[t, h, w]` |
| representing the number of temporal, height, and width patches for a media item (image or video). |
| media_nums_per_sample (`List[int]`, *optional*): |
| List indicating how many media items each sample in the batch has. For example, `[2, 1, 3]` means |
| the first sample has 2 media items, the second has 1, and the third has 3. |
| vision_position_ids (`torch.LongTensor` of shape `(batch_size, vision_seq_len)`, *optional*): |
| Position IDs for vision tokens used in cross-attention. These are computed from text position IDs |
| based on the positions of image/video tokens in the input text. |
| cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*): |
| Attention mask for cross-attention between text and vision. Controls which vision tokens each text |
| token can attend to, enforcing causal visibility for video frames. |
| vision_token_info (`List[dict]`, *optional*): |
| Cached metadata describing how packed vision tokens were regrouped per sample. Reused in decode |
| to expand frame-level cross-attention masks to token-level masks without recomputing vision features. |
| rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Cached offsets between text sequence length and multimodal RoPE positions. Reused in decode to |
| reconstruct text position ids from the current cache length. |
| """ |
| cache_position = kwargs.pop("cache_position", None) |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
| |
| cross_attention_states = None |
|
|
| if pixel_values is not None: |
| |
| batch_size = inputs_embeds.shape[0] |
| |
| |
| if media_nums_per_sample is None: |
| |
| if batch_size == 1: |
| media_nums_per_sample = [grid_thw.shape[0]] |
| else: |
| raise ValueError("media_nums_per_sample must be provided when batch_size > 1") |
| |
| |
| |
| vision_embeds, vision_token_info = self.get_vision_features( |
| pixel_values, grid_thw, media_nums_per_sample |
| ) |
| |
| |
| cross_attention_states = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
| |
| if position_ids is None: |
| |
| |
| position_ids = self.compute_position_ids( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| rope_deltas=rope_deltas, |
| ) |
|
|
| |
| full_text_row_masked_out_mask = None |
| |
| if cross_attention_mask is not None: |
| |
| |
| |
| |
| cross_attention_mask = self._expand_cross_attention_mask( |
| cross_attention_mask, |
| vision_token_info, |
| target_dtype=inputs_embeds.dtype |
| ) |
| |
| |
| if cross_attention_mask is not None: |
| negative_inf_value = torch.finfo(cross_attention_mask.dtype).min |
| full_text_row_masked_out_mask = ( |
| (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] |
| ) |
| cross_attention_mask = cross_attention_mask * full_text_row_masked_out_mask |
|
|
| if vision_position_ids is None and cross_attention_states is not None and input_ids is not None: |
| vision_position_ids, position_ids, rope_deltas = self.compute_vision_position_ids( |
| input_ids, |
| position_ids, |
| vision_token_info, |
| cross_attention_states, |
| attention_mask |
| ) |
|
|
| outputs = self.language_model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| cross_attention_states=cross_attention_states, |
| cross_attention_mask=cross_attention_mask, |
| vision_position_ids=vision_position_ids, |
| full_text_row_masked_out_mask=full_text_row_masked_out_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **kwargs, |
| ) |
|
|
| if not return_dict: |
| last_hidden_state = outputs[0] |
| model_outputs = ( |
| last_hidden_state, |
| outputs[1] if len(outputs) > 1 else past_key_values, |
| ) |
| if output_hidden_states: |
| model_outputs += (outputs[2],) |
| if output_attentions: |
| attn_idx = 3 if output_hidden_states else 2 |
| model_outputs += (outputs[attn_idx],) |
| model_outputs += (vision_token_info, rope_deltas) |
| return model_outputs |
|
|
| return MossVLModelOutputWithPast( |
| last_hidden_state=outputs.last_hidden_state, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| vision_token_info=vision_token_info, |
| rope_deltas=rope_deltas, |
| ) |
|
|
|
|
|
|
|
|
|
|
| @auto_docstring( |
| custom_intro=""" |
| The MossVL model with a language modeling head on top, for conditional generation tasks. |
| Combines Qwen3VL vision encoder with LLM via cross-attention layers. |
| """ |
| ) |
| class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin): |
| |
| |
| |
| |
| _tied_weights_keys: dict[str, str] = {} |
| config: MossVLConfig |
| _checkpoint_conversion_mapping = {} |
| accepts_loss_kwargs = False |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = MossVLModel(config) |
| self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def set_decoder(self, decoder): |
| self.model.set_decoder(decoder) |
|
|
| def get_decoder(self): |
| return self.model.get_decoder() |
|
|
|
|
| def get_vision_features( |
| self, |
| pixel_values: torch.FloatTensor, |
| grid_thw: Optional[torch.LongTensor] = None, |
| media_nums_per_sample: Optional[List[int]] = None |
| ): |
| """ |
| Get vision features for images and videos (merged). |
| |
| Args: |
| pixel_values: vision pixel values (images and videos merged) |
| grid_thw: [num_media, 3] tensor with (t, h, w) for each media item |
| media_nums_per_sample: List indicating how many media items each sample has |
| Returns: |
| vision_embeds: [batch_size, max_seq_len, hidden_size] |
| vision_token_info: List[Dict] with media positions and padding info for each sample |
| """ |
| return self.model.get_vision_features(pixel_values, grid_thw, media_nums_per_sample) |
|
|
| @property |
| def language_model(self): |
| return self.model.language_model |
|
|
| @property |
| def visual(self): |
| return self.model.visual |
|
|
| @auto_docstring |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| grid_thw: Optional[torch.LongTensor] = None, |
| media_nums_per_sample: Optional[List[int]] = None, |
| vision_position_ids: Optional[torch.LongTensor] = None, |
| cross_attention_mask: Optional[torch.Tensor] = None, |
| vision_token_info: Optional[List[dict]] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Union[tuple, CausalLMOutputWithPast]: |
| """ |
| Args: |
| grid_thw (`torch.LongTensor` of shape `(num_media, 3)`, *optional*): |
| Grid size for each media item in (temporal, height, width) format. Each row contains `[t, h, w]` |
| representing the number of temporal, height, and width patches for a media item (image or video). |
| media_nums_per_sample (`List[int]`, *optional*): |
| List indicating how many media items each sample in the batch has. For example, `[2, 1, 3]` means |
| the first sample has 2 media items, the second has 1, and the third has 3. |
| vision_position_ids (`torch.LongTensor` of shape `(batch_size, vision_seq_len)`, *optional*): |
| Position IDs for vision tokens used in cross-attention. These are computed from text position IDs |
| based on the positions of image/video tokens in the input text. |
| cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*): |
| Attention mask for cross-attention between text and vision. Controls which vision tokens each text |
| token can attend to, enforcing causal visibility for video frames. |
| vision_token_info (`List[dict]`, *optional*): |
| Cached metadata describing how packed vision tokens were regrouped per sample. Reused across decode |
| steps to expand cross-attention masks without re-running the vision encoder. |
| rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Cached multimodal RoPE offsets returned by the base model during prefill and reused during decode. |
| """ |
| cache_position = kwargs.pop("cache_position", None) |
| outputs = self.model( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| grid_thw=grid_thw, |
| media_nums_per_sample=media_nums_per_sample, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| vision_position_ids=vision_position_ids, |
| cross_attention_mask=cross_attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| vision_token_info=vision_token_info, |
| rope_deltas=rope_deltas, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state |
|
|
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) |
|
|
| if not return_dict: |
| output = (logits,) |
| output += outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return MossVLCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| vision_token_info=outputs.vision_token_info, |
| rope_deltas=outputs.rope_deltas, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| position_ids=None, |
| use_cache=True, |
| pixel_values=None, |
| grid_thw=None, |
| media_nums_per_sample=None, |
| vision_position_ids=None, |
| vision_token_info=None, |
| rope_deltas=None, |
| cross_attention_mask=None, |
| **kwargs, |
| ): |
| """ |
| Prepare inputs for generation. |
| |
| Note: Currently only supports offline visual understanding, meaning all multimodal |
| content must be provided before generation starts. We don't support adding new |
| images/videos during generation (streaming mode). |
| |
| Args: |
| media_nums_per_sample: One video counts as one media item (regardless of frame count) |
| """ |
| kwargs.pop("cache_position", None) |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| position_ids=position_ids, |
| pixel_values=pixel_values, |
| grid_thw=grid_thw, |
| media_nums_per_sample=media_nums_per_sample, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
|
|
| model_input = model_inputs.get("input_ids") |
| if model_input is None: |
| model_input = model_inputs.get("inputs_embeds") |
| current_length = model_input.shape[1] |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
| |
| model_inputs["position_ids"] = None |
| model_inputs["vision_token_info"] = vision_token_info |
| model_inputs["rope_deltas"] = rope_deltas |
| |
| |
| if cross_attention_mask is not None: |
| |
| |
| cross_attention_mask = cross_attention_mask[:, :, -current_length:, :] |
| model_inputs["cross_attention_mask"] = cross_attention_mask |
|
|
| |
| |
| if past_seen_tokens > 0: |
| model_inputs["pixel_values"] = None |
| model_inputs["grid_thw"] = None |
| model_inputs["media_nums_per_sample"] = None |
| model_inputs["vision_position_ids"] = None |
| |
| else: |
| |
| model_inputs["vision_position_ids"] = vision_position_ids |
|
|
| return model_inputs |
|
|
| def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): |
| """ |
| Update model kwargs for generation, extending cross_attention_mask for the newly generated token. |
| |
| In offline mode (all multimodal content provided before generation): |
| - Each newly generated token should have the same cross_attention_mask pattern as the previous token |
| - This ensures all generated tokens can attend to all vision tokens that were visible before |
| """ |
| cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) |
| |
| model_kwargs = super()._update_model_kwargs_for_generation( |
| outputs=outputs, |
| model_kwargs=model_kwargs, |
| is_encoder_decoder=is_encoder_decoder, |
| **kwargs, |
| ) |
| |
| if cross_attention_mask_prev is not None: |
| model_kwargs["cross_attention_mask"] = cross_attention_mask_prev |
|
|
| if getattr(outputs, "vision_token_info", None) is not None: |
| model_kwargs["vision_token_info"] = outputs.vision_token_info |
| if getattr(outputs, "rope_deltas", None) is not None: |
| model_kwargs["rope_deltas"] = outputs.rope_deltas |
| |
| return model_kwargs |
|
|
|
|
| __all__ = [ |
| "MossVLVisionModel", |
| "MossVLForConditionalGeneration", |
| "MossVLModel", |
| "MossVLPreTrainedModel", |
| "MossVLTextModel", |
| ] |
|
|