|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from math import sqrt |
|
from dataclasses import dataclass |
|
from typing import Callable, Optional, Union, Tuple |
|
|
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.generation import GenerationMixin |
|
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub |
|
from transformers.masking_utils import create_causal_mask, create_chunked_causal_mask |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput |
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
from transformers.processing_utils import Unpack |
|
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, is_torch_accelerator_available |
|
from typing import Any, Literal, Optional, TypedDict, Union |
|
|
|
|
|
from .configuration_step3 import Step3VLConfig,Step3TextConfig,Step3VisionEncoderConfig |
|
from .vision_encoder import StepCLIPVisionTransformer |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
import torch |
|
from typing import Optional |
|
|
|
class Step3VLImagePixelInputs(TypedDict): |
|
type: Literal["pixel_values"] |
|
pixel_values: torch.Tensor |
|
patch_pixel_values: Optional[torch.Tensor] |
|
num_patches: list[int] |
|
|
|
|
|
class Step3VLImageEmbeddingInputs(TypedDict): |
|
type: Literal["image_embeds"] |
|
image_embeds: torch.Tensor |
|
|
|
|
|
Step3VLImageInputs = Union[Step3VLImagePixelInputs, |
|
Step3VLImageEmbeddingInputs] |
|
|
|
def _flatten_embeddings(embeddings) -> torch.Tensor: |
|
""" |
|
Recursively flattens and concatenates NestedTensors on all but the last |
|
dimension. |
|
""" |
|
|
|
if isinstance(embeddings, torch.Tensor): |
|
|
|
return embeddings.flatten(0, -2) |
|
|
|
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) |
|
|
|
def _embedding_count_expression(embeddings) -> str: |
|
""" |
|
Constructs a debugging representation of the number of embeddings in the |
|
NestedTensors. |
|
""" |
|
|
|
if isinstance(embeddings, torch.Tensor): |
|
return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) |
|
|
|
return " + ".join( |
|
_embedding_count_expression(inner) for inner in embeddings) |
|
|
|
def _merge_multimodal_embeddings( |
|
inputs_embeds: torch.Tensor, |
|
is_multimodal: torch.Tensor, |
|
multimodal_embeddings, |
|
) -> torch.Tensor: |
|
""" |
|
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
|
positions in ``inputs_embeds`` corresponding to placeholder tokens in |
|
``input_ids``. |
|
|
|
Note: |
|
This updates ``inputs_embeds`` in place. |
|
""" |
|
num_expected_tokens = is_multimodal.sum().item() |
|
assert isinstance(num_expected_tokens, int) |
|
|
|
flattened = _flatten_embeddings(multimodal_embeddings) |
|
if flattened.shape[0] != num_expected_tokens: |
|
expr = _embedding_count_expression(multimodal_embeddings) |
|
raise ValueError( |
|
f"Attempted to assign {expr} = {flattened.shape[0]} " |
|
f"multimodal tokens to {num_expected_tokens} placeholders") |
|
|
|
is_multimodal = is_multimodal.to(inputs_embeds.device) |
|
flattened = flattened.to(inputs_embeds.device) |
|
inputs_embeds[is_multimodal] = flattened |
|
return inputs_embeds |
|
|
|
def merge_multimodal_embeddings( |
|
input_ids: torch.Tensor, |
|
inputs_embeds: torch.Tensor, |
|
multimodal_embeddings, |
|
placeholder_token_id: Union[int, list[int]], |
|
) -> torch.Tensor: |
|
""" |
|
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
|
positions in ``inputs_embeds`` corresponding to placeholder tokens in |
|
``input_ids``. |
|
|
|
``placeholder_token_id`` can be a list of token ids (e.g, token ids |
|
of img_start, img_break, and img_end tokens) when needed: This means |
|
the order of these tokens in the ``input_ids`` MUST MATCH the order of |
|
their embeddings in ``multimodal_embeddings`` since we need to |
|
slice-merge instead of individually scattering. |
|
|
|
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where |
|
- T is text token |
|
- S is image start token |
|
- I is image embedding token |
|
- B is image break token |
|
- E is image end token. |
|
|
|
Then the image embeddings (that correspond to I's) from vision encoder |
|
must be padded with embeddings of S, B, and E in the same order of |
|
input_ids for a correct embedding merge. |
|
|
|
Note: |
|
This updates ``inputs_embeds`` in place. |
|
""" |
|
if isinstance(placeholder_token_id, list): |
|
placeholder_token_id = torch.tensor(placeholder_token_id, |
|
device=input_ids.device) |
|
return _merge_multimodal_embeddings( |
|
inputs_embeds, |
|
torch.isin(input_ids, placeholder_token_id), |
|
multimodal_embeddings, |
|
) |
|
|
|
return _merge_multimodal_embeddings( |
|
inputs_embeds, |
|
(input_ids == placeholder_token_id), |
|
multimodal_embeddings, |
|
) |
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`, *optional*): |
|
Deprecated and unused. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def eager_attention_forward( |
|
module: nn.Module, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
scaling: float, |
|
dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
return attn_output, attn_weights |
|
|
|
@dataclass |
|
class Step3vCausalLMOutputWithPast(ModelOutput): |
|
r""" |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss (for next-token prediction). |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
`past_key_values` input) to speed up sequential decoding. |
|
image_hidden_states (`torch.FloatTensor`, *optional*): |
|
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. |
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
last_hidden_state: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[list[torch.FloatTensor]] = None |
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None |
|
attentions: Optional[tuple[torch.FloatTensor]] = None |
|
image_hidden_states: Optional[torch.FloatTensor] = None |
|
|
|
class Step3vRMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-5): |
|
""" |
|
Step3vRMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
|
def _norm(self, x): |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
output = self._norm(x.float()).type_as(x) |
|
return output * self.weight |
|
|
|
def extra_repr(self): |
|
return f"{tuple(self.weight.shape)}, eps={self.eps}" |
|
|
|
class Step3vRotaryEmbedding(nn.Module): |
|
def __init__(self, config: Step3VLConfig, device=None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embedding |
|
self.original_max_seq_len = config.max_position_embedding |
|
|
|
self.config = config |
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
@torch.no_grad() |
|
@dynamic_rope_update |
|
def forward(self, x, position_ids): |
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() * self.attention_scaling |
|
sin = emb.sin() * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
class Step3vMLP(nn.Module): |
|
def __init__(self, config, intermediate_size=None): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size |
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = ACT2FN["silu"] |
|
|
|
def forward(self, x): |
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
return down_proj |
|
|
|
class MoELinear(nn.Module): |
|
def __init__(self, num_experts, in_features, out_features): |
|
super().__init__() |
|
self.num_experts = num_experts |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features)) |
|
def forward(self, x,expert_id): |
|
x = F.linear(x, self.weight[expert_id]) |
|
return x |
|
|
|
class Step3vMoEMLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.num_experts = config.moe_num_experts |
|
self.top_k = config.moe_top_k |
|
self.hidden_size = config.hidden_size |
|
self.moe_intermediate_size = config.moe_intermediate_size |
|
|
|
self.gate = nn.Linear(self.hidden_size, self.num_experts , bias=False) |
|
self.up_proj = MoELinear(self.num_experts, self.hidden_size, self.moe_intermediate_size) |
|
self.gate_proj = MoELinear(self.num_experts, self.hidden_size, self.moe_intermediate_size) |
|
self.down_proj = MoELinear(self.num_experts, self.moe_intermediate_size, self.hidden_size) |
|
|
|
self.act_fn = ACT2FN["silu"] |
|
|
|
def get_expert_output(self, inputs: torch.Tensor, expert_id): |
|
|
|
return self.down_proj( |
|
self.act_fn(self.gate_proj(inputs,expert_id)) * self.up_proj(inputs,expert_id),expert_id |
|
) |
|
|
|
def forward(self, hidden_states): |
|
""" """ |
|
batch_size, sequence_length, hidden_dim = hidden_states.shape |
|
hidden_states = hidden_states.view(-1, hidden_dim) |
|
|
|
router_logits = self.gate(hidden_states) |
|
|
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
|
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) |
|
|
|
routing_weights = routing_weights.to(hidden_states.dtype) |
|
|
|
final_hidden_states = torch.zeros( |
|
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device |
|
) |
|
|
|
|
|
|
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) |
|
|
|
|
|
for expert_idx in range(self.num_experts): |
|
idx, top_x = torch.where(expert_mask[expert_idx]) |
|
|
|
|
|
|
|
|
|
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) |
|
current_hidden_states = ( |
|
self.get_expert_output(current_state, expert_idx) * routing_weights[top_x, idx, None] |
|
) |
|
|
|
|
|
|
|
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) |
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) |
|
return final_hidden_states |
|
|
|
|
|
class Step3vAttention(nn.Module): |
|
def __init__(self, config: Step3VLConfig, layer_idx): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
self.num_key_value_heads = 1 |
|
self.total_num_kv_heads = self.num_key_value_heads |
|
self.num_attention_heads = config.num_attention_heads |
|
self.num_key_value_groups = config.num_attention_heads // self.num_key_value_heads |
|
self.q_size = getattr(config, "share_q_dim", self.head_dim) |
|
self.kv_size = self.num_key_value_heads * self.head_dim |
|
self.scaling = self.head_dim**-0.5 |
|
self.is_causal = True |
|
|
|
self.q_proj = nn.Linear(config.hidden_size, self.q_size , bias=False) |
|
self.k_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
|
|
|
self.inter_norm = Step3vRMSNorm(self.q_size, eps=config.rms_norm_eps) |
|
|
|
self.wq = nn.Linear(self.q_size, self.head_dim * self.num_attention_heads, bias=False) |
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor], |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
input_shape = hidden_states.shape[:-1] |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2) |
|
|
|
query_states = self.inter_norm(query_states) |
|
query_states = self.wq(query_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
|
|
|
|
|
if self.config._attn_implementation != "eager": |
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
assert(attention_mask is None) |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
scaling=self.scaling, |
|
**kwargs, |
|
) |
|
attn_output = attn_output.reshape(*input_shape, -1) |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output, attn_weights |
|
|
|
class Step3vDecoderLayer(GradientCheckpointingLayer): |
|
def __init__(self, config, layer_idx): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.layer_idx = layer_idx |
|
self.self_attn = Step3vAttention(config, layer_idx) |
|
self.attention_type = "full_attention" |
|
|
|
moe_layers_enum = getattr(config, "moe_layers_enum", None) |
|
if moe_layers_enum is not None: |
|
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(',')] |
|
else: |
|
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] |
|
self.is_moe_layer = layer_idx in moe_layers_idx |
|
self.use_moe = False |
|
|
|
if self.is_moe_layer: |
|
self.moe = Step3vMoEMLP(config) |
|
self.share_expert = Step3vMLP(config, intermediate_size=config.share_expert_dim) |
|
self.use_moe = True |
|
else: |
|
self.mlp = Step3vMLP(config, intermediate_size=config.intermediate_size) |
|
|
|
self.input_layernorm = Step3vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = Step3vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[tuple[torch.Tensor]] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> torch.FloatTensor: |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
hidden_states, _ = self.self_attn( |
|
hidden_states=hidden_states, |
|
position_embeddings=position_embeddings, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
if self.use_moe: |
|
moe_output = self.moe(hidden_states) |
|
share_output = self.share_expert(hidden_states) |
|
hidden_states = moe_output + share_output |
|
else: |
|
hidden_states = self.mlp(hidden_states) |
|
|
|
if isinstance(hidden_states, tuple): |
|
hidden_states, _ = hidden_states |
|
|
|
hidden_states = residual + hidden_states |
|
return hidden_states |
|
|
|
class Step3vPreTrainedModel(PreTrainedModel): |
|
supports_gradient_checkpointing = True |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn = False |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
|
|
_supports_static_cache = True |
|
_supports_attention_backend = True |
|
|
|
class Step3Model(Step3vPreTrainedModel, GenerationMixin): |
|
_no_split_modules = ["Step3vDecoderLayer"] |
|
base_model_prefix = "model" |
|
_tied_weights_keys = ["lm_head.weight"] |
|
config: Step3TextConfig |
|
|
|
def __init__(self, config: Step3TextConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[Step3vDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = Step3vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = Step3vRotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self, input_ids): |
|
return self.embed_tokens(input_ids) |
|
|
|
@can_return_tuple |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> Union[tuple, BaseModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
) |
|
use_cache = False |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) |
|
|
|
if use_cache and past_key_values is None: |
|
past_key_values = DynamicCache() |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
if not isinstance(causal_mask_mapping := attention_mask, dict): |
|
|
|
mask_kwargs = { |
|
"config": self.config, |
|
"input_embeds": inputs_embeds, |
|
"attention_mask": attention_mask, |
|
"cache_position": cache_position, |
|
"past_key_values": past_key_values, |
|
"position_ids": position_ids, |
|
} |
|
|
|
causal_mask_mapping = { |
|
"full_attention": create_causal_mask(**mask_kwargs), |
|
} |
|
|
|
|
|
|
|
freq_cis = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
i = 0 |
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type], |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=freq_cis, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = layer_outputs |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values if use_cache else None, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
class Step3vModel(Step3vPreTrainedModel): |
|
|
|
|
|
|
|
|
|
config: Step3VLConfig |
|
base_model_prefix = "" |
|
_checkpoint_conversion_mapping = {"^model": "language_model"} |
|
|
|
def __init__(self, config: Step3VLConfig): |
|
super().__init__(config) |
|
self.vision_model = StepCLIPVisionTransformer(config.vision_config) |
|
self.language_model = Step3Model(config.text_config) |
|
self.vocab_size = config.text_config.vocab_size |
|
|
|
self.vit_downsampler = nn.Conv2d( |
|
config.vision_config.hidden_size, |
|
config.vision_config.output_hidden_size, |
|
kernel_size=2, |
|
stride=config.understand_projector_stride) |
|
|
|
self.vit_downsampler2 = nn.Conv2d( |
|
config.vision_config.output_hidden_size, |
|
config.vision_config.output_hidden_size * 2, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
) |
|
|
|
self.vit_large_projector = nn.Linear( |
|
config.vision_config.output_hidden_size * 2, |
|
config.hidden_size, |
|
bias=config.projector_bias, |
|
) |
|
|
|
self.image_placeholder_token_id = config.image_token_id |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings( |
|
self, |
|
input_ids: torch.Tensor, |
|
multimodal_embeddings = None, |
|
) -> torch.Tensor: |
|
input_ids = input_ids.squeeze(0) |
|
if multimodal_embeddings is None: |
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids) |
|
else: |
|
is_text = input_ids != self.config.image_token_id |
|
text_ids = input_ids[is_text] |
|
text_embeds = self.language_model.get_input_embeddings(text_ids) |
|
inputs_embeds = torch.empty(input_ids.shape[0], |
|
text_embeds.shape[-1], |
|
dtype=text_embeds.dtype, |
|
device=text_embeds.device) |
|
inputs_embeds[is_text] = text_embeds |
|
inputs_embeds = merge_multimodal_embeddings( |
|
input_ids, inputs_embeds, multimodal_embeddings, |
|
self.config.image_token_id) |
|
inputs_embeds = inputs_embeds.unsqueeze(0) |
|
return inputs_embeds |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
return self.language_model.set_input_embeddings(value) |
|
|
|
def set_decoder(self, decoder): |
|
self.language_model = decoder |
|
|
|
def get_decoder(self): |
|
return self.language_model |
|
|
|
def _parse_and_validate_image_input( |
|
self, **kwargs: object) -> Optional[Step3VLImageInputs]: |
|
pixel_values = kwargs.pop("pixel_values", None) |
|
patch_pixel_values = kwargs.pop("patch_pixel_values", None) |
|
num_patches = kwargs.pop("num_patches", None) |
|
image_embeds = kwargs.pop("image_embeds", None) |
|
|
|
if pixel_values is None and image_embeds is None: |
|
return None |
|
|
|
if pixel_values is not None: |
|
|
|
if pixel_values.dim() >= 3: |
|
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) |
|
if patch_pixel_values is not None: |
|
|
|
|
|
patch_pixel_values = patch_pixel_values.view( |
|
-1, *patch_pixel_values.shape[-3:]) |
|
|
|
if patch_pixel_values.shape[0] == 0: |
|
patch_pixel_values = None |
|
|
|
|
|
return Step3VLImagePixelInputs( |
|
type="pixel_values", |
|
pixel_values=pixel_values.to(self.dtype).to(self.device), |
|
patch_pixel_values=patch_pixel_values.to(self.dtype).to( |
|
self.device) if patch_pixel_values is not None else None, |
|
num_patches=num_patches, |
|
) |
|
|
|
if image_embeds is not None: |
|
if image_embeds.dim() == 2 or image_embeds.dim() >= 3: |
|
image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) |
|
else: |
|
raise ValueError( |
|
f"Unexpected shape for image_embeds: {image_embeds.shape}") |
|
|
|
return Step3VLImageEmbeddingInputs( |
|
type="image_embeds", |
|
image_embeds=image_embeds.to(self.dtype).to(self.device), |
|
) |
|
return None |
|
|
|
def _process_image_features(self, |
|
image_features: torch.Tensor) -> torch.Tensor: |
|
B, P = image_features.shape[:2] |
|
HW = int(sqrt(P)) |
|
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) |
|
image_features = self.vit_downsampler(image_features) |
|
image_features = self.vit_downsampler2(image_features) |
|
n_dim = image_features.size(1) |
|
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1) |
|
image_features = self.vit_large_projector(image_features) |
|
return image_features |
|
|
|
def _get_vision_model_output(self, |
|
input_tensor: torch.Tensor) -> torch.Tensor: |
|
return self.vision_model(input_tensor)[:, 4:] |
|
|
|
def _process_image_input( |
|
self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]: |
|
|
|
if image_input["type"] == "image_embeds": |
|
image_features = image_input["image_embeds"] |
|
else: |
|
image_features = self._get_vision_model_output( |
|
image_input["pixel_values"]) |
|
patch_image_features = self._get_vision_model_output( |
|
image_input["patch_pixel_values"] |
|
) if image_input["patch_pixel_values"] is not None else None |
|
num_patches = image_input["num_patches"] |
|
|
|
image_features = self._process_image_features(image_features) |
|
patch_image_features = self._process_image_features( |
|
patch_image_features) if patch_image_features is not None else None |
|
|
|
merged_image_features = [] |
|
cur_patch_idx = 0 |
|
for i, num_patch in enumerate(num_patches): |
|
cur_feature = [] |
|
if num_patch > 0: |
|
patch_slice = patch_image_features[ |
|
cur_patch_idx:cur_patch_idx + num_patch] |
|
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) |
|
cur_feature.append(image_features[i].view( |
|
-1, image_features.shape[-1])) |
|
cur_patch_idx += num_patch |
|
merged_image_features.append( |
|
torch.cat(cur_feature) if len(cur_feature) > |
|
1 else cur_feature[0]) |
|
return merged_image_features |
|
|
|
def get_multimodal_embeddings(self, **kwargs): |
|
image_input = self._parse_and_validate_image_input(**kwargs) |
|
if image_input is None: |
|
return None |
|
vision_embeddings = self._process_image_input(image_input) |
|
return vision_embeddings |
|
|
|
@can_return_tuple |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
images: Optional[list[Image.Image]] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> Union[tuple, CausalLMOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, Llama4ForCausalLM |
|
|
|
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if inputs_embeds is None: |
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs) |
|
inputs_embeds = self.get_input_embeddings(input_ids, |
|
vision_embeddings) |
|
input_ids = None |
|
|
|
|
|
outputs = self.language_model( |
|
input_ids=None, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
output = Step3vCausalLMOutputWithPast( |
|
last_hidden_state=outputs.last_hidden_state, |
|
past_key_values=outputs.past_key_values, |
|
attentions=outputs.attentions, |
|
) |
|
return output if return_dict else output.to_tuple() |
|
|
|
|
|
class Step3vForConditionalGeneration(Step3vPreTrainedModel, GenerationMixin): |
|
_checkpoint_conversion_mapping = { |
|
"^vision_model": "model.vision_model", |
|
r"^model(?!\.(language_model|vision_model))": "model.language_model", |
|
} |
|
_tied_weights_keys = ["lm_head.weight"] |
|
config: Step3VLConfig |
|
|
|
def __init__(self, config: Step3VLConfig): |
|
super().__init__(config) |
|
self.model = Step3vModel(config) |
|
self.lm_head = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False) |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self): |
|
return self.model.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.model.set_output_embeddings(new_embeddings) |
|
|
|
def set_decoder(self, decoder): |
|
self.model.set_decoder(decoder) |
|
|
|
def get_decoder(self): |
|
return self.model.get_decoder() |
|
|
|
@property |
|
def language_model(self): |
|
return self.model.language_model |
|
|
|
@property |
|
def visual(self): |
|
return self.model.visual |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
num_patches = None, |
|
patch_pixel_values = None, |
|
patch_newline_mask = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> Union[tuple, Step3vCausalLMOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
|
|
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") |
|
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
|
|
|
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:" |
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15) |
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
num_patches = num_patches, |
|
patch_pixel_values = patch_pixel_values, |
|
patch_newline_mask=patch_newline_mask, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
logits = self.lm_head(hidden_states) |
|
|
|
los = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
|
|
|
return Step3vCausalLMOutputWithPast( |
|
logits=logits, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
pixel_values=None, |
|
attention_mask=None, |
|
cache_position=None, |
|
logits_to_keep=None, |
|
**kwargs, |
|
): |
|
|
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
cache_position=cache_position, |
|
logits_to_keep=logits_to_keep, |
|
**kwargs, |
|
) |
|
|
|
if cache_position[0] == 0: |
|
|
|
|
|
model_inputs["pixel_values"] = pixel_values |
|
|
|
return model_inputs |
|
|
|
def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]: |
|
if key.startswith("language_model."): |
|
return key[len("language_model."):], True |
|
|
|
return key, False |
|
|