step3 / modeling_step3.py
buyun's picture
Add files using upload-large-folder tool
4ad74b5 verified
# coding=utf-8
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import sqrt
from dataclasses import dataclass
from typing import Callable, Optional, Union, Tuple
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask, create_chunked_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, is_torch_accelerator_available
from typing import Any, Literal, Optional, TypedDict, Union
from .configuration_step3 import Step3VLConfig,Step3TextConfig,Step3VisionEncoderConfig
from .vision_encoder import StepCLIPVisionTransformer
logger = logging.get_logger(__name__)
import torch
from typing import Optional
class Step3VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
patch_pixel_values: Optional[torch.Tensor]
num_patches: list[int]
class Step3VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
Step3VLImageInputs = Union[Step3VLImagePixelInputs,
Step3VLImageEmbeddingInputs]
def _flatten_embeddings(embeddings) -> torch.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders")
is_multimodal = is_multimodal.to(inputs_embeds.device)
flattened = flattened.to(inputs_embeds.device)
inputs_embeds[is_multimodal] = flattened
return inputs_embeds
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings,
placeholder_token_id: Union[int, list[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(placeholder_token_id,
device=input_ids.device)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
@dataclass
class Step3vCausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[list[torch.FloatTensor]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
class Step3vRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""
Step3vRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Step3vRotaryEmbedding(nn.Module):
def __init__(self, config: Step3VLConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embedding
self.original_max_seq_len = config.max_position_embedding
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
# self.rope_init_fn = _compute_ntk_by_part_parameters if self.rope_type == "ntk_bypart" else ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Step3vMLP(nn.Module):
def __init__(self, config, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN["silu"]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoELinear(nn.Module):
def __init__(self, num_experts, in_features, out_features):
super().__init__()
self.num_experts = num_experts
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features))
def forward(self, x,expert_id):
x = F.linear(x, self.weight[expert_id])
return x
class Step3vMoEMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
self.top_k = config.moe_top_k
self.hidden_size = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
# gating
self.gate = nn.Linear(self.hidden_size, self.num_experts , bias=False)
self.up_proj = MoELinear(self.num_experts, self.hidden_size, self.moe_intermediate_size)
self.gate_proj = MoELinear(self.num_experts, self.hidden_size, self.moe_intermediate_size)
self.down_proj = MoELinear(self.num_experts, self.moe_intermediate_size, self.hidden_size)
self.act_fn = ACT2FN["silu"]
def get_expert_output(self, inputs: torch.Tensor, expert_id):
return self.down_proj(
self.act_fn(self.gate_proj(inputs,expert_id)) * self.up_proj(inputs,expert_id),expert_id
)
def forward(self, hidden_states):
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
self.get_expert_output(current_state, expert_idx) * routing_weights[top_x, idx, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states
class Step3vAttention(nn.Module):
def __init__(self, config: Step3VLConfig, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_heads = 1
self.total_num_kv_heads = self.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.num_key_value_groups = config.num_attention_heads // self.num_key_value_heads
self.q_size = getattr(config, "share_q_dim", self.head_dim)
self.kv_size = self.num_key_value_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, self.q_size , bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.inter_norm = Step3vRMSNorm(self.q_size, eps=config.rms_norm_eps)
self.wq = nn.Linear(self.q_size, self.head_dim * self.num_attention_heads, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2)
query_states = self.inter_norm(query_states)
query_states = self.wq(query_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
# TODO: considering FP8;
# RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
# but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
assert(attention_mask is None)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Step3vDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = Step3vAttention(config, layer_idx)
self.attention_type = "full_attention"
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(',')]
else:
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
self.is_moe_layer = layer_idx in moe_layers_idx
self.use_moe = False
if self.is_moe_layer:
self.moe = Step3vMoEMLP(config)
self.share_expert = Step3vMLP(config, intermediate_size=config.share_expert_dim)
self.use_moe = True
else:
self.mlp = Step3vMLP(config, intermediate_size=config.intermediate_size)
self.input_layernorm = Step3vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Step3vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.use_moe:
moe_output = self.moe(hidden_states)
share_output = self.share_expert(hidden_states)
hidden_states = moe_output + share_output
else:
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
class Step3vPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = False
_supports_sdpa = True
_supports_flex_attn = True
_supports_static_cache = True
_supports_attention_backend = True
class Step3Model(Step3vPreTrainedModel, GenerationMixin):
_no_split_modules = ["Step3vDecoderLayer"]
base_model_prefix = "model"
_tied_weights_keys = ["lm_head.weight"]
config: Step3TextConfig
def __init__(self, config: Step3TextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Step3vDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Step3vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Step3vRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self, input_ids):
return self.embed_tokens(input_ids)
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# create position embeddings to be shared across the decoder layers
freq_cis = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
i = 0
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=freq_cis,
**kwargs,
)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Step3vModel(Step3vPreTrainedModel):
# _no_split_modules = ["Step3vDecoderLayer"]
# base_model_prefix = ""
# _tied_weights_keys = ["lm_head.weight"]
# _tp_plan = {"lm_head": "colwise_rep"}
config: Step3VLConfig
base_model_prefix = ""
_checkpoint_conversion_mapping = {"^model": "language_model"}
def __init__(self, config: Step3VLConfig):
super().__init__(config)
self.vision_model = StepCLIPVisionTransformer(config.vision_config)
self.language_model = Step3Model(config.text_config)
self.vocab_size = config.text_config.vocab_size
self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride)
self.vit_downsampler2 = nn.Conv2d(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
self.vit_large_projector = nn.Linear(
config.vision_config.output_hidden_size * 2,
config.hidden_size,
bias=config.projector_bias,
)
self.image_placeholder_token_id = config.image_token_id
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings = None,
) -> torch.Tensor:
input_ids = input_ids.squeeze(0)
if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
else:
is_text = input_ids != self.config.image_token_id
text_ids = input_ids[is_text]
text_embeds = self.language_model.get_input_embeddings(text_ids)
inputs_embeds = torch.empty(input_ids.shape[0],
text_embeds.shape[-1],
dtype=text_embeds.dtype,
device=text_embeds.device)
inputs_embeds[is_text] = text_embeds
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
inputs_embeds = inputs_embeds.unsqueeze(0)
return inputs_embeds
def set_input_embeddings(self, value):
return self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Step3VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
patch_pixel_values = kwargs.pop("patch_pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
# pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_values.dim() >= 3:
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
if patch_pixel_values is not None:
# patch_pixel_values = flatten_bn(patch_pixel_values,
# concat=True)
patch_pixel_values = patch_pixel_values.view(
-1, *patch_pixel_values.shape[-3:])
# Handle empty patch_pixel_values by setting to None
if patch_pixel_values.shape[0] == 0:
patch_pixel_values = None
# num_patches = flatten_bn(num_patches, concat=True).tolist()
return Step3VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype).to(self.device),
patch_pixel_values=patch_pixel_values.to(self.dtype).to(
self.device) if patch_pixel_values is not None else None,
num_patches=num_patches,
)
if image_embeds is not None:
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
else:
raise ValueError(
f"Unexpected shape for image_embeds: {image_embeds.shape}")
return Step3VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds.to(self.dtype).to(self.device),
)
return None
def _process_image_features(self,
image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]
HW = int(sqrt(P))
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
image_features = self.vit_downsampler(image_features)
image_features = self.vit_downsampler2(image_features)
n_dim = image_features.size(1)
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
image_features = self.vit_large_projector(image_features)
return image_features
def _get_vision_model_output(self,
input_tensor: torch.Tensor) -> torch.Tensor:
return self.vision_model(input_tensor)[:, 4:]
def _process_image_input(
self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
image_features = image_input["image_embeds"]
else:
image_features = self._get_vision_model_output(
image_input["pixel_values"])
patch_image_features = self._get_vision_model_output(
image_input["patch_pixel_values"]
) if image_input["patch_pixel_values"] is not None else None
num_patches = image_input["num_patches"]
image_features = self._process_image_features(image_features)
patch_image_features = self._process_image_features(
patch_image_features) if patch_image_features is not None else None
merged_image_features = []
cur_patch_idx = 0
for i, num_patch in enumerate(num_patches):
cur_feature = []
if num_patch > 0:
patch_slice = patch_image_features[
cur_patch_idx:cur_patch_idx + num_patch]
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
cur_feature.append(image_features[i].view(
-1, image_features.shape[-1]))
cur_patch_idx += num_patch
merged_image_features.append(
torch.cat(cur_feature) if len(cur_feature) >
1 else cur_feature[0])
return merged_image_features
def get_multimodal_embeddings(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
images: Optional[list[Image.Image]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = Step3vCausalLMOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
attentions=outputs.attentions,
)
return output if return_dict else output.to_tuple()
class Step3vForConditionalGeneration(Step3vPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^vision_model": "model.vision_model",
r"^model(?!\.(language_model|vision_model))": "model.language_model",
}
_tied_weights_keys = ["lm_head.weight"]
config: Step3VLConfig
def __init__(self, config: Step3VLConfig):
super().__init__(config)
self.model = Step3vModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
@property
def language_model(self):
return self.model.language_model
@property
def visual(self):
return self.model.visual
def forward(
self,
input_ids: torch.LongTensor = None,
num_patches = None,
patch_pixel_values = None,
patch_newline_mask = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Step3vCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.model(
input_ids=input_ids,
num_patches = num_patches,
patch_pixel_values = patch_pixel_values,
patch_newline_mask=patch_newline_mask,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
los = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return Step3vCausalLMOutputWithPast(
logits=logits,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
return model_inputs
def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
if key.startswith("language_model."):
return key[len("language_model."):], True
return key, False