|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from collections.abc import Callable |
|
from dataclasses import dataclass |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ...activations import ACT2FN |
|
from ...cache_utils import Cache, HybridCache, StaticCache |
|
from ...generation import GenerationMixin |
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs |
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput |
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS |
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
from ...processing_utils import Unpack |
|
from ...utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_torchdynamo_compiling, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from ...utils.deprecation import deprecate_kwarg |
|
from ..auto import AutoModel, AutoModelForCausalLM |
|
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig |
|
from .speech_conformer_encoder import ConformerEncoder |
|
|
|
logger = logging.get_logger(__name__) |
|
_CONFIG_FOR_DOC = "Gemma3Config" |
|
|
|
|
|
@dataclass |
|
class Gemma3CausalLMOutputWithPast(ModelOutput): |
|
""" |
|
Base class for Gemma3 causal language model (or autoregressive) outputs. |
|
|
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss (for next-token prediction). |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
`past_key_values` input) to speed up sequential decoding. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
image_hidden_states (`torch.FloatTensor`, *optional*): |
|
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. |
|
image_hidden_states of the model produced by the vision encoder after projecting last hidden state. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
image_hidden_states: Optional[torch.FloatTensor] = None |
|
audio_hidden_states: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class Gemma3TextScaledWordEmbedding(nn.Embedding): |
|
""" |
|
This module overrides nn.Embeddings' forward by multiplying with embeddings scale. |
|
""" |
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): |
|
super().__init__(num_embeddings, embedding_dim, padding_idx) |
|
self.embed_scale = embed_scale |
|
|
|
def forward(self, input_ids: torch.Tensor): |
|
return super().forward(input_ids) * self.embed_scale |
|
|
|
|
|
class Gemma3MLP(nn.Module): |
|
def __init__(self, config: Gemma3TextConfig): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_activation] |
|
|
|
def forward(self, x): |
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
return down_proj |
|
|
|
|
|
class Gemma3RMSNorm(nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.zeros(dim)) |
|
|
|
def _norm(self, x): |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
output = self._norm(x.float()) |
|
|
|
|
|
output = output * (1.0 + self.weight.float()) |
|
return output.type_as(x) |
|
|
|
def extra_repr(self): |
|
return f"{tuple(self.weight.shape)}, eps={self.eps}" |
|
|
|
|
|
class Gemma3RotaryEmbedding(nn.Module): |
|
def __init__(self, config: Gemma3TextConfig, device=None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embeddings |
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
self.config = config |
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
def _dynamic_frequency_update(self, position_ids, device): |
|
""" |
|
dynamic RoPE layers should recompute `inv_freq` in the following situations: |
|
1 - growing beyond the cached sequence length (allow scaling) |
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences) |
|
""" |
|
seq_len = torch.max(position_ids) + 1 |
|
if seq_len > self.max_seq_len_cached: |
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.max_seq_len_cached = seq_len |
|
|
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: |
|
|
|
|
|
self.original_inv_freq = self.original_inv_freq.to(device) |
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
|
self.max_seq_len_cached = self.original_max_seq_len |
|
|
|
@torch.no_grad() |
|
def forward(self, x, position_ids): |
|
if "dynamic" in self.rope_type: |
|
self._dynamic_frequency_update(position_ids, device=x.device) |
|
|
|
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type |
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() |
|
sin = emb.sin() |
|
|
|
|
|
cos = cos * self.attention_scaling |
|
sin = sin * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`, *optional*): |
|
Deprecated and unused. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def eager_attention_forward( |
|
module: nn.Module, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
dropout: float = 0.0, |
|
scaling: Optional[float] = None, |
|
softcap: Optional[float] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if scaling is None: |
|
scaling = module.head_dim**-0.5 |
|
|
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
|
|
if softcap is not None: |
|
attn_weights = attn_weights / softcap |
|
attn_weights = torch.tanh(attn_weights) |
|
attn_weights = attn_weights * softcap |
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
return attn_output, attn_weights |
|
|
|
|
|
class Gemma3Attention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: Gemma3TextConfig, layer_idx: int): |
|
super().__init__() |
|
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
self.scaling = config.query_pre_attn_scalar**-0.5 |
|
self.attention_dropout = self.config.attention_dropout |
|
self.is_causal = True |
|
|
|
self.q_proj = nn.Linear( |
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.k_proj = nn.Linear( |
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.v_proj = nn.Linear( |
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.o_proj = nn.Linear( |
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
|
) |
|
self.attn_logit_softcapping = self.config.attn_logit_softcapping |
|
self.sliding_window = config.sliding_window if self.is_sliding else None |
|
|
|
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) |
|
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
query_states = self.q_norm(query_states) |
|
key_states = self.k_norm(key_states) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = { |
|
"sin": sin, |
|
"cos": cos, |
|
"cache_position": cache_position, |
|
"sliding_window": self.sliding_window, |
|
} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": |
|
seq_len = attention_mask.shape[-1] |
|
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] |
|
|
|
attention_interface: Callable = eager_attention_forward |
|
if self.config._attn_implementation != "eager": |
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): |
|
logger.warning_once( |
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " |
|
"Falling back to eager attention. This warning can be removed using the argument " |
|
'`attn_implementation="eager"` when loading the model.' |
|
) |
|
else: |
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask.to(query_states), |
|
dropout=self.attention_dropout if self.training else 0.0, |
|
scaling=self.scaling, |
|
sliding_window=self.sliding_window, |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output, attn_weights |
|
|
|
|
|
class Gemma3DecoderLayer(nn.Module): |
|
def __init__(self, config: Gemma3TextConfig, layer_idx: int): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.layer_idx = layer_idx |
|
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) |
|
self.mlp = Gemma3MLP(config) |
|
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
self.is_sliding = self.self_attn.is_sliding |
|
self.sliding_window = config.sliding_window |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings_global: torch.Tensor, |
|
position_embeddings_local: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
last_cache_position: int = 0, |
|
**kwargs, |
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
if self.is_sliding and attention_mask is not None: |
|
|
|
effective_seq_len = max(cache_position.shape[0], self.sliding_window) |
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
attention_mask = attention_mask[:, -effective_seq_len:] |
|
|
|
|
|
else: |
|
min_dtype = torch.finfo(attention_mask.dtype).min |
|
sliding_window_mask = torch.tril( |
|
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window |
|
) |
|
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
|
|
|
|
|
offset = last_cache_position - effective_seq_len |
|
|
|
offset = max(0, offset) |
|
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] |
|
|
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
if self.self_attn.is_sliding: |
|
position_embeddings = position_embeddings_local |
|
else: |
|
position_embeddings = position_embeddings_global |
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
position_embeddings=position_embeddings, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.pre_feedforward_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = self.post_feedforward_layernorm(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
GEMMA3_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`Gemma3Config`]): |
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
load the weights associated with the model, only the configuration. Check out the |
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare Gemma3 Model outputting raw hidden-states without any specific head on top.", |
|
GEMMA3_START_DOCSTRING, |
|
) |
|
class Gemma3PreTrainedModel(PreTrainedModel): |
|
config_class = Gemma3Config |
|
base_model_prefix = "language_model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = [ |
|
"Gemma3DecoderLayer", |
|
"SiglipVisionEmbeddings", |
|
"SiglipEncoderLayer", |
|
"SiglipMultiheadAttentionPoolingHead", |
|
] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
_supports_cache_class = True |
|
_supports_quantized_cache = True |
|
_supports_static_cache = True |
|
_supports_attention_backend = True |
|
|
|
def _init_weights(self, module): |
|
|
|
|
|
std = ( |
|
self.config.initializer_range |
|
if hasattr(self.config, "initializer_range") |
|
else self.config.text_config.initializer_range |
|
) |
|
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
GEMMA3_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
|
information on the default strategy. |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
|
|
|
Two formats are allowed: |
|
- a [`~cache_utils.Cache`] instance, see our |
|
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); |
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
|
cache format. |
|
|
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
|
legacy cache format will be returned. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
|
of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
`past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
|
the complete sequence length. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare Gemma3Text Model outputting raw hidden-states without any specific head on top.", |
|
GEMMA3_START_DOCSTRING, |
|
) |
|
class Gemma3TextModel(Gemma3PreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3TextDecoderLayer`] |
|
|
|
Args: |
|
config: Gemma3TextConfig |
|
""" |
|
|
|
config_class = Gemma3TextConfig |
|
|
|
def __init__(self, config: Gemma3TextConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.embed_tokens = Gemma3TextScaledWordEmbedding( |
|
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 |
|
) |
|
self.layers = nn.ModuleList( |
|
[Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = Gemma3RotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
|
|
config = copy.deepcopy(config) |
|
config.rope_theta = config.rope_local_base_freq |
|
config.rope_scaling = {"rope_type": "default"} |
|
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[HybridCache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
last_cache_position: Optional[int] = None, |
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
) |
|
use_cache = False |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if use_cache and past_key_values is None and not self.training: |
|
batch_size, seq_len, _ = inputs_embeds.shape |
|
past_key_values = HybridCache( |
|
self.config, |
|
max_batch_size=batch_size, |
|
max_cache_len=seq_len, |
|
dtype=inputs_embeds.dtype, |
|
) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, |
|
past_seen_tokens + inputs_embeds.shape[1], |
|
device=inputs_embeds.device, |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
|
|
if last_cache_position is None: |
|
last_cache_position = 0 |
|
if attention_mask is not None: |
|
|
|
|
|
last_cache_position = ( |
|
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() |
|
) |
|
causal_mask = self._update_causal_mask( |
|
attention_mask, |
|
inputs_embeds, |
|
cache_position, |
|
past_key_values, |
|
output_attentions, |
|
) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings_global = self.rotary_emb(hidden_states, position_ids) |
|
position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
position_embeddings_global, |
|
position_embeddings_local, |
|
causal_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
last_cache_position, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
position_embeddings_global=position_embeddings_global, |
|
position_embeddings_local=position_embeddings_local, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
last_cache_position=last_cache_position, |
|
**flash_attn_kwargs, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
output = BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
return output if return_dict else output.to_tuple() |
|
|
|
@torch.no_grad() |
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: HybridCache, |
|
output_attentions: bool, |
|
): |
|
|
|
|
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
return attention_mask |
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
sequence_length = input_tensor.shape[1] |
|
if isinstance(past_key_values, (HybridCache, StaticCache)): |
|
target_length = past_key_values.get_max_cache_shape() |
|
else: |
|
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
|
|
|
|
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=target_length, |
|
dtype=dtype, |
|
device=device, |
|
cache_position=cache_position, |
|
batch_size=input_tensor.shape[0], |
|
) |
|
return causal_mask |
|
|
|
@staticmethod |
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask: torch.Tensor, |
|
sequence_length: int, |
|
target_length: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
cache_position: torch.Tensor, |
|
batch_size: int, |
|
**kwargs, |
|
): |
|
""" |
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
Args: |
|
attention_mask (`torch.Tensor`): |
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
|
`(batch_size, 1, query_length, key_value_length)`. |
|
sequence_length (`int`): |
|
The sequence length being processed. |
|
target_length (`int`): |
|
The target length: when generating with static cache, the mask should be as long as the static cache, |
|
to account for the 0 padding, the part of the cache that is not filled yet. |
|
dtype (`torch.dtype`): |
|
The dtype to use for the 4D attention mask. |
|
device (`torch.device`): |
|
The device to place the 4D attention mask on. |
|
cache_position (`torch.Tensor`): |
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
batch_size (`torch.Tensor`): |
|
Batch size. |
|
""" |
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
causal_mask = attention_mask |
|
else: |
|
min_dtype = torch.finfo(dtype).min |
|
causal_mask = torch.full( |
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
|
) |
|
if sequence_length != 1: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
if attention_mask is not None: |
|
causal_mask = causal_mask.clone() |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
|
causal_mask.device |
|
) |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
padding_mask, min_dtype |
|
) |
|
|
|
return causal_mask |
|
|
|
|
|
class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
_tp_plan = {"lm_head": "colwise_rep"} |
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
config_class = Gemma3TextConfig |
|
base_model_prefix = "language_model" |
|
|
|
def __init__(self, config: Gemma3TextConfig): |
|
super().__init__(config) |
|
self.model = Gemma3TextModel(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[HybridCache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**loss_kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
logits_to_keep (`int` or `torch.Tensor`, *optional*): |
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
|
This is useful when using packed tensor format (single dimension for batch and sequence length). |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM |
|
|
|
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") |
|
|
|
>>> prompt = "What is your favorite condiment?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"What is your favorite condiment?" |
|
```""" |
|
|
|
if self.training and self.config._attn_implementation != "eager": |
|
logger.warning_once( |
|
"It is strongly recommended to train Gemma3 models with the `eager` attention implementation " |
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." |
|
) |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
**loss_kwargs, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
if self.config.final_logit_softcapping is not None: |
|
logits = logits / self.config.final_logit_softcapping |
|
logits = torch.tanh(logits) |
|
logits = logits * self.config.final_logit_softcapping |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
attention_mask=None, |
|
inputs_embeds=None, |
|
cache_position=None, |
|
position_ids=None, |
|
use_cache=True, |
|
logits_to_keep=None, |
|
**kwargs, |
|
): |
|
|
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
cache_position=cache_position, |
|
position_ids=position_ids, |
|
use_cache=use_cache, |
|
logits_to_keep=logits_to_keep, |
|
**kwargs, |
|
) |
|
|
|
|
|
|
|
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 |
|
if logits_to_keep is None: |
|
_ = model_inputs.pop("logits_to_keep", None) |
|
|
|
if ( |
|
isinstance(past_key_values, HybridCache) |
|
and attention_mask.ndim == 2 |
|
and not self.config._attn_implementation == "flash_attention_2" |
|
): |
|
if model_inputs["inputs_embeds"] is not None: |
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
|
device = model_inputs["inputs_embeds"].device |
|
else: |
|
batch_size, sequence_length = model_inputs["input_ids"].shape |
|
device = model_inputs["input_ids"].device |
|
|
|
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=past_key_values.get_max_cache_shape(), |
|
dtype=self.lm_head.weight.dtype, |
|
device=device, |
|
cache_position=cache_position, |
|
batch_size=batch_size, |
|
) |
|
model_inputs["attention_mask"] = attention_mask |
|
|
|
return model_inputs |
|
|
|
|
|
class Gemma3MultiModalProjector(nn.Module): |
|
def __init__(self, config: Gemma3Config): |
|
super().__init__() |
|
|
|
self.mm_input_projection_weight = nn.Parameter( |
|
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) |
|
) |
|
|
|
self.mm_soft_emb_norm = Gemma3RMSNorm( |
|
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps |
|
) |
|
|
|
self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) |
|
self.tokens_per_side = int(config.mm_tokens_per_image**0.5) |
|
self.kernel_size = self.patches_per_image // self.tokens_per_side |
|
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) |
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
batch_size, _, seq_length = vision_outputs.shape |
|
|
|
reshaped_vision_outputs = vision_outputs.transpose(1, 2) |
|
reshaped_vision_outputs = reshaped_vision_outputs.reshape( |
|
batch_size, seq_length, self.patches_per_image, self.patches_per_image |
|
) |
|
reshaped_vision_outputs = reshaped_vision_outputs.contiguous() |
|
|
|
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) |
|
pooled_vision_outputs = pooled_vision_outputs.flatten(2) |
|
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) |
|
|
|
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) |
|
|
|
projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) |
|
return projected_vision_outputs.type_as(vision_outputs) |
|
|
|
|
|
|
|
@add_start_docstrings( |
|
"""The GEMMA3 model which consists of a vision backbone and a language model.""", |
|
GEMMA3_START_DOCSTRING, |
|
) |
|
class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): |
|
def __init__(self, config: Gemma3Config): |
|
super().__init__(config) |
|
self.vision_tower = AutoModel.from_config(config=config.vision_config) |
|
audio_config = config.audio_config.to_diff_dict() |
|
for item in ['transformers_version', 'model_type', 'torch_dtype']: |
|
if item in audio_config: |
|
audio_config.pop(item) |
|
self.audio_tower = ConformerEncoder(**audio_config) |
|
self.audio_tower.post_init({}) |
|
self.audio_projector = nn.Sequential( |
|
nn.Linear(in_features=config.audio_config.attention_dim, out_features=config.text_config.hidden_size, bias=True), |
|
nn.GELU(approximate='none'), |
|
nn.Linear(in_features=config.text_config.hidden_size, out_features=config.text_config.hidden_size, bias=True) |
|
).to(dtype=self.dtype) |
|
|
|
self.multi_modal_projector = Gemma3MultiModalProjector(config) |
|
self.vocab_size = config.text_config.vocab_size |
|
|
|
language_model = AutoModelForCausalLM.from_config(config=config.text_config) |
|
|
|
if language_model._tied_weights_keys is not None: |
|
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] |
|
self.language_model = language_model |
|
|
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self): |
|
return self.language_model.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.language_model.set_output_embeddings(new_embeddings) |
|
|
|
def set_decoder(self, decoder): |
|
self.language_model.set_decoder(decoder) |
|
|
|
def get_decoder(self): |
|
return self.language_model.get_decoder() |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask, |
|
token_type_ids, |
|
past_key_values, |
|
cache_position, |
|
input_tensor, |
|
is_training: bool = False, |
|
): |
|
if self.config.text_config._attn_implementation == "flash_attention_2": |
|
return attention_mask |
|
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
|
|
return attention_mask |
|
|
|
using_static_cache = isinstance(past_key_values, StaticCache) |
|
min_dtype = torch.finfo(self.dtype).min |
|
inputs_lead_dim, sequence_length = input_tensor.shape[:2] |
|
if using_static_cache: |
|
target_length = past_key_values.get_max_cache_shape() |
|
elif isinstance(past_key_values, HybridCache): |
|
target_length = past_key_values.get_max_cache_shape() |
|
else: |
|
target_length = ( |
|
attention_mask.shape[-1] |
|
if isinstance(attention_mask, torch.Tensor) |
|
else cache_position[0] + sequence_length + 1 |
|
) |
|
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
return attention_mask |
|
|
|
causal_mask = torch.full( |
|
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device |
|
) |
|
|
|
|
|
if sequence_length != 1: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
|
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) |
|
|
|
|
|
if token_type_ids is not None and sequence_length != 1: |
|
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) |
|
token_type_mask[token_type_ids == 0] = False |
|
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) |
|
causal_mask = causal_mask.clone() |
|
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( |
|
token_type_mask, 0.0 |
|
) |
|
|
|
if attention_mask is not None: |
|
causal_mask = causal_mask.clone() |
|
mask_length = attention_mask.shape[-1] |
|
|
|
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
padding_mask, min_dtype |
|
) |
|
|
|
return causal_mask |
|
|
|
def get_image_features(self, pixel_values: torch.Tensor): |
|
""" |
|
Projects the last hidden state from the vision model into language model space. |
|
|
|
Args: |
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) |
|
The tensors corresponding to the input images. |
|
Returns: |
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). |
|
""" |
|
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state |
|
image_features = self.multi_modal_projector(vision_outputs) |
|
return image_features |
|
|
|
def get_audio_features(self, input_audio_embeds: torch.FloatTensor, audio_attention_mask: torch.FloatTensor, audio_embed_sizes: torch.FloatTensor): |
|
""" |
|
Projects the last hidden state from the audio model into language model space. |
|
|
|
Args: |
|
audio_inputs (`torch.FloatTensor]` of shape `(batch_size, sequence_length, feature_dim)`) |
|
The tensors corresponding to the input audio features. |
|
|
|
Returns: |
|
audio_features (`torch.Tensor`): Audio feature tensor of shape `(batch_size, audio_length, embed_dim)`). |
|
""" |
|
audio_features, masks = self.audio_tower(input_audio_embeds, audio_attention_mask) |
|
audio_outputs = self.audio_projector(audio_features) |
|
return audio_outputs |
|
|
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
pixel_values: torch.FloatTensor = None, |
|
input_audio_embeds: torch.FloatTensor = None, |
|
audio_embed_sizes: torch.FloatTensor = None, |
|
audio_attention_mask: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**lm_kwargs, |
|
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. |
|
|
|
logits_to_keep (`int` or `torch.Tensor`, *optional*): |
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
|
This is useful when using packed tensor format (single dimension for batch and sequence length). |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration |
|
|
|
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") |
|
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") |
|
|
|
>>> prompt = "answer en Where is the cow standing?" |
|
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(**inputs, max_length=30) |
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"answer en Where is the cow standing?\nbeach" |
|
```""" |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
is_training = token_type_ids is not None and labels is not None |
|
|
|
|
|
if input_ids is not None and self.config.image_token_index >= self.vocab_size or self.config.audio_token_index >= self.vocab_size: |
|
special_image_mask = input_ids == self.config.image_token_index |
|
special_audio_mask = input_ids == self.config.audio_token_index |
|
llm_input_ids = input_ids.clone() |
|
llm_input_ids[special_image_mask] = 0 |
|
llm_input_ids[special_audio_mask] = 0 |
|
else: |
|
llm_input_ids = input_ids |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.get_input_embeddings()(llm_input_ids) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) + 1 |
|
|
|
|
|
if pixel_values is not None: |
|
image_features = self.get_image_features(pixel_values) |
|
|
|
if input_ids is None: |
|
special_image_mask = inputs_embeds == self.get_input_embeddings()( |
|
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) |
|
) |
|
else: |
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
|
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
|
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] |
|
raise ValueError( |
|
f"Number of images does not match number of special image tokens in the input text. " |
|
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " |
|
"tokens from image embeddings." |
|
) |
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
|
|
|
|
|
if input_audio_embeds is not None: |
|
audio_features = self.get_audio_features(input_audio_embeds, audio_attention_mask, audio_embed_sizes) |
|
if input_ids is None: |
|
special_audio_mask = inputs_embeds == self.get_input_embeddings()( |
|
torch.tensor(self.config.audio_token_index, dtype=torch.long, device=inputs_embeds.device) |
|
) |
|
else: |
|
special_audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1) |
|
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
|
masked_audio_features = [] |
|
for i, size in enumerate(audio_embed_sizes): |
|
masked_audio_features.append(audio_features[i, :size, :]) |
|
masked_audio_features = torch.cat(masked_audio_features, dim=0) |
|
|
|
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != masked_audio_features.numel(): |
|
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0] |
|
masked_audio_size = audio_embed_sizes.sum()[0] |
|
raise ValueError( |
|
f"Number of images does not match number of special image tokens in the input text. " |
|
f"Got {audio_tokens_in_text} image tokens in the text but {masked_audio_size} " |
|
"tokens from image embeddings." |
|
) |
|
masked_audio_features = masked_audio_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, masked_audio_features) |
|
|
|
|
|
if labels is not None and self.pad_token_id in labels: |
|
logger.warning_once( |
|
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " |
|
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", |
|
) |
|
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) |
|
|
|
causal_mask = self._update_causal_mask( |
|
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training |
|
) |
|
outputs = self.language_model( |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
logits_to_keep=logits_to_keep, |
|
**lm_kwargs, |
|
) |
|
|
|
logits = outputs.logits |
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.float() |
|
shift_logits = logits[..., :-1, :] |
|
shift_labels = labels[..., 1:] |
|
if attention_mask is not None: |
|
|
|
|
|
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) |
|
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() |
|
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() |
|
else: |
|
shift_logits = shift_logits.contiguous() |
|
shift_labels = shift_labels.contiguous() |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) |
|
flat_labels = shift_labels.view(-1).to(shift_logits.device) |
|
loss = loss_fct(flat_logits, flat_labels) |
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return Gemma3CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
image_hidden_states=image_features if pixel_values is not None else None, |
|
audio_hidden_states=audio_features if input_audio_embeds is not None else None, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
cache_position=None, |
|
position_ids=None, |
|
pixel_values=None, |
|
input_audio_embeds=None, |
|
audio_embed_sizes=None, |
|
audio_attention_mask=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
use_cache=True, |
|
logits_to_keep=None, |
|
labels=None, |
|
**kwargs, |
|
): |
|
|
|
model_inputs = self.language_model.prepare_inputs_for_generation( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
cache_position=cache_position, |
|
use_cache=use_cache, |
|
logits_to_keep=logits_to_keep, |
|
token_type_ids=token_type_ids, |
|
**kwargs, |
|
) |
|
|
|
|
|
if model_inputs.get("position_ids") is not None: |
|
model_inputs["position_ids"] += 1 |
|
|
|
|
|
if cache_position[0] == 0: |
|
model_inputs["pixel_values"] = pixel_values |
|
model_inputs["input_audio_embeds"] = input_audio_embeds |
|
model_inputs["audio_embed_sizes"] = audio_embed_sizes |
|
model_inputs["audio_attention_mask"] = audio_attention_mask |
|
is_training = token_type_ids is not None and labels is not None |
|
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): |
|
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids |
|
causal_mask = self._update_causal_mask( |
|
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training |
|
) |
|
model_inputs["attention_mask"] = causal_mask |
|
|
|
return model_inputs |
|
|
|
def tie_weights(self): |
|
return self.language_model.tie_weights() |
|
|
|
|
|
__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] |
|
|