Transformers documentation

Gemma3n

You are viewing v4.53.1 version. A newer version v4.57.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

PyTorch SDPA

Gemma3n

Overview

Gemma3n is a multimodal model with pretrained and instruction-tuned variants, available in E4B and E2B sizes. While large portions of the language model architecture are shared with prior Gemma releases, there are many new additions in this model, including Alternating Updates (AltUp), Learned Augmented Residual Layer (LAuReL), MatFormer, Per-Layer Embeddings (PLE), activation sparsity, and KV cache sharing. The language model uses a similar attention pattern to Gemma 3 with alternating 4 local sliding window self-attention layers for every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces [MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a Universal Speech Model (USM) as the audio encoder.

The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning.

You can find all the original Gemma 3n checkpoints under the Gemma 3n release.

Click on the Gemma 3n models in the right sidebar for more examples of how to apply Gemma to different vision, audio, and language tasks.

The example below demonstrates how to generate text based on an image with Pipeline or the AutoModel class.

Pipeline
AutoModel
transformers CLI
import torch
from transformers import pipeline

pipeline = pipeline(
    task="image-text-to-text",
    model="google/gemma-3n-e4b",
    device=0,
    torch_dtype=torch.bfloat16
)
pipeline(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
    text="<start_of_image> What is shown in this image?"
)

Notes

  • Use Gemma3nForConditionalGeneration for image-audio-and-text, image-and-text, image-and-audio, audio-and-text, image-only and aduio-only inputs.

  • Gemma 3n supports multiple images per input, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images.

    url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
    url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
    
    messages =[
        {
            "role": "system",
            "content": [
                {"type": "text", "text": "You are a helpful assistant."}
            ]
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "url": url_cow},
                {"type": "image", "url": url_cat},
                {"type": "text", "text": "Which image is cuter?"},
            ]
        },
    ]
  • Text passed to the processor should have a <image_soft_token> token wherever an image should be inserted.

  • Gemma 3n accept at most one target audio clip per input, though multiple audio clips can be provided in few-shot prompts, for example.

  • Text passed to the processor should have a <audio_soft_token> token wherever an audio clip should be inserted.

  • The processor has its own apply_chat_template() method to convert chat messages to model inputs.

Gemma3nAudioFeatureExtractor

class transformers.Gemma3nAudioFeatureExtractor

< >

( feature_size: int = 128 sampling_rate: int = 16000 padding_value: float = 0.0 return_attention_mask: bool = True frame_length_ms: float = 32.0 hop_length_ms: float = 10.0 min_frequency: float = 125.0 max_frequency: float = 7600.0 preemphasis: float = 0.97 preemphasis_htk_flavor: bool = True fft_overdrive: bool = True dither: float = 0.0 input_scale_factor: float = 1.0 mel_floor: float = 1e-05 per_bin_mean: typing.Optional[collections.abc.Sequence[float]] = None per_bin_stddev: typing.Optional[collections.abc.Sequence[float]] = None **kwargs )

Parameters

  • feature_size (int, optional, defaults to 128) — The feature dimension of the extracted features.
  • sampling_rate (int, optional, defaults to 16000) — The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
  • padding_value (float, optional, defaults to 0.0) — Padding value used to pad the audio. Should correspond to silences.
  • return_attention_mask (bool, optional, defaults to True) — Whether to return the attention mask for the generated MEL spectrograms.
  • frame_length_ms (float, optional, defaults to 32.0) — The length of a frame in milliseconds.
  • hop_length_ms (float, optional, defaults to 10.0) — Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
  • min_frequency (float, optional, defaults to 125.0) — The minimum frequency (in Hz) for the Mel filterbank.
  • max_frequency (float, optional, defaults to 7600.0) — The maximum frequency (in Hz) for the Mel filterbank.
  • preemphasis (float, optional, defaults to 0.97) — The preemphasis coefficient.
  • preemphasis_htk_flavor (bool, optional, defaults to True) — Whether to use HTK-style preemphasis.
  • fft_overdrive (bool, optional, defaults to True) — Whether to use FFT overdrive.
  • dither (float, optional, defaults to 0.0) — Adds dithering. In other words, adds a small Gaussian noise to each frame. E.g. use 0.0001 to add dithering with a normal distribution centered around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech). The value 0.0 means no dithering. Dithering has similar effect as spectrogram(mel_floor=...). It reduces the high log_mel_fbank values for signals with hard-zero sections, when VAD cutoff is present in the signal.
  • input_scale_factor (float, optional, defaults to 1.0) — Scaling factor applied to the input waveform.
  • mel_floor (float, optional, defaults to 1e-05) — Minimum value for Mel spectrograms to avoid log(0).
  • per_bin_mean (Optional[Sequence[float]], optional) — Mean values for per-bin normalization.
  • per_bin_stddev (Optional[Sequence[float]], optional) — Standard deviation values for per-bin normalization.

An audio feature extractor Universal Speech Models https://arxiv.org/abs/2303.01037.

Gemma3nProcessor

class transformers.Gemma3nProcessor

< >

( feature_extractor image_processor tokenizer chat_template = None audio_seq_length: int = 188 image_seq_length: int = 256 **kwargs )

Parameters

  • feature_extractor (Gemma3nAudioFeatureExtractor) — Feature extractor that converts raw audio waveforms into MEL spectrograms for the audio encoder. This should return a BatchFeature with input_features and input_features_mask features.
  • image_processor (SiglipImageProcessorFast) — Image processor that prepares batches of images for the vision encoder. This should return a BatchFeature with a pixel_values feature.
  • tokenizer (GemmaTokenizerFast) — The text tokenizer for the model.
  • chat_template (string, optional) — A Jinja template for generating text prompts from a set of messages.
  • audio_seq_length (int, optional, defaults to 188) — The number of audio soft tokens that will be added to the text prompt
  • image_seq_length (int, optional, defaults to 256) — The number of image soft tokens that should be added to

A processor for Gemma 3n, wrapping the full capabilities of a feature extractor, image processor, and tokenizer into a single processor.

batch_decode

< >

( *args **kwargs )

This method forwards all its arguments to GemmaTokenizerFast’s batch_decode(). Please refer to the docstring of this method for more information.

decode

< >

( *args **kwargs )

This method forwards all its arguments to GemmaTokenizerFast’s decode(). Please refer to the docstring of this method for more information.

Gemma3nTextConfig

class transformers.Gemma3nTextConfig

< >

( vocab_size: int = 262400 vocab_size_per_layer_input: int = 262144 hidden_size: int = 2048 hidden_size_per_layer_input: int = 256 intermediate_size: typing.Union[int, collections.abc.Sequence[int]] = 16384 num_hidden_layers: int = 35 num_attention_heads: int = 8 num_key_value_heads: int = 2 head_dim: int = 256 hidden_activation: str = 'gelu_pytorch_tanh' max_position_embeddings: int = 32768 initializer_range: float = 0.02 rms_norm_eps: float = 1e-06 use_cache: bool = True pad_token_id: int = 0 eos_token_id: int = 1 bos_token_id: int = 2 rope_theta: float = 1000000.0 rope_scaling: typing.Optional[dict[str, typing.Any]] = None rope_local_base_freq: float = 10000.0 attention_bias: bool = False attention_dropout: float = 0.0 sliding_window: int = 512 layer_types: typing.Optional[collections.abc.Sequence[str]] = None final_logit_softcapping: float = 30.0 altup_active_idx: int = 0 altup_coef_clip: float = 120.0 altup_correct_scale: bool = True altup_num_inputs: int = 4 num_kv_shared_layers: int = 15 laurel_rank: int = 64 activation_sparsity_pattern: typing.Union[float, collections.abc.Sequence[float], NoneType] = (0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) **kwargs )

Parameters

  • vocab_size (int, optional, defaults to 262400) — Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling Gemma3nTextModel
  • vocab_size_per_layer_input (int, optional, defaults to 262144) — Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
  • hidden_size (int, optional, defaults to 2048) — Dimension of the hidden representations.
  • hidden_size_per_layer_input (int, optional, defaults to 256) — Dimension of the hidden representations for per-layer emebeddings.
  • intermediate_size (int or Sequence[int], optional, defaults to 16384) — Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers to account for vairable intermediate_size values across layers. In such cases, len(intermediate_size) == num_hidden_layers.
  • num_hidden_layers (int, optional, defaults to 35) — Number of hidden layers in the Transformer decoder.
  • num_attention_heads (int, optional, defaults to 8) — Number of attention heads for each attention layer in the Transformer decoder.
  • num_key_value_heads (int, optional, defaults to 2) — This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout this paper. If not specified, will default to num_attention_heads.
  • head_dim (int, optional, defaults to 256) — The attention head dimension.
  • hidden_activation (str or function, optional, defaults to "gelu_pytorch_tanh") — The non-linear activation function (function or string) in the decoder. Will default to "gelu_pytorch_tanh" if not specified. "gelu_pytorch_tanh" uses an approximation of the "gelu" activation function.
  • max_position_embeddings (int, optional, defaults to 32768) — The maximum sequence length that this model might ever be used with.
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • rms_norm_eps (float, optional, defaults to 1e-06) — The epsilon used by the rms normalization layers.
  • use_cache (bool, optional, defaults to True) — Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
  • pad_token_id (int, optional, defaults to 0) — Padding token id.
  • eos_token_id (int, optional, defaults to 1) — End of stream token id.
  • bos_token_id (int, optional, defaults to 2) — Beginning of stream token id.
  • rope_theta (float, optional, defaults to 1000000.0) — The base period of the RoPE embeddings.
  • rope_scaling (Dict, optional) — Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents: rope_type (str): The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’], with ‘default’ being the original RoPE implementation. factor (float, optional): Used with all rope types except ‘default’. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x original maximum pre-trained length. original_max_position_embeddings (int, optional): Used with ‘dynamic’, ‘longrope’ and ‘llama3’. The original max position embeddings used during pretraining. attention_factor (float, optional): Used with ‘yarn’ and ‘longrope’. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value. beta_fast (float, optional): Only used with ‘yarn’. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. beta_slow (float, optional): Only used with ‘yarn’. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. short_factor (List[float], optional): Only used with ‘longrope’. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 long_factor (List[float], optional): Only used with ‘longrope’. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 low_freq_factor (float, optional): Only used with ‘llama3’. Scaling factor applied to low frequency components of the RoPE high_freq_factor (float, optional*): Only used with ‘llama3’. Scaling factor applied to high frequency components of the RoPE
  • rope_local_base_freq (float, optional, defaults to 10000.0) — The base period of the RoPE embeddings for local attention.
  • attention_bias (bool, defaults to False, optional, defaults to False) — Whether to use a bias in the query, key, value and output projection layers during self-attention.
  • attention_dropout (float, optional, defaults to 0.0) — The dropout ratio for the attention probabilities.
  • sliding_window (int, optional, defaults to 512) — This is the size of the sliding window used by local attention layers.
  • layer_types (Optional, optional) — A sequence of strings defining the attention type for that layer as either “sliding_attention” or “full_attention”. If not provided, layer_types will de inferred from num_hidden_layers using a pattern of four “sliding_attention” layers followed one “full_attention”. The last layer in the model should always be a “full_attention” layer.
  • final_logit_softcapping (float, optional, defaults to 30.0) — Scaling factor when applying tanh softcapping on the logits.
  • altup_active_idx (int, optional, defaults to 0) — The index of the prediction from which AltUp will compute additional predictions or correct
  • altup_coef_clip (float, optional, defaults to 120.0) — The maximum amplitude of an AltUp prediction or correction coeficient weight.
  • altup_correct_scale (bool, optional, defaults to True) — If True, apply the AltUp.correct_output_scale to the corrected prediction at altup_active_idx.
  • altup_num_inputs (int, optional, defaults to 4) — The number of predictions that AltUp should be make given the input sequence.
  • num_kv_shared_layers (int, optional, defaults to 15) — The number of layer that share KV cache values. During the forward pass, the last num_kv_shared_layers layers in the model “share” the KV values in that each local and global layer in this range uses the KV cache values computed for the last local or global layer, respectively, before entering this range. The value should be num_kv_shared_layers should be a scalar of sliding_window_pattern.
  • laurel_rank (int, optional, defaults to 64) — The intermediate size for the linear projections in the Learned Augmented Residual Layer.
  • activation_sparsity_pattern (Sequence[float], optional, defaults to (0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) — The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must explicitly provide a sparsity value for each layer in the model.

This is the configuration class to store the configuration of a Gemma3nTextModel. It is used to instantiate an Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g. google/gemma-3n-E4B.

Configuration objects that inherit from Gemma3nTextConfig and can be used to control the model outputs. Read the documentation from Gemma3nTextConfig for more information.

>>> from transformers import Gemma3nTextModel, Gemma3nTextConfig

>>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
>>> configuration = Gemma3nTextConfig()

>>> # Initializing a model from the gemma3n_text-E4B style configuration
>>> model = Gemma3nTextModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

Gemma3nVisionConfig

class transformers.Gemma3nVisionConfig

< >

( initializer_range: float = 0.02 do_pooling: bool = False architecture: str = 'mobilenetv5_300m_enc' hidden_size: int = 2048 vocab_size: int = 128 vocab_offset: int = 262144 rms_norm_eps: float = 1e-06 model_args: typing.Optional[dict] = None **kwargs )

Parameters

  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • do_pooling (bool, optional, defaults to False) — Whether to do pooling for the last_hidden_state in TimmWrapper or not.
  • architecture (str, optional, defaults to "mobilenetv5_300m_enc") — Determines vision architecture for TimmWrapper.
  • hidden_size (int, optional, defaults to 2048) — Dimension of the hidden representations.
  • vocab_size (int, optional, defaults to 128) — Vocabulary size of the additional hard-token embeddings for vision model.
  • vocab_offset (int, optional, defaults to 262144) — Offset between the tokenizer vocab index for the token ids embedded by Gemma3nMultimodalEmbedder and the 0-indexed Gemma3nMultimodalEmbedder.embedding table.
  • rms_norm_eps (float, optional, defaults to 1e-06) — The epsilon used by the rms normalization layers.

This is the configuration class to store the configuration for a timm backbone TimmWrapper. It is used to instantiate an timm model model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B vision tower, e.g. google/gemma-3n-E4B.

Configuration objects inherit from Gemma3nVisionConfig and can be used to control the model outputs. Read the documentation from Gemma3nVisionConfig for more information.

Config loads imagenet label descriptions and stores them in id2label attribute, label2id attribute for default imagenet models is set to None due to occlusions in the label descriptions.

Example:

>>> from transformers import Gemma3nVisionConfig, TimmWrapper

>>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
>>> configuration = Gemma3nVisionConfig()

>>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
>>> model = TimmWrapper(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

Gemma3nAudioConfig

class transformers.Gemma3nAudioConfig

< >

( vocab_size: int = 128 vocab_offset: int = 262272 input_feat_size: int = 128 hidden_size: int = 1536 rms_norm_eps: float = 1e-06 gradient_clipping: float = 10000000000.0 conf_attention_chunk_size: int = 12 conf_attention_context_left: int = 13 conf_attention_context_right: int = 0 conf_attention_logit_cap: float = 50.0 conf_num_attention_heads: int = 8 conf_num_hidden_layers: int = 12 conf_conv_kernel_size: int = 5 conf_reduction_factor: int = 4 conf_residual_weight: float = 0.5 sscp_conv_channel_size: tuple = (128, 32) sscp_conv_group_norm_eps: float = 0.001 sscp_conv_kernel_size: tuple = ((3, 3), (3, 3)) sscp_conv_stride_size: tuple = ((2, 2), (2, 2)) **kwargs )

Parameters

  • vocab_size (int, optional, defaults to 128) — Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings included in the Gemma3nTextModel to provide, e.g., the end of audio and audio soft token placeholder tokens when converting input_ids to embeddings in the Gemma3nForConditionalGeneration model.
  • vocab_offset (int, optional, defaults to 262272) — Offset between the tokenizer vocab index for the token ids embedded by Gemma3nMultimodalEmbedder and the 0-indexed Gemma3nMultimodalEmbedder.embedding table.
  • input_feat_size (int, optional, defaults to 128) — The number of channels in each mel-spectrogram frame.
  • hidden_size (int, optional, defaults to 1536) — Dimension of the hidden representations.
  • rms_norm_eps (float, optional, defaults to 1e-06) — The epsilon used by the rms normalization layers.
  • gradient_clipping (float, optional, defaults to 10000000000.0) — Clipping value used to stablize extremely large gradient values.
  • conf_attention_chunk_size (int, optional, defaults to 12) — The sub-sequence size for local attention processing inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_attention_context_left (int, optional, defaults to 13) — The left context size of the local attention inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_attention_context_right (int, optional, defaults to 0) — The right context size of the local attention inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_attention_logit_cap (float, optional, defaults to 50.0) — Logit cap applied during local attention inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_num_attention_heads (int, optional, defaults to 8) — The number of attention heads in local attention inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_num_hidden_layers (int, optional, defaults to 12) — The number of layers that use local attention inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_conv_kernel_size (int, optional, defaults to 5) — Convolution kernel size for the conformer block inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_reduction_factor (int, optional, defaults to 4) — Reduction factor used in the conformer block inside the Conformer (“conf”) section of the Universal Speech Model.
  • conf_residual_weight (float, optional, defaults to 0.5) — Residual connection weight inside the Conformer (“conf”) section of the Universal Speech Model.
  • sscp_conv_channel_size (tuple(int, int), optional, defaults to (128, 32)) — The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection (“sscp”) section of the Universal Speech Model.
  • sscp_conv_group_norm_eps (float, optional, defaults to 0.001) — Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution Projection (“sscp”) section of the Universal Speech Model.
  • sscp_conv_kernel_size (tuple(tuple(int, int), tuple(int, int)), optional, defaults to ((3, 3), (3, 3))) — Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample Convolution Projection (“sscp”) section of the Universal Speech Model. The kernel sizes are specified as a tuple of height and width for each layer, where the height corresponds to the time dimension and the width corresponds to the frequency dimension.
  • sscp_conv_stride_size (tuple(tuple(int, int), tuple(int, int)), optional, defaults to ((2, 2), (2, 2))) — Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample Convolution Projection (“sscp”) section of the Universal Speech Model. The stride sizes are specified as a tuple of height and width for each layer, where the height corresponds to the time dimension and the width corresponds to the frequency dimension.

This is the configuration class to store the configuration of a Gemma3nAudioEncoder, based on Gogole’s Universal Speech Model. It is used to instantiate an Gemma3nAudioEncoder model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g. google/gemma-3n-E4B.

Configuration objects that inherit from Gemma3nAudioConfig and can be used to control the model outputs. Read the documentation from Gemma3nAudioConfig for more information.

Example:

>>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder

>>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
>>> configuration = Gemma3nAudioConfig()

>>> # Initializing a model from the gemma3n_audio-E4B style configuration
>>> model = Gemma3nAudioEncoder(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

Gemma3nConfig

class transformers.Gemma3nConfig

< >

( text_config: typing.Union[transformers.models.gemma3n.configuration_gemma3n.Gemma3nTextConfig, dict[str, typing.Any], NoneType] = None vision_config: typing.Union[transformers.models.gemma3n.configuration_gemma3n.Gemma3nVisionConfig, dict[str, typing.Any], NoneType] = None audio_config: typing.Union[transformers.models.gemma3n.configuration_gemma3n.Gemma3nAudioConfig, dict[str, typing.Any], NoneType] = None audio_soft_tokens_per_image: int = 188 vision_soft_tokens_per_image: int = 256 boi_token_id: int = 255999 eoi_token_id: int = 262144 image_token_id: int = 262145 boa_token_id: int = 256000 eoa_token_id: int = 262272 audio_token_id: int = 262273 initializer_range: float = 0.02 **kwargs )

Parameters

  • text_config (Union[Gemma3nTextConfig, dict], optional) — The config object of the text backbone.
  • vision_config (Union[AutoConfig, dict], optional) — Custom vision config or dict.
  • audio_config (Union[AutoConfig, dict], optional) — Custom audio config or dict.
  • audio_soft_tokens_per_image (int, optional, defaults to 188) — The number of soft tokens per audio clip.
  • vision_soft_tokens_per_image (int, optional, defaults to 256) — The number of soft tokens per image.
  • boi_token_id (int, optional, defaults to 255999) — The begin-of-image token index to wrap the image prompt.
  • eoi_token_id (int, optional, defaults to 262144) — The end-of-image token index to wrap the image prompt.
  • image_token_id (int, optional, defaults to 262145) — The image token index to encode the image prompt.
  • boa_token_id (int, optional, defaults to 256000) — The begin-of-audio token index to wrap the audio prompt.
  • eoa_token_id (int, optional, defaults to 262272) — The end-of-audio token index to wrap the audio prompt.
  • audio_token_id (int, optional, defaults to 262273) — The audio token index to encode the audio prompt.
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

This is the configuration class to store the configuration of a Gemma3nForConditionalGeneration. It is used to instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Gemma3n-E4B.

e.g. google/gemma-3n-E4B

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig

>>> # Initializing a MobileNet vision config, which is loaded from TIMM
>>> vision_config = Gemma3nVisionConfig()

>>> # Initializing a Gemma3n Audio config
>>> audio_config = Gemma3nAudioConfig()

>>> # Initializing a Gemma3n Text config
>>> text_config = Gemma3nTextConfig()

>>> # Initializing a Gemma3n gemma-3-4b style configuration
>>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)

>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3nTextConfig(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

Gemma3nTextModel

class transformers.Gemma3nTextModel

< >

( config: Gemma3nTextConfig )

Parameters

  • config (Gemma3nTextConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The base Gemma 3n language model without a language modeling head.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( input_ids: typing.Optional[torch.LongTensor] = None per_layer_inputs: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None **flash_attn_kwargs: typing_extensions.Unpack[transformers.modeling_flash_attention_utils.FlashAttentionKwargs] ) transformers.modeling_outputs.BaseModelOutputWithPast or tuple(torch.FloatTensor)

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • per_layer_inputs (torch.Tensor, optional, defaults to None) — Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
  • attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

    • 1 for tokens that are not masked,
    • 0 for tokens that are masked.

    What are attention masks?

  • position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, config.n_positions - 1].

    What are position IDs?

  • past_key_values (~cache_utils.Cache, optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the past_key_values returned by the model at a previous stage of decoding, when use_cache=True or config.use_cache=True.

    Two formats are allowed:

    • a Cache instance, see our kv cache guide;
    • Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)). This is also known as the legacy cache format.

    The model will output the same cache format that is fed as input. If no past_key_values are passed, the legacy cache format will be returned.

    If past_key_values are used, the user can optionally input only the last input_ids (those that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of all input_ids of shape (batch_size, sequence_length).

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • use_cache (bool, optional) — If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.

Returns

transformers.modeling_outputs.BaseModelOutputWithPast or tuple(torch.FloatTensor)

A transformers.modeling_outputs.BaseModelOutputWithPast or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Gemma3nConfig) and inputs.

  • last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) — Sequence of hidden-states at the output of the last layer of the model.

    If past_key_values is used only the last hidden-state of the sequences of shape (batch_size, 1, hidden_size) is output.

  • past_key_values (Cache, optional, returned when use_cache=True is passed or when config.use_cache=True) — It is a Cache instance. For more details, see our kv cache guide.

    Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if config.is_encoder_decoder=True in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

The Gemma3nTextModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Gemma3nModel

class transformers.Gemma3nModel

< >

( config: Gemma3nConfig )

Parameters

  • config (Gemma3nConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a language modeling head.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( input_ids: typing.Optional[torch.LongTensor] = None pixel_values: typing.Optional[torch.FloatTensor] = None input_features: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.Tensor] = None input_features_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Union[list[torch.FloatTensor], transformers.cache_utils.Cache, NoneType] = None token_type_ids: typing.Optional[torch.LongTensor] = None cache_position: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None **lm_kwargs )

labels (torch.LongTensor of shape (batch_size, sequence_length), optional): Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.text_config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.text_config.vocab_size].

Example:

>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration

>>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")

>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = processor(images=image, text=prompt,  return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"

Gemma3nForCausalLM

class transformers.Gemma3nForCausalLM

< >

( config: Gemma3nTextConfig )

Parameters

  • config (Gemma3nTextConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The base Gemma 3n language model with a language modeling head.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **loss_kwargs ) transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

    • 1 for tokens that are not masked,
    • 0 for tokens that are masked.

    What are attention masks?

  • position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, config.n_positions - 1].

    What are position IDs?

  • past_key_values (~cache_utils.Cache, optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the past_key_values returned by the model at a previous stage of decoding, when use_cache=True or config.use_cache=True.

    Two formats are allowed:

    • a Cache instance, see our kv cache guide;
    • Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)). This is also known as the legacy cache format.

    The model will output the same cache format that is fed as input. If no past_key_values are passed, the legacy cache format will be returned.

    If past_key_values are used, the user can optionally input only the last input_ids (those that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of all input_ids of shape (batch_size, sequence_length).

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.vocab_size].
  • use_cache (bool, optional) — If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
  • logits_to_keep (Union[int, torch.Tensor], defaults to 0) — If an int, compute logits for the last logits_to_keep tokens. If 0, calculate logits for all input_ids (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a torch.Tensor, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns

transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)

A transformers.modeling_outputs.CausalLMOutputWithPast or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Gemma3nConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) — Language modeling loss (for next-token prediction).

  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (Cache, optional, returned when use_cache=True is passed or when config.use_cache=True) — It is a Cache instance. For more details, see our kv cache guide.

    Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

The Gemma3nForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoTokenizer, Gemma3nForCausalLM

>>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")

>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"

Gemma3nForConditionalGeneration

class transformers.Gemma3nForConditionalGeneration

< >

( config: Gemma3nConfig )

Parameters

  • config (Gemma3nConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling head.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( input_ids: typing.Optional[torch.LongTensor] = None pixel_values: typing.Optional[torch.FloatTensor] = None input_features: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.Tensor] = None input_features_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Union[list[torch.FloatTensor], transformers.cache_utils.Cache, NoneType] = None token_type_ids: typing.Optional[torch.LongTensor] = None cache_position: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **lm_kwargs ) transformers.models.gemma3n.modeling_gemma3n.Gemma3nCausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional) — The tensors corresponding to the input images. Pixel values can be obtained using {image_processor_class}. See {image_processor_class}.__call__ for details ({processor_class} uses {image_processor_class} for processing images).
  • input_features (torch.Tensor, optional, defaults to None) — The audio inputs to be encoded.
  • attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

    • 1 for tokens that are not masked,
    • 0 for tokens that are masked.

    What are attention masks?

  • input_features_mask (torch.Tensor, optional, defaults to None) — The attention mask for the input audio.
  • position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, config.n_positions - 1].

    What are position IDs?

  • past_key_values (Union[list[torch.FloatTensor], ~cache_utils.Cache, NoneType]) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the past_key_values returned by the model at a previous stage of decoding, when use_cache=True or config.use_cache=True.

    Two formats are allowed:

    • a Cache instance, see our kv cache guide;
    • Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)). This is also known as the legacy cache format.

    The model will output the same cache format that is fed as input. If no past_key_values are passed, the legacy cache format will be returned.

    If past_key_values are used, the user can optionally input only the last input_ids (those that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of all input_ids of shape (batch_size, sequence_length).

  • token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]:

    • 0 corresponds to a sentence A token,
    • 1 corresponds to a sentence B token.

    What are token type IDs?

  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.text_config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.text_config.vocab_size].
  • use_cache (bool, optional) — If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • logits_to_keep (Union[int, torch.Tensor], defaults to 0) — If an int, compute logits for the last logits_to_keep tokens. If 0, calculate logits for all input_ids (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a torch.Tensor, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns

transformers.models.gemma3n.modeling_gemma3n.Gemma3nCausalLMOutputWithPast or tuple(torch.FloatTensor)

A transformers.models.gemma3n.modeling_gemma3n.Gemma3nCausalLMOutputWithPast or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Gemma3nConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) — Language modeling loss (for next-token prediction).

  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.text_config.vocab_size)) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (tuple(tuple(torch.FloatTensor)), optional, returned when use_cache=True is passed or when config.use_cache=True) — Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))

    Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (tuple[torch.FloatTensor], optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple[torch.FloatTensor], optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • image_hidden_states (torch.FloatTensor, optional) — A torch.FloatTensor of size (batch_size, num_images, sequence_length, hidden_size). image_hidden_states of the model produced by the vision encoder after projecting last hidden state.

  • audio_hidden_states (torch.FloatTensor, optional) — A torch.FloatTensor of size (batch_size, num_images, sequence_length, hidden_size). audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.

The Gemma3nForConditionalGeneration forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration

>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

>>> messages = [
...     {
...         "role": "system",
...         "content": [
...             {"type": "text", "text": "You are a helpful assistant."}
...         ]
...     },
...     {
...         "role": "user", "content": [
...             {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
...             {"type": "text", "text": "Where is the cat standing?"},
...         ]
...     },
... ]

>>> inputs = processor.apply_chat_template(
...     messages,
...     tokenizer=True,
...     return_dict=True,
...     return_tensors="pt",
...     add_generation_prompt=True
... )
>>> # Generate
>>> generate_ids = model.generate(**inputs)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
< > Update on GitHub